import torch
from torch import nn
from torch.nn import functional as F
[docs]
class VAE(nn.Module):
"""
Variational autoencoder designed to reduce dimensionality of a mesh data by encoding their
coordinates into a latent space, being able to reconstruct them from that latent space.
Adapted from the base VAE in https://github.com/AntixK/PyTorch-VAE.
Args:
in_channels (int): Number of coordinate points in the airfoil data.
latent_dim (int): Dimensionality of the latent space.
hidden_dims List[int]: List of hidden dimensions for the encoder and decoder. Assumed symmetrical.
act_function (Callable): Activation function to use for the encoder and decoder.
"""
def __init__(self,
in_channels,
latent_dim,
hidden_dims = None,
act_function = nn.ELU(),
**kwargs):
super(VAE, self).__init__()
self.latent_dim = latent_dim
self.in_channels = in_channels
if hidden_dims is None:
hidden_dims = [128, 64, 32]
self.name = f'VAE_MLP{hidden_dims[0]}_{in_channels}_{latent_dim}'
# Build Encoder
modules = []
modules.append(nn.Linear(in_channels, hidden_dims[0]))
modules.append(act_function)
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.Linear(hidden_dims[i], hidden_dims[i+1]),
act_function)
)
self.encoder = nn.Sequential(*modules)
# Latent variable distributions
self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1], latent_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1])
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.Linear(hidden_dims[i], hidden_dims[i+1]),
act_function)
)
modules.append(nn.Linear(hidden_dims[-1], in_channels))
self.decoder = nn.Sequential(*modules)
[docs]
def encode(self, x):
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
Args:
X: (Tensor) Input tensor to encoder [N x D_in]
Return:
Tuple[Tensor, Tensor] List of latent codes
"""
encoded = self.encoder(x)
mu = self.fc_mu(encoded)
log_var = self.fc_var(encoded)
return [mu, log_var]
[docs]
def reparameterize(self, mu, logvar):
"""
Reparameterization trick to sample from N(mu, var) from N(0,1).
Args:
mu (Tensor): Mean of the latent Gaussian [B x D_latent]
logvar (Tensor): Standard deviation of the latent Gaussian [B x D_latent]
Returns:
Tensor: Sampled latent code [B x D_latent]
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
[docs]
def decode(self, z):
"""
Maps the given latent codes onto the coordinate space.
Args:
z (Tensor): Latent code [B x D_latent]
Returns:
Tensor: Reconstructed input [B x D_out]
"""
decoded = self.decoder_input(z)
decoded = self.decoder(decoded)
return decoded
[docs]
def forward(self, input, **kwargs):
"""
Forward pass through the network.
Args:
input (Tensor): Input tensor to the VAE [N x D_in]
Returns:
List[Tensor]: List containing the reconstructed input, original input, mean, and log variance
"""
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return [self.decode(z), input, mu, log_var]
[docs]
def loss_function(self,
pred,
*args,
**kwargs):
"""
Computes the VAE loss function using the Kullback-Leibler divergence.
Args:
pred (List[Tensor]): List containing the reconstructed input, original input, mean, and log variance
*args: Additional arguments
**kwargs: Additional keyword arguments, including 'weight' for Beta-VAE
Returns:
dict: Dictionary containing the total loss, reconstruction loss, and KL divergence loss
"""
recons = pred[0]
input = pred[1]
mu = pred[2]
log_var = pred[3]
recon_loss = F.mse_loss(recons, input)
kl_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1))
weight = kwargs['weight']
loss = recon_loss + weight * kl_loss
return {'loss': loss, 'recon_loss': recon_loss.detach(), 'kl_loss': -kl_loss.detach()}
[docs]
def sample(self,
num_samples,
current_device,
std_coef=1.0,
**kwargs):
"""
Samples from the latent space and returns the corresponding reconstructed input.
Args:
num_samples (int): Number of samples
current_device (int): Device to run the model
std_coef (float, optional): Standard deviation coefficient for sampling. Default is 1.0.
Returns:
Tensor: Sampled and decoded tensor
"""
mean = torch.zeros(num_samples, self.latent_dim)
std = torch.ones(num_samples, self.latent_dim) * std_coef
z = torch.normal(mean, std)
z = z.to(current_device)
samples = self.decode(z)
return samples