%% starG_helpers.m
%% Utility functions for ★_G comparison experiments
%% LH & SU & Claude 2026

classdef starG_helpers
    methods (Static)
        
        function r2 = computeR2(y_pred, y_true)
            y_pred = y_pred(:);
            y_true = y_true(:);
            ss_res = sum((y_pred - y_true).^2);
            ss_tot = sum((y_true - mean(y_true)).^2) + 1e-10;
            r2 = 1 - ss_res / ss_tot;
        end
        
        function n_params = countMLPParams(layers)
            n_params = 0;
            for l = 1:length(layers)-1
                n_params = n_params + layers(l) * layers(l+1) + layers(l+1);
            end
        end
        
        function X_inv = computeInvariantFeatures(X)
            X_mean = mean(X, 3);
            X_std = std(X, 0, 3);
            X_min = min(X, [], 3);
            X_max = max(X, [], 3);
            X_inv = [X_mean, X_std, X_min, X_max];
        end
        
        function [X, Y] = generateMolecularData(n_samples, n_feat, n_rot)
            X = zeros(n_samples, n_feat, n_rot);
            Y = zeros(n_samples, 1);
            
            for i = 1:n_samples
                n_atoms = randi([4, 10]);
                pos = randn(n_atoms, 3) * 2;
                pos = pos - mean(pos, 1);
                Z = randi([1, 9], n_atoms, 1);
                Z(Z > 1 & Z < 6) = 6;
                
                dists = pdist(pos);
                if isempty(dists), dists = 1; end
                Y(i) = mean(dists) + 0.3 * std(dists) + sum(Z.^2) / 500;
                
                for g = 1:n_rot
                    theta = 2 * pi * (g - 1) / n_rot;
                    R = [cos(theta), -sin(theta), 0; 
                         sin(theta), cos(theta), 0; 
                         0, 0, 1];
                    pos_rot = (R * pos')';
                    
                    feat = zeros(n_feat, 1);
                    for a = 1:n_atoms
                        x = pos_rot(a, 1);
                        y = pos_rot(a, 2);
                        z = pos_rot(a, 3);
                        r = norm(pos_rot(a, :));
                        
                        feat(1) = feat(1) + Z(a) * x;
                        feat(2) = feat(2) + Z(a) * y;
                        feat(3) = feat(3) + Z(a) * z;
                        feat(4) = feat(4) + Z(a) * r;
                        feat(5) = feat(5) + Z(a) * r^2;
                        
                        for f = 6:n_feat
                            feat(f) = feat(f) + Z(a) * cos((f-5) * atan2(y, x)) * exp(-r^2 / 8);
                        end
                    end
                    X(i, :, g) = feat';
                end
            end
        end
        
        function x = roundp(x, p)
            scale = 10^p;
            x = round(x * scale) / scale;
        end
        
    end
end