%% NeuralStarGFramework.m - CLEAN FIXED VERSION
%% Neural Network Framework Built on ★_G Algebra
%% LH & SU & Claude
%% ============================================================================
classdef NeuralStarGFramework < handle
properties
G
layers
weights
biases
activations
useGPU
learningRate
weightDecay
m_weights
v_weights
m_biases
v_biases
t
trainLoss
valLoss
trainR2
valR2
end
methods
function obj = NeuralStarGFramework(G, layerSizes, varargin)
% Parse optional arguments
p = inputParser;
addParameter(p, 'learningRate', 0.001);
addParameter(p, 'useGPU', false);
parse(p, varargin{:});
obj.G = G;
obj.layers = layerSizes;
obj.learningRate = p.Results.learningRate;
obj.useGPU = p.Results.useGPU;
obj.weightDecay = 1e-4;
obj.t = 0;
nLayers = length(layerSizes) - 1;
obj.weights = cell(nLayers, 1);
obj.biases = cell(nLayers, 1);
obj.m_weights = cell(nLayers, 1);
obj.v_weights = cell(nLayers, 1);
obj.m_biases = cell(nLayers, 1);
obj.v_biases = cell(nLayers, 1);
% Activations: ReLU for hidden, linear for output
obj.activations = cell(nLayers, 1);
for l = 1:nLayers-1
obj.activations{l} = @(x) max(0, x);
end
obj.activations{nLayers} = @(x) x;
% Initialize weights
for l = 1:nLayers
fan_in = layerSizes(l);
fan_out = layerSizes(l+1);
scale = sqrt(2 / (fan_in + fan_out));
obj.weights{l} = scale * randn(fan_out, fan_in, G.n);
obj.biases{l} = zeros(fan_out, 1, G.n);
obj.m_weights{l} = zeros(size(obj.weights{l}));
obj.v_weights{l} = zeros(size(obj.weights{l}));
obj.m_biases{l} = zeros(size(obj.biases{l}));
obj.v_biases{l} = zeros(size(obj.biases{l}));
end
end
function [output, cache] = forward(obj, X)
nLayers = length(obj.weights);
% Handle dimensions
if ndims(X) == 2
batch_size = 1;
n_feat = size(X, 1);
X = reshape(X, [1, n_feat, obj.G.n]);
else
batch_size = size(X, 1);
end
cache = cell(nLayers + 1, 1);
cache{1} = X;
A = X;
for l = 1:nLayers
W = obj.weights{l};
b = obj.biases{l};
[out_dim, in_dim, n_g] = size(W);
Z = zeros(batch_size, out_dim, n_g);
for i = 1:batch_size
A_i = squeeze(A(i, :, :));
if size(A_i, 1) ~= in_dim
A_i = A_i';
end
A_i_3d = reshape(A_i, [in_dim, 1, n_g]);
Z_i = obj.G.starG(W, A_i_3d);
Z_i = Z_i + b;
Z(i, :, :) = squeeze(Z_i);
end
A = obj.activations{l}(Z);
cache{l + 1} = A;
end
output = A;
end
function y = invariantPool(obj, X)
y = mean(mean(X, 3), 2);
end
function [y_pred, cache] = predict(obj, X)
[output, cache] = obj.forward(X);
y_pred = obj.invariantPool(output);
y_pred = squeeze(y_pred);
end
function loss = computeLoss(obj, y_pred, y_true)
loss = mean((y_pred(:) - y_true(:)).^2);
end
function grads = backward(obj, X, y_true)
epsilon = 1e-5;
nLayers = length(obj.weights);
grads.weights = cell(nLayers, 1);
grads.biases = cell(nLayers, 1);
for l = 1:nLayers
grads.weights{l} = zeros(size(obj.weights{l}));
grads.biases{l} = zeros(size(obj.biases{l}));
W_orig = obj.weights{l};
[d1, d2, d3] = size(W_orig);
n_sample = min(30, d1 * d2 * d3);
sample_idx = randperm(d1 * d2 * d3, n_sample);
for idx = sample_idx
[ii, jj, kk] = ind2sub([d1, d2, d3], idx);
obj.weights{l}(ii, jj, kk) = W_orig(ii, jj, kk) + epsilon;
y_plus = obj.predict(X);
loss_plus = obj.computeLoss(y_plus, y_true);
obj.weights{l}(ii, jj, kk) = W_orig(ii, jj, kk) - epsilon;
y_minus = obj.predict(X);
loss_minus = obj.computeLoss(y_minus, y_true);
grads.weights{l}(ii, jj, kk) = (loss_plus - loss_minus) / (2 * epsilon);
obj.weights{l}(ii, jj, kk) = W_orig(ii, jj, kk);
end
grads.weights{l} = grads.weights{l} * (d1 * d2 * d3 / n_sample);
end
end
function obj = adamUpdate(obj, grads)
beta1 = 0.9;
beta2 = 0.999;
eps = 1e-8;
obj.t = obj.t + 1;
for l = 1:length(obj.weights)
obj.m_weights{l} = beta1 * obj.m_weights{l} + (1 - beta1) * grads.weights{l};
obj.v_weights{l} = beta2 * obj.v_weights{l} + (1 - beta2) * (grads.weights{l}.^2);
m_hat = obj.m_weights{l} / (1 - beta1^obj.t);
v_hat = obj.v_weights{l} / (1 - beta2^obj.t);
obj.weights{l} = obj.weights{l} - obj.learningRate * m_hat ./ (sqrt(v_hat) + eps);
obj.weights{l} = obj.weights{l} * (1 - obj.weightDecay);
end
end
function obj = train(obj, X_train, Y_train, X_val, Y_val, varargin)
p = inputParser;
addParameter(p, 'epochs', 100);
addParameter(p, 'batchSize', 32);
addParameter(p, 'verbose', true);
addParameter(p, 'patience', 20);
parse(p, varargin{:});
epochs = p.Results.epochs;
batchSize = p.Results.batchSize;
verbose = p.Results.verbose;
patience = p.Results.patience;
n_train = size(X_train, 1);
n_batches = ceil(n_train / batchSize);
obj.trainLoss = zeros(epochs, 1);
obj.valLoss = zeros(epochs, 1);
obj.trainR2 = zeros(epochs, 1);
obj.valR2 = zeros(epochs, 1);
best_val_loss = inf;
patience_counter = 0;
best_weights = obj.weights;
best_biases = obj.biases;
if verbose
fprintf('\nTraining Neural Star_G Network\n');
fprintf('Epochs: %d, Batch: %d, LR: %.4f\n\n', epochs, batchSize, obj.learningRate);
end
for epoch = 1:epochs
tic;
perm = randperm(n_train);
epoch_loss = 0;
for batch = 1:n_batches
batch_start = (batch - 1) * batchSize + 1;
batch_end = min(batch * batchSize, n_train);
batch_idx = perm(batch_start:batch_end);
X_batch = X_train(batch_idx, :, :);
Y_batch = Y_train(batch_idx);
y_pred = obj.predict(X_batch);
batch_loss = obj.computeLoss(y_pred, Y_batch);
epoch_loss = epoch_loss + batch_loss;
grads = obj.backward(X_batch, Y_batch);
obj = obj.adamUpdate(grads);
end
epoch_time = toc;
obj.trainLoss(epoch) = epoch_loss / n_batches;
y_val_pred = obj.predict(X_val);
obj.valLoss(epoch) = obj.computeLoss(y_val_pred, Y_val);
y_train_pred = obj.predict(X_train);
ss_res_train = sum((y_train_pred - Y_train).^2);
ss_tot_train = sum((Y_train - mean(Y_train)).^2) + 1e-10;
obj.trainR2(epoch) = 1 - ss_res_train / ss_tot_train;
ss_res_val = sum((y_val_pred - Y_val).^2);
ss_tot_val = sum((Y_val - mean(Y_val)).^2) + 1e-10;
obj.valR2(epoch) = 1 - ss_res_val / ss_tot_val;
if obj.valLoss(epoch) < best_val_loss
best_val_loss = obj.valLoss(epoch);
best_weights = obj.weights;
best_biases = obj.biases;
patience_counter = 0;
else
patience_counter = patience_counter + 1;
end
if verbose && mod(epoch, 5) == 0
fprintf('Epoch %3d: Loss=%.4f, R2=%.4f, Val R2=%.4f (%.1fs)\n', ...
epoch, obj.trainLoss(epoch), obj.trainR2(epoch), obj.valR2(epoch), epoch_time);
end
if patience_counter >= patience
if verbose
fprintf('Early stopping at epoch %d\n', epoch);
end
break;
end
end
obj.weights = best_weights;
obj.biases = best_biases;
if verbose
fprintf('\nBest Val R2: %.4f\n\n', max(obj.valR2(1:epoch)));
end
end
function obj = compressWeights(obj, rank)
fprintf('Compressing to rank %d...\n', rank);
for l = 1:length(obj.weights)
W_orig = obj.weights{l};
W_comp = obj.G.truncate(W_orig, rank);
err = norm(W_orig(:) - W_comp(:)) / (norm(W_orig(:)) + 1e-10);
obj.weights{l} = W_comp;
fprintf('Layer %d: %.2f%% error\n', l, err * 100);
end
end
end
end