Source code for models.VAE

import torch
from torch import nn
from torch.nn import functional as F

[docs] class VAE(nn.Module): """ Variational autoencoder designed to reduce dimensionality of a mesh data by encoding their coordinates into a latent space, being able to reconstruct them from that latent space. Adapted from the base VAE in https://github.com/AntixK/PyTorch-VAE. Args: in_channels (int): Number of coordinate points in the airfoil data. latent_dim (int): Dimensionality of the latent space. hidden_dims List[int]: List of hidden dimensions for the encoder and decoder. Assumed symmetrical. act_function (Callable): Activation function to use for the encoder and decoder. """ def __init__(self, in_channels, latent_dim, hidden_dims = None, act_function = nn.ELU(), **kwargs): super(VAE, self).__init__() self.latent_dim = latent_dim self.in_channels = in_channels if hidden_dims is None: hidden_dims = [128, 64, 32] self.name = f'VAE_MLP{hidden_dims[0]}_{in_channels}_{latent_dim}' # Build Encoder modules = [] modules.append(nn.Linear(in_channels, hidden_dims[0])) modules.append(act_function) for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.Linear(hidden_dims[i], hidden_dims[i+1]), act_function) ) self.encoder = nn.Sequential(*modules) # Latent variable distributions self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim) self.fc_var = nn.Linear(hidden_dims[-1], latent_dim) # Build Decoder modules = [] self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1]) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.Linear(hidden_dims[i], hidden_dims[i+1]), act_function) ) modules.append(nn.Linear(hidden_dims[-1], in_channels)) self.decoder = nn.Sequential(*modules)
[docs] def encode(self, x): """ Encodes the input by passing through the encoder network and returns the latent codes. Args: X: (Tensor) Input tensor to encoder [N x D_in] Return: Tuple[Tensor, Tensor] List of latent codes """ encoded = self.encoder(x) mu = self.fc_mu(encoded) log_var = self.fc_var(encoded) return [mu, log_var]
[docs] def reparameterize(self, mu, logvar): """ Reparameterization trick to sample from N(mu, var) from N(0,1). Args: mu (Tensor): Mean of the latent Gaussian [B x D_latent] logvar (Tensor): Standard deviation of the latent Gaussian [B x D_latent] Returns: Tensor: Sampled latent code [B x D_latent] """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu
[docs] def decode(self, z): """ Maps the given latent codes onto the coordinate space. Args: z (Tensor): Latent code [B x D_latent] Returns: Tensor: Reconstructed input [B x D_out] """ decoded = self.decoder_input(z) decoded = self.decoder(decoded) return decoded
[docs] def forward(self, input, **kwargs): """ Forward pass through the network. Args: input (Tensor): Input tensor to the VAE [N x D_in] Returns: List[Tensor]: List containing the reconstructed input, original input, mean, and log variance """ mu, log_var = self.encode(input) z = self.reparameterize(mu, log_var) return [self.decode(z), input, mu, log_var]
[docs] def loss_function(self, pred, *args, **kwargs): """ Computes the VAE loss function using the Kullback-Leibler divergence. Args: pred (List[Tensor]): List containing the reconstructed input, original input, mean, and log variance *args: Additional arguments **kwargs: Additional keyword arguments, including 'weight' for Beta-VAE Returns: dict: Dictionary containing the total loss, reconstruction loss, and KL divergence loss """ recons = pred[0] input = pred[1] mu = pred[2] log_var = pred[3] recon_loss = F.mse_loss(recons, input) kl_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1)) weight = kwargs['weight'] loss = recon_loss + weight * kl_loss return {'loss': loss, 'recon_loss': recon_loss.detach(), 'kl_loss': -kl_loss.detach()}
[docs] def sample(self, num_samples, current_device, std_coef=1.0, **kwargs): """ Samples from the latent space and returns the corresponding reconstructed input. Args: num_samples (int): Number of samples current_device (int): Device to run the model std_coef (float, optional): Standard deviation coefficient for sampling. Default is 1.0. Returns: Tensor: Sampled and decoded tensor """ mean = torch.zeros(num_samples, self.latent_dim) std = torch.ones(num_samples, self.latent_dim) * std_coef z = torch.normal(mean, std) z = z.to(current_device) samples = self.decode(z) return samples