{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "edf6f0d4",
   "metadata": {},
   "source": [
    "# Optimal Transport in linear ICA\n",
    "\n",
    "- Check whether the ICs extracted from constrained orthogonal space, after the first IC, are actually maximal space in the original unconstrained space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "75bcf076",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import scipy.stats\n",
    "from wasserstein_ica import WassersteinICA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b978bfe0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- 1. Setup Data (Same as before) ---\n",
    "n_samples = 5000\n",
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "\n",
    "s1 = np.random.laplace(0, 1, n_samples)\n",
    "s2 = np.random.uniform(-np.sqrt(3), np.sqrt(3), n_samples)\n",
    "s3 = np.random.standard_t(df=3, size=n_samples)\n",
    "s4 = np.random.beta(0.5, 0.5, size=n_samples); s4 = (s4 - np.mean(s4))/np.std(s4)\n",
    "\n",
    "S_true = np.stack([s1, s2, s3, s4])\n",
    "n_sources = 4\n",
    "A_true = np.random.randn(n_sources, n_sources)\n",
    "X_mixed = A_true @ S_true\n",
    "X_torch = torch.tensor(X_mixed, dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3ee4ec55",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Step 1: Constrained Extraction ---\n",
      "IC 1 Found. Score: 0.2514\n",
      "IC 2 Found. Score: 0.0175\n",
      "IC 3 Found. Score: 0.0342\n",
      "IC 4 Found. Score: 0.0418\n"
     ]
    }
   ],
   "source": [
    "# --- 2. Run Standard Constrained Extraction ---\n",
    "ica = WassersteinICA(X_torch)\n",
    "ica.whiten()\n",
    "\n",
    "constrained_weights = []\n",
    "print(\"--- Step 1: Constrained Extraction ---\")\n",
    "for i in range(n_sources):\n",
    "    # Constraint: Orthogonal to previous\n",
    "    prev = torch.stack(constrained_weights) if constrained_weights else None\n",
    "    \n",
    "    w, dist = ica.optimize_wasserstein2(\n",
    "        prev_components=prev, \n",
    "        continuous=True, \n",
    "        max_iter=500, \n",
    "        lr=0.05\n",
    "    )\n",
    "    constrained_weights.append(w)\n",
    "    print(f\"IC {i+1} Found. Score: {dist:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ff3d4008",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "--- Step 2: Unconstrained Relaxation Test ---\n",
      "IC   | Orig Score | New Score  | Drift (Deg) | Status\n",
      "-----------------------------------------------------------------\n",
      "1    | 0.2514     | 0.2514     | 0.00°       | STABLE\n",
      "2    | 0.0175     | 0.0166     | 2.44°       | STABLE\n",
      "3    | 0.0342     | 0.0335     | 2.02°       | STABLE\n",
      "4    | 0.0418     | 0.0383     | 4.56°       | STABLE\n"
     ]
    }
   ],
   "source": [
    "# --- 3. Run Relaxation Test (Stability Check) ---\n",
    "print(\"\\n--- Step 2: Unconstrained Relaxation Test ---\")\n",
    "print(f\"{'IC':<4} | {'Orig Score':<10} | {'New Score':<10} | {'Drift (Deg)':<10} | {'Status'}\")\n",
    "print(\"-\" * 65)\n",
    "\n",
    "for i, w_init in enumerate(constrained_weights):\n",
    "    # CLONE the weight to avoid modifying the original list\n",
    "    w = w_init.clone().detach()\n",
    "    w.requires_grad_(True)\n",
    "    \n",
    "    # Run a custom optimization loop WITHOUT constraints (prev_components=None)\n",
    "    # We use a smaller LR to act as a \"fine-tuning\" or \"check\" step\n",
    "    optimizer = torch.optim.SGD([w], lr=0.01)\n",
    "    \n",
    "    for _ in range(200):\n",
    "        optimizer.zero_grad()\n",
    "        loss = -ica.wasserstein2_distance(w) # Maximize Distance\n",
    "        loss.backward()\n",
    "        \n",
    "        # Standard Gradient Ascent on Sphere (No Orthogonality Projection)\n",
    "        with torch.no_grad():\n",
    "            grad = w.grad\n",
    "            grad = grad - torch.dot(grad, w) * w # Tangent projection\n",
    "            w += 0.01 * grad\n",
    "            w /= torch.norm(w)\n",
    "            w.grad.zero_()\n",
    "            \n",
    "    w_relaxed = w.detach()\n",
    "    \n",
    "    # Measure Results\n",
    "    orig_score = ica.wasserstein2_distance(w_init).item()\n",
    "    new_score = ica.wasserstein2_distance(w_relaxed).item()\n",
    "    \n",
    "    # Calculate Angle between Start and End\n",
    "    cos_sim = torch.dot(w_init, w_relaxed).item()\n",
    "    cos_sim = min(1.0, max(-1.0, cos_sim)) # Numerical stability\n",
    "    angle = np.degrees(np.arccos(abs(cos_sim))) # Abs for sign ambiguity\n",
    "    \n",
    "    status = \"STABLE\" if angle < 5.0 else \"DRIFTED\"\n",
    "    print(f\"{i+1:<4} | {orig_score:.4f}     | {new_score:.4f}     | {angle:.2f}°       | {status}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04cc7771",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ot_ica",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
