# This architectre comes from the Piratenet paper (https://arxiv.org/pdf/2402.00326.pdf)
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]
class PirateNetBlock(nn.Module):
def __init__(self, hidden_dim):
super(PirateNetBlock, self).__init__()
self.dense1 = nn.Linear(hidden_dim, hidden_dim)
self.dense2 = nn.Linear(hidden_dim, hidden_dim)
self.dense3 = nn.Linear(hidden_dim, hidden_dim)
self.alpha = nn.Parameter(torch.zeros(1))
[docs]
def forward(self, x, u, v):
f = F.tanh(self.dense1(x))
z1 = f * u + (1 - f) * v
g = F.tanh(self.dense2(z1))
z2 = g * u + (1 - g) * v
h = F.tanh(self.dense3(z2))
return self.alpha * h + (1 - self.alpha) * x
[docs]
class PirateNet(nn.Module):
def __init__(self, input_dim, output_dim, num_blocks, hidden_dim=256, s=1.0, activation=F.tanh):
super(PirateNet, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.num_blocks = num_blocks
self.hidden_dim = hidden_dim
self.s = s
self.activation = activation
self.B = nn.Parameter(torch.randn(input_dim, hidden_dim // 2) * s)
self.embedding = lambda x: torch.cat([torch.cos(torch.matmul(x, self.B)),
torch.sin(torch.matmul(x, self.B))], dim=-1)
# (B, D) x (D, H) --> (B, H) --cat--> (B, 2*H)
self.blocks = nn.ModuleList([PirateNetBlock(hidden_dim) for _ in range(num_blocks)])
self.U = nn.Linear(hidden_dim, hidden_dim)
self.V = nn.Linear(hidden_dim, hidden_dim)
# Final layer
self.final_layer = nn.Linear(hidden_dim, output_dim, bias=False)
self.initialize_weights()
[docs]
def initialize_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
module.bias.data.zero_()
[docs]
def forward(self, x):
x = self.embedding(x)
u = self.activation(self.U(x))
v = self.activation(self.V(x))
for block in self.blocks:
x = block(x, u, v)
return self.final_layer(x)
[docs]
def initialize_last_layer(self, Y, input_data):
phi = self.embedding(input_data)
# W = torch.linalg.lstsq(phi, Y).solution
W = torch.linalg.pinv(phi) @ Y
print(W.shape, self.final_layer.weight.data.shape, phi.shape, torch.linalg.pinv(phi).shape)
self.final_layer.weight.data = W.T