function [feat, norm_params] = extractStarGFeatures(X, G, n_feat_hint, norm_params)
%% extractStarGFeatures - Invariant features via generalized Fourier
%%
%% PERFORMANCE: sections A-D are fully vectorized (no per-sample loop).
%% Only section E (star_G-SVD tube norms) requires a loop.
%%
%% LH & Claude 2026

    [n_samples, n_f, n_g] = size(X);

    % ================================================================
    % LAYOUT: determine or retrieve from norm_params
    % ================================================================
    if nargin < 4 || isempty(norm_params)
        is_training = true;
        [p, q] = best_reshape(n_f);
        n_svd = min(p, q);

        sample1 = squeeze(X(1, :, :));
        if size(sample1,1) ~= n_f, sample1 = sample1'; end
        row_var = var(sample1, 0, 2);
        inv_mask = row_var < 1e-8 * (max(row_var) + 1e-20);
        n_inv = sum(inv_mask);
        eq_idx = find(~inv_mask);
        K_rows = min(14, length(eq_idx));  % include coupled rows for product groups
    else
        is_training = false;
        p = norm_params.p; q = norm_params.q; n_svd = norm_params.n_svd;
        inv_mask = norm_params.inv_mask; n_inv = norm_params.n_inv;
        eq_idx = norm_params.eq_idx; K_rows = norm_params.K_rows;
    end

    F_mat = G.F;  % n_g x n_g

    % ================================================================
    % VECTORIZED FEATURES (no per-sample loop)
    % ================================================================

    % A. DC COMPONENT: mean across group dim  [n_samples x n_f]
    dc_all = mean(X, 3);

    % B. AC ENERGY: std across group dim  [n_samples x n_f]
    ac_all = std(X, 0, 3);

    % C+D. Generalized Fourier transform of all samples at once
    % X is n_samples x n_f x n_g.  We want X_hat(i,j,k) = sum_g X(i,j,g) * F(g,k)
    % Reshape to (n_samples*n_f) x n_g, multiply by F, reshape back.
    X_2d = reshape(X, [], n_g);            % (n_samples*n_f) x n_g
    Xhat_2d = X_2d * F_mat;               % (n_samples*n_f) x n_g
    Xhat = reshape(Xhat_2d, n_samples, n_f, n_g);

    % C. Total per-frequency power  [n_samples x n_g]
    col_power = squeeze(sum(abs(Xhat).^2, 2));  % sum across features
    if size(col_power, 1) == 1, col_power = col_power'; end

    % D. Per-row Fourier power for first K equivariant rows [n_samples x K*n_g]
    per_row_power = zeros(n_samples, K_rows * n_g);
    for kr = 1:K_rows
        ri = eq_idx(kr);
        rp = squeeze(abs(Xhat(:, ri, :)).^2);  % n_samples x n_g
        if size(rp, 1) == 1, rp = rp'; end
        per_row_power(:, (kr-1)*n_g+1 : kr*n_g) = real(rp);
    end

    % F. Direct invariant features  [n_samples x n_inv]
    inv_feat = squeeze(X(:, inv_mask, 1));
    if n_inv == 1, inv_feat = inv_feat(:); end

    % G. Compact statistics (vectorized via svd of each sample's n_f x n_g matrix)
    % We compute nuclear norm, spectral norm, condition, entropy
    stats = zeros(n_samples, 4);

    % ================================================================
    % PER-SAMPLE LOOP (only for SVD-based features)
    % ================================================================
    svd_feat = zeros(n_samples, n_svd);

    for i = 1:n_samples
        Xi = squeeze(X(i, :, :));
        if size(Xi,1) ~= n_f, Xi = Xi'; end

        % E. Star_G-SVD tube norms
        Xi_padded = Xi;
        if n_f < p*q
            Xi_padded = [Xi; zeros(p*q - n_f, n_g)];
        end
        Xi_tensor = reshape(Xi_padded(1:p*q, :), [p, q, n_g]);
        [~, S_tensor, ~] = G.starG_SVD(Xi_tensor);
        tn = zeros(n_svd, 1);
        for k = 1:n_svd
            tn(k) = norm(squeeze(S_tensor(k, k, :)));
        end
        svd_feat(i, :) = sort(tn, 'descend')';

        % G. Compact statistics
        sv = svd(Xi);
        stats(i, 1) = sum(sv);                              % nuclear norm
        stats(i, 2) = sv(1);                                % spectral norm
        sv_nz = sv(sv > 1e-10);
        stats(i, 3) = sv(1) / (sv_nz(end) + 1e-20);        % condition
        sv_p = sv_nz / (sum(sv_nz) + 1e-20);
        stats(i, 4) = -sum(sv_p .* log(sv_p + 1e-20));     % entropy
    end

    % ================================================================
    % ASSEMBLE
    % ================================================================
    feat = [dc_all, ac_all, real(col_power), per_row_power, svd_feat, inv_feat, stats];

    feat(isnan(feat)) = 0;
    feat(isinf(feat)) = 0;

    % ================================================================
    % NORMALIZATION
    % ================================================================
    if is_training
        raw_std = std(feat, 0, 1);
        keep = raw_std >= 1e-8;

        norm_params = struct();
        norm_params.p = p; norm_params.q = q; norm_params.n_svd = n_svd;
        norm_params.inv_mask = inv_mask; norm_params.n_inv = n_inv;
        norm_params.eq_idx = eq_idx; norm_params.K_rows = K_rows;
        norm_params.keep = keep;
        norm_params.mu = mean(feat(:, keep), 1);
        norm_params.sig = std(feat(:, keep), 0, 1);
        norm_params.sig(norm_params.sig < 1e-10) = 1;
    end

    feat = feat(:, norm_params.keep);
    feat = (feat - norm_params.mu) ./ norm_params.sig;

    % Intercept column
    feat = [ones(n_samples, 1), feat];
end


function [p, q] = best_reshape(n_f)
    best_min = 0; p = n_f; q = 1;
    for pp = 2:floor(sqrt(n_f * 2))
        qq = floor(n_f / pp);
        if qq < 1, continue; end
        if min(pp, qq) > best_min
            best_min = min(pp, qq); p = pp; q = qq;
        end
    end
end
