tensor-group-sym / experiments / starG_mlp.m
starG_mlp.m
Raw
%% starG_mlp.m
%% MLP functions for _G comparison experiments
%% LH & SU & Claude 2026

classdef starG_mlp
    methods (Static)
        
        function [W, b, history] = train(X_train, Y_train, X_val, Y_val, hidden_layers, varargin)
            %TRAIN Train MLP with Adam optimizer
            %
            %   [W, b, history] = starG_mlp.train(X_train, Y_train, X_val, Y_val, hidden_layers, ...)
            %
            %   Optional parameters:
            %       'epochs'        - Maximum epochs (default: 200)
            %       'learningRate'  - Initial learning rate (default: 0.01)
            %       'batchSize'     - Mini-batch size (default: 32)
            %       'patience'      - Early stopping patience (default: 20)
            %       'verbose'       - Print progress (default: false)
            
            p = inputParser;
            addParameter(p, 'epochs', 200);
            addParameter(p, 'learningRate', 0.01);
            addParameter(p, 'batchSize', 32);
            addParameter(p, 'patience', 20);
            addParameter(p, 'verbose', false);
            parse(p, varargin{:});
            
            epochs = p.Results.epochs;
            lr = p.Results.learningRate;
            batch_size = p.Results.batchSize;
            patience = p.Results.patience;
            verbose = p.Results.verbose;
            
            Y_train = Y_train(:);
            Y_val = Y_val(:);
            
            n_input = size(X_train, 2);
            n_output = 1;
            layer_sizes = [n_input, hidden_layers, n_output];
            n_layers = length(layer_sizes) - 1;
            
            % Initialize weights (He initialization)
            W = cell(n_layers, 1);
            b = cell(n_layers, 1);
            for l = 1:n_layers
                fan_in = layer_sizes(l);
                fan_out = layer_sizes(l+1);
                W{l} = randn(fan_out, fan_in) * sqrt(2 / fan_in);
                b{l} = zeros(fan_out, 1);
            end
            
            history = struct();
            history.train_loss = zeros(epochs, 1);
            history.val_loss = zeros(epochs, 1);
            
            best_val_loss = inf;
            best_W = W;
            best_b = b;
            wait = 0;
            
            n_train = size(X_train, 1);
            n_batches = ceil(n_train / batch_size);
            
            % Adam state
            m_W = cell(n_layers, 1);
            v_W = cell(n_layers, 1);
            m_b = cell(n_layers, 1);
            v_b = cell(n_layers, 1);
            for l = 1:n_layers
                m_W{l} = zeros(size(W{l}));
                v_W{l} = zeros(size(W{l}));
                m_b{l} = zeros(size(b{l}));
                v_b{l} = zeros(size(b{l}));
            end
            
            beta1 = 0.9;
            beta2 = 0.999;
            eps = 1e-8;
            t = 0;
            
            final_epoch = epochs;
            
            for epoch = 1:epochs
                perm = randperm(n_train);
                epoch_loss = 0;
                
                for batch = 1:n_batches
                    t = t + 1;
                    
                    batch_start = (batch - 1) * batch_size + 1;
                    batch_end = min(batch * batch_size, n_train);
                    batch_idx = perm(batch_start:batch_end);
                    
                    X_batch = X_train(batch_idx, :);
                    Y_batch = Y_train(batch_idx);
                    n_batch = length(batch_idx);
                    
                    % Forward pass
                    A = cell(n_layers + 1, 1);
                    Z = cell(n_layers, 1);
                    A{1} = X_batch';
                    
                    for l = 1:n_layers
                        Z{l} = W{l} * A{l} + b{l};
                        if l < n_layers
                            A{l+1} = max(0, Z{l});  % ReLU
                        else
                            A{l+1} = Z{l};  % Linear
                        end
                    end
                    
                    Y_pred_batch = A{n_layers + 1}';
                    batch_loss = mean((Y_pred_batch - Y_batch).^2);
                    epoch_loss = epoch_loss + batch_loss;
                    
                    % Backward pass
                    dA = 2 * (A{n_layers + 1} - Y_batch') / n_batch;
                    
                    for l = n_layers:-1:1
                        if l < n_layers
                            dZ = dA .* (Z{l} > 0);
                        else
                            dZ = dA;
                        end
                        
                        dW = dZ * A{l}' / n_batch;
                        db = mean(dZ, 2);
                        
                        if l > 1
                            dA = W{l}' * dZ;
                        end
                        
                        % Adam update
                        m_W{l} = beta1 * m_W{l} + (1 - beta1) * dW;
                        v_W{l} = beta2 * v_W{l} + (1 - beta2) * (dW.^2);
                        m_b{l} = beta1 * m_b{l} + (1 - beta1) * db;
                        v_b{l} = beta2 * v_b{l} + (1 - beta2) * (db.^2);
                        
                        m_W_hat = m_W{l} / (1 - beta1^t);
                        v_W_hat = v_W{l} / (1 - beta2^t);
                        m_b_hat = m_b{l} / (1 - beta1^t);
                        v_b_hat = v_b{l} / (1 - beta2^t);
                        
                        W{l} = W{l} - lr * m_W_hat ./ (sqrt(v_W_hat) + eps);
                        b{l} = b{l} - lr * m_b_hat ./ (sqrt(v_b_hat) + eps);
                    end
                end
                
                history.train_loss(epoch) = epoch_loss / n_batches;
                
                Y_val_pred = starG_mlp.forward(X_val, W, b);
                history.val_loss(epoch) = mean((Y_val_pred - Y_val).^2);
                
                if history.val_loss(epoch) < best_val_loss
                    best_val_loss = history.val_loss(epoch);
                    best_W = W;
                    best_b = b;
                    wait = 0;
                else
                    wait = wait + 1;
                end
                
                if verbose && mod(epoch, 20) == 0
                    Y_train_pred = starG_mlp.forward(X_train, W, b);
                    train_r2 = starG_helpers.computeR2(Y_train_pred, Y_train);
                    val_r2 = starG_helpers.computeR2(Y_val_pred, Y_val);
                    fprintf('    Epoch %3d: loss=%.4f, val_loss=%.4f, R2=%.3f, val_R2=%.3f\n', ...
                            epoch, history.train_loss(epoch), history.val_loss(epoch), train_r2, val_r2);
                end
                
                if wait >= patience
                    if verbose
                        fprintf('    Early stopping at epoch %d\n', epoch);
                    end
                    final_epoch = epoch;
                    break;
                end
                final_epoch = epoch;
            end
            
            W = best_W;
            b = best_b;
            history.train_loss = history.train_loss(1:final_epoch);
            history.val_loss = history.val_loss(1:final_epoch);
        end
        
        
        function Y = forward(X, W, b)
            %FORWARD Forward pass through MLP
            A = X';
            n_layers = length(W);
            
            for l = 1:n_layers
                Z = W{l} * A + b{l};
                if l < n_layers
                    A = max(0, Z);  % ReLU
                else
                    A = Z;  % Linear
                end
            end
            
            Y = A';
        end
        
        
        function [W, b] = trainSimple(X_train, Y_train, X_val, Y_val, hidden, epochs, lr, verbose)
            %TRAINSIMPLE Simple MLP training (legacy interface)
            if nargin < 8, verbose = false; end
            
            layers = [size(X_train, 2), hidden, 1];
            nL = length(layers) - 1;
            
            W = cell(nL, 1);
            b = cell(nL, 1);
            for l = 1:nL
                W{l} = randn(layers(l+1), layers(l)) * sqrt(2 / layers(l));
                b{l} = zeros(layers(l+1), 1);
            end
            
            n = size(X_train, 1);
            bs = min(64, n);
            best_loss = inf;
            best_W = W;
            best_b = b;
            wait = 0;
            patience = 15;
            
            Y_train = Y_train(:);
            Y_val = Y_val(:);
            
            for epoch = 1:epochs
                perm = randperm(n);
                for i = 1:bs:n
                    idx = perm(i:min(i+bs-1, n));
                    Xb = X_train(idx, :);
                    Yb = Y_train(idx);
                    
                    A = Xb';
                    As = {A};
                    for l = 1:nL
                        Z = W{l} * A + b{l};
                        if l < nL
                            A = max(0, Z);
                        else
                            A = Z;
                        end
                        As{end+1} = A;
                    end
                    
                    dA = 2 * (A - Yb') / length(Yb);
                    for l = nL:-1:1
                        if l < nL
                            dA = dA .* (As{l+1} > 0);
                        end
                        dW = dA * As{l}' / size(dA, 2);
                        db = mean(dA, 2);
                        W{l} = W{l} - lr * dW;
                        b{l} = b{l} - lr * db;
                        if l > 1
                            dA = W{l}' * dA;
                        end
                    end
                end
                
                Yp = starG_mlp.forward(X_val, W, b);
                val_loss = mean((Yp - Y_val).^2);
                if val_loss < best_loss
                    best_loss = val_loss;
                    best_W = W;
                    best_b = b;
                    wait = 0;
                else
                    wait = wait + 1;
                end
                if wait >= patience
                    break;
                end
                
                if verbose && mod(epoch, 20) == 0
                    fprintf('  Epoch %d: val_loss = %.4f\n', epoch, val_loss);
                end
            end
            
            W = best_W;
            b = best_b;
        end
        
    end
end