tensor-group-sym / python / NeuralStarGFramework.py
NeuralStarGFramework.py
Raw
"""
NeuralStarGFramework.py - Neural Network Framework Built on ★_G Algebra
LH & SU & Claude 2026

A neural network implementation that uses the Star-G algebra for
equivariant tensor operations.
"""

import numpy as np
from typing import List, Tuple, Optional, Dict, Callable, Any
import time
from dataclasses import dataclass, field


@dataclass
class TrainingHistory:
    """Container for training metrics."""
    train_loss: List[float] = field(default_factory=list)
    val_loss: List[float] = field(default_factory=list)
    train_r2: List[float] = field(default_factory=list)
    val_r2: List[float] = field(default_factory=list)


@dataclass
class Gradients:
    """Container for gradient tensors."""
    weights: List[np.ndarray] = field(default_factory=list)
    biases: List[np.ndarray] = field(default_factory=list)


class NeuralStarGFramework:
    """
    Neural Network Framework Built on ★_G Algebra.
    
    Implements a multi-layer neural network where weight matrices are
    3D tensors and matrix multiplication is replaced by the ★_G product,
    providing group equivariance.
    
    Parameters
    ----------
    G : StarGAlgebra
        The group algebra instance defining the ★_G product
    layer_sizes : list of int
        Sizes of each layer including input and output
    learning_rate : float, optional
        Initial learning rate for Adam optimizer (default: 0.001)
    use_gpu : bool, optional
        Whether to use GPU acceleration (default: False)
    
    Attributes
    ----------
    weights : list of ndarray
        Weight tensors for each layer, shape (out_dim, in_dim, n_group)
    biases : list of ndarray
        Bias tensors for each layer, shape (out_dim, 1, n_group)
    history : TrainingHistory
        Training metrics history
    
    Example
    -------
    >>> from star_g_algebra import StarGAlgebra
    >>> G = StarGAlgebra('cyclic', 8)
    >>> net = NeuralStarGFramework(G, [16, 32, 16, 1], learning_rate=0.001)
    >>> net.train(X_train, Y_train, X_val, Y_val, epochs=100)
    >>> predictions = net.predict(X_test)
    """
    
    def __init__(
        self,
        G: Any,  # StarGAlgebra
        layer_sizes: List[int],
        learning_rate: float = 0.001,
        use_gpu: bool = False
    ):
        self.G = G
        self.layers = layer_sizes
        self.learning_rate = learning_rate
        self.use_gpu = use_gpu
        self.weight_decay = 1e-4
        self.t = 0  # Adam timestep
        
        n_layers = len(layer_sizes) - 1
        
        # Initialize weight and bias tensors
        self.weights: List[np.ndarray] = []
        self.biases: List[np.ndarray] = []
        
        # Adam optimizer moment estimates
        self.m_weights: List[np.ndarray] = []
        self.v_weights: List[np.ndarray] = []
        self.m_biases: List[np.ndarray] = []
        self.v_biases: List[np.ndarray] = []
        
        # Activation functions: ReLU for hidden layers, linear for output
        self.activations: List[Callable] = []
        for l in range(n_layers - 1):
            self.activations.append(lambda x: np.maximum(0, x))
        self.activations.append(lambda x: x)  # Linear for output
        
        # Xavier/Glorot initialization
        for l in range(n_layers):
            fan_in = layer_sizes[l]
            fan_out = layer_sizes[l + 1]
            scale = np.sqrt(2.0 / (fan_in + fan_out))
            
            # Weight tensor: (out_dim, in_dim, n_group)
            W = scale * np.random.randn(fan_out, fan_in, G.n)
            self.weights.append(W)
            
            # Bias tensor: (out_dim, 1, n_group)
            b = np.zeros((fan_out, 1, G.n))
            self.biases.append(b)
            
            # Initialize Adam moments to zeros
            self.m_weights.append(np.zeros_like(W))
            self.v_weights.append(np.zeros_like(W))
            self.m_biases.append(np.zeros_like(b))
            self.v_biases.append(np.zeros_like(b))
        
        # Training history
        self.history = TrainingHistory()
    
    def forward(self, X: np.ndarray) -> Tuple[np.ndarray, List[np.ndarray]]:
        """
        Forward pass through the network.
        
        Parameters
        ----------
        X : ndarray
            Input tensor of shape (batch_size, n_features, n_group) or
            (n_features, n_group) for single sample
        
        Returns
        -------
        output : ndarray
            Network output of shape (batch_size, out_dim, n_group)
        cache : list of ndarray
            Cached activations for each layer (used in backprop)
        """
        n_layers = len(self.weights)
        
        # Handle 2D input (single sample)
        if X.ndim == 2:
            batch_size = 1
            n_feat = X.shape[0]
            X = X.reshape(1, n_feat, self.G.n)
        else:
            batch_size = X.shape[0]
        
        # Cache stores activations for backpropagation
        cache: List[np.ndarray] = [X]
        A = X
        
        for l in range(n_layers):
            W = self.weights[l]
            b = self.biases[l]
            out_dim, in_dim, n_g = W.shape
            
            Z = np.zeros((batch_size, out_dim, n_g))
            
            for i in range(batch_size):
                # Extract sample i
                A_i = A[i, :, :]  # (in_dim, n_g) or needs transpose
                
                # Ensure correct shape
                if A_i.shape[0] != in_dim:
                    A_i = A_i.T
                
                # Reshape for starG: (in_dim, 1, n_g)
                A_i_3d = A_i.reshape(in_dim, 1, n_g)
                
                # Star-G product: W (out_dim, in_dim, n_g) ★ A_i (in_dim, 1, n_g)
                Z_i = self.G.star_g(W, A_i_3d)  # (out_dim, 1, n_g)
                
                # Add bias
                Z_i = Z_i + b  # (out_dim, 1, n_g)
                
                # Store result
                Z[i, :, :] = Z_i.squeeze(axis=1)  # (out_dim, n_g)
            
            # Apply activation function
            A = self.activations[l](Z)
            cache.append(A)
        
        return A, cache
    
    def invariant_pool(self, X: np.ndarray) -> np.ndarray:
        """
        Pool across group dimension to get invariant output.
        
        Computes mean over spatial and group dimensions.
        
        Parameters
        ----------
        X : ndarray of shape (batch_size, out_dim, n_group)
        
        Returns
        -------
        y : ndarray of shape (batch_size,)
        """
        # Mean over group dimension (axis 2), then over output dimension (axis 1)
        y = np.mean(np.mean(X, axis=2), axis=1)
        return y
    
    def predict(self, X: np.ndarray) -> Tuple[np.ndarray, List[np.ndarray]]:
        """
        Make predictions on input data.
        
        Parameters
        ----------
        X : ndarray
            Input tensor
        
        Returns
        -------
        y_pred : ndarray
            Predictions of shape (batch_size,) or scalar
        cache : list of ndarray
            Cached activations
        """
        output, cache = self.forward(X)
        y_pred = self.invariant_pool(output)
        y_pred = y_pred.squeeze()
        return y_pred, cache
    
    def compute_loss(self, y_pred: np.ndarray, y_true: np.ndarray) -> float:
        """
        Compute mean squared error loss.
        
        Parameters
        ----------
        y_pred : ndarray
            Predicted values
        y_true : ndarray
            True values
        
        Returns
        -------
        loss : float
            MSE loss value
        """
        return float(np.mean((y_pred.flatten() - y_true.flatten()) ** 2))
    
    def backward(self, X: np.ndarray, y_true: np.ndarray) -> Gradients:
        """
        Compute gradients using numerical differentiation.
        
        Uses central finite differences with random sampling for efficiency.
        
        Parameters
        ----------
        X : ndarray
            Input batch
        y_true : ndarray
            True target values
        
        Returns
        -------
        grads : Gradients
            Container with weight and bias gradients
        """
        epsilon = 1e-5
        n_layers = len(self.weights)
        
        grads = Gradients()
        grads.weights = [np.zeros_like(W) for W in self.weights]
        grads.biases = [np.zeros_like(b) for b in self.biases]
        
        for l in range(n_layers):
            W_orig = self.weights[l].copy()
            d1, d2, d3 = W_orig.shape
            total_params = d1 * d2 * d3
            
            # Sample subset of parameters for efficiency
            n_sample = min(30, total_params)
            sample_idx = np.random.choice(total_params, n_sample, replace=False)
            
            for idx in sample_idx:
                # Convert flat index to 3D indices
                ii, jj, kk = np.unravel_index(idx, (d1, d2, d3))
                
                # Forward difference
                self.weights[l][ii, jj, kk] = W_orig[ii, jj, kk] + epsilon
                y_plus, _ = self.predict(X)
                loss_plus = self.compute_loss(y_plus, y_true)
                
                # Backward difference
                self.weights[l][ii, jj, kk] = W_orig[ii, jj, kk] - epsilon
                y_minus, _ = self.predict(X)
                loss_minus = self.compute_loss(y_minus, y_true)
                
                # Central difference gradient
                grads.weights[l][ii, jj, kk] = (loss_plus - loss_minus) / (2 * epsilon)
                
                # Restore original weight
                self.weights[l][ii, jj, kk] = W_orig[ii, jj, kk]
            
            # Scale gradient estimate to account for sampling
            grads.weights[l] = grads.weights[l] * (total_params / n_sample)
        
        return grads
    
    def adam_update(self, grads: Gradients):
        """
        Update weights using Adam optimizer.
        
        Parameters
        ----------
        grads : Gradients
            Computed gradients
        """
        beta1 = 0.9
        beta2 = 0.999
        eps = 1e-8
        
        self.t += 1
        
        for l in range(len(self.weights)):
            # Update biased first moment estimate
            self.m_weights[l] = beta1 * self.m_weights[l] + (1 - beta1) * grads.weights[l]
            # Update biased second raw moment estimate
            self.v_weights[l] = beta2 * self.v_weights[l] + (1 - beta2) * (grads.weights[l] ** 2)
            
            # Compute bias-corrected estimates
            m_hat = self.m_weights[l] / (1 - beta1 ** self.t)
            v_hat = self.v_weights[l] / (1 - beta2 ** self.t)
            
            # Update weights
            self.weights[l] = self.weights[l] - self.learning_rate * m_hat / (np.sqrt(v_hat) + eps)
            
            # Apply weight decay
            self.weights[l] = self.weights[l] * (1 - self.weight_decay)
    
    def train(
        self,
        X_train: np.ndarray,
        Y_train: np.ndarray,
        X_val: np.ndarray,
        Y_val: np.ndarray,
        epochs: int = 100,
        batch_size: int = 32,
        verbose: bool = True,
        patience: int = 20
    ) -> 'NeuralStarGFramework':
        """
        Train the neural network.
        
        Parameters
        ----------
        X_train : ndarray of shape (n_train, n_features, n_group)
            Training input data
        Y_train : ndarray of shape (n_train,)
            Training targets
        X_val : ndarray of shape (n_val, n_features, n_group)
            Validation input data
        Y_val : ndarray of shape (n_val,)
            Validation targets
        epochs : int, optional
            Maximum number of training epochs (default: 100)
        batch_size : int, optional
            Mini-batch size (default: 32)
        verbose : bool, optional
            Whether to print training progress (default: True)
        patience : int, optional
            Early stopping patience (default: 20)
        
        Returns
        -------
        self : NeuralStarGFramework
            The trained network
        """
        n_train = X_train.shape[0]
        n_batches = int(np.ceil(n_train / batch_size))
        
        # Reset history
        self.history = TrainingHistory()
        
        best_val_loss = np.inf
        patience_counter = 0
        best_weights = [W.copy() for W in self.weights]
        best_biases = [b.copy() for b in self.biases]
        
        if verbose:
            print(f"\nTraining Neural Star_G Network")
            print(f"Epochs: {epochs}, Batch: {batch_size}, LR: {self.learning_rate:.4f}\n")
        
        for epoch in range(1, epochs + 1):
            start_time = time.time()
            
            # Shuffle training data
            perm = np.random.permutation(n_train)
            epoch_loss = 0.0
            
            for batch in range(n_batches):
                batch_start = batch * batch_size
                batch_end = min((batch + 1) * batch_size, n_train)
                batch_idx = perm[batch_start:batch_end]
                
                X_batch = X_train[batch_idx]
                Y_batch = Y_train[batch_idx]
                
                # Forward pass
                y_pred, _ = self.predict(X_batch)
                batch_loss = self.compute_loss(y_pred, Y_batch)
                epoch_loss += batch_loss
                
                # Backward pass and update
                grads = self.backward(X_batch, Y_batch)
                self.adam_update(grads)
            
            epoch_time = time.time() - start_time
            
            # Record training loss
            train_loss = epoch_loss / n_batches
            self.history.train_loss.append(train_loss)
            
            # Validation metrics
            y_val_pred, _ = self.predict(X_val)
            val_loss = self.compute_loss(y_val_pred, Y_val)
            self.history.val_loss.append(val_loss)
            
            # Compute R² scores
            y_train_pred, _ = self.predict(X_train)
            train_r2 = self._compute_r2(y_train_pred, Y_train)
            val_r2 = self._compute_r2(y_val_pred, Y_val)
            
            self.history.train_r2.append(train_r2)
            self.history.val_r2.append(val_r2)
            
            # Early stopping check
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_weights = [W.copy() for W in self.weights]
                best_biases = [b.copy() for b in self.biases]
                patience_counter = 0
            else:
                patience_counter += 1
            
            # Print progress
            if verbose and epoch % 5 == 0:
                print(f"Epoch {epoch:3d}: Loss={train_loss:.4f}, "
                      f"R2={train_r2:.4f}, Val R2={val_r2:.4f} ({epoch_time:.1f}s)")
            
            # Early stopping
            if patience_counter >= patience:
                if verbose:
                    print(f"Early stopping at epoch {epoch}")
                break
        
        # Restore best weights
        self.weights = best_weights
        self.biases = best_biases
        
        if verbose:
            best_val_r2 = max(self.history.val_r2)
            print(f"\nBest Val R2: {best_val_r2:.4f}\n")
        
        return self
    
    def _compute_r2(self, y_pred: np.ndarray, y_true: np.ndarray) -> float:
        """Compute R² (coefficient of determination)."""
        y_pred = y_pred.flatten()
        y_true = y_true.flatten()
        ss_res = np.sum((y_pred - y_true) ** 2)
        ss_tot = np.sum((y_true - np.mean(y_true)) ** 2) + 1e-10
        return float(1 - ss_res / ss_tot)
    
    def compress_weights(self, rank: int):
        """
        Compress network weights using truncated ★_G-SVD.
        
        Parameters
        ----------
        rank : int
            Target rank for truncation
        """
        print(f"Compressing to rank {rank}...")
        
        for l in range(len(self.weights)):
            W_orig = self.weights[l]
            W_comp = self.G.truncate(W_orig, rank)
            
            orig_norm = np.linalg.norm(W_orig) + 1e-10
            err = np.linalg.norm(W_orig - W_comp) / orig_norm
            
            self.weights[l] = W_comp
            print(f"Layer {l + 1}: {err * 100:.2f}% error")
    
    def count_parameters(self) -> int:
        """Count total number of trainable parameters."""
        total = 0
        for W, b in zip(self.weights, self.biases):
            total += W.size + b.size
        return total
    
    def summary(self):
        """Print network architecture summary."""
        print("\n" + "=" * 50)
        print("Neural Star-G Network Summary")
        print("=" * 50)
        print(f"Group order: {self.G.n}")
        print(f"Layer sizes: {self.layers}")
        print(f"Learning rate: {self.learning_rate}")
        print(f"Weight decay: {self.weight_decay}")
        print("-" * 50)
        
        total_params = 0
        for l, (W, b) in enumerate(zip(self.weights, self.biases)):
            w_params = W.size
            b_params = b.size
            layer_params = w_params + b_params
            total_params += layer_params
            
            print(f"Layer {l + 1}:")
            print(f"  Weight shape: {W.shape}")
            print(f"  Bias shape:   {b.shape}")
            print(f"  Parameters:   {layer_params:,}")
        
        print("-" * 50)
        print(f"Total parameters: {total_params:,}")
        print("=" * 50 + "\n")
    
    def save(self, filepath: str):
        """
        Save network weights to file.
        
        Parameters
        ----------
        filepath : str
            Path to save file (.npz format)
        """
        save_dict = {
            'layers': np.array(self.layers),
            'learning_rate': self.learning_rate,
            'weight_decay': self.weight_decay,
            't': self.t,
        }
        
        # Save weights and biases
        for l in range(len(self.weights)):
            save_dict[f'weight_{l}'] = self.weights[l]
            save_dict[f'bias_{l}'] = self.biases[l]
            save_dict[f'm_weight_{l}'] = self.m_weights[l]
            save_dict[f'v_weight_{l}'] = self.v_weights[l]
            save_dict[f'm_bias_{l}'] = self.m_biases[l]
            save_dict[f'v_bias_{l}'] = self.v_biases[l]
        
        np.savez(filepath, **save_dict)
        print(f"Model saved to {filepath}")
    
    def load(self, filepath: str):
        """
        Load network weights from file.
        
        Parameters
        ----------
        filepath : str
            Path to saved file (.npz format)
        """
        data = np.load(filepath)
        
        self.layers = data['layers'].tolist()
        self.learning_rate = float(data['learning_rate'])
        self.weight_decay = float(data['weight_decay'])
        self.t = int(data['t'])
        
        n_layers = len(self.layers) - 1
        
        self.weights = []
        self.biases = []
        self.m_weights = []
        self.v_weights = []
        self.m_biases = []
        self.v_biases = []
        
        for l in range(n_layers):
            self.weights.append(data[f'weight_{l}'])
            self.biases.append(data[f'bias_{l}'])
            self.m_weights.append(data[f'm_weight_{l}'])
            self.v_weights.append(data[f'v_weight_{l}'])
            self.m_biases.append(data[f'm_bias_{l}'])
            self.v_biases.append(data[f'v_bias_{l}'])
        
        print(f"Model loaded from {filepath}")


# =============================================================================
# Convenience Functions
# =============================================================================

def create_star_g_network(
    G: Any,
    input_dim: int,
    hidden_dims: List[int],
    output_dim: int,
    **kwargs
) -> NeuralStarGFramework:
    """
    Convenience function to create a Star-G neural network.
    
    Parameters
    ----------
    G : StarGAlgebra
        Group algebra instance
    input_dim : int
        Input feature dimension
    hidden_dims : list of int
        Hidden layer dimensions
    output_dim : int
        Output dimension
    **kwargs
        Additional arguments passed to NeuralStarGFramework
    
    Returns
    -------
    net : NeuralStarGFramework
        Initialized network
    """
    layer_sizes = [input_dim] + hidden_dims + [output_dim]
    return NeuralStarGFramework(G, layer_sizes, **kwargs)


# =============================================================================
# Example Usage and Tests
# =============================================================================

if __name__ == "__main__":
    # Import the StarGAlgebra (assuming it's in the same directory)
    try:
        from star_g_algebra import StarGAlgebra
    except ImportError:
        print("StarGAlgebra not found. Creating minimal mock for testing.")
        
        class StarGAlgebra:
            """Minimal mock for testing."""
            def __init__(self, group_type, n):
                self.n = n
                self.is_cyclic = True
            
            def star_g(self, A, B):
                """Simple FFT-based star product for cyclic groups."""
                from numpy.fft import fft, ifft
                Ahat = fft(A, axis=2)
                Bhat = fft(B, axis=2)
                Chat = np.zeros((A.shape[0], B.shape[1], A.shape[2]), dtype=complex)
                for k in range(A.shape[2]):
                    Chat[:, :, k] = Ahat[:, :, k] @ Bhat[:, :, k]
                C = ifft(Chat, axis=2)
                return np.real(C)
            
            def truncate(self, A, k):
                """Simple truncation."""
                return A
    
    print("=" * 60)
    print("NeuralStarGFramework Test Suite")
    print("=" * 60)
    
    # Create group algebra
    n_group = 8
    G = StarGAlgebra('cyclic', n_group)
    print(f"\nGroup order: {G.n}")
    
    # Create network
    input_dim = 16
    hidden_dims = [32, 16]
    output_dim = 1
    
    net = NeuralStarGFramework(
        G,
        [input_dim] + hidden_dims + [output_dim],
        learning_rate=0.001
    )
    
    # Print summary
    net.summary()
    
    # Generate synthetic data
    print("Generating synthetic data...")
    n_train = 100
    n_val = 20
    
    X_train = np.random.randn(n_train, input_dim, n_group)
    Y_train = np.sum(np.mean(X_train, axis=2), axis=1) + np.random.randn(n_train) * 0.1
    
    X_val = np.random.randn(n_val, input_dim, n_group)
    Y_val = np.sum(np.mean(X_val, axis=2), axis=1) + np.random.randn(n_val) * 0.1
    
    print(f"X_train shape: {X_train.shape}")
    print(f"Y_train shape: {Y_train.shape}")
    
    # Test forward pass
    print("\nTesting forward pass...")
    output, cache = net.forward(X_train[:5])
    print(f"Output shape: {output.shape}")
    
    # Test prediction
    print("\nTesting prediction...")
    y_pred, _ = net.predict(X_train[:5])
    print(f"Predictions shape: {y_pred.shape}")
    print(f"Sample predictions: {y_pred[:3]}")
    
    # Test training (short run)
    print("\nTesting training (10 epochs)...")
    net.train(
        X_train, Y_train,
        X_val, Y_val,
        epochs=10,
        batch_size=16,
        verbose=True,
        patience=5
    )
    
    # Test compression
    print("\nTesting weight compression...")
    net.compress_weights(rank=2)
    
    # Test save/load
    print("\nTesting save/load...")
    net.save("test_model.npz")
    
    net2 = NeuralStarGFramework(G, [input_dim] + hidden_dims + [output_dim])
    net2.load("test_model.npz")
    
    # Verify loaded model gives same predictions
    y_pred1, _ = net.predict(X_val[:3])
    y_pred2, _ = net2.predict(X_val[:3])
    
    print(f"Original predictions: {y_pred1}")
    print(f"Loaded predictions:   {y_pred2}")
    print(f"Match: {np.allclose(y_pred1, y_pred2)}")
    
    # Cleanup
    import os
    if os.path.exists("test_model.npz"):
        os.remove("test_model.npz")
    
    print("\n" + "=" * 60)
    print("All tests completed successfully!")
    print("=" * 60)