tensor-group-sym / core / NeuralStarGFramework.m
NeuralStarGFramework.m
Raw

%% NeuralStarGFramework.m - CLEAN FIXED VERSION
%% Neural Network Framework Built on _G Algebra
%% LH & SU & Claude
%% ============================================================================

classdef NeuralStarGFramework < handle
    properties
        G
        layers
        weights
        biases
        activations
        useGPU
        learningRate
        weightDecay
        m_weights
        v_weights
        m_biases
        v_biases
        t
        trainLoss
        valLoss
        trainR2
        valR2
    end
    
    methods
        function obj = NeuralStarGFramework(G, layerSizes, varargin)
            % Parse optional arguments
            p = inputParser;
            addParameter(p, 'learningRate', 0.001);
            addParameter(p, 'useGPU', false);
            parse(p, varargin{:});
            
            obj.G = G;
            obj.layers = layerSizes;
            obj.learningRate = p.Results.learningRate;
            obj.useGPU = p.Results.useGPU;
            obj.weightDecay = 1e-4;
            obj.t = 0;
            
            nLayers = length(layerSizes) - 1;
            obj.weights = cell(nLayers, 1);
            obj.biases = cell(nLayers, 1);
            obj.m_weights = cell(nLayers, 1);
            obj.v_weights = cell(nLayers, 1);
            obj.m_biases = cell(nLayers, 1);
            obj.v_biases = cell(nLayers, 1);
            
            % Activations: ReLU for hidden, linear for output
            obj.activations = cell(nLayers, 1);
            for l = 1:nLayers-1
                obj.activations{l} = @(x) max(0, x);
            end
            obj.activations{nLayers} = @(x) x;
            
            % Initialize weights
            for l = 1:nLayers
                fan_in = layerSizes(l);
                fan_out = layerSizes(l+1);
                scale = sqrt(2 / (fan_in + fan_out));
                
                obj.weights{l} = scale * randn(fan_out, fan_in, G.n);
                obj.biases{l} = zeros(fan_out, 1, G.n);
                
                obj.m_weights{l} = zeros(size(obj.weights{l}));
                obj.v_weights{l} = zeros(size(obj.weights{l}));
                obj.m_biases{l} = zeros(size(obj.biases{l}));
                obj.v_biases{l} = zeros(size(obj.biases{l}));
            end
        end
        
        function [output, cache] = forward(obj, X)
            nLayers = length(obj.weights);
            
            % Handle dimensions
            if ndims(X) == 2
                batch_size = 1;
                n_feat = size(X, 1);
                X = reshape(X, [1, n_feat, obj.G.n]);
            else
                batch_size = size(X, 1);
            end
            
            cache = cell(nLayers + 1, 1);
            cache{1} = X;
            A = X;
            
            for l = 1:nLayers
                W = obj.weights{l};
                b = obj.biases{l};
                [out_dim, in_dim, n_g] = size(W);
                
                Z = zeros(batch_size, out_dim, n_g);
                
                for i = 1:batch_size
                    A_i = squeeze(A(i, :, :));
                    if size(A_i, 1) ~= in_dim
                        A_i = A_i';
                    end
                    
                    A_i_3d = reshape(A_i, [in_dim, 1, n_g]);
                    Z_i = obj.G.starG(W, A_i_3d);
                    Z_i = Z_i + b;
                    Z(i, :, :) = squeeze(Z_i);
                end
                
                A = obj.activations{l}(Z);
                cache{l + 1} = A;
            end
            
            output = A;
        end
        
        function y = invariantPool(obj, X)
            y = mean(mean(X, 3), 2);
        end
        
        function [y_pred, cache] = predict(obj, X)
            [output, cache] = obj.forward(X);
            y_pred = obj.invariantPool(output);
            y_pred = squeeze(y_pred);
        end
        
        function loss = computeLoss(obj, y_pred, y_true)
            loss = mean((y_pred(:) - y_true(:)).^2);
        end
        
        function grads = backward(obj, X, y_true)
            epsilon = 1e-5;
            nLayers = length(obj.weights);
            grads.weights = cell(nLayers, 1);
            grads.biases = cell(nLayers, 1);
            
            for l = 1:nLayers
                grads.weights{l} = zeros(size(obj.weights{l}));
                grads.biases{l} = zeros(size(obj.biases{l}));
                
                W_orig = obj.weights{l};
                [d1, d2, d3] = size(W_orig);
                
                n_sample = min(30, d1 * d2 * d3);
                sample_idx = randperm(d1 * d2 * d3, n_sample);
                
                for idx = sample_idx
                    [ii, jj, kk] = ind2sub([d1, d2, d3], idx);
                    
                    obj.weights{l}(ii, jj, kk) = W_orig(ii, jj, kk) + epsilon;
                    y_plus = obj.predict(X);
                    loss_plus = obj.computeLoss(y_plus, y_true);
                    
                    obj.weights{l}(ii, jj, kk) = W_orig(ii, jj, kk) - epsilon;
                    y_minus = obj.predict(X);
                    loss_minus = obj.computeLoss(y_minus, y_true);
                    
                    grads.weights{l}(ii, jj, kk) = (loss_plus - loss_minus) / (2 * epsilon);
                    obj.weights{l}(ii, jj, kk) = W_orig(ii, jj, kk);
                end
                
                grads.weights{l} = grads.weights{l} * (d1 * d2 * d3 / n_sample);
            end
        end
        
        function obj = adamUpdate(obj, grads)
            beta1 = 0.9;
            beta2 = 0.999;
            eps = 1e-8;
            
            obj.t = obj.t + 1;
            
            for l = 1:length(obj.weights)
                obj.m_weights{l} = beta1 * obj.m_weights{l} + (1 - beta1) * grads.weights{l};
                obj.v_weights{l} = beta2 * obj.v_weights{l} + (1 - beta2) * (grads.weights{l}.^2);
                
                m_hat = obj.m_weights{l} / (1 - beta1^obj.t);
                v_hat = obj.v_weights{l} / (1 - beta2^obj.t);
                
                obj.weights{l} = obj.weights{l} - obj.learningRate * m_hat ./ (sqrt(v_hat) + eps);
                obj.weights{l} = obj.weights{l} * (1 - obj.weightDecay);
            end
        end
        
        function obj = train(obj, X_train, Y_train, X_val, Y_val, varargin)
            p = inputParser;
            addParameter(p, 'epochs', 100);
            addParameter(p, 'batchSize', 32);
            addParameter(p, 'verbose', true);
            addParameter(p, 'patience', 20);
            parse(p, varargin{:});
            
            epochs = p.Results.epochs;
            batchSize = p.Results.batchSize;
            verbose = p.Results.verbose;
            patience = p.Results.patience;
            
            n_train = size(X_train, 1);
            n_batches = ceil(n_train / batchSize);
            
            obj.trainLoss = zeros(epochs, 1);
            obj.valLoss = zeros(epochs, 1);
            obj.trainR2 = zeros(epochs, 1);
            obj.valR2 = zeros(epochs, 1);
            
            best_val_loss = inf;
            patience_counter = 0;
            best_weights = obj.weights;
            best_biases = obj.biases;
            
            if verbose
                fprintf('\nTraining Neural Star_G Network\n');
                fprintf('Epochs: %d, Batch: %d, LR: %.4f\n\n', epochs, batchSize, obj.learningRate);
            end
            
            for epoch = 1:epochs
                tic;
                perm = randperm(n_train);
                epoch_loss = 0;
                
                for batch = 1:n_batches
                    batch_start = (batch - 1) * batchSize + 1;
                    batch_end = min(batch * batchSize, n_train);
                    batch_idx = perm(batch_start:batch_end);
                    
                    X_batch = X_train(batch_idx, :, :);
                    Y_batch = Y_train(batch_idx);
                    
                    y_pred = obj.predict(X_batch);
                    batch_loss = obj.computeLoss(y_pred, Y_batch);
                    epoch_loss = epoch_loss + batch_loss;
                    
                    grads = obj.backward(X_batch, Y_batch);
                    obj = obj.adamUpdate(grads);
                end
                
                epoch_time = toc;
                obj.trainLoss(epoch) = epoch_loss / n_batches;
                
                y_val_pred = obj.predict(X_val);
                obj.valLoss(epoch) = obj.computeLoss(y_val_pred, Y_val);
                
                y_train_pred = obj.predict(X_train);
                ss_res_train = sum((y_train_pred - Y_train).^2);
                ss_tot_train = sum((Y_train - mean(Y_train)).^2) + 1e-10;
                obj.trainR2(epoch) = 1 - ss_res_train / ss_tot_train;
                
                ss_res_val = sum((y_val_pred - Y_val).^2);
                ss_tot_val = sum((Y_val - mean(Y_val)).^2) + 1e-10;
                obj.valR2(epoch) = 1 - ss_res_val / ss_tot_val;
                
                if obj.valLoss(epoch) < best_val_loss
                    best_val_loss = obj.valLoss(epoch);
                    best_weights = obj.weights;
                    best_biases = obj.biases;
                    patience_counter = 0;
                else
                    patience_counter = patience_counter + 1;
                end
                
                if verbose && mod(epoch, 5) == 0
                    fprintf('Epoch %3d: Loss=%.4f, R2=%.4f, Val R2=%.4f (%.1fs)\n', ...
                        epoch, obj.trainLoss(epoch), obj.trainR2(epoch), obj.valR2(epoch), epoch_time);
                end
                
                if patience_counter >= patience
                    if verbose
                        fprintf('Early stopping at epoch %d\n', epoch);
                    end
                    break;
                end
            end
            
            obj.weights = best_weights;
            obj.biases = best_biases;
            
            if verbose
                fprintf('\nBest Val R2: %.4f\n\n', max(obj.valR2(1:epoch)));
            end
        end
        
        function obj = compressWeights(obj, rank)
            fprintf('Compressing to rank %d...\n', rank);
            for l = 1:length(obj.weights)
                W_orig = obj.weights{l};
                W_comp = obj.G.truncate(W_orig, rank);
                err = norm(W_orig(:) - W_comp(:)) / (norm(W_orig(:)) + 1e-10);
                obj.weights{l} = W_comp;
                fprintf('Layer %d: %.2f%% error\n', l, err * 100);
            end
        end
    end
end