Source code for variational.datasets

"""
The datasets module includes a number of standard VAE datasets in the style of the torchvision.datasets module.
"""
import os
import shutil
import zipfile

import numpy as np
from PIL import Image
from torch.utils.data import Dataset


[docs]def make_dataset(dir, extensions): from torchvision.datasets.folder import has_file_allowed_extension images = [] for root, _, fnames in sorted(os.walk(dir)): for fname in sorted(fnames): if has_file_allowed_extension(fname, extensions): path = os.path.join(root, fname) item = path images.append(item) return images
[docs]class SimpleImageFolder(Dataset): """ Simple image folder dataset that loads all images from inside a folder and returns items in (image, image) tuple Args: root (str): Root directory of dataset containing all aligned images loader (function, optional): Image loader function that takes a file or path and returns the loaded image (see torchvision.datasets.folder) extensions (:obj:`list` of :obj:`str`, optional): List of file extensions that can be loaded transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` """ def __init__(self, root, loader=None, extensions=None, transform=None, target_transform=None): from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS loader = default_loader if loader is None else loader extensions = IMG_EXTENSIONS if extensions is None else extensions samples = make_dataset(root, extensions) self.root = root self.loader = loader self.extensions = extensions self.samples = samples self.transform = transform self.target_transform = target_transform def __getitem__(self, index): """ Args: index (int): Index of image Returns: tuple: (sample, target) where target is target transformed image. """ path = self.samples[index] sample = self.loader(path) input_sample, target_sample = sample, sample if self.transform is not None: input_sample = self.transform(sample) if self.target_transform is not None: target_sample = self.target_transform(sample) return input_sample, target_sample def __len__(self): return len(self.samples)
[docs]class CelebA(SimpleImageFolder): """ `CelebA <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ auto-encoding dataset Args: root (str): Root directory of dataset containing all aligned images in 'root' transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` """ def __init__(self, root, transform=None, target_transform=None): super(CelebA, self).__init__(root, transform=transform, target_transform=target_transform) def __getitem__(self, index): item = super(CelebA, self).__getitem__(index) return item
[docs]class CelebA_HQ(SimpleImageFolder): """ CelebA_HQ, high quality version of `celebA <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ auto-encoding dataset as introduced by `Progressive GAN <https://arxiv.org/abs/1710.10196>`_ Args: root (str): Root directory of dataset containing all hq images in 'root' as_npy (bool, optional): If True, assume images are stored in numpy arrays. Else assume a standard image format transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` """ def __init__(self, root, as_npy=False, transform=None): from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS if as_npy: loader = self.npy_loader extensions = ['npy'] else: loader = default_loader extensions = IMG_EXTENSIONS super(CelebA_HQ, self).__init__(root, loader, extensions, transform)
[docs] @staticmethod def npy_loader(path): img = np.load(path)[0].transpose([1,2,0]) pil_image = Image.fromarray(img) return pil_image
def __getitem__(self, index): item = super(CelebA_HQ, self).__getitem__(index) return item
[docs]class dSprites(Dataset): """ `dSprites <https://github.com/deepmind/dsprites-dataset>`_ Dataset Args: root (str): Root directory of dataset containing 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz' or to download it to download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` """ def __init__(self, root, download=False, transform=None): super(dSprites, self).__init__() self.file = root self.transform = transform if download: self.download() self.data = self.load_data() self.latents_sizes = np.array([1, 3, 6, 40, 32, 32]) self.latents_bases = np.concatenate((self.latents_sizes[::-1].cumprod()[::-1][1:], np.array([1, ]))) self.latents_values = np.load(os.path.join(self.file, "latents_values.npy")) self.latents_classes = np.load(os.path.join(self.file, "latents_classes.npy"))
[docs] def download(self): if not os.path.exists(os.path.join(self.file, "imgs.npy")): data_url = 'https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz?raw=true' import sys if sys.version_info[0] < 3: import urllib2 as request else: import urllib.request as request file = os.path.join(self.file, "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz") os.makedirs(self.file, exist_ok=True) with request.urlopen(data_url) as response, open(file, 'wb+') as out_file: shutil.copyfileobj(response, out_file) zip_ref = zipfile.ZipFile(file, 'r') zip_ref.extractall(self.file) zip_ref.close()
[docs] def get_img_by_latent(self, latent_code): """ Returns the image defined by the latent code Args: latent_code (:obj:`list` of :obj:`int`): Latent code of length 6 defining each generative factor Returns: Image defined by given code """ def latent_to_index(latents): return np.dot(latents, self.latents_bases).astype(int) idx = latent_to_index(latent_code) return self.__getitem__(idx)[0]
[docs] def load_data(self): root = os.path.join(self.file, "imgs.npy") data = np.load(root) return data
def __getitem__(self, index): data = self.data[index] data = Image.fromarray(data * 255, mode='L') if self.transform is not None: data = self.transform(data) return data, data def __len__(self): return self.data.shape[0]