Vote up!

2

Variational Autoencoder (VAE) — PyTorch Tutorial

By Reza Kalantar.

In autoencoders, the information from the input data is mapped into a fixed latent representation. This is particularly useful when we aim to train models to generate deterministic predictions. In contrast, a variational autoencoder (VAE) converts the input data to a variational representation vector (as the name suggests), where the elements of this vector represent different attributes about the input data distribution. This probabilistic property of the VAE makes it a generative model. The latent representation in VAE is composed of a probability distribution (μ, σ) that best defines our input data. To find out more about the intuition of VAEs, I recommend reading Understanding Variational Autoencoders (VAEs) and What is a variational autoencoder?

In this article, we only focus on a simple VAE in PyTorch and visualize its latent representation after training on the MNIST dataset. Let’s begin by importing some libraries:

import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from mpl_toolkits.axes_grid1 import ImageGrid
from torchvision.utils import save_image, make_grid

## and download the MNIST dataset and make dataloaders:

# create a transofrm to apply to each datapoint
transform = transforms.Compose([transforms.ToTensor()])

# download the MNIST datasets
path = '~/datasets'
train_dataset = MNIST(path, transform=transform, download=True)
test_dataset  = MNIST(path, transform=transform, download=True)

# create train and test dataloaders
batch_size = 100
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Let’s visualize some sample training data:

# get 25 sample training images for visualization
dataiter = iter(train_loader)
image = dataiter.next()

num_samples = 25
sample_images = [image[0][i,0] for i in range(num_samples)] 

fig = plt.figure(figsize=(5, 5))
grid = ImageGrid(fig, 111, nrows_ncols=(5, 5), axes_pad=0.1)

for ax, im in zip(grid, sample_images):
    ax.imshow(im, cmap='gray')
    ax.axis('off')

plt.show()

25 sample training images

Now, we create a simple VAE which has fully-connected encoders and decoders . The input dimension is 784 which is the flattened dimension of MNIST images (28×28). In the encoder, the mean (μ) and variance (σ²) vectors are our variational representation vectors (size=200). Notice that we multiply the latent variance with the epsilon (ε) parameter for reparameterization before decoding. This allows us to perform backpropagation and tackle the node stochasticity. Read more on reparameterization here.

Also, our final encoder dimension has dimension 2 which are the μ and σ vectors. These continuous vectors define our latent space distribution that allows us to sample images in VAE.

class VAE(nn.Module):

    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=200, device=device):
        super(VAE, self).__init__()

        # encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, latent_dim),
            nn.LeakyReLU(0.2)
            )
        
        # latent mean and variance 
        self.mean_layer = nn.Linear(latent_dim, 2)
        self.logvar_layer = nn.Linear(latent_dim, 2)
        
        # decoder
        self.decoder = nn.Sequential(
            nn.Linear(2, latent_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(latent_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
            )
     
    def encode(self, x):
        x = self.encoder(x)
        mean, logvar = self.mean_layer(x), self.logvar_layer(x)
        return mean, logvar

    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(device)      
        z = mean + var*epsilon
        return z

    def decode(self, x):
        return self.decoder(x)

    def forward(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterization(mean, logvar)
        x_hat = self.decode(z)
        return x_hat, mean, log_var
        
    def forward(self, x):
        mean, log_var = self.encode(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var)) 
        x_hat = self.decode(z)  
        return x_hat, mean, log_var

Now, we can define our model and optimizer:

model = VAE().to(device)
optimizer = Adam(model.parameters(), lr=1e-3)

The loss function in VAE consists of reproduction loss and the Kullback–Leibler (KL) divergence. The KL divergence is a metric used to measure the distance between two probability distributions. KL divergence is an important concept in generative modelling, but in this tutorial we won’t go into more detail. Read more: Intuitive Guide to Understanding KL Divergence.

def loss_function(x, x_hat, mean, log_var):
    reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')
    KLD = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())

    return reproduction_loss + KLD

Finally, we can train our model:

def train(model, optimizer, epochs, device):
    model.train()
    for epoch in range(epochs):
        overall_loss = 0
        for batch_idx, (x, _) in enumerate(train_loader):
            x = x.view(batch_size, x_dim).to(device)

            optimizer.zero_grad()

            x_hat, mean, log_var = model(x)
            loss = loss_function(x, x_hat, mean, log_var)
            
            overall_loss += loss.item()
            
            loss.backward()
            optimizer.step()

        print("\tEpoch", epoch + 1, "\tAverage Loss: ", overall_loss/(batch_idx*batch_size))
    return overall_loss

And train:

train(model, optimizer, epochs=50, device=device)

We now know that all we need to generate an image from the latent space is two float values (mean and variance). Let’s generate some images from the latent space:

def generate_digit(mean, var):
    z_sample = torch.tensor([[mean, var]], dtype=torch.float).to(device)
    x_decoded = model.decode(z_sample)
    digit = x_decoded.detach().cpu().reshape(28, 28) # reshape vector to 2d array
    plt.imshow(digit, cmap='gray')
    plt.axis('off')
    plt.show()

generate_digit(0.0, 1.0), generate_digit(1.0, 0.0)

Interesting! A more impressive view of the latent space:

def plot_latent_space(model, scale=1.0, n=25, digit_size=28, figsize=15):
    # display a n*n 2D manifold of digits
    figure = np.zeros((digit_size * n, digit_size * n))

    # construct a grid 
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = torch.tensor([[xi, yi]], dtype=torch.float).to(device)
            x_decoded = model.decode(z_sample)
            digit = x_decoded[0].detach().cpu().reshape(digit_size, digit_size)
            figure[i * digit_size : (i + 1) * digit_size, j * digit_size : (j + 1) * digit_size,] = digit

    plt.figure(figsize=(figsize, figsize))
    plt.title('VAE Latent Space Visualization')
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("mean, z [0]")
    plt.ylabel("var, z [1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()


plot_latent_space(model)

Latent space visualization, range: [-1.0, 1.0]

This is what the latent space looks like for mean and variance values between -1.0 and 1.0. What happens if we change this scale to -5.0 and 5.0?

Latent space visualization, range: [-5.0, 5.0]

Again, interesting! We can now see the range of mean and variance values that most digit representations lie within. Now, we know how to build a simple VAE from scratch, sample images and visualize the latent space. But VAEs do not end here, there are more advanced techniques that make representation learning even more fascinating. I will explore them in my future articles. Find me on Github, Linkedin and Google Scholar.

 

Please login or register to post a comment.