"""
The datasets module includes a number of standard VAE datasets in the style of the torchvision.datasets module.
"""
import os
import shutil
import zipfile
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
[docs]def make_dataset(dir, extensions):
from torchvision.datasets.folder import has_file_allowed_extension
images = []
for root, _, fnames in sorted(os.walk(dir)):
for fname in sorted(fnames):
if has_file_allowed_extension(fname, extensions):
path = os.path.join(root, fname)
item = path
images.append(item)
return images
[docs]class SimpleImageFolder(Dataset):
"""
Simple image folder dataset that loads all images from inside a folder and returns items in (image, image) tuple
Args:
root (str): Root directory of dataset containing all aligned images
loader (function, optional): Image loader function that takes a file or path and returns the loaded image (see torchvision.datasets.folder)
extensions (:obj:`list` of :obj:`str`, optional): List of file extensions that can be loaded
transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``
"""
def __init__(self, root, loader=None, extensions=None, transform=None, target_transform=None):
from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS
loader = default_loader if loader is None else loader
extensions = IMG_EXTENSIONS if extensions is None else extensions
samples = make_dataset(root, extensions)
self.root = root
self.loader = loader
self.extensions = extensions
self.samples = samples
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
"""
Args:
index (int): Index of image
Returns:
tuple: (sample, target) where target is target transformed image.
"""
path = self.samples[index]
sample = self.loader(path)
input_sample, target_sample = sample, sample
if self.transform is not None:
input_sample = self.transform(sample)
if self.target_transform is not None:
target_sample = self.target_transform(sample)
return input_sample, target_sample
def __len__(self):
return len(self.samples)
[docs]class CelebA(SimpleImageFolder):
"""
`CelebA <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ auto-encoding dataset
Args:
root (str): Root directory of dataset containing all aligned images in 'root'
transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``
"""
def __init__(self, root, transform=None, target_transform=None):
super(CelebA, self).__init__(root, transform=transform, target_transform=target_transform)
def __getitem__(self, index):
item = super(CelebA, self).__getitem__(index)
return item
[docs]class CelebA_HQ(SimpleImageFolder):
"""
CelebA_HQ, high quality version of `celebA <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ auto-encoding dataset as introduced by `Progressive GAN <https://arxiv.org/abs/1710.10196>`_
Args:
root (str): Root directory of dataset containing all hq images in 'root'
as_npy (bool, optional): If True, assume images are stored in numpy arrays. Else assume a standard image format
transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``
"""
def __init__(self, root, as_npy=False, transform=None):
from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS
if as_npy:
loader = self.npy_loader
extensions = ['npy']
else:
loader = default_loader
extensions = IMG_EXTENSIONS
super(CelebA_HQ, self).__init__(root, loader, extensions, transform)
[docs] @staticmethod
def npy_loader(path):
img = np.load(path)[0].transpose([1,2,0])
pil_image = Image.fromarray(img)
return pil_image
def __getitem__(self, index):
item = super(CelebA_HQ, self).__getitem__(index)
return item
[docs]class dSprites(Dataset):
"""
`dSprites <https://github.com/deepmind/dsprites-dataset>`_ Dataset
Args:
root (str): Root directory of dataset containing 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz' or to download it to
download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``
"""
def __init__(self, root, download=False, transform=None):
super(dSprites, self).__init__()
self.file = root
self.transform = transform
if download:
self.download()
self.data = self.load_data()
self.latents_sizes = np.array([1, 3, 6, 40, 32, 32])
self.latents_bases = np.concatenate((self.latents_sizes[::-1].cumprod()[::-1][1:], np.array([1, ])))
self.latents_values = np.load(os.path.join(self.file, "latents_values.npy"))
self.latents_classes = np.load(os.path.join(self.file, "latents_classes.npy"))
[docs] def download(self):
if not os.path.exists(os.path.join(self.file, "imgs.npy")):
data_url = 'https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz?raw=true'
import sys
if sys.version_info[0] < 3:
import urllib2 as request
else:
import urllib.request as request
file = os.path.join(self.file, "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz")
os.makedirs(self.file, exist_ok=True)
with request.urlopen(data_url) as response, open(file, 'wb+') as out_file:
shutil.copyfileobj(response, out_file)
zip_ref = zipfile.ZipFile(file, 'r')
zip_ref.extractall(self.file)
zip_ref.close()
[docs] def get_img_by_latent(self, latent_code):
"""
Returns the image defined by the latent code
Args:
latent_code (:obj:`list` of :obj:`int`): Latent code of length 6 defining each generative factor
Returns:
Image defined by given code
"""
def latent_to_index(latents):
return np.dot(latents, self.latents_bases).astype(int)
idx = latent_to_index(latent_code)
return self.__getitem__(idx)[0]
[docs] def load_data(self):
root = os.path.join(self.file, "imgs.npy")
data = np.load(root)
return data
def __getitem__(self, index):
data = self.data[index]
data = Image.fromarray(data * 255, mode='L')
if self.transform is not None:
data = self.transform(data)
return data, data
def __len__(self):
return self.data.shape[0]