import numpy as np import torch import scipy.stats class WassersteinICA: def __init__(self, X): self.X = X self.n = X.shape[1] self.whitened = False self.epsilon = 1e-7 # Cache for the analytical target (computed once) self.analytical_target = None def whiten(self): """ Whiten the data (zero mean, unit variance, uncorrelated). """ X_centered = self.X - torch.mean(self.X, dim=1, keepdim=True) cov = torch.matmul(X_centered, X_centered.t()) / (self.n - 1) D, E = torch.linalg.eigh(cov) D_inv_sqrt = torch.diag(1.0 / torch.sqrt(D + 1e-5)) self.W_white = torch.matmul(D_inv_sqrt, E.T) self.X_white = torch.matmul(self.W_white, X_centered) self.whitened = True # Pre-compute the exact analytical Gaussian target self.analytical_target = self._compute_analytical_target(self.n) def _compute_analytical_target(self, n): """ Computes the EXACT 'Average Quantile' for each bin analytically. Formula: Target_i = N * (pdf(z_{i-1}) - pdf(z_i)) This replaces 'sampling' with exact calculus. """ p_edges = np.linspace(0, 1, n + 1) z_edges = scipy.stats.norm.ppf(p_edges) phi_edges = scipy.stats.norm.pdf(z_edges) target_np = n * (phi_edges[:-1] - phi_edges[1:]) return torch.tensor(target_np, dtype=torch.float32, device=self.X.device) # ========================================== # VECTORIZED: Core Distance Metric # ========================================== def wasserstein2_analytical(self, W, cost='l2', dither_sigma=0.0): """ Computes W distance. Supports both single vectors (legacy) and matrices (batched parallel). cost: 'l2' for standard Wasserstein, 'logcosh' for robust Huber-like geometry. dither_sigma: Injects continuous noise to smooth discrete CDF steps. """ assert self.whitened, "Call whiten() before computing distance." is_1d = W.dim() == 1 if is_1d: W = W.unsqueeze(0) Y = torch.mm(W, self.X_white) # DITHERING: Inject continuous noise to break discrete ties and smooth the CDF if dither_sigma > 0: Y = Y + torch.randn_like(Y) * dither_sigma sorted_Y, _ = torch.sort(Y, dim=1) diff = sorted_Y - self.analytical_target if cost == 'l2': distances = torch.mean(diff ** 2, dim=1) elif cost == 'logcosh': # Numerically stable logcosh to prevent NaN gradients on massive outliers abs_diff = torch.abs(diff) logcosh_diff = abs_diff + torch.log1p(torch.exp(-2.0 * abs_diff)) - np.log(2.0) distances = torch.mean(logcosh_diff, dim=1) else: raise ValueError("cost must be 'l2' or 'logcosh'") if is_1d: return distances[0] return distances # ========================================== # VECTORIZED: Phase 1 (Deflation & Restarts) # ========================================== def optimize_wasserstein2(self, prev_components=None, grid_points=100, continuous=True, max_iter=200, lr=0.1, n_restarts=50, decay_rate=0.5, decay_step=50, cost='l2', dither_sigma=0.0): """ Find ONE maximizer of W distance (Deflationary). """ if continuous: W_batch = torch.randn(n_restarts, self.X.shape[0], device=self.X.device) if prev_components is not None and prev_components.shape[0] > 0: proj = torch.matmul(W_batch, prev_components.t()) W_batch = W_batch - torch.matmul(proj, prev_components) W_batch = W_batch / torch.norm(W_batch, dim=1, keepdim=True) W_batch.requires_grad_(True) current_lr = lr for i in range(max_iter): if (i + 1) % decay_step == 0: current_lr *= decay_rate # Pass the dither parameter down dist = self.wasserstein2_analytical(W_batch, cost=cost, dither_sigma=dither_sigma).sum() if W_batch.grad is not None: W_batch.grad.zero_() dist.backward() with torch.no_grad(): grad = W_batch.grad if prev_components is not None and prev_components.shape[0] > 0: proj_grad = torch.matmul(grad, prev_components.t()) grad = grad - torch.matmul(proj_grad, prev_components) dot_pw = torch.sum(grad * W_batch.data, dim=1, keepdim=True) grad = grad - dot_pw * W_batch.data grad_norms = torch.norm(grad, dim=1, keepdim=True) grad = torch.where(grad_norms > 1.0, grad / grad_norms, grad) W_batch.data += current_lr * grad if prev_components is not None and prev_components.shape[0] > 0: proj = torch.matmul(W_batch.data, prev_components.t()) W_batch.data = W_batch.data - torch.matmul(proj, prev_components) W_batch.data /= torch.norm(W_batch.data, dim=1, keepdim=True) with torch.no_grad(): # Evaluate final best vector without noise to get the true mathematical distance final_distances = self.wasserstein2_analytical(W_batch, cost=cost, dither_sigma=0.0) best_idx = torch.argmax(final_distances) best_w = W_batch[best_idx].detach().clone() best_dist = final_distances[best_idx].item() return best_w, best_dist else: # Legacy Discrete Grid Search angles = torch.linspace(0, 2 * np.pi, steps=grid_points, device=self.X.device) candidates = torch.stack([torch.cos(angles), torch.sin(angles)], dim=1) if prev_components is not None and prev_components.shape[0] > 0: proj = torch.matmul(candidates, prev_components.t()) candidates = candidates - torch.matmul(proj, prev_components) norms = torch.norm(candidates, dim=1, keepdim=True) mask = norms.squeeze() > 1e-6 candidates = candidates[mask] if candidates.shape[0] == 0: raise ValueError("No valid candidates.") candidates = candidates / norms[mask] dist_best = -np.inf w_best = None for w in candidates: d = self.wasserstein2_analytical(w, cost=cost, dither_sigma=dither_sigma).item() if d > dist_best: dist_best = d w_best = w return w_best, dist_best def _symmetric_decorrelation(self, W): M = torch.mm(W, W.t()) evals, evecs = torch.linalg.eigh(M) d_inv_sqrt = torch.diag(1.0 / torch.sqrt(evals + 1e-5)) inv_sqrt_M = torch.mm(torch.mm(evecs, d_inv_sqrt), evecs.t()) return torch.mm(inv_sqrt_M, W) # ========================================== # VECTORIZED: Phase 2 # ========================================== def optimize_symmetric(self, n_components=None, max_iter=300, lr=1.0, init_w=None, optimizer='sgd', penalty_weight=10.0, use_sinkhorn=False, reg=0.01, sinkhorn_iter=50, cost='l2', dither_sigma=0.0, batch_size=512, n_restarts=None): # Added n_restarts to signature if n_components is None: n_components = self.X.shape[0] # DYNAMIC RESTART LOGIC: Default to dims * 4, capped at 200 if n_restarts is None: n_restarts = n_components * 4 if n_restarts > 200: n_restarts = 200 elif n_restarts < 20: # Floor to ensure basic exploration n_restarts = 20 if init_w is not None: W = init_w.clone().to(self.X.device) else: # If no initial W is provided, generate a batch of random restarts # using the dynamically calculated n_restarts W = torch.randn(n_restarts, n_components, self.X.shape[0], device=self.X.device) # Note: You would likely apply decorrelation across the batch dimension here W = self._symmetric_decorrelation(W) W.requires_grad_(True) # Store originals to safely patch during stochastic batching original_X_white = self.X_white original_n = self.n original_target = self.analytical_target # ... [Rest of the optimization logic remains the same] ... if optimizer == 'sgd': for i in range(max_iter): if W.grad is not None: W.grad.zero_() if use_sinkhorn: total_dist = self.sinkhorn_distance(W, reg=reg, n_iter=sinkhorn_iter).sum() else: total_dist = self.wasserstein2_analytical(W, cost=cost, dither_sigma=dither_sigma).sum() loss = -total_dist loss.backward() with torch.no_grad(): grad = W.grad grad_norms = torch.norm(grad, dim=1, keepdim=True) grad = torch.where(grad_norms > 1.0, grad / grad_norms, grad) W += lr * grad W.data = self._symmetric_decorrelation(W) W.requires_grad_(True) elif optimizer == 'stiefel': current_lr = lr for i in range(max_iter): if W.grad is not None: W.grad.zero_() # STOCHASTIC BATCHING: Randomly slice data to inject gradient noise if batch_size is not None and batch_size < original_n: indices = torch.randperm(original_n, device=self.X.device)[:batch_size] self.X_white = original_X_white[:, indices] self.n = batch_size self.analytical_target = self._compute_analytical_target(self.n) if use_sinkhorn: total_dist = self.sinkhorn_distance(W, reg=reg, n_iter=sinkhorn_iter).sum() else: total_dist = self.wasserstein2_analytical(W, cost=cost, dither_sigma=dither_sigma).sum() total_dist.backward() with torch.no_grad(): grad = W.grad # Stiefel Projection G_Wt = torch.mm(grad, W.data.t()) W_Gt = torch.mm(W.data, grad.t()) sym = 0.5 * (G_Wt + W_Gt) tangent_grad = grad - torch.mm(sym, W.data) tangent_norms = torch.norm(tangent_grad, dim=1, keepdim=True) tangent_grad = torch.where(tangent_norms > 1.0, tangent_grad / tangent_norms, tangent_grad) # Apply step with decaying learning rate to allow settling W += current_lr * tangent_grad W.data = self._symmetric_decorrelation(W.data) W.requires_grad_(True) # Decay learning rate by 1% each step current_lr *= 0.99 elif optimizer == 'lbfgs': penalties = [penalty_weight, penalty_weight * 100, penalty_weight * 10000, penalty_weight * 1000000] steps = max_iter // len(penalties) if steps < 5: steps = 5 for p in penalties: optim = torch.optim.LBFGS([W], lr=lr, max_iter=steps, history_size=50, line_search_fn='strong_wolfe', tolerance_grad=1e-7, tolerance_change=1e-7) def closure(): if W.grad is not None: W.grad.zero_() if use_sinkhorn: total_dist = self.sinkhorn_distance(W, reg=reg, n_iter=sinkhorn_iter).sum() else: total_dist = self.wasserstein2_analytical(W, cost=cost, dither_sigma=dither_sigma).sum() gram = torch.mm(W, W.t()) trace_gram = torch.trace(gram) trace_gram_sq = torch.trace(torch.mm(gram, gram)) ortho_penalty = trace_gram_sq - 2 * trace_gram + n_components loss = -total_dist + (p * ortho_penalty) loss.backward() return loss try: optim.step(closure) except RuntimeError: break with torch.no_grad(): W.data = self._symmetric_decorrelation(W) # Restore original full dataset after optimization loop finishes self.X_white = original_X_white self.n = original_n self.analytical_target = original_target return W.detach() # ========================================== # NEW: OT-Mapping Fixed-Point Rule # ========================================== def optimize_fixed_point(self, n_components=None, max_iter=100, tol=1e-5, init_w=None, step_size=0.5): """ Calculates the OT mapping to the perfect Gaussian, then steps AWAY from it. Acts as Gradient Ascent on the Wasserstein landscape. """ assert self.whitened, "Call whiten() before optimization." if n_components is None: n_components = self.X.shape[0] # 1. Initialization if init_w is not None: W = init_w.clone().to(self.X.device) else: W = torch.randn(n_components, self.X.shape[0], device=self.X.device) W = self._symmetric_decorrelation(W) # We need target matrix T broadcasted to match dimensions (C x N) T = self.analytical_target.unsqueeze(0).expand(n_components, -1) for i in range(max_iter): # Step 1: Project the data (Y = WX) Y = torch.mm(W, self.X_white) # Step 2: Find the ranking/sorting indices idx = torch.argsort(Y, dim=1) # Step 3: Create the "Ideal Target" (Y_ideal) Y_ideal = torch.empty_like(Y) Y_ideal.scatter_(1, idx, T) # Step 4: The Gradient (Direction pointing INTO the Gaussian valley) G = torch.mm(Y_ideal, self.X_white.t()) / (self.n - 1) # Step 5: The Anti-Gaussian Step (Climbing the hill) # We subtract G to step AWAY from the Gaussian W_new = W - step_size * G # Step 6: Symmetrically Orthogonalize W_new W_new = self._symmetric_decorrelation(W_new) # Step 7: Check for convergence cos_theta = torch.abs(torch.diag(torch.mm(W_new, W.t()))) min_cos = torch.min(cos_theta).item() W = W_new if (1.0 - min_cos) < tol: break return W.detach() # ========================================== # LEGACY / BACKWARD COMPATIBILITY FUNCTIONS # ========================================== def _normal_quantile(self, q): q_np = q.cpu().numpy() inv_cdf = scipy.stats.norm.ppf(q_np) return torch.tensor(inv_cdf, dtype=torch.float32, device=q.device) def wasserstein2_distance(self, w): assert self.whitened, "Call whiten() before computing distance." y = torch.mv(self.X_white.t(), w) sorted_y, _ = torch.sort(y) steps = torch.arange(1, self.n + 1, dtype=torch.float32, device=self.X.device) q = (steps - 0.5) / self.n F_n_inv = self._normal_quantile(q) return torch.mean((sorted_y - F_n_inv) ** 2) def wasserstein1_distance(self, w): assert self.whitened, "Call whiten() before computing distance." y = torch.mv(self.X_white.t(), w) sorted_y, _ = torch.sort(y) steps = torch.arange(1, self.n + 1, dtype=torch.float32, device=self.X.device) q = (steps - 0.5) / self.n F_n_inv = self._normal_quantile(q) return torch.mean(torch.abs(sorted_y - F_n_inv)) def _wasserstein2_gradient_approx(self, w, delta=1e-5): grad = torch.zeros_like(w) base_val = self.wasserstein2_distance(w) for i in range(len(w)): w_perturb = w.clone() w_perturb[i] += delta w_perturb /= torch.norm(w_perturb) val = self.wasserstein2_distance(w_perturb) grad[i] = (val - base_val) / delta return grad def sinkhorn_distance(self, W, reg=0.01, n_iter=50): """ Batched Entropy-Regularized W2 distance (Sinkhorn) in Log-Space. W shape: (n_components, n_dimensions) OR (n_dimensions,) """ assert self.whitened, "Call whiten() before computing distance." is_1d = W.dim() == 1 if is_1d: W = W.unsqueeze(0) B = W.shape[0] # Batch size / Number of components # 1. Project all data at once (Shape: B x N) Y = torch.mm(W, self.X_white) # 2. Target: Gaussian Quantiles (Shape: N) steps = torch.arange(1, self.n + 1, dtype=torch.float32, device=self.X.device) q = (steps - 0.5) / self.n target = self._normal_quantile(q) # 3. Batched Cost Matrix C: (B, N_y, N_target) # Broadcasting: Y is (B, N, 1), target is (1, 1, N) C = (Y.unsqueeze(2) - target.view(1, 1, self.n)) ** 2 # 4. Sinkhorn Iterations f = torch.zeros(B, self.n, device=self.X.device) g = torch.zeros(B, self.n, device=self.X.device) log_mu = -torch.log(torch.tensor(self.n, dtype=torch.float32, device=self.X.device)) for _ in range(n_iter): # Update f: Sum over target dimension (dim=2) f = reg * (log_mu - torch.logsumexp((g.unsqueeze(1) - C) / reg, dim=2)) # Update g: Sum over Y dimension (dim=1) g = reg * (log_mu - torch.logsumexp((f.unsqueeze(2) - C) / reg, dim=1)) # 5. Calculate total cost for each batch element # log_P shape: (B, N, N) log_P = (f.unsqueeze(2) + g.unsqueeze(1) - C) / reg distances = torch.sum(torch.exp(log_P) * C, dim=(1, 2)) if is_1d: return distances[0] return distances def optimize_symmetric_sinkhorn(self, n_components=None, max_iter=300, lr=1.0, init_w=None, reg=0.05): return self.optimize_symmetric( n_components=n_components, max_iter=max_iter, lr=lr, init_w=init_w, optimizer='lbfgs', use_sinkhorn=True, reg=reg )