tensor-group-sym / experiments / starG_methods.m
starG_methods.m
Raw
%% starG_methods.m
%% Methods for _G vs ENN comparison
%% LH & SU & Claude 2026

classdef starG_methods
    methods (Static)
        
        function results = runStarGSVD(X_train, X_val, X_test, Y_train, Y_val, Y_test, G, n_feat, n_rot)
            %RUNSTARGSVD Run _G-SVD + Ridge regression
            
            fprintf(',  Method 1: _G-SVD + Ridge (Exact Invariance) , \n');
            tic;
            
            % Extract features
            [train_feat, norm_params] = extractStarGFeatures(X_train, G, n_feat);
            val_feat = extractStarGFeatures(X_val, G, n_feat, norm_params);
            test_feat = extractStarGFeatures(X_test, G, n_feat, norm_params);
            
            n_extracted = size(train_feat, 2);
            fprintf('  Exact invariant features: %d\n', n_extracted);
            
            % Cross-validation for lambda
            lambdas = [0.001, 0.01, 0.1, 1, 10, 100];
            best_lambda = 0.1;
            best_val_r2 = -inf;
            
            for lam = lambdas
                W_temp = (train_feat' * train_feat + lam * eye(n_extracted)) \ (train_feat' * Y_train);
                Y_val_pred = val_feat * W_temp;
                val_r2 = starG_helpers.computeR2(Y_val_pred, Y_val);
                if val_r2 > best_val_r2
                    best_val_r2 = val_r2;
                    best_lambda = lam;
                end
            end
            
            fprintf('  Best lambda: %.4f (val R² = %.4f)\n', best_lambda, best_val_r2);
            
            % Train final model
            W_svd = (train_feat' * train_feat + best_lambda * eye(n_extracted)) \ (train_feat' * Y_train);
            
            results.train_time = toc;
            results.n_params = n_extracted;
            results.norm_params = norm_params;
            results.W = W_svd;
            
            % Test performance
            Y_pred = test_feat * W_svd;
            results.test_r2 = starG_helpers.computeR2(Y_pred, Y_test);
            
            % Rotation invariance
            n_test = size(X_test, 1);
            rot_r2 = zeros(n_rot, 1);
            
            for g = 1:n_rot
                X_test_rot = circshift(X_test, g-1, 3);
                feat_g = extractStarGFeatures(X_test_rot, G, n_feat, norm_params);
                Y_pred_g = feat_g * W_svd;
                rot_r2(g) = starG_helpers.computeR2(Y_pred_g, Y_test);
            end
            
            results.rotation_r2 = rot_r2';
            results.rotation_std = std(rot_r2);
            
            fprintf('  Test R² = %.4f\n', results.test_r2);
            fprintf('  Rotation σ = %.2e\n', results.rotation_std);
            
            if results.rotation_std < 1e-12
                fprintf('   Exact rotation invariance!\n\n');
            else
                fprintf('  Rotation variance: %.2e\n\n', results.rotation_std);
            end
        end
        
        
        function results = runNeuralStarG(X_train, X_val, X_test, Y_train, Y_val, Y_test, G, n_feat, n_rot, norm_params)
            %RUNNEURALSTARG Run Neural _G (MLP on invariant features)
            
            fprintf(',  Method 2: Neural _G (Exact Invariance) , \n');
            tic;
            
            % Extract features using same normalization
            train_feat = extractStarGFeatures(X_train, G, n_feat, norm_params);
            val_feat = extractStarGFeatures(X_val, G, n_feat, norm_params);
            test_feat = extractStarGFeatures(X_test, G, n_feat, norm_params);
            
            n_input = size(train_feat, 2);
            fprintf('  Input features: %d\n', n_input);
            
            % Network architecture
            hidden_layers = [64, 32];
            layer_sizes = [n_input, hidden_layers, 1];
            
            fprintf('  Architecture: %s\n', mat2str(layer_sizes));
            
            n_params = starG_helpers.countMLPParams(layer_sizes);
            fprintf('  Parameters: %d\n', n_params);
            
            % Train
            fprintf('  Training...\n');
            [W_nn, b_nn, ~] = starG_mlp.train(train_feat, Y_train, val_feat, Y_val, ...
                                              hidden_layers, ...
                                              'epochs', 300, ...
                                              'learningRate', 0.01, ...
                                              'patience', 30, ...
                                              'verbose', true);
            
            results.train_time = toc;
            results.n_params = n_params;
            results.W = W_nn;
            results.b = b_nn;
            
            % Test performance
            Y_pred = starG_mlp.forward(test_feat, W_nn, b_nn);
            results.test_r2 = starG_helpers.computeR2(Y_pred, Y_test);
            
            % Rotation invariance
            rot_r2 = zeros(n_rot, 1);
            
            for g = 1:n_rot
                X_test_rot = circshift(X_test, g-1, 3);
                feat_g = extractStarGFeatures(X_test_rot, G, n_feat, norm_params);
                Y_pred_g = starG_mlp.forward(feat_g, W_nn, b_nn);
                rot_r2(g) = starG_helpers.computeR2(Y_pred_g, Y_test);
            end
            
            results.rotation_r2 = rot_r2';
            results.rotation_std = std(rot_r2);
            
            fprintf('  Test R² = %.4f\n', results.test_r2);
            fprintf('  Rotation σ = %.2e\n\n', results.rotation_std);
        end
        
        
        function results = runStandardMLP(X_train, X_val, X_test, Y_train, Y_val, Y_test, n_feat, n_rot)
            %RUNSTANDARDMLP Run standard MLP (no invariance)
            
            fprintf(',  Method 3: Standard MLP (No Invariance) , \n');
            
            X_train_r1 = squeeze(X_train(:, :, 1));
            X_val_r1 = squeeze(X_val(:, :, 1));
            X_test_r1 = squeeze(X_test(:, :, 1));
            
            tic;
            [W_mlp, b_mlp] = starG_mlp.trainSimple(X_train_r1, Y_train, X_val_r1, Y_val, [64, 32], 100, 0.005, false);
            results.train_time = toc;
            
            results.n_params = starG_helpers.countMLPParams([n_feat, 64, 32, 1]);
            results.W = W_mlp;
            results.b = b_mlp;
            
            Y_pred = starG_mlp.forward(X_test_r1, W_mlp, b_mlp);
            results.test_r2 = starG_helpers.computeR2(Y_pred, Y_test);
            
            rot_r2 = zeros(n_rot, 1);
            for g = 1:n_rot
                X_test_g = squeeze(X_test(:, :, g));
                Y_pred_g = starG_mlp.forward(X_test_g, W_mlp, b_mlp);
                rot_r2(g) = starG_helpers.computeR2(Y_pred_g, Y_test);
            end
            
            results.rotation_r2 = rot_r2';
            results.rotation_std = std(rot_r2);
            
            fprintf('  R² = %.4f, σ = %.4f, params = %d\n\n', ...
                    results.test_r2, results.rotation_std, results.n_params);
        end
        
        
        function results = runInvariantMLP(X_train, X_val, X_test, Y_train, Y_val, Y_test, n_rot)
            %RUNINVARIANTMLP Run MLP on hand-crafted invariant features
            
            fprintf(',  Method 4: Invariant MLP (ENN-style) , \n');
            
            X_train_inv = starG_helpers.computeInvariantFeatures(X_train);
            X_val_inv = starG_helpers.computeInvariantFeatures(X_val);
            X_test_inv = starG_helpers.computeInvariantFeatures(X_test);
            
            n_inv_feat = size(X_train_inv, 2);
            
            tic;
            [W_inv, b_inv] = starG_mlp.trainSimple(X_train_inv, Y_train, X_val_inv, Y_val, [64, 32], 100, 0.005, false);
            results.train_time = toc;
            
            results.n_params = starG_helpers.countMLPParams([n_inv_feat, 64, 32, 1]);
            results.W = W_inv;
            results.b = b_inv;
            
            Y_pred = starG_mlp.forward(X_test_inv, W_inv, b_inv);
            results.test_r2 = starG_helpers.computeR2(Y_pred, Y_test);
            
            rot_r2 = zeros(n_rot, 1);
            for g = 1:n_rot
                X_test_rot = circshift(X_test, g-1, 3);
                X_test_inv_g = starG_helpers.computeInvariantFeatures(X_test_rot);
                Y_pred_g = starG_mlp.forward(X_test_inv_g, W_inv, b_inv);
                rot_r2(g) = starG_helpers.computeR2(Y_pred_g, Y_test);
            end
            
            results.rotation_r2 = rot_r2';
            results.rotation_std = std(rot_r2);
            
            fprintf('  R² = %.4f, σ = %.2e, params = %d\n\n', ...
                    results.test_r2, results.rotation_std, results.n_params);
        end
        
        
        function results = runAugmentedMLP(X_train, X_val, X_test, Y_train, Y_val, Y_test, n_feat, n_rot)
            %RUNAUGMENTEDMLP Run MLP with data augmentation
            
            fprintf(',  Method 5: Augmented MLP , \n');
            
            n_train = size(X_train, 1);
            
            X_train_aug = zeros(n_train * n_rot, n_feat);
            Y_train_aug = zeros(n_train * n_rot, 1);
            for i = 1:n_train
                for g = 1:n_rot
                    idx = (i-1) * n_rot + g;
                    X_train_aug(idx, :) = X_train(i, :, g);
                    Y_train_aug(idx) = Y_train(i);
                end
            end
            
            X_val_r1 = squeeze(X_val(:, :, 1));
            X_test_r1 = squeeze(X_test(:, :, 1));
            
            tic;
            [W_aug, b_aug] = starG_mlp.trainSimple(X_train_aug, Y_train_aug, X_val_r1, Y_val, [64, 32], 80, 0.005, false);
            results.train_time = toc;
            
            results.n_params = starG_helpers.countMLPParams([n_feat, 64, 32, 1]);
            results.W = W_aug;
            results.b = b_aug;
            
            Y_pred = starG_mlp.forward(X_test_r1, W_aug, b_aug);
            results.test_r2 = starG_helpers.computeR2(Y_pred, Y_test);
            
            rot_r2 = zeros(n_rot, 1);
            for g = 1:n_rot
                X_test_g = squeeze(X_test(:, :, g));
                Y_pred_g = starG_mlp.forward(X_test_g, W_aug, b_aug);
                rot_r2(g) = starG_helpers.computeR2(Y_pred_g, Y_test);
            end
            
            results.rotation_r2 = rot_r2';
            results.rotation_std = std(rot_r2);
            
            fprintf('  R² = %.4f, σ = %.4f, params = %d\n\n', ...
                    results.test_r2, results.rotation_std, results.n_params);
        end
        
    end
end