tensor-group-sym / python / large_scale / starg_torch / neural.py
neural.py
Raw
"""Neural ★_G in PyTorch.

A feed-forward network whose linear layers are ★_G products with weight
tensors W^(l) ∈ R^(n_{l+1} × n_l × |G|), with ReLU activations on hidden
layers and an invariant pooling on the output. Equivariance is exact by
construction (Algorithm 6, SI Section 12 of the manuscript).
"""

from __future__ import annotations

from typing import List

import torch
import torch.nn as nn

from .algebra import GroupAlgebra
from .product import starg_product


class StarGLinear(nn.Module):
    """Linear layer in the ★_G algebra: y = W ★_G x + b."""

    def __init__(self, in_features: int, out_features: int, G: GroupAlgebra, bias: bool = True):
        super().__init__()
        self.G = G
        # He initialization scaled for the in-features × n_g parameter footprint
        scale = (2.0 / (in_features + out_features)) ** 0.5
        self.weight = nn.Parameter(scale * torch.randn(out_features, in_features, G.n))
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features, 1, G.n))
        else:
            self.register_parameter("bias", None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch, in_features, n_g) → expand to (batch, in_features, 1, n_g)?
        # We operate in the matrix-mimetic ★_G sense: for each sample,
        # treat W as (out, in, n) and x as (in, 1, n), product gives (out, 1, n).
        # Vectorize across batch.
        if x.dim() == 3:
            B = x.shape[0]
            x_b = x.unsqueeze(2)  # (B, in, 1, n_g)
            W_b = self.weight.unsqueeze(0).expand(B, -1, -1, -1)  # (B, out, in, n_g)
            y = starg_product(W_b, x_b, self.G).squeeze(2)  # (B, out, n_g)
        else:
            raise ValueError(f"expected (batch, features, n_g) input, got {x.shape}")
        if self.bias is not None:
            y = y + self.bias.squeeze(1)  # (out, n_g) broadcast
        return y


class NeuralStarG(nn.Module):
    """Feed-forward Neural ★_G with ReLU hidden activations and invariant pooling.

    Forward: input shape (batch, n_in, |G|) → output shape (batch, output_dim).
    """

    def __init__(
        self,
        layer_sizes: List[int],
        G: GroupAlgebra,
        output_dim: int = 1,
    ):
        super().__init__()
        self.G = G
        self.output_dim = output_dim
        layers: List[nn.Module] = []
        for i in range(len(layer_sizes) - 1):
            layers.append(StarGLinear(layer_sizes[i], layer_sizes[i + 1], G))
        self.layers = nn.ModuleList(layers)
        self.head = nn.Linear(layer_sizes[-1], output_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i < len(self.layers) - 1:
                x = torch.relu(x)
        # Invariant pooling: average over the group axis and the feature axis
        invariant = x.mean(dim=(1, 2))  # (batch,)
        # If output_dim > 1, use a small linear head on invariant features
        # constructed from per-channel means.
        if self.output_dim == 1:
            return invariant.unsqueeze(-1)
        chan = x.mean(dim=2)  # (batch, n_features)
        return self.head(chan)