Optimal Transport in linear ICA
- Check whether the ICs extracted from constrained orthogonal space, after the first IC, are actually maximal space in the original unconstrained space
import numpy as np
import torch
import matplotlib.pyplot as plt
import scipy.stats
from wasserstein_ica import WassersteinICA
# --- 1. Setup Data (Same as before) ---
n_samples = 5000
np.random.seed(42)
torch.manual_seed(42)
s1 = np.random.laplace(0, 1, n_samples)
s2 = np.random.uniform(-np.sqrt(3), np.sqrt(3), n_samples)
s3 = np.random.standard_t(df=3, size=n_samples)
s4 = np.random.beta(0.5, 0.5, size=n_samples); s4 = (s4 - np.mean(s4))/np.std(s4)
S_true = np.stack([s1, s2, s3, s4])
n_sources = 4
A_true = np.random.randn(n_sources, n_sources)
X_mixed = A_true @ S_true
X_torch = torch.tensor(X_mixed, dtype=torch.float32)
# --- 2. Run Standard Constrained Extraction ---
ica = WassersteinICA(X_torch)
ica.whiten()
constrained_weights = []
print("--- Step 1: Constrained Extraction ---")
for i in range(n_sources):
# Constraint: Orthogonal to previous
prev = torch.stack(constrained_weights) if constrained_weights else None
w, dist = ica.optimize_wasserstein2(
prev_components=prev,
continuous=True,
max_iter=500,
lr=0.05
)
constrained_weights.append(w)
print(f"IC {i+1} Found. Score: {dist:.4f}")
--- Step 1: Constrained Extraction ---
IC 1 Found. Score: 0.2514
IC 2 Found. Score: 0.0175
IC 3 Found. Score: 0.0342
IC 4 Found. Score: 0.0418
# --- 3. Run Relaxation Test (Stability Check) ---
print("\n--- Step 2: Unconstrained Relaxation Test ---")
print(f"{'IC':<4} | {'Orig Score':<10} | {'New Score':<10} | {'Drift (Deg)':<10} | {'Status'}")
print("-" * 65)
for i, w_init in enumerate(constrained_weights):
# CLONE the weight to avoid modifying the original list
w = w_init.clone().detach()
w.requires_grad_(True)
# Run a custom optimization loop WITHOUT constraints (prev_components=None)
# We use a smaller LR to act as a "fine-tuning" or "check" step
optimizer = torch.optim.SGD([w], lr=0.01)
for _ in range(200):
optimizer.zero_grad()
loss = -ica.wasserstein2_distance(w) # Maximize Distance
loss.backward()
# Standard Gradient Ascent on Sphere (No Orthogonality Projection)
with torch.no_grad():
grad = w.grad
grad = grad - torch.dot(grad, w) * w # Tangent projection
w += 0.01 * grad
w /= torch.norm(w)
w.grad.zero_()
w_relaxed = w.detach()
# Measure Results
orig_score = ica.wasserstein2_distance(w_init).item()
new_score = ica.wasserstein2_distance(w_relaxed).item()
# Calculate Angle between Start and End
cos_sim = torch.dot(w_init, w_relaxed).item()
cos_sim = min(1.0, max(-1.0, cos_sim)) # Numerical stability
angle = np.degrees(np.arccos(abs(cos_sim))) # Abs for sign ambiguity
status = "STABLE" if angle < 5.0 else "DRIFTED"
print(f"{i+1:<4} | {orig_score:.4f} | {new_score:.4f} | {angle:.2f}° | {status}")
--- Step 2: Unconstrained Relaxation Test ---
IC | Orig Score | New Score | Drift (Deg) | Status
-----------------------------------------------------------------
1 | 0.2514 | 0.2514 | 0.00° | STABLE
2 | 0.0175 | 0.0166 | 2.44° | STABLE
3 | 0.0342 | 0.0335 | 2.02° | STABLE
4 | 0.0418 | 0.0383 | 4.56° | STABLE