tensor-group-sym / python / starG_helpers.py
starG_helpers.py
Raw
"""
starG_helpers.py
Utility functions for ★_G comparison experiments
LH & SU & Claude 2026
"""

import numpy as np
from scipy.spatial.distance import pdist
from typing import List, Tuple, Union


class StarGHelpers:
    """
    Utility functions for ★_G comparison experiments.
    
    All methods are static and can be called without instantiation.
    """
    
    @staticmethod
    def compute_r2(y_pred: np.ndarray, y_true: np.ndarray) -> float:
        """
        Compute R² (coefficient of determination) score.
        
        Parameters
        ----------
        y_pred : array-like
            Predicted values
        y_true : array-like
            True values
            
        Returns
        -------
        r2 : float
            R² score (1.0 is perfect prediction)
        """
        y_pred = np.asarray(y_pred).flatten()
        y_true = np.asarray(y_true).flatten()
        
        ss_res = np.sum((y_pred - y_true) ** 2)
        ss_tot = np.sum((y_true - np.mean(y_true)) ** 2) + 1e-10
        
        r2 = 1 - ss_res / ss_tot
        return r2
    
    @staticmethod
    def count_mlp_params(layers: List[int]) -> int:
        """
        Count total parameters in a Multi-Layer Perceptron.
        
        Parameters
        ----------
        layers : list of int
            Number of neurons in each layer (including input and output)
            
        Returns
        -------
        n_params : int
            Total number of trainable parameters (weights + biases)
        
        Example
        -------
        >>> StarGHelpers.count_mlp_params([784, 256, 128, 10])
        235146  # (784*256 + 256) + (256*128 + 128) + (128*10 + 10)
        """
        n_params = 0
        for l in range(len(layers) - 1):
            # Weights: layers[l] * layers[l+1]
            # Biases: layers[l+1]
            n_params += layers[l] * layers[l + 1] + layers[l + 1]
        return n_params
    
    @staticmethod
    def compute_invariant_features(X: np.ndarray) -> np.ndarray:
        """
        Compute rotation-invariant features from 3D tensor.
        
        Computes statistics (mean, std, min, max) across the group dimension
        (axis 2) to create rotation-invariant representations.
        
        Parameters
        ----------
        X : ndarray of shape (n_samples, n_features, n_rotations)
            Input tensor with group structure
            
        Returns
        -------
        X_inv : ndarray of shape (n_samples, 4 * n_features)
            Invariant features: [mean, std, min, max] concatenated
        """
        X_mean = np.mean(X, axis=2)
        X_std = np.std(X, axis=2, ddof=0)  # ddof=0 matches MATLAB's default
        X_min = np.min(X, axis=2)
        X_max = np.max(X, axis=2)
        
        X_inv = np.concatenate([X_mean, X_std, X_min, X_max], axis=1)
        return X_inv
    
    @staticmethod
    def generate_molecular_data(
        n_samples: int, 
        n_feat: int, 
        n_rot: int
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Generate synthetic molecular data with rotational symmetry.
        
        Creates molecules with random atom positions and atomic numbers,
        then generates features under different rotations around the z-axis.
        
        Parameters
        ----------
        n_samples : int
            Number of molecules to generate
        n_feat : int
            Number of features per rotation
        n_rot : int
            Number of rotations (group order)
            
        Returns
        -------
        X : ndarray of shape (n_samples, n_feat, n_rot)
            Feature tensor with rotational structure
        Y : ndarray of shape (n_samples, 1)
            Target values (rotation-invariant molecular property)
        
        Notes
        -----
        The target Y is computed from rotation-invariant quantities:
        - Mean pairwise distance between atoms
        - Standard deviation of pairwise distances
        - Sum of squared atomic numbers
        
        Features include:
        - Weighted atomic positions (x, y, z)
        - Weighted radial distances
        - Angular harmonics with Gaussian decay
        """
        X = np.zeros((n_samples, n_feat, n_rot))
        Y = np.zeros((n_samples, 1))
        
        for i in range(n_samples):
            # Generate random molecule
            n_atoms = np.random.randint(4, 11)  # [4, 10] inclusive
            pos = np.random.randn(n_atoms, 3) * 2
            pos = pos - np.mean(pos, axis=0)  # Center at origin
            
            # Random atomic numbers, biased toward carbon (6)
            Z = np.random.randint(1, 10, n_atoms)  # [1, 9] inclusive
            Z[(Z > 1) & (Z < 6)] = 6  # Replace 2,3,4,5 with 6 (carbon)
            
            # Compute rotation-invariant target
            dists = pdist(pos)
            if len(dists) == 0:
                dists = np.array([1.0])
            Y[i] = np.mean(dists) + 0.3 * np.std(dists) + np.sum(Z ** 2) / 500
            
            # Generate features for each rotation
            for g in range(n_rot):
                theta = 2 * np.pi * g / n_rot
                
                # Rotation matrix around z-axis
                R = np.array([
                    [np.cos(theta), -np.sin(theta), 0],
                    [np.sin(theta),  np.cos(theta), 0],
                    [0,              0,             1]
                ])
                
                # Rotate positions
                pos_rot = (R @ pos.T).T
                
                # Compute features
                feat = np.zeros(n_feat)
                
                for a in range(n_atoms):
                    x, y, z = pos_rot[a]
                    r = np.linalg.norm(pos_rot[a])
                    
                    # Basic weighted features
                    feat[0] += Z[a] * x
                    feat[1] += Z[a] * y
                    feat[2] += Z[a] * z
                    feat[3] += Z[a] * r
                    feat[4] += Z[a] * r ** 2
                    
                    # Angular harmonics with Gaussian decay
                    for f in range(5, n_feat):
                        angular_freq = f - 4  # 1, 2, 3, ...
                        angle = np.arctan2(y, x)
                        feat[f] += Z[a] * np.cos(angular_freq * angle) * np.exp(-r ** 2 / 8)
                
                X[i, :, g] = feat
        
        return X, Y
    
    @staticmethod
    def roundp(x: Union[float, np.ndarray], p: int) -> Union[float, np.ndarray]:
        """
        Round to specified number of decimal places.
        
        Parameters
        ----------
        x : float or ndarray
            Value(s) to round
        p : int
            Number of decimal places
            
        Returns
        -------
        rounded : float or ndarray
            Rounded value(s)
            
        Example
        -------
        >>> StarGHelpers.roundp(3.14159, 2)
        3.14
        >>> StarGHelpers.roundp(np.array([1.234, 5.678]), 1)
        array([1.2, 5.7])
        """
        scale = 10 ** p
        return np.round(x * scale) / scale


# =============================================================================
# Convenience functions (for direct import)
# =============================================================================

def compute_r2(y_pred, y_true):
    """Alias for StarGHelpers.compute_r2"""
    return StarGHelpers.compute_r2(y_pred, y_true)


def count_mlp_params(layers):
    """Alias for StarGHelpers.count_mlp_params"""
    return StarGHelpers.count_mlp_params(layers)


def compute_invariant_features(X):
    """Alias for StarGHelpers.compute_invariant_features"""
    return StarGHelpers.compute_invariant_features(X)


def generate_molecular_data(n_samples, n_feat, n_rot):
    """Alias for StarGHelpers.generate_molecular_data"""
    return StarGHelpers.generate_molecular_data(n_samples, n_feat, n_rot)


def roundp(x, p):
    """Alias for StarGHelpers.roundp"""
    return StarGHelpers.roundp(x, p)


# =============================================================================
# Example Usage and Tests
# =============================================================================

if __name__ == "__main__":
    print("=" * 60)
    print("StarGHelpers Test Suite")
    print("=" * 60)
    
    # Test compute_r2
    print("\n1. Testing compute_r2:")
    y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
    y_pred_perfect = y_true.copy()
    y_pred_good = y_true + np.random.randn(5) * 0.1
    y_pred_bad = np.ones(5) * np.mean(y_true)
    
    print(f"   Perfect prediction R²: {compute_r2(y_pred_perfect, y_true):.4f}")
    print(f"   Good prediction R²:    {compute_r2(y_pred_good, y_true):.4f}")
    print(f"   Mean prediction R²:    {compute_r2(y_pred_bad, y_true):.4f}")
    
    # Test count_mlp_params
    print("\n2. Testing count_mlp_params:")
    architectures = [
        [10, 20, 1],
        [784, 256, 128, 10],
        [100, 64, 32, 16, 1]
    ]
    for arch in architectures:
        n_params = count_mlp_params(arch)
        print(f"   Architecture {arch}: {n_params:,} parameters")
    
    # Test compute_invariant_features
    print("\n3. Testing compute_invariant_features:")
    X_test = np.random.randn(5, 8, 4)  # 5 samples, 8 features, 4 rotations
    X_inv = compute_invariant_features(X_test)
    print(f"   Input shape:  {X_test.shape}")
    print(f"   Output shape: {X_inv.shape}")
    print(f"   Expected:     ({X_test.shape[0]}, {4 * X_test.shape[1]})")
    
    # Test generate_molecular_data
    print("\n4. Testing generate_molecular_data:")
    n_samples, n_feat, n_rot = 100, 16, 8
    X, Y = generate_molecular_data(n_samples, n_feat, n_rot)
    print(f"   Generated X shape: {X.shape}")
    print(f"   Generated Y shape: {Y.shape}")
    print(f"   Y statistics: min={Y.min():.3f}, max={Y.max():.3f}, mean={Y.mean():.3f}")
    
    # Verify rotational equivariance of features
    print("\n   Checking rotational structure:")
    # For cyclic features, check that feature magnitude varies smoothly
    sample_idx = 0
    feat_magnitudes = np.linalg.norm(X[sample_idx], axis=0)
    print(f"   Feature magnitudes across rotations: {feat_magnitudes}")
    
    # Test roundp
    print("\n5. Testing roundp:")
    test_values = [3.14159265, 2.71828182, 1.41421356]
    for val in test_values:
        print(f"   {val} -> p=2: {roundp(val, 2)}, p=4: {roundp(val, 4)}")
    
    # Test with array
    arr = np.array([1.234567, 8.901234])
    print(f"   Array {arr} -> p=3: {roundp(arr, 3)}")
    
    print("\n" + "=" * 60)
    print("All tests completed successfully!")
    print("=" * 60)