ot-in-linear-ica / src / wasserstein_ica / core.py
core.py
Raw
import numpy as np 
import torch 
import scipy.stats 

class WassersteinICA: 
    def __init__(self, X): 
        self.X = X 
        self.n = X.shape[1] 
        self.whitened = False 
        self.epsilon = 1e-7 
        # Cache for the analytical target (computed once) 
        self.analytical_target = None  

    def whiten(self): 
        """ 
        Whiten the data (zero mean, unit variance, uncorrelated). 
        """ 
        X_centered = self.X - torch.mean(self.X, dim=1, keepdim=True) 
        cov = torch.matmul(X_centered, X_centered.t()) / (self.n - 1) 
        D, E = torch.linalg.eigh(cov) 
        D_inv_sqrt = torch.diag(1.0 / torch.sqrt(D + 1e-5)) 
        self.W_white = torch.matmul(D_inv_sqrt, E.T) 
        self.X_white = torch.matmul(self.W_white, X_centered) 
        self.whitened = True 
        # Pre-compute the exact analytical Gaussian target 
        self.analytical_target = self._compute_analytical_target(self.n) 
      
    def _compute_analytical_target(self, n): 
        """ 
        Computes the EXACT 'Average Quantile' for each bin analytically. 
        Formula: Target_i = N * (pdf(z_{i-1}) - pdf(z_i)) 
        This replaces 'sampling' with exact calculus. 
        """ 
        p_edges = np.linspace(0, 1, n + 1) 
        z_edges = scipy.stats.norm.ppf(p_edges) 
        phi_edges = scipy.stats.norm.pdf(z_edges) 
          
        target_np = n * (phi_edges[:-1] - phi_edges[1:]) 
        return torch.tensor(target_np, dtype=torch.float32, device=self.X.device) 

    # ========================================== 
    # VECTORIZED: Core Distance Metric 
    # ========================================== 
    def wasserstein2_analytical(self, W, cost='l2', dither_sigma=0.0): 
        """ 
        Computes W distance.  
        Supports both single vectors (legacy) and matrices (batched parallel). 
        cost: 'l2' for standard Wasserstein, 'logcosh' for robust Huber-like geometry.
        dither_sigma: Injects continuous noise to smooth discrete CDF steps.
        """ 
        assert self.whitened, "Call whiten() before computing distance." 
          
        is_1d = W.dim() == 1 
        if is_1d: 
            W = W.unsqueeze(0) 
              
        Y = torch.mm(W, self.X_white)  
        
        # DITHERING: Inject continuous noise to break discrete ties and smooth the CDF
        if dither_sigma > 0:
            Y = Y + torch.randn_like(Y) * dither_sigma
            
        sorted_Y, _ = torch.sort(Y, dim=1)  
        diff = sorted_Y - self.analytical_target 
          
        if cost == 'l2':
            distances = torch.mean(diff ** 2, dim=1) 
        elif cost == 'logcosh':
            # Numerically stable logcosh to prevent NaN gradients on massive outliers
            abs_diff = torch.abs(diff)
            logcosh_diff = abs_diff + torch.log1p(torch.exp(-2.0 * abs_diff)) - np.log(2.0)
            distances = torch.mean(logcosh_diff, dim=1)
        else:
            raise ValueError("cost must be 'l2' or 'logcosh'")
          
        if is_1d: 
            return distances[0] 
        return distances

    # ========================================== 
    # VECTORIZED: Phase 1 (Deflation & Restarts) 
    # ========================================== 
    def optimize_wasserstein2(self, prev_components=None, grid_points=100, continuous=True, 
                              max_iter=200, lr=0.1, n_restarts=50, decay_rate=0.5, decay_step=50, cost='l2', dither_sigma=0.0): 
        """ 
        Find ONE maximizer of W distance (Deflationary). 
        """ 
        if continuous: 
            W_batch = torch.randn(n_restarts, self.X.shape[0], device=self.X.device) 
              
            if prev_components is not None and prev_components.shape[0] > 0: 
                proj = torch.matmul(W_batch, prev_components.t()) 
                W_batch = W_batch - torch.matmul(proj, prev_components) 
                  
            W_batch = W_batch / torch.norm(W_batch, dim=1, keepdim=True) 
            W_batch.requires_grad_(True) 
            current_lr = lr 
              
            for i in range(max_iter): 
                if (i + 1) % decay_step == 0: current_lr *= decay_rate 
                  
                # Pass the dither parameter down
                dist = self.wasserstein2_analytical(W_batch, cost=cost, dither_sigma=dither_sigma).sum() 
                  
                if W_batch.grad is not None: W_batch.grad.zero_() 
                dist.backward() 
                  
                with torch.no_grad(): 
                    grad = W_batch.grad 
                    if prev_components is not None and prev_components.shape[0] > 0: 
                        proj_grad = torch.matmul(grad, prev_components.t()) 
                        grad = grad - torch.matmul(proj_grad, prev_components) 
                          
                    dot_pw = torch.sum(grad * W_batch.data, dim=1, keepdim=True) 
                    grad = grad - dot_pw * W_batch.data  
                      
                    grad_norms = torch.norm(grad, dim=1, keepdim=True) 
                    grad = torch.where(grad_norms > 1.0, grad / grad_norms, grad) 

                    W_batch.data += current_lr * grad 
                      
                    if prev_components is not None and prev_components.shape[0] > 0: 
                        proj = torch.matmul(W_batch.data, prev_components.t()) 
                        W_batch.data = W_batch.data - torch.matmul(proj, prev_components) 
                          
                    W_batch.data /= torch.norm(W_batch.data, dim=1, keepdim=True) 
              
            with torch.no_grad(): 
                # Evaluate final best vector without noise to get the true mathematical distance
                final_distances = self.wasserstein2_analytical(W_batch, cost=cost, dither_sigma=0.0) 
                best_idx = torch.argmax(final_distances) 
                best_w = W_batch[best_idx].detach().clone() 
                best_dist = final_distances[best_idx].item() 
                  
            return best_w, best_dist 
              
        else: 
            # Legacy Discrete Grid Search
            angles = torch.linspace(0, 2 * np.pi, steps=grid_points, device=self.X.device) 
            candidates = torch.stack([torch.cos(angles), torch.sin(angles)], dim=1) 
            if prev_components is not None and prev_components.shape[0] > 0: 
                proj = torch.matmul(candidates, prev_components.t()) 
                candidates = candidates - torch.matmul(proj, prev_components) 
                norms = torch.norm(candidates, dim=1, keepdim=True) 
                mask = norms.squeeze() > 1e-6 
                candidates = candidates[mask] 
                if candidates.shape[0] == 0: raise ValueError("No valid candidates.") 
                candidates = candidates / norms[mask] 
              
            dist_best = -np.inf 
            w_best = None 
            for w in candidates: 
                d = self.wasserstein2_analytical(w, cost=cost, dither_sigma=dither_sigma).item() 
                if d > dist_best: 
                    dist_best = d 
                    w_best = w 
            return w_best, dist_best 

    def _symmetric_decorrelation(self, W): 
        M = torch.mm(W, W.t()) 
        evals, evecs = torch.linalg.eigh(M) 
        d_inv_sqrt = torch.diag(1.0 / torch.sqrt(evals + 1e-5)) 
        inv_sqrt_M = torch.mm(torch.mm(evecs, d_inv_sqrt), evecs.t()) 
        return torch.mm(inv_sqrt_M, W) 

    # ========================================== 
    # VECTORIZED: Phase 2 
    # ========================================== 
    def optimize_symmetric(self, n_components=None, max_iter=300, lr=1.0, init_w=None,  
                           optimizer='sgd', penalty_weight=10.0, use_sinkhorn=False,  
                           reg=0.01, sinkhorn_iter=50, cost='l2', dither_sigma=0.0, 
                           batch_size=512, n_restarts=None): # Added n_restarts to signature
        
        if n_components is None: n_components = self.X.shape[0] 

        # DYNAMIC RESTART LOGIC: Default to dims * 4, capped at 200
        if n_restarts is None:
            n_restarts = n_components * 4
            if n_restarts > 200:
                n_restarts = 200
            elif n_restarts < 20: # Floor to ensure basic exploration
                n_restarts = 20

        if init_w is not None: 
            W = init_w.clone().to(self.X.device) 
        else: 
            # If no initial W is provided, generate a batch of random restarts
            # using the dynamically calculated n_restarts
            W = torch.randn(n_restarts, n_components, self.X.shape[0], device=self.X.device) 
            # Note: You would likely apply decorrelation across the batch dimension here
            W = self._symmetric_decorrelation(W) 
            
        W.requires_grad_(True) 
          
        # Store originals to safely patch during stochastic batching
        original_X_white = self.X_white
        original_n = self.n
        original_target = self.analytical_target
          
        # ... [Rest of the optimization logic remains the same] ...
          
        if optimizer == 'sgd': 
            for i in range(max_iter): 
                if W.grad is not None: W.grad.zero_() 
                  
                if use_sinkhorn: 
                    total_dist = self.sinkhorn_distance(W, reg=reg, n_iter=sinkhorn_iter).sum() 
                else: 
                    total_dist = self.wasserstein2_analytical(W, cost=cost, dither_sigma=dither_sigma).sum() 
                  
                loss = -total_dist 
                loss.backward() 
                  
                with torch.no_grad(): 
                    grad = W.grad 
                    grad_norms = torch.norm(grad, dim=1, keepdim=True) 
                    grad = torch.where(grad_norms > 1.0, grad / grad_norms, grad) 
                      
                    W += lr * grad 
                    W.data = self._symmetric_decorrelation(W) 
                    W.requires_grad_(True)  

        elif optimizer == 'stiefel':
            current_lr = lr
            for i in range(max_iter):
                if W.grad is not None: W.grad.zero_()
                
                # STOCHASTIC BATCHING: Randomly slice data to inject gradient noise
                if batch_size is not None and batch_size < original_n:
                    indices = torch.randperm(original_n, device=self.X.device)[:batch_size]
                    self.X_white = original_X_white[:, indices]
                    self.n = batch_size
                    self.analytical_target = self._compute_analytical_target(self.n)
                
                if use_sinkhorn:
                    total_dist = self.sinkhorn_distance(W, reg=reg, n_iter=sinkhorn_iter).sum()
                else:
                    total_dist = self.wasserstein2_analytical(W, cost=cost, dither_sigma=dither_sigma).sum()
                
                total_dist.backward()
                
                with torch.no_grad():
                    grad = W.grad
                    
                    # Stiefel Projection
                    G_Wt = torch.mm(grad, W.data.t())
                    W_Gt = torch.mm(W.data, grad.t())
                    sym = 0.5 * (G_Wt + W_Gt)
                    tangent_grad = grad - torch.mm(sym, W.data)
                    
                    tangent_norms = torch.norm(tangent_grad, dim=1, keepdim=True)
                    tangent_grad = torch.where(tangent_norms > 1.0, tangent_grad / tangent_norms, tangent_grad)
                    
                    # Apply step with decaying learning rate to allow settling
                    W += current_lr * tangent_grad
                    W.data = self._symmetric_decorrelation(W.data)
                    W.requires_grad_(True)
                    
                # Decay learning rate by 1% each step
                current_lr *= 0.99 

        elif optimizer == 'lbfgs': 
            penalties = [penalty_weight, penalty_weight * 100, penalty_weight * 10000, penalty_weight * 1000000] 
            steps = max_iter // len(penalties) 
            if steps < 5: steps = 5 

            for p in penalties: 
                optim = torch.optim.LBFGS([W], lr=lr, max_iter=steps, history_size=50,  
                                          line_search_fn='strong_wolfe', tolerance_grad=1e-7, tolerance_change=1e-7) 
                def closure(): 
                    if W.grad is not None: W.grad.zero_() 
                      
                    if use_sinkhorn: 
                        total_dist = self.sinkhorn_distance(W, reg=reg, n_iter=sinkhorn_iter).sum() 
                    else: 
                        total_dist = self.wasserstein2_analytical(W, cost=cost, dither_sigma=dither_sigma).sum() 
                      
                    gram = torch.mm(W, W.t()) 
                    trace_gram = torch.trace(gram) 
                    trace_gram_sq = torch.trace(torch.mm(gram, gram)) 
                    ortho_penalty = trace_gram_sq - 2 * trace_gram + n_components 
                      
                    loss = -total_dist + (p * ortho_penalty) 
                    loss.backward() 
                    return loss 
                  
                try: optim.step(closure) 
                except RuntimeError: break 
              
            with torch.no_grad(): W.data = self._symmetric_decorrelation(W)  
        
        # Restore original full dataset after optimization loop finishes
        self.X_white = original_X_white
        self.n = original_n
        self.analytical_target = original_target
                  
        return W.detach()
      

    # ========================================== 
    # NEW: OT-Mapping Fixed-Point Rule 
    # ========================================== 
    def optimize_fixed_point(self, n_components=None, max_iter=100, tol=1e-5, init_w=None, step_size=0.5): 
        """ 
        Calculates the OT mapping to the perfect Gaussian, then steps AWAY from it. 
        Acts as Gradient Ascent on the Wasserstein landscape. 
        """ 
        assert self.whitened, "Call whiten() before optimization." 
        if n_components is None: n_components = self.X.shape[0] 

        # 1. Initialization 
        if init_w is not None: 
            W = init_w.clone().to(self.X.device) 
        else: 
            W = torch.randn(n_components, self.X.shape[0], device=self.X.device) 
          
        W = self._symmetric_decorrelation(W) 
          
        # We need target matrix T broadcasted to match dimensions (C x N) 
        T = self.analytical_target.unsqueeze(0).expand(n_components, -1) 
          
        for i in range(max_iter): 
            # Step 1: Project the data (Y = WX) 
            Y = torch.mm(W, self.X_white) 
              
            # Step 2: Find the ranking/sorting indices 
            idx = torch.argsort(Y, dim=1) 
              
            # Step 3: Create the "Ideal Target" (Y_ideal) 
            Y_ideal = torch.empty_like(Y) 
            Y_ideal.scatter_(1, idx, T) 
              
            # Step 4: The Gradient (Direction pointing INTO the Gaussian valley) 
            G = torch.mm(Y_ideal, self.X_white.t()) / (self.n - 1) 
              
            # Step 5: The Anti-Gaussian Step (Climbing the hill) 
            # We subtract G to step AWAY from the Gaussian 
            W_new = W - step_size * G 
              
            # Step 6: Symmetrically Orthogonalize W_new 
            W_new = self._symmetric_decorrelation(W_new) 
              
            # Step 7: Check for convergence 
            cos_theta = torch.abs(torch.diag(torch.mm(W_new, W.t()))) 
            min_cos = torch.min(cos_theta).item() 
              
            W = W_new 
              
            if (1.0 - min_cos) < tol: 
                break 
                  
        return W.detach() 

    # ========================================== 
    # LEGACY / BACKWARD COMPATIBILITY FUNCTIONS 
    # ========================================== 
    def _normal_quantile(self, q): 
        q_np = q.cpu().numpy() 
        inv_cdf = scipy.stats.norm.ppf(q_np) 
        return torch.tensor(inv_cdf, dtype=torch.float32, device=q.device) 
      
    def wasserstein2_distance(self, w): 
        assert self.whitened, "Call whiten() before computing distance." 
        y = torch.mv(self.X_white.t(), w) 
        sorted_y, _ = torch.sort(y) 
        steps = torch.arange(1, self.n + 1, dtype=torch.float32, device=self.X.device) 
        q = (steps - 0.5) / self.n 
        F_n_inv = self._normal_quantile(q) 
        return torch.mean((sorted_y - F_n_inv) ** 2) 

    def wasserstein1_distance(self, w): 
        assert self.whitened, "Call whiten() before computing distance." 
        y = torch.mv(self.X_white.t(), w) 
        sorted_y, _ = torch.sort(y) 
        steps = torch.arange(1, self.n + 1, dtype=torch.float32, device=self.X.device) 
        q = (steps - 0.5) / self.n 
        F_n_inv = self._normal_quantile(q) 
        return torch.mean(torch.abs(sorted_y - F_n_inv)) 

    def _wasserstein2_gradient_approx(self, w, delta=1e-5): 
        grad = torch.zeros_like(w) 
        base_val = self.wasserstein2_distance(w) 
        for i in range(len(w)): 
            w_perturb = w.clone() 
            w_perturb[i] += delta 
            w_perturb /= torch.norm(w_perturb) 
            val = self.wasserstein2_distance(w_perturb) 
            grad[i] = (val - base_val) / delta 
        return grad 

    def sinkhorn_distance(self, W, reg=0.01, n_iter=50): 
        """ 
        Batched Entropy-Regularized W2 distance (Sinkhorn) in Log-Space. 
        W shape: (n_components, n_dimensions) OR (n_dimensions,) 
        """ 
        assert self.whitened, "Call whiten() before computing distance." 
          
        is_1d = W.dim() == 1 
        if is_1d: 
            W = W.unsqueeze(0) 
              
        B = W.shape[0] # Batch size / Number of components 
          
        # 1. Project all data at once (Shape: B x N) 
        Y = torch.mm(W, self.X_white) 
          
        # 2. Target: Gaussian Quantiles (Shape: N) 
        steps = torch.arange(1, self.n + 1, dtype=torch.float32, device=self.X.device) 
        q = (steps - 0.5) / self.n 
        target = self._normal_quantile(q) 
          
        # 3. Batched Cost Matrix C: (B, N_y, N_target) 
        # Broadcasting: Y is (B, N, 1), target is (1, 1, N) 
        C = (Y.unsqueeze(2) - target.view(1, 1, self.n)) ** 2 
          
        # 4. Sinkhorn Iterations 
        f = torch.zeros(B, self.n, device=self.X.device) 
        g = torch.zeros(B, self.n, device=self.X.device) 
        log_mu = -torch.log(torch.tensor(self.n, dtype=torch.float32, device=self.X.device)) 
          
        for _ in range(n_iter): 
            # Update f: Sum over target dimension (dim=2) 
            f = reg * (log_mu - torch.logsumexp((g.unsqueeze(1) - C) / reg, dim=2)) 
            # Update g: Sum over Y dimension (dim=1) 
            g = reg * (log_mu - torch.logsumexp((f.unsqueeze(2) - C) / reg, dim=1)) 
              
        # 5. Calculate total cost for each batch element 
        # log_P shape: (B, N, N) 
        log_P = (f.unsqueeze(2) + g.unsqueeze(1) - C) / reg 
        distances = torch.sum(torch.exp(log_P) * C, dim=(1, 2)) 
          
        if is_1d: 
            return distances[0] 
        return distances 
      
    def optimize_symmetric_sinkhorn(self, n_components=None, max_iter=300, lr=1.0, init_w=None, reg=0.05): 
        return self.optimize_symmetric( 
            n_components=n_components,  
            max_iter=max_iter,  
            lr=lr,  
            init_w=init_w,  
            optimizer='lbfgs', 
            use_sinkhorn=True,  
            reg=reg             
        )