variational¶
Distributions¶
The distributions module is an extension of the torch.distributions package intended to facilitate implementations
required for specific variational approaches through the SimpleDistribution class. Generally, using a
torch.distributions.Distribution object should be preferred over a SimpleDistribution, for better
argument validation and more complete implementations. However, if you need to implement something new for a specific
variational approach, then a SimpleDistribution may be more forgiving. Furthermore, you may find it easier
to understand the function of the implementations here.
-
class
variational.distributions.SimpleDistribution(batch_shape=<sphinx.ext.autodoc.importer._MockObject object>, event_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]¶ Abstract base class for a simple distribution which only implements rsample and log_prob. If the log_prob function is not differentiable with respect to the distribution parameters or the given value, then this should be mentioned in the documentation.
-
arg_constraints¶
-
has_rsample= True¶
-
log_prob(value)[source]¶ Returns the log of the probability density/mass function evaluated at value. :param value: Value at which to evaluate log probabilty :type value: torch.Tensor, Number
-
mean¶
-
rsample(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]¶ Returns a reparameterized sample or batch of reparameterized samples if the distribution parameters are batched.
-
support¶
-
variance¶
-
-
class
variational.distributions.SimpleExponential(lograte)[source]¶ The SimpleExponential class is a
SimpleDistributionwhich implements a straight forward Exponential distribution with the given lograte. This performs significantly fewer checks than torch.distributions.Exponential , but should be sufficient for the purpose of implementing a VAE. By using a lograte, the log_prob can be computed in a stable fashion, without taking a logarithm.Parameters: lograte (torch.Tensor, Number) – The natural log of the rate of the distribution, numbers will be cast to tensors -
log_prob(value)[source]¶ Calculates the log probability that the given value was drawn from this distribution. The log_prob for this distribution is fully differentiable and has stable gradient since we use the lograte here.
Parameters: value (torch.Tensor, Number) – The sampled value Returns: The log probability that the given value was drawn from this distribution
-
rsample(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]¶ Simple rsample for an Exponential distribution.
Parameters: sample_shape (torch.Size, tuple) – Shape of the sample (per lograte given) Returns: A reparameterized sample with gradient with respect to the distribution parameters
-
-
class
variational.distributions.SimpleNormal(mu, logvar)[source]¶ The SimpleNormal class is a
SimpleDistributionwhich implements a straight forward Normal / Gaussian distribution. This performs significantly fewer checks than torch.distributions.Normal, but should be sufficient for the purpose of implementing a VAE.Parameters: - mu (torch.Tensor, Number) – The mean of the distribution, numbers will be cast to tensors
- logvar (torch.Tensor, Number) – The log variance of the distribution, numbers will be cast to tensors
-
log_prob(value)[source]¶ Calculates the log probability that the given value was drawn from this distribution. Since the density of a Gaussian is differentiable, this function is differentiable.
Parameters: value (torch.Tensor, Number) – The sampled value Returns: The log probability that the given value was drawn from this distribution
-
rsample(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]¶ Simple rsample for a Normal distribution.
Parameters: sample_shape (torch.Size, tuple) – Shape of the sample (per mean / variance given) Returns: A reparameterized sample with gradient with respect to the distribution parameters
-
class
variational.distributions.SimpleUniform(low, high)[source]¶ The SimpleUniform class is a
SimpleDistributionwhich implements a straight forward Uniform distribution in the interval[low, high). This performs significantly fewer checks than torch.distributions.Uniform, but should be sufficient for the purpose of implementing a VAE.Parameters: - low (torch.Tensor, Number) – The lower range of the distribution (inclusive), numbers will be cast to tensors
- high (torch.Tensor, Number) – The upper range of the distribution (exclusive), numbers will be cast to tensors
-
log_prob(value)[source]¶ Calculates the log probability that the given value was drawn from this distribution. Since this distribution is uniform, the log probability is
-log(high - low)for all values in the range[low, high)and -inf elsewhere. This function is therefore only piecewise differentiable.Parameters: value (torch.Tensor, Number) – The sampled value Returns: The log probability that the given value was drawn from this distribution
-
rsample(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]¶ Simple rsample for a Uniform distribution.
Parameters: sample_shape (torch.Size, tuple) – Shape of the sample (per low / high given) Returns: A reparameterized sample with gradient with respect to the distribution parameters
-
class
variational.distributions.SimpleWeibull(l, k)[source]¶ The SimpleWeibull class is a
SimpleDistributionwhich implements a straight forward Weibull distribution. This performs significantly fewer checks than torch.distributions.Weibull, but should be sufficient for the purpose of implementing a VAE.@article{squires2019a, title={A Variational Autoencoder for Probabilistic Non-Negative Matrix Factorisation}, author={Steven Squires and Adam Prugel-Bennett and Mahesan Niranjan}, year={2019} }
Parameters: - l (torch.Tensor, Number) – The scale parameter of the distribution, numbers will be cast to tensors
- k (torch.Tensor, Number) – The shape parameter of the distribution, numbers will be cast to tensors
-
log_prob(value)[source]¶ Calculates the log probability that the given value was drawn from this distribution. This function is differentiable and its log probability is -inf for values less than 0.
Parameters: value (torch.Tensor, Number) – The sampled value Returns: The log probability that the given value was drawn from this distribution
-
rsample(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]¶ Simple rsample for a Weibull distribution.
Parameters: sample_shape (torch.Size, tuple) – Shape of the sample (per k / lambda given) Returns: A reparameterized sample with gradient with respect to the distribution parameters
Divergences¶
The divergence module includes abstractions around and implementations of divergences intended to simplify the construction and usage of multiple divergences and divergences on different parts of latent spaces.
-
class
variational.divergence.DivergenceBase(keys, state_key=None)[source]¶ The
DivergenceBaseclass is an abstract base class which defines a series of useful methods for dealing with divergences. The keys dict given on init is used to map objects in state to kwargs in the compute function.Parameters: -
compute(**kwargs)[source]¶ Compute the loss with the given kwargs defined in the constructor.
Parameters: kwargs – The bound kwargs, taken from state with the keys given in the constructor Returns: The calculated divergence as a two dimensional tensor (batch, distribution dimensions)
-
with_beta(beta)[source]¶ Multiply the divergence by the given beta, as introduced by beta-vae.
@article{higgins2016beta, title={beta-vae: Learning basic visual concepts with a constrained variational framework}, author={Higgins, Irina and Matthey, Loic and Pal, Arka and Burgess, Christopher and Glorot, Xavier and Botvinick, Matthew and Mohamed, Shakir and Lerchner, Alexander}, year={2016} }
Parameters: beta (float) – The beta (> 1) to multiply by. Returns: self Return type: Divergence
-
with_linear_capacity(min_c=0, max_c=25, steps=100000, gamma=1000)[source]¶ Limit divergence by capacity, linearly increased from min_c to max_c for steps, as introduced in Understanding disentangling in beta-VAE.
@article{burgess2018understanding, title={Understanding disentangling in beta-vae}, author={Burgess, Christopher P and Higgins, Irina and Pal, Arka and Matthey, Loic and Watters, Nick and Desjardins, Guillaume and Lerchner, Alexander}, journal={arXiv preprint arXiv:1804.03599}, year={2018} }
Parameters: - min_c (float) – Minimum capacity
- max_c (float) – Maximum capacity
- steps (int) – Number of steps to increase over
- gamma (float) – Multiplicative gamma, usually a high number
Returns: self
Return type: Divergence
-
with_post_function(post_fcn)[source]¶ Register the given post function, to be applied after to loss after reduction.
Parameters: post_fcn – A function of loss which applies some operation (e.g. multiplying by beta) Returns: self Return type: Divergence
-
with_reduction(reduction_fcn)[source]¶ Override the reduction operation with the given function, use this if your divergence doesn’t output a two dimensional tensor.
Parameters: reduction_fcn – The function to be applied to the divergence output and return a single value Returns: self Return type: Divergence
-
-
class
variational.divergence.SimpleExponentialSimpleExponentialKL(input_key, target_key, state_key=None)[source]¶ A KL divergence between two SimpleExponential (or similar) distributions.
Note
The distribution object must have lograte attribute- Args:
- input_key:
StateKeyinstance which will be mapped to the input distribution object. target_key:StateKeyinstance which will be mapped to the target distribution object. state_key: If not None, the value outputted bycompute()is stored in state with the given key.
-
class
variational.divergence.SimpleNormalSimpleNormalKL(input_key, target_key, state_key=None)[source]¶ A KL divergence between two SimpleNormal (or similar) distributions.
Note
The distribution objects must have mu and logvar attributes
Parameters: - input_key –
StateKeyinstance which will be mapped to the input distribution object. - target_key –
StateKeyinstance which will be mapped to the target distribution object. - state_key – If not None, the value outputted by
compute()is stored in state with the given key.
- input_key –
-
class
variational.divergence.SimpleNormalUnitNormalKL(input_key, state_key=None)[source]¶ A KL divergence between a SimpleNormal (or similar) instance and a fixed unit normal (N[0, 1]) target.
Note
The distribution object must have mu and logvar attributes
Parameters: - input_key –
StateKeyinstance which will be mapped to the distribution object. - state_key – If not None, the value outputted by
compute()is stored in state with the given key.
- input_key –
-
class
variational.divergence.SimpleWeibullSimpleWeibullKL(input_key, target_key, state_key=None)[source]¶ A KL divergence between two SimpleWeibull (or similar) distributions. The distribution object must have lambda (scale) and k (shape) attributes.
@article{DBLP:journals/corr/Bauckhage14, author={Christian Bauckhage}, title={Computing the Kullback-Leibler Divergence between two Generalized Gamma Distributions}, journal={CoRR}, volume={abs/1401.6853}, year={2014} }
- Args:
- input_key:
StateKeyinstance which will be mapped to the input distribution object. target_key:StateKeyinstance which will be mapped to the target distribution object. state_key: If not None, the value outputted bycompute()is stored in state with the given key.
Auto-Encoding¶
The auto encoder module includes an abstraction around the standard VAE architecture which allows for simple construction without loss of flexibility.
-
class
variational.auto_encoder.AutoEncoderBase(latent_dims)[source]¶ -
decode(sample, state=None)[source]¶ Decode the given latent space sample batch to images.
Parameters: - sample – The latent space samples
- state – The trial state
Returns: Decoded images
-
Datasets¶
The datasets module includes a number of standard VAE datasets in the style of the torchvision.datasets module.
-
class
variational.datasets.CelebA(root, transform=None, target_transform=None)[source]¶ CelebA auto-encoding dataset
Parameters: - root (str) – Root directory of dataset containing all aligned images in ‘root’
- transform (
Transform, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,transforms.RandomCrop - target_transform (
Transform, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,transforms.RandomCrop
-
class
variational.datasets.CelebA_HQ(root, as_npy=False, transform=None)[source]¶ CelebA_HQ, high quality version of celebA auto-encoding dataset as introduced by Progressive GAN
Parameters: - root (str) – Root directory of dataset containing all hq images in ‘root’
- as_npy (bool, optional) – If True, assume images are stored in numpy arrays. Else assume a standard image format
- transform (
Transform, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,transforms.RandomCrop
-
class
variational.datasets.SimpleImageFolder(root, loader=None, extensions=None, transform=None, target_transform=None)[source]¶ Simple image folder dataset that loads all images from inside a folder and returns items in (image, image) tuple
Parameters: - root (str) – Root directory of dataset containing all aligned images
- loader (function, optional) – Image loader function that takes a file or path and returns the loaded image (see torchvision.datasets.folder)
- extensions (
listofstr, optional) – List of file extensions that can be loaded - transform (
Transform, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,transforms.RandomCrop - target_transform (
Transform, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,transforms.RandomCrop
-
class
variational.datasets.dSprites(root, download=False, transform=None)[source]¶ dSprites Dataset
Parameters: - root (str) – Root directory of dataset containing ‘dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz’ or to download it to
- download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
- transform (
Transform, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,transforms.RandomCrop
Visualisation¶
The visualisation module contains a number of latent space visualisation techniques designed to work with the base variation.auto_encoder alongside an abstraction that allows for creating custom visualisations.
-
class
variational.visualisation.CodePathWalker(num_steps, p1, p2)[source]¶ Latent space walker that walks between two specified codes p1 and p2
Parameters: - num_steps (int) – Number of steps to take between points
- p1 (torch.Tensor) – Batch of codes
- p2 (torch.Tensor) – Batch of codes
-
class
variational.visualisation.ImagePathWalker(num_steps, im1, im2)[source]¶ Latent space walker that walks between two specified images im1 and im2
Parameters: - num_steps (int) – Number of steps to take between points
- im1 (torch.Tensor) – Batch of images
- im2 (torch.Tensor) – Batch of images
-
class
variational.visualisation.LatentWalker(same_image, row_size)[source]¶ Parameters: - same_image (bool) – If True, use the same image for all latent dimension walks. Else each dimension has different image
- row_size (int) – Number of images displayed in each row of the grid.
-
for_data(data_key)[source]¶ Parameters: data_key ( StateKey) – State key which will contain data to act onReturns: self Return type: LatentWalker
-
for_space(space_id)[source]¶ Sets the ID for which latent space to vary when model outputs [latent_space_0, latent_space_1, …]
Parameters: space_id (int) – ID of the latent space to vary Returns: self Return type: LatentWalker
-
on_train()[source]¶ Sets the walker to run during training
Returns: self Return type: LatentWalker
-
on_val()[source]¶ Sets the walker to run during validation
Returns: self Return type: LatentWalker
-
to_file(file)[source]¶ Parameters: file (string, pathlib.Path object or file object) – File in which result is saved Returns: self Return type: LatentWalker
-
to_key(state_key)[source]¶ Parameters: state_key ( StateKey) – State key under which to store resultReturns: self Return type: LatentWalker
-
class
variational.visualisation.LinSpaceWalker(lin_start=-1, lin_end=1, lin_steps=8, dims_to_walk=[0], zero_init=False, same_image=False)[source]¶ Latent space walker that explores each dimension linearly from start to end points
Parameters: - lin_start (float) – Starting point of linspace
- lin_end (float) – End point of linspace
- lin_steps (int) – Number of steps to take in linspace
- dims_to_walk (list of int) – List of dimensions to walk
- zero_init (bool) – If True, dimensions not being walked are 0. Else, they are obtained from encoder
- same_image (bool) – If True, use same image for each dimension walked. Else, use different images
-
class
variational.visualisation.RandomWalker(var=1, num_images=32, uniform=False, row_size=8)[source]¶ Latent space walker that shows random samples from latent space
Parameters: - var (float or torch.Tensor) – Variance of random sample
- num_images (int) – Number of random images to sample
- uniform (bool) – If True, sample uniform distribution [-v, v). If False, sample normal distribution with var v
- row_size (int) – Number of images displayed in each row of the grid.
-
class
variational.visualisation.ReconstructionViewer(row_size=8, recon_key=<sphinx.ext.autodoc.importer._MockObject object>)[source]¶ Latent space walker that just returns the reconstructed images for the batch
Parameters: - row_size (int) – Number of images displayed in each row of the grid.
- recon_key (StateKey) –
StateKeyof the reconstructed images