Source code for models.pinn.piratenet

# This architectre comes from the Piratenet paper (https://arxiv.org/pdf/2402.00326.pdf)

import torch
import torch.nn as nn
import torch.nn.functional as F

[docs] class PirateNetBlock(nn.Module): def __init__(self, hidden_dim): super(PirateNetBlock, self).__init__() self.dense1 = nn.Linear(hidden_dim, hidden_dim) self.dense2 = nn.Linear(hidden_dim, hidden_dim) self.dense3 = nn.Linear(hidden_dim, hidden_dim) self.alpha = nn.Parameter(torch.zeros(1))
[docs] def forward(self, x, u, v): f = F.tanh(self.dense1(x)) z1 = f * u + (1 - f) * v g = F.tanh(self.dense2(z1)) z2 = g * u + (1 - g) * v h = F.tanh(self.dense3(z2)) return self.alpha * h + (1 - self.alpha) * x
[docs] class PirateNet(nn.Module): def __init__(self, input_dim, output_dim, num_blocks, hidden_dim=256, s=1.0, activation=F.tanh): super(PirateNet, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.num_blocks = num_blocks self.hidden_dim = hidden_dim self.s = s self.activation = activation self.B = nn.Parameter(torch.randn(input_dim, hidden_dim // 2) * s) self.embedding = lambda x: torch.cat([torch.cos(torch.matmul(x, self.B)), torch.sin(torch.matmul(x, self.B))], dim=-1) # (B, D) x (D, H) --> (B, H) --cat--> (B, 2*H) self.blocks = nn.ModuleList([PirateNetBlock(hidden_dim) for _ in range(num_blocks)]) self.U = nn.Linear(hidden_dim, hidden_dim) self.V = nn.Linear(hidden_dim, hidden_dim) # Final layer self.final_layer = nn.Linear(hidden_dim, output_dim, bias=False) self.initialize_weights()
[docs] def initialize_weights(self): for module in self.modules(): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: module.bias.data.zero_()
[docs] def forward(self, x): x = self.embedding(x) u = self.activation(self.U(x)) v = self.activation(self.V(x)) for block in self.blocks: x = block(x, u, v) return self.final_layer(x)
[docs] def initialize_last_layer(self, Y, input_data): phi = self.embedding(input_data) # W = torch.linalg.lstsq(phi, Y).solution W = torch.linalg.pinv(phi) @ Y print(W.shape, self.final_layer.weight.data.shape, phi.shape, torch.linalg.pinv(phi).shape) self.final_layer.weight.data = W.T