function [feat, norm_params] = extractStarGFeatures(X, G, n_feat_hint, norm_params)
%% extractStarGFeatures - Invariant features via generalized Fourier
%%
%% PERFORMANCE: sections A-D are fully vectorized (no per-sample loop).
%% Only section E (star_G-SVD tube norms) requires a loop.
%%
%% LH & Claude 2026
[n_samples, n_f, n_g] = size(X);
% ================================================================
% LAYOUT: determine or retrieve from norm_params
% ================================================================
if nargin < 4 || isempty(norm_params)
is_training = true;
[p, q] = best_reshape(n_f);
n_svd = min(p, q);
sample1 = squeeze(X(1, :, :));
if size(sample1,1) ~= n_f, sample1 = sample1'; end
row_var = var(sample1, 0, 2);
inv_mask = row_var < 1e-8 * (max(row_var) + 1e-20);
n_inv = sum(inv_mask);
eq_idx = find(~inv_mask);
K_rows = min(14, length(eq_idx)); % include coupled rows for product groups
else
is_training = false;
p = norm_params.p; q = norm_params.q; n_svd = norm_params.n_svd;
inv_mask = norm_params.inv_mask; n_inv = norm_params.n_inv;
eq_idx = norm_params.eq_idx; K_rows = norm_params.K_rows;
end
F_mat = G.F; % n_g x n_g
% ================================================================
% VECTORIZED FEATURES (no per-sample loop)
% ================================================================
% A. DC COMPONENT: mean across group dim [n_samples x n_f]
dc_all = mean(X, 3);
% B. AC ENERGY: std across group dim [n_samples x n_f]
ac_all = std(X, 0, 3);
% C+D. Generalized Fourier transform of all samples at once
% X is n_samples x n_f x n_g. We want X_hat(i,j,k) = sum_g X(i,j,g) * F(g,k)
% Reshape to (n_samples*n_f) x n_g, multiply by F, reshape back.
X_2d = reshape(X, [], n_g); % (n_samples*n_f) x n_g
Xhat_2d = X_2d * F_mat; % (n_samples*n_f) x n_g
Xhat = reshape(Xhat_2d, n_samples, n_f, n_g);
% C. Total per-frequency power [n_samples x n_g]
col_power = squeeze(sum(abs(Xhat).^2, 2)); % sum across features
if size(col_power, 1) == 1, col_power = col_power'; end
% D. Per-row Fourier power for first K equivariant rows [n_samples x K*n_g]
per_row_power = zeros(n_samples, K_rows * n_g);
for kr = 1:K_rows
ri = eq_idx(kr);
rp = squeeze(abs(Xhat(:, ri, :)).^2); % n_samples x n_g
if size(rp, 1) == 1, rp = rp'; end
per_row_power(:, (kr-1)*n_g+1 : kr*n_g) = real(rp);
end
% F. Direct invariant features [n_samples x n_inv]
inv_feat = squeeze(X(:, inv_mask, 1));
if n_inv == 1, inv_feat = inv_feat(:); end
% G. Compact statistics (vectorized via svd of each sample's n_f x n_g matrix)
% We compute nuclear norm, spectral norm, condition, entropy
stats = zeros(n_samples, 4);
% ================================================================
% PER-SAMPLE LOOP (only for SVD-based features)
% ================================================================
svd_feat = zeros(n_samples, n_svd);
for i = 1:n_samples
Xi = squeeze(X(i, :, :));
if size(Xi,1) ~= n_f, Xi = Xi'; end
% E. Star_G-SVD tube norms
Xi_padded = Xi;
if n_f < p*q
Xi_padded = [Xi; zeros(p*q - n_f, n_g)];
end
Xi_tensor = reshape(Xi_padded(1:p*q, :), [p, q, n_g]);
[~, S_tensor, ~] = G.starG_SVD(Xi_tensor);
tn = zeros(n_svd, 1);
for k = 1:n_svd
tn(k) = norm(squeeze(S_tensor(k, k, :)));
end
svd_feat(i, :) = sort(tn, 'descend')';
% G. Compact statistics
sv = svd(Xi);
stats(i, 1) = sum(sv); % nuclear norm
stats(i, 2) = sv(1); % spectral norm
sv_nz = sv(sv > 1e-10);
stats(i, 3) = sv(1) / (sv_nz(end) + 1e-20); % condition
sv_p = sv_nz / (sum(sv_nz) + 1e-20);
stats(i, 4) = -sum(sv_p .* log(sv_p + 1e-20)); % entropy
end
% ================================================================
% ASSEMBLE
% ================================================================
feat = [dc_all, ac_all, real(col_power), per_row_power, svd_feat, inv_feat, stats];
feat(isnan(feat)) = 0;
feat(isinf(feat)) = 0;
% ================================================================
% NORMALIZATION
% ================================================================
if is_training
raw_std = std(feat, 0, 1);
keep = raw_std >= 1e-8;
norm_params = struct();
norm_params.p = p; norm_params.q = q; norm_params.n_svd = n_svd;
norm_params.inv_mask = inv_mask; norm_params.n_inv = n_inv;
norm_params.eq_idx = eq_idx; norm_params.K_rows = K_rows;
norm_params.keep = keep;
norm_params.mu = mean(feat(:, keep), 1);
norm_params.sig = std(feat(:, keep), 0, 1);
norm_params.sig(norm_params.sig < 1e-10) = 1;
end
feat = feat(:, norm_params.keep);
feat = (feat - norm_params.mu) ./ norm_params.sig;
% Intercept column
feat = [ones(n_samples, 1), feat];
end
function [p, q] = best_reshape(n_f)
best_min = 0; p = n_f; q = 1;
for pp = 2:floor(sqrt(n_f * 2))
qq = floor(n_f / pp);
if qq < 1, continue; end
if min(pp, qq) > best_min
best_min = min(pp, qq); p = pp; q = qq;
end
end
end