"""Neural ★_G in PyTorch.
A feed-forward network whose linear layers are ★_G products with weight
tensors W^(l) ∈ R^(n_{l+1} × n_l × |G|), with ReLU activations on hidden
layers and an invariant pooling on the output. Equivariance is exact by
construction (Algorithm 6, SI Section 12 of the manuscript).
"""
from __future__ import annotations
from typing import List
import torch
import torch.nn as nn
from .algebra import GroupAlgebra
from .product import starg_product
class StarGLinear(nn.Module):
"""Linear layer in the ★_G algebra: y = W ★_G x + b."""
def __init__(self, in_features: int, out_features: int, G: GroupAlgebra, bias: bool = True):
super().__init__()
self.G = G
# He initialization scaled for the in-features × n_g parameter footprint
scale = (2.0 / (in_features + out_features)) ** 0.5
self.weight = nn.Parameter(scale * torch.randn(out_features, in_features, G.n))
if bias:
self.bias = nn.Parameter(torch.zeros(out_features, 1, G.n))
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (batch, in_features, n_g) → expand to (batch, in_features, 1, n_g)?
# We operate in the matrix-mimetic ★_G sense: for each sample,
# treat W as (out, in, n) and x as (in, 1, n), product gives (out, 1, n).
# Vectorize across batch.
if x.dim() == 3:
B = x.shape[0]
x_b = x.unsqueeze(2) # (B, in, 1, n_g)
W_b = self.weight.unsqueeze(0).expand(B, -1, -1, -1) # (B, out, in, n_g)
y = starg_product(W_b, x_b, self.G).squeeze(2) # (B, out, n_g)
else:
raise ValueError(f"expected (batch, features, n_g) input, got {x.shape}")
if self.bias is not None:
y = y + self.bias.squeeze(1) # (out, n_g) broadcast
return y
class NeuralStarG(nn.Module):
"""Feed-forward Neural ★_G with ReLU hidden activations and invariant pooling.
Forward: input shape (batch, n_in, |G|) → output shape (batch, output_dim).
"""
def __init__(
self,
layer_sizes: List[int],
G: GroupAlgebra,
output_dim: int = 1,
):
super().__init__()
self.G = G
self.output_dim = output_dim
layers: List[nn.Module] = []
for i in range(len(layer_sizes) - 1):
layers.append(StarGLinear(layer_sizes[i], layer_sizes[i + 1], G))
self.layers = nn.ModuleList(layers)
self.head = nn.Linear(layer_sizes[-1], output_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for i, layer in enumerate(self.layers):
x = layer(x)
if i < len(self.layers) - 1:
x = torch.relu(x)
# Invariant pooling: average over the group axis and the feature axis
invariant = x.mean(dim=(1, 2)) # (batch,)
# If output_dim > 1, use a small linear head on invariant features
# constructed from per-channel means.
if self.output_dim == 1:
return invariant.unsqueeze(-1)
chan = x.mean(dim=2) # (batch, n_features)
return self.head(chan)