from abc import ABC, abstractmethod
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.utils
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from cetaceo.data import BaseDataset
[docs]
class PINN(ABC):
"""
This class represents a Physics-Informed Neural Network (PINN) model.
Args:
neural_net (torch.nn.Module): The neural network model.
device (str): The device to run the model on (e.g., 'cpu', 'cuda').
Attributes:
device (str): The device the model is running on.
model (torch.nn.Module): The neural network model.
"""
def __init__(self, neural_net, device):
self.device = device
self.model = neural_net.to(device)
def __call__(self, x):
"""
Forward pass of the PINN model.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor.
"""
return self.model(x)
def _prepare_input_variables(self, x_batch):
"""
Prepares the input variables for training.
Args:
x_batch (torch.Tensor): The input batch tensor.
Returns:
List[torch.Tensor]: The list of prepared input variables.
"""
input_variables = []
for input_variable in range(x_batch.shape[1]):
flow_variable = x_batch[:, input_variable : input_variable + 1]
flow_variable.requires_grad_(True)
input_variables.append(flow_variable)
return input_variables
def _create_dataset(self, x, y=None):
"""
Creates a PyTorch dataset.
Args:
x (torch.Tensor): The input tensor.
y (torch.Tensor, optional): The target tensor. Defaults to None.
Returns:
torch.utils.data.TensorDataset: The created dataset.
"""
if y is not None:
dataset = TensorDataset(x, y)
else:
dataset = TensorDataset(x)
return dataset
def _get_dataloader(self, dataset, batch_size=None):
"""
Creates data loaders for training.
Args:
x (torch.Tensor): The input tensor.
y (torch.Tensor, optional): The target tensor. Defaults to None.
batch_size (int, optional): The batch size. Defaults to None.
Returns:
Union[torch.utils.data.DataLoader, List[torch.Tensor]]: The data loaders.
"""
# dataset = self._create_dataset(x, y)
if batch_size is not None:
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1, persistent_workers=True)
else:
data_loader = [dataset[:]]
return data_loader
[docs]
def bc_data_loss(self, pred, y, boundary_conditions, use_bfloat16=False):
"""
Computes the loss from boundary conditions and data.
Args:
pred (torch.Tensor): The predicted output tensor.
y (torch.Tensor): The target tensor.
boundary_conditions (List[BoundaryCondition]): The list of boundary conditions.
use_bfloat16 (bool, optional): Whether to use bfloat16 precision. Defaults to False.
Returns:
List[torch.Tensor]: The list of loss tensors.
"""
if use_bfloat16:
with torch.autocast("cuda", dtype=torch.bfloat16):
bc_losses = [bc.loss(self.model(bc.points.to(self.device))) for bc in boundary_conditions]
else:
bc_losses = [bc.loss(self.model(bc.points.to(self.device))) for bc in boundary_conditions]
if y is not None:
data_loss = torch.nn.functional.mse_loss(pred, y.to(self.device))
bc_losses.append(data_loss)
return bc_losses
[docs]
def compute_loss(self, x, y, boundary_conditions, use_bfloat16=False):
"""
Computes the total loss for training.
Args:
x (torch.Tensor): The input tensor.
y (torch.Tensor): The target tensor.
boundary_conditions (List[BoundaryCondition]): The list of boundary conditions.
use_bfloat16 (bool, optional): Whether to use bfloat16 precision. Defaults to False.
Returns:
List[torch.Tensor]: The list of loss tensors.
"""
input_variables = self._prepare_input_variables(x)
input_tensor = torch.cat(input_variables, dim=1).to(self.device)
if use_bfloat16:
with torch.autocast("cuda", dtype=torch.bfloat16):
pred = self.model(input_tensor)
else:
pred = self.model(input_tensor)
return [self.pde_loss(pred, *input_variables)] + self.bc_data_loss(pred, y, boundary_conditions, use_bfloat16)
[docs]
def fit(
self,
train_dataset: BaseDataset,
optimizer_class=torch.optim.Adam,
optimizer_params={},
lr_scheduler_class=None,
lr_scheduler_params={},
epochs=1000,
boundary_conditions=[],
update_logs_steps=1,
loaded_logs=None,
batch_size=None,
eval_dataset: BaseDataset = None,
use_bfloat16=False,
**kwargs,
):
"""
Trains the PINN model.
Args:
train_dataset (BaseDataset): The training dataset.
optimizer_class (torch.optim.Optimizer, optional): The optimizer class. Defaults to ``torch.optim.Adam``.
optimizer_params (dict, optional): The optimizer parameters. Defaults to ``{}``.
lr_scheduler_class (torch.optim.lr_scheduler._LRScheduler, optional): The learning rate scheduler class. Defaults to ``None``.
lr_scheduler_params (dict, optional): The learning rate scheduler parameters. Defaults to ``{}``.
epochs (int, optional): The number of epochs to train for. Defaults to ``1000``.
boundary_conditions (List[BoundaryCondition], optional): The list of boundary conditions. Defaults to ``[]``.
update_logs_steps (int, optional): The interval for updating the progress. Defaults to ``100``.
loaded_logs (dict, optional): Loaded training logs to be used as initial logs. Defaults to ``None``.
batch_size (int, optional): The batch size. If none, the batch size will be equal to the number of collocation points given on `train_dataset`. Defaults to ``None``.
eval_dataset (BaseDataset, optional): The evaluation dataset. Defaults to ``None``.
use_bfloat16 (bool, optional): Whether to use bfloat16 precision. Defaults to ``False``.
**kwargs: Additional keyword arguments.
Returns:
dict: The training logs.
"""
logs = (
loaded_logs
if loaded_logs is not None
else {
"loss_from_pde": [],
"loss_from_data_and_bc": [],
"total_loss": [],
}
)
train_data_loader = self._get_dataloader(train_dataset, batch_size)
test_data_loader = None
if eval_dataset is not None:
test_data_loader = self._get_dataloader(eval_dataset, batch_size)
if "test_loss" not in logs:
logs["test_loss"] = []
optimizer = optimizer_class(self.model.parameters(), **optimizer_params)
if lr_scheduler_class is not None:
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_params)
def closure():
x_batch = batch[0].to(self.device)
y_batch = batch[1].to(self.device) if len(batch) == 2 else None
optimizer.zero_grad()
losses = self.compute_loss(x_batch, y_batch, boundary_conditions, use_bfloat16)
loss = sum(losses)
loss.backward()
loss_from_pde = losses[0].item()
logs["loss_from_pde"].append(loss_from_pde)
logs["loss_from_data_and_bc"].append(loss.item() - loss_from_pde)
logs["total_loss"].append(loss.item())
if update_logs_steps != 0 and (epoch % update_logs_steps == 0):
extended_desc = ''
if len(losses) > 1:
extended_desc = f", data/bc losses: [{', '.join(f'{x:.4e}' for x in losses[1:])}]"
if 'test_loss' in logs and len(logs['test_loss']) > 0:
extended_desc += f", test loss: {logs['test_loss'][-1]:.4e}"
desc = f"Epoch {epoch+1}/{epochs} Iteration {closure.iteration}. Pde loss: {loss_from_pde:.4e}" + extended_desc
pbar.set_description(desc)
closure.iteration += 1
return loss
try:
self.model.train()
desc = f"Epoch 1/{epochs} Iteration 0."
pbar = tqdm(range(epochs), desc=desc)
for epoch in pbar:
closure.iteration = 0
for batch in train_data_loader: #data_iterable:
optimizer.step(closure=closure)
if lr_scheduler_class is not None:
lr_scheduler.step()
if test_data_loader is not None:
self.model.eval()
test_loss = 0
for batch in test_data_loader:
x_batch = batch[0].to(self.device)
y_batch = batch[1].to(self.device) if len(batch) == 2 else None
losses = self.compute_loss(x_batch, y_batch, boundary_conditions)
test_loss += sum(losses).item()
logs["test_loss"].append(test_loss / len(test_data_loader))
self.model.train()
except KeyboardInterrupt:
print("Training stopped manually")
return logs
[docs]
def predict(self, X: BaseDataset, **kwargs) -> np.ndarray:
"""
Predicts for the input dataset.
Args:
X (BaseDataset): The input dataset.
Returns:
np.ndarray: The predictions of the model.
"""
self.model.eval()
data = X[:]
input_data = data[0] # keep only the input data
return self.model(input_data.to(self.device)).detach().cpu().numpy()
def __repr__(self):
"""
Returns a string representation of the PINN model.
Returns:
str: The string representation.
"""
print(f"Number of parameters: {sum(p.numel() for p in self.model.parameters())}")
return self.model.__repr__()
[docs]
def plot_training_logs(self, logs):
"""
Plots the training logs.
Args:
logs (dict): The training logs.
"""
plt.figure(figsize=(10, 6))
plt.plot(logs["loss_from_pde"], label="PDE Loss")
plt.plot(logs["loss_from_data_and_bc"], label="Data Conditions and BC Loss")
plt.plot(logs["total_loss"], label="Total Loss")
if "test_loss" in logs and len(logs["test_loss"]) != 0:
total_epochs = len(logs["test_loss"]) #152 = 150 + 2
total_iters = len(logs["total_loss"]) #320 = 150 * 2 + 2 * 10
iters_per_epoch = total_iters // total_epochs
plt.plot(np.arange(iters_per_epoch + total_iters % total_epochs, total_iters+1, step=iters_per_epoch), logs["test_loss"], label="Test Loss")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.yscale("log")
plt.title("Training Losses")
plt.legend()
plt.show()
[docs]
@abstractmethod
def pde_loss(self, pred, *input_variables):
"""
Computes the loss from the partial differential equation (PDE).
Args:
pred (torch.Tensor): The predicted output tensor.
*input_variables (torch.Tensor): The input variables for the PDE. e.g. x, y, t.
Returns:
torch.Tensor: The loss tensor.
"""
pass
[docs]
def save(self, path):
"""
Saves the model to a file using torchscript.
Args:
path (str): The path to save the model.
"""
path = Path(path)
scripted_model = torch.jit.script(self.model)
scripted_model.save(path)
[docs]
@classmethod
def load(cls, path, device='cpu'):
"""
Loads the model from a file.
Args:
path (str): The path to load the model.
neural_net (torch.nn.Module): The neural network model.
device (str, optional): The device to run the model on. Defaults to 'cpu'.
Returns:
PINN: The loaded PINN model.
"""
model = torch.jit.load(path, map_location=device)
return cls(neural_net=model, device=device)
[docs]
class OldPINN(ABC):
"""
This class represents a Physics-Informed Neural Network (PINN) model.
Args:
neural_net (torch.nn.Module): The neural network model.
device (str): The device to run the model on (e.g., 'cpu', 'cuda').
Attributes:
device (str): The device the model is running on.
model (torch.nn.Module): The neural network model.
"""
def __init__(self, neural_net, device):
"""
Initializes the PINN model.
Args:
neural_net (torch.nn.Module): The neural network model.
device (str): The device to run the model on (e.g., 'cpu', 'cuda').
"""
self.device = device
self.model = neural_net.to(device)
def __call__(self, x):
"""
Forward pass of the PINN model.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor.
"""
return self.model(x)
def _prepare_input_variables(self, x_batch):
"""
Prepares the input variables for training.
Args:
x_batch (torch.Tensor): The input batch tensor.
Returns:
List[torch.Tensor]: The list of prepared input variables.
"""
input_variables = []
for input_variable in range(x_batch.shape[1]):
flow_variable = x_batch[:, input_variable : input_variable + 1]
flow_variable.requires_grad_(True)
input_variables.append(flow_variable)
return input_variables
def _create_dataset(self, x, y=None):
"""
Creates a PyTorch dataset.
Args:
x (torch.Tensor): The input tensor.
y (torch.Tensor, optional): The target tensor. Defaults to None.
Returns:
torch.utils.data.TensorDataset: The created dataset.
"""
if y is not None:
dataset = TensorDataset(x, y)
else:
dataset = TensorDataset(x)
return dataset
def _get_dataloaders(self, x, y=None, batch_size=None):
"""
Creates data loaders for training.
Args:
x (torch.Tensor): The input tensor.
y (torch.Tensor, optional): The target tensor. Defaults to None.
batch_size (int, optional): The batch size. Defaults to None.
Returns:
Union[torch.utils.data.DataLoader, List[torch.Tensor]]: The data loaders.
"""
dataset = self._create_dataset(x, y)
if batch_size is not None:
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=1, persistent_workers=True)
else:
data_loader = [dataset.tensors]
return data_loader
[docs]
def bc_data_loss(self, pred, y, boundary_conditions, use_bfloat16=False):
"""
Computes the loss from boundary conditions and data.
Args:
pred (torch.Tensor): The predicted output tensor.
y (torch.Tensor): The target tensor.
boundary_conditions (List[BoundaryCondition]): The list of boundary conditions.
use_bfloat16 (bool, optional): Whether to use bfloat16 precision. Defaults to False.
Returns:
List[torch.Tensor]: The list of loss tensors.
"""
if use_bfloat16:
with torch.autocast("cuda", dtype=torch.bfloat16):
bc_losses = [bc.loss(self.model(bc.points.to(self.device))) for bc in boundary_conditions]
else:
bc_losses = [bc.loss(self.model(bc.points.to(self.device))) for bc in boundary_conditions]
if y is not None:
data_loss = torch.nn.functional.mse_loss(pred, y.to(self.device))
bc_losses.append(data_loss)
return bc_losses
[docs]
def compute_loss(self, x, y, boundary_conditions, use_bfloat16=False):
"""
Computes the total loss for training.
Args:
x (torch.Tensor): The input tensor.
y (torch.Tensor): The target tensor.
boundary_conditions (List[BoundaryCondition]): The list of boundary conditions.
use_bfloat16 (bool, optional): Whether to use bfloat16 precision. Defaults to False.
Returns:
List[torch.Tensor]: The list of loss tensors.
"""
input_variables = self._prepare_input_variables(x)
input_tensor = torch.cat(input_variables, dim=1).to(self.device)
if use_bfloat16:
with torch.autocast("cuda", dtype=torch.bfloat16):
pred = self.model(input_tensor)
else:
pred = self.model(input_tensor)
return [self.pde_loss(pred, *input_variables)] + self.bc_data_loss(pred, y, boundary_conditions, use_bfloat16)
[docs]
def train_model(
self,
optimizer,
epochs,
x_train,
y_train=None,
boundary_conditions=[],
print_every=100,
loaded_logs=None,
batch_size=None,
x_test=None,
y_test=None,
use_bfloat16=False,
):
"""
Trains the PINN model.
Args:
optimizer (torch.optim.Optimizer): The optimizer used for the training.
epochs (int): The number of epochs to train for.
x_train (torch.Tensor): The input training tensor containing the collocation points.
y_train (torch.Tensor, optional): The target training tensor if any simulation data wants to be given. Defaults to ``None``.
boundary_conditions (List[BoundaryCondition], optional): The list of boundary conditions. Defaults to ``[]``.
print_every (int, optional): The interval for printing progress. Defaults to ``100``.
loaded_logs (dict, optional): Loaded training logs to be used as initial logs. Defaults to ``None``.
batch_size (int, optional): The batch size. If none, the batch size will be equal to the number of collocation points given on `x_train` Defaults to ``None``.
x_test (torch.Tensor, optional): The input test tensor containg the test points. Defaults to ``None``.
y_test (torch.Tensor, optional): The target test tensor with simulation data on the test points. If no value is given, then the test loss will be the loss o the PDE + the loss of the boundary conditions. Defaults to ``None``.
use_bfloat16 (bool, optional): Whether to use bfloat16 precision. Defaults to ``False``.
Returns:
dict: The training logs.
"""
logs = (
loaded_logs
if loaded_logs is not None
else {
"loss_from_pde": [],
"loss_from_data_and_bc": [],
"total_loss": [],
}
)
train_data_loader = self._get_dataloaders(x_train, y_train, batch_size)
test_data_loader = None
if x_test is not None:
dataset_test = self._create_dataset(x_test, y_test)
test_data_loader = [dataset_test.tensors]
logs["test_loss"] = []
def closure(batch):
x_batch = batch[0].to(self.device)
y_batch = batch[1].to(self.device) if y_train is not None else None
optimizer.zero_grad()
losses = self.compute_loss(x_batch, y_batch, boundary_conditions, use_bfloat16)
loss = sum(losses)
loss.backward()
loss_from_pde = losses[0].item()
logs["loss_from_pde"].append(loss_from_pde)
logs["loss_from_data_and_bc"].append(loss.item() - loss_from_pde)
logs["total_loss"].append(loss.item())
if print_every != 0 and (closure.iteration % print_every == 0):
extended_desc = ''
if len(losses) > 1:
extended_desc = f", data/bc losses: [{', '.join(f'{x:.4e}' for x in losses[1:])}]"
if 'test_loss' in logs and len(logs['test_loss']) > 0:
extended_desc += f", test loss: {logs['test_loss'][-1]:.4e}"
desc = f"Epoch {epoch+1}/{epochs} Iteration {closure.iteration}. Pde loss: {loss_from_pde:.4e}" + extended_desc
pbar.set_description(desc)
closure.iteration += 1
return loss
try:
for epoch in range(epochs):
closure.iteration = 0
desc = f"Epoch {epoch+1}/{epochs} Iteration {closure.iteration}. Pde loss: , data loss: "
if print_every != 0 and epoch % print_every == 0:
pbar = tqdm(train_data_loader, desc=desc)
data_iterable = pbar if print_every != 0 else train_data_loader
for batch in data_iterable:
optimizer.step(lambda: closure(batch))
if test_data_loader is not None:
self.model.eval()
test_loss = 0
for batch in test_data_loader:
x_batch = batch[0].to(self.device)
y_batch = batch[1].to(self.device) if y_test is not None else None
losses = self.compute_loss(x_batch, y_batch, boundary_conditions)
test_loss += sum(losses).item()
logs["test_loss"].append(test_loss / len(test_data_loader))
self.model.train()
except KeyboardInterrupt:
print("Training stopped manually")
return logs
def __repr__(self):
"""
Returns a string representation of the PINN model.
Returns:
str: The string representation.
"""
print(f"Number of parameters: {sum(p.numel() for p in self.model.parameters())}")
return self.model.__repr__()
[docs]
def plot_training_logs(self, logs):
"""
Plots the training logs.
Args:
logs (dict): The training logs.
"""
plt.figure(figsize=(10, 6))
plt.plot(logs["loss_from_pde"], label="PDE Loss")
plt.plot(logs["loss_from_data_and_bc"], label="Data Conditions and BC Loss")
plt.plot(logs["total_loss"], label="Total Loss")
if "test_loss" in logs and len(logs["test_loss"]) != 0:
total_epochs = len(logs["test_loss"])
total_iters = len(logs["total_loss"])
iters_per_epoch = total_iters // total_epochs
plt.plot(np.arange(iters_per_epoch, total_iters+1, step=iters_per_epoch), logs["test_loss"], label="Test Loss")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.yscale("log")
plt.title("Training Losses")
plt.legend()
plt.show()
[docs]
@abstractmethod
def pde_loss(self, pred, *input_variables):
"""
Computes the loss from the partial differential equation (PDE).
Args:
pred (torch.Tensor): The predicted output tensor.
*input_variables (torch.Tensor): The input variables for the PDE. e.g. x, y, t.
Returns:
torch.Tensor: The loss tensor.
"""
pass