%% run_tests.m - StarGAlgebra Test Suite
%% Run with: >> run_tests
%% LH & SU & Claude 2026
fprintf('\n');
fprintf('################################################################\n');
fprintf('# StarGAlgebra Test Suite                                     #\n');
fprintf('################################################################\n\n');

% Check GPU
hasGPU = gpuDeviceCount > 0;
if hasGPU
    gpu = gpuDevice;
    fprintf('GPU: %s (%.1f GB)\n\n', gpu.Name, gpu.TotalMemory/1e9);
else
    fprintf('No GPU detected\n\n');
end

%% Correctness Tests
fprintf('============================================================\n');
fprintf('                    CORRECTNESS TESTS                       \n');
fprintf('============================================================\n');

groups = {
    {'cyclic', 4, 'Z_4'}
    {'cyclic', 8, 'Z_8'}
    {'klein4', [], 'Klein-4'}
    {'dihedral', 3, 'D_3'}
    {'symmetric', 3, 'S_3'}
    {'quaternion', [], 'Q_8'}
};

results = cell(length(groups), 2);

for i = 1:length(groups)
    gtype = groups{i}{1};
    gparam = groups{i}{2};
    gname = groups{i}{3};
    
    fprintf('\n,  %s , \n', gname);
    
    try
        if isempty(gparam)
            G = StarGAlgebra(gtype);
        else
            G = StarGAlgebra(gtype, gparam);
        end
        G.runAllTests();
        results{i,1} = gname;
        results{i,2} = 'PASS';
    catch ME
        fprintf('ERROR: %s\n', ME.message);
        results{i,1} = gname;
        results{i,2} = 'FAIL';
    end
end

%% GPU Tests
if hasGPU
    fprintf('\n============================================================\n');
    fprintf('                      GPU TESTS                             \n');
    fprintf('============================================================\n\n');
    
    G = StarGAlgebra('cyclic', 8);
    G_gpu = G.enableGPU();
    
    fprintf('\n,  GPU Correctness , \n');
    G_gpu.runAllTests();
    
    fprintf('\n,  CPU vs GPU Performance , \n');
    sizes = [20, 50, 100, 200];
    
    fprintf('%-10s %-12s %-12s %-10s\n', 'Size', 'CPU (s)', 'GPU (s)', 'Speedup');
    fprintf('%s\n', repmat('-', 1, 45));
    
    for sz = sizes
        A = randn(sz, sz, G.n);
        B = randn(sz, sz, G.n);
        
        G.starG(A, B);
        G_gpu.starG(A, B);
        
        tic; for r=1:5, G.starG(A, B); end; t_cpu = toc/5;
        tic; for r=1:5, G_gpu.starG(A, B); end; t_gpu = toc/5;
        
        fprintf('%-10d %-12.4f %-12.4f %-10.2fx\n', sz, t_cpu, t_gpu, t_cpu/t_gpu);
    end
    
    fprintf('\n,  GPU Benchmark , \n');
    G_gpu.benchmark([50, 100, 200, 500]);
end

%% Summary
fprintf('\n============================================================\n');
fprintf('                        SUMMARY                             \n');
fprintf('============================================================\n\n');

fprintf('%-15s %s\n', 'Group', 'Result');
fprintf('%s\n', repmat('-', 1, 25));
for i = 1:size(results, 1)
    fprintf('%-15s %s\n', results{i,1}, results{i,2});
end

fprintf('\n################################################################\n');
fprintf('# Done                                                        #\n');
fprintf('################################################################\n');