variational

Distributions

The distributions module is an extension of the torch.distributions package intended to facilitate implementations required for specific variational approaches through the SimpleDistribution class. Generally, using a torch.distributions.Distribution object should be preferred over a SimpleDistribution, for better argument validation and more complete implementations. However, if you need to implement something new for a specific variational approach, then a SimpleDistribution may be more forgiving. Furthermore, you may find it easier to understand the function of the implementations here.

class variational.distributions.SimpleDistribution(batch_shape=<sphinx.ext.autodoc.importer._MockObject object>, event_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]

Abstract base class for a simple distribution which only implements rsample and log_prob. If the log_prob function is not differentiable with respect to the distribution parameters or the given value, then this should be mentioned in the documentation.

arg_constraints
cdf(value)[source]
entropy()[source]
enumerate_support(expand=True)[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
icdf(value)[source]
log_prob(value)[source]

Returns the log of the probability density/mass function evaluated at value. :param value: Value at which to evaluate log probabilty :type value: torch.Tensor, Number

mean
rsample(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]

Returns a reparameterized sample or batch of reparameterized samples if the distribution parameters are batched.

support
variance
class variational.distributions.SimpleExponential(lograte)[source]

The SimpleExponential class is a SimpleDistribution which implements a straight forward Exponential distribution with the given lograte. This performs significantly fewer checks than torch.distributions.Exponential , but should be sufficient for the purpose of implementing a VAE. By using a lograte, the log_prob can be computed in a stable fashion, without taking a logarithm.

Parameters:lograte (torch.Tensor, Number) – The natural log of the rate of the distribution, numbers will be cast to tensors
log_prob(value)[source]

Calculates the log probability that the given value was drawn from this distribution. The log_prob for this distribution is fully differentiable and has stable gradient since we use the lograte here.

Parameters:value (torch.Tensor, Number) – The sampled value
Returns:The log probability that the given value was drawn from this distribution
rsample(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]

Simple rsample for an Exponential distribution.

Parameters:sample_shape (torch.Size, tuple) – Shape of the sample (per lograte given)
Returns:A reparameterized sample with gradient with respect to the distribution parameters
class variational.distributions.SimpleNormal(mu, logvar)[source]

The SimpleNormal class is a SimpleDistribution which implements a straight forward Normal / Gaussian distribution. This performs significantly fewer checks than torch.distributions.Normal, but should be sufficient for the purpose of implementing a VAE.

Parameters:
  • mu (torch.Tensor, Number) – The mean of the distribution, numbers will be cast to tensors
  • logvar (torch.Tensor, Number) – The log variance of the distribution, numbers will be cast to tensors
log_prob(value)[source]

Calculates the log probability that the given value was drawn from this distribution. Since the density of a Gaussian is differentiable, this function is differentiable.

Parameters:value (torch.Tensor, Number) – The sampled value
Returns:The log probability that the given value was drawn from this distribution
rsample(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]

Simple rsample for a Normal distribution.

Parameters:sample_shape (torch.Size, tuple) – Shape of the sample (per mean / variance given)
Returns:A reparameterized sample with gradient with respect to the distribution parameters
class variational.distributions.SimpleUniform(low, high)[source]

The SimpleUniform class is a SimpleDistribution which implements a straight forward Uniform distribution in the interval [low, high). This performs significantly fewer checks than torch.distributions.Uniform, but should be sufficient for the purpose of implementing a VAE.

Parameters:
  • low (torch.Tensor, Number) – The lower range of the distribution (inclusive), numbers will be cast to tensors
  • high (torch.Tensor, Number) – The upper range of the distribution (exclusive), numbers will be cast to tensors
log_prob(value)[source]

Calculates the log probability that the given value was drawn from this distribution. Since this distribution is uniform, the log probability is -log(high - low) for all values in the range [low, high) and -inf elsewhere. This function is therefore only piecewise differentiable.

Parameters:value (torch.Tensor, Number) – The sampled value
Returns:The log probability that the given value was drawn from this distribution
rsample(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]

Simple rsample for a Uniform distribution.

Parameters:sample_shape (torch.Size, tuple) – Shape of the sample (per low / high given)
Returns:A reparameterized sample with gradient with respect to the distribution parameters
class variational.distributions.SimpleWeibull(l, k)[source]

The SimpleWeibull class is a SimpleDistribution which implements a straight forward Weibull distribution. This performs significantly fewer checks than torch.distributions.Weibull, but should be sufficient for the purpose of implementing a VAE.

@article{squires2019a,
  title={A Variational Autoencoder for Probabilistic Non-Negative Matrix Factorisation},
  author={Steven Squires and Adam Prugel-Bennett and Mahesan Niranjan},
  year={2019}
}
Parameters:
  • l (torch.Tensor, Number) – The scale parameter of the distribution, numbers will be cast to tensors
  • k (torch.Tensor, Number) – The shape parameter of the distribution, numbers will be cast to tensors
log_prob(value)[source]

Calculates the log probability that the given value was drawn from this distribution. This function is differentiable and its log probability is -inf for values less than 0.

Parameters:value (torch.Tensor, Number) – The sampled value
Returns:The log probability that the given value was drawn from this distribution
rsample(sample_shape=<sphinx.ext.autodoc.importer._MockObject object>)[source]

Simple rsample for a Weibull distribution.

Parameters:sample_shape (torch.Size, tuple) – Shape of the sample (per k / lambda given)
Returns:A reparameterized sample with gradient with respect to the distribution parameters

Divergences

The divergence module includes abstractions around and implementations of divergences intended to simplify the construction and usage of multiple divergences and divergences on different parts of latent spaces.

class variational.divergence.DivergenceBase(keys, state_key=None)[source]

The DivergenceBase class is an abstract base class which defines a series of useful methods for dealing with divergences. The keys dict given on init is used to map objects in state to kwargs in the compute function.

Parameters:
  • keys (dict) – Dictionary which maps kwarg names to StateKey objects. When compute() is called, the given kwargs are mapped to their associated values in state.
  • state_key – If not None, the value outputted by compute() is stored in state with the given key.
compute(**kwargs)[source]

Compute the loss with the given kwargs defined in the constructor.

Parameters:kwargs – The bound kwargs, taken from state with the keys given in the constructor
Returns:The calculated divergence as a two dimensional tensor (batch, distribution dimensions)
loss(state)[source]
on_criterion(state)[source]
on_criterion_validation(state)[source]
with_beta(beta)[source]

Multiply the divergence by the given beta, as introduced by beta-vae.

@article{higgins2016beta,
  title={beta-vae: Learning basic visual concepts with a constrained variational framework},
  author={Higgins, Irina and Matthey, Loic and Pal, Arka and Burgess, Christopher and Glorot, Xavier and Botvinick, Matthew and Mohamed, Shakir and Lerchner, Alexander},
  year={2016}
}
Parameters:beta (float) – The beta (> 1) to multiply by.
Returns:self
Return type:Divergence
with_linear_capacity(min_c=0, max_c=25, steps=100000, gamma=1000)[source]

Limit divergence by capacity, linearly increased from min_c to max_c for steps, as introduced in Understanding disentangling in beta-VAE.

@article{burgess2018understanding,
  title={Understanding disentangling in beta-vae},
  author={Burgess, Christopher P and Higgins, Irina and Pal, Arka and Matthey, Loic and Watters, Nick and Desjardins, Guillaume and Lerchner, Alexander},
  journal={arXiv preprint arXiv:1804.03599},
  year={2018}
}
Parameters:
  • min_c (float) – Minimum capacity
  • max_c (float) – Maximum capacity
  • steps (int) – Number of steps to increase over
  • gamma (float) – Multiplicative gamma, usually a high number
Returns:

self

Return type:

Divergence

with_post_function(post_fcn)[source]

Register the given post function, to be applied after to loss after reduction.

Parameters:post_fcn – A function of loss which applies some operation (e.g. multiplying by beta)
Returns:self
Return type:Divergence
with_reduction(reduction_fcn)[source]

Override the reduction operation with the given function, use this if your divergence doesn’t output a two dimensional tensor.

Parameters:reduction_fcn – The function to be applied to the divergence output and return a single value
Returns:self
Return type:Divergence
with_sum_mean_reduction()[source]

Override the reduction function to take a sum over dimension one and a mean over dimension zero. (default)

Returns:self
Return type:Divergence
with_sum_sum_reduction()[source]

Override the reduction function to take a sum over all dimensions.

Returns:self
Return type:Divergence
class variational.divergence.SimpleExponentialSimpleExponentialKL(input_key, target_key, state_key=None)[source]

A KL divergence between two SimpleExponential (or similar) distributions.

Note

The distribution object must have lograte attribute
Args:
input_key: StateKey instance which will be mapped to the input distribution object. target_key: StateKey instance which will be mapped to the target distribution object. state_key: If not None, the value outputted by compute() is stored in state with the given key.
compute(input, target)[source]

Compute the loss with the given kwargs defined in the constructor.

Parameters:kwargs – The bound kwargs, taken from state with the keys given in the constructor
Returns:The calculated divergence as a two dimensional tensor (batch, distribution dimensions)
class variational.divergence.SimpleNormalSimpleNormalKL(input_key, target_key, state_key=None)[source]

A KL divergence between two SimpleNormal (or similar) distributions.

Note

The distribution objects must have mu and logvar attributes

Parameters:
  • input_keyStateKey instance which will be mapped to the input distribution object.
  • target_keyStateKey instance which will be mapped to the target distribution object.
  • state_key – If not None, the value outputted by compute() is stored in state with the given key.
compute(input, target)[source]

Compute the loss with the given kwargs defined in the constructor.

Parameters:kwargs – The bound kwargs, taken from state with the keys given in the constructor
Returns:The calculated divergence as a two dimensional tensor (batch, distribution dimensions)
class variational.divergence.SimpleNormalUnitNormalKL(input_key, state_key=None)[source]

A KL divergence between a SimpleNormal (or similar) instance and a fixed unit normal (N[0, 1]) target.

Note

The distribution object must have mu and logvar attributes

Parameters:
  • input_keyStateKey instance which will be mapped to the distribution object.
  • state_key – If not None, the value outputted by compute() is stored in state with the given key.
compute(input)[source]

Compute the loss with the given kwargs defined in the constructor.

Parameters:kwargs – The bound kwargs, taken from state with the keys given in the constructor
Returns:The calculated divergence as a two dimensional tensor (batch, distribution dimensions)
class variational.divergence.SimpleWeibullSimpleWeibullKL(input_key, target_key, state_key=None)[source]

A KL divergence between two SimpleWeibull (or similar) distributions. The distribution object must have lambda (scale) and k (shape) attributes.

@article{DBLP:journals/corr/Bauckhage14,
  author={Christian Bauckhage},
  title={Computing the Kullback-Leibler Divergence between two Generalized Gamma Distributions},
  journal={CoRR},
  volume={abs/1401.6853},
  year={2014}
}
Args:
input_key: StateKey instance which will be mapped to the input distribution object. target_key: StateKey instance which will be mapped to the target distribution object. state_key: If not None, the value outputted by compute() is stored in state with the given key.
compute(input, target)[source]

Compute the loss with the given kwargs defined in the constructor.

Parameters:kwargs – The bound kwargs, taken from state with the keys given in the constructor
Returns:The calculated divergence as a two dimensional tensor (batch, distribution dimensions)

Auto-Encoding

The auto encoder module includes an abstraction around the standard VAE architecture which allows for simple construction without loss of flexibility.

class variational.auto_encoder.AutoEncoderBase(latent_dims)[source]
decode(sample, state=None)[source]

Decode the given latent space sample batch to images.

Parameters:
  • sample – The latent space samples
  • state – The trial state
Returns:

Decoded images

encode(x, state=None)[source]

Encode the given batch of images and return latent space sample for each.

Parameters:
  • x – Batch of images to encode
  • state – The trial state
Returns:

Encoded samples / tuple of samples for different spaces

forward(x, state=None)[source]

Encode then decode the inputs, returning the result. Also binds the target as the input images in state.

Parameters:
  • x – Model input batch
  • state – The trial state
Returns:

Auto-Encoded images

Datasets

The datasets module includes a number of standard VAE datasets in the style of the torchvision.datasets module.

class variational.datasets.CelebA(root, transform=None, target_transform=None)[source]

CelebA auto-encoding dataset

Parameters:
  • root (str) – Root directory of dataset containing all aligned images in ‘root’
  • transform (Transform, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (Transform, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
class variational.datasets.CelebA_HQ(root, as_npy=False, transform=None)[source]

CelebA_HQ, high quality version of celebA auto-encoding dataset as introduced by Progressive GAN

Parameters:
  • root (str) – Root directory of dataset containing all hq images in ‘root’
  • as_npy (bool, optional) – If True, assume images are stored in numpy arrays. Else assume a standard image format
  • transform (Transform, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
static npy_loader(path)[source]
class variational.datasets.SimpleImageFolder(root, loader=None, extensions=None, transform=None, target_transform=None)[source]

Simple image folder dataset that loads all images from inside a folder and returns items in (image, image) tuple

Parameters:
  • root (str) – Root directory of dataset containing all aligned images
  • loader (function, optional) – Image loader function that takes a file or path and returns the loaded image (see torchvision.datasets.folder)
  • extensions (list of str, optional) – List of file extensions that can be loaded
  • transform (Transform, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (Transform, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
class variational.datasets.dSprites(root, download=False, transform=None)[source]

dSprites Dataset

Parameters:
  • root (str) – Root directory of dataset containing ‘dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz’ or to download it to
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
  • transform (Transform, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
download()[source]
get_img_by_latent(latent_code)[source]

Returns the image defined by the latent code

Parameters:latent_code (list of int) – Latent code of length 6 defining each generative factor
Returns:Image defined by given code
load_data()[source]
variational.datasets.make_dataset(dir, extensions)[source]

Visualisation

The visualisation module contains a number of latent space visualisation techniques designed to work with the base variation.auto_encoder alongside an abstraction that allows for creating custom visualisations.

class variational.visualisation.CodePathWalker(num_steps, p1, p2)[source]

Latent space walker that walks between two specified codes p1 and p2

Parameters:
  • num_steps (int) – Number of steps to take between points
  • p1 (torch.Tensor) – Batch of codes
  • p2 (torch.Tensor) – Batch of codes
vis(state)[source]

Create the tensor of images to be displayed

class variational.visualisation.ImagePathWalker(num_steps, im1, im2)[source]

Latent space walker that walks between two specified images im1 and im2

Parameters:
  • num_steps (int) – Number of steps to take between points
  • im1 (torch.Tensor) – Batch of images
  • im2 (torch.Tensor) – Batch of images
vis(state)[source]

Create the tensor of images to be displayed

class variational.visualisation.LatentWalker(same_image, row_size)[source]
Parameters:
  • same_image (bool) – If True, use the same image for all latent dimension walks. Else each dimension has different image
  • row_size (int) – Number of images displayed in each row of the grid.
for_data(data_key)[source]
Parameters:data_key (StateKey) – State key which will contain data to act on
Returns:self
Return type:LatentWalker
for_space(space_id)[source]

Sets the ID for which latent space to vary when model outputs [latent_space_0, latent_space_1, …]

Parameters:space_id (int) – ID of the latent space to vary
Returns:self
Return type:LatentWalker
on_train()[source]

Sets the walker to run during training

Returns:self
Return type:LatentWalker
on_val()[source]

Sets the walker to run during validation

Returns:self
Return type:LatentWalker
to_file(file)[source]
Parameters:file (string, pathlib.Path object or file object) – File in which result is saved
Returns:self
Return type:LatentWalker
to_key(state_key)[source]
Parameters:state_key (StateKey) – State key under which to store result
Returns:self
Return type:LatentWalker
vis(state)[source]

Create the tensor of images to be displayed

class variational.visualisation.LinSpaceWalker(lin_start=-1, lin_end=1, lin_steps=8, dims_to_walk=[0], zero_init=False, same_image=False)[source]

Latent space walker that explores each dimension linearly from start to end points

Parameters:
  • lin_start (float) – Starting point of linspace
  • lin_end (float) – End point of linspace
  • lin_steps (int) – Number of steps to take in linspace
  • dims_to_walk (list of int) – List of dimensions to walk
  • zero_init (bool) – If True, dimensions not being walked are 0. Else, they are obtained from encoder
  • same_image (bool) – If True, use same image for each dimension walked. Else, use different images
vis(state)[source]

Create the tensor of images to be displayed

class variational.visualisation.RandomWalker(var=1, num_images=32, uniform=False, row_size=8)[source]

Latent space walker that shows random samples from latent space

Parameters:
  • var (float or torch.Tensor) – Variance of random sample
  • num_images (int) – Number of random images to sample
  • uniform (bool) – If True, sample uniform distribution [-v, v). If False, sample normal distribution with var v
  • row_size (int) – Number of images displayed in each row of the grid.
vis(state)[source]

Create the tensor of images to be displayed

class variational.visualisation.ReconstructionViewer(row_size=8, recon_key=<sphinx.ext.autodoc.importer._MockObject object>)[source]

Latent space walker that just returns the reconstructed images for the batch

Parameters:
  • row_size (int) – Number of images displayed in each row of the grid.
  • recon_key (StateKey) – StateKey of the reconstructed images
vis(state)[source]

Create the tensor of images to be displayed