"""Neural ★_G in PyTorch. A feed-forward network whose linear layers are ★_G products with weight tensors W^(l) ∈ R^(n_{l+1} × n_l × |G|), with ReLU activations on hidden layers and an invariant pooling on the output. Equivariance is exact by construction (Algorithm 6, SI Section 12 of the manuscript). """ from __future__ import annotations from typing import List import torch import torch.nn as nn from .algebra import GroupAlgebra from .product import starg_product class StarGLinear(nn.Module): """Linear layer in the ★_G algebra: y = W ★_G x + b.""" def __init__(self, in_features: int, out_features: int, G: GroupAlgebra, bias: bool = True): super().__init__() self.G = G # He initialization scaled for the in-features × n_g parameter footprint scale = (2.0 / (in_features + out_features)) ** 0.5 self.weight = nn.Parameter(scale * torch.randn(out_features, in_features, G.n)) if bias: self.bias = nn.Parameter(torch.zeros(out_features, 1, G.n)) else: self.register_parameter("bias", None) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (batch, in_features, n_g) → expand to (batch, in_features, 1, n_g)? # We operate in the matrix-mimetic ★_G sense: for each sample, # treat W as (out, in, n) and x as (in, 1, n), product gives (out, 1, n). # Vectorize across batch. if x.dim() == 3: B = x.shape[0] x_b = x.unsqueeze(2) # (B, in, 1, n_g) W_b = self.weight.unsqueeze(0).expand(B, -1, -1, -1) # (B, out, in, n_g) y = starg_product(W_b, x_b, self.G).squeeze(2) # (B, out, n_g) else: raise ValueError(f"expected (batch, features, n_g) input, got {x.shape}") if self.bias is not None: y = y + self.bias.squeeze(1) # (out, n_g) broadcast return y class NeuralStarG(nn.Module): """Feed-forward Neural ★_G with ReLU hidden activations and invariant pooling. Forward: input shape (batch, n_in, |G|) → output shape (batch, output_dim). """ def __init__( self, layer_sizes: List[int], G: GroupAlgebra, output_dim: int = 1, ): super().__init__() self.G = G self.output_dim = output_dim layers: List[nn.Module] = [] for i in range(len(layer_sizes) - 1): layers.append(StarGLinear(layer_sizes[i], layer_sizes[i + 1], G)) self.layers = nn.ModuleList(layers) self.head = nn.Linear(layer_sizes[-1], output_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: for i, layer in enumerate(self.layers): x = layer(x) if i < len(self.layers) - 1: x = torch.relu(x) # Invariant pooling: average over the group axis and the feature axis invariant = x.mean(dim=(1, 2)) # (batch,) # If output_dim > 1, use a small linear head on invariant features # constructed from per-channel means. if self.output_dim == 1: return invariant.unsqueeze(-1) chan = x.mean(dim=2) # (batch, n_features) return self.head(chan)