ot-in-linear-ica / exp / other / coupling_diag.py
coupling_diag.py
Raw
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

def set_thesis_theme():
    plt.rcParams.update({
        'figure.figsize': (8, 4),
        'figure.dpi': 300,
        'axes.spines.top': False,
        'axes.spines.right': False,
        'axes.spines.left': False,
        'axes.spines.bottom': False,
        'xtick.bottom': False,
        'xtick.labelbottom': False,
        'ytick.left': False,
        'ytick.labelleft': False,
        'font.family': 'serif',
        'font.size': 12
    })
set_thesis_theme()

fig, ax = plt.subplots()

# Define the two distributions
x_mu = np.linspace(-4, 4, 500)
x_nu = np.linspace(4, 12, 500)

y_mu = norm.pdf(x_mu, 0, 1)
y_nu = norm.pdf(x_nu, 8, 1.2)

# Plot the distributions
ax.plot(x_mu, y_mu, color='#1f77b4', lw=2)
ax.fill_between(x_mu, 0, y_mu, color='#1f77b4', alpha=0.3)
ax.text(0, -0.05, r'Source Measure $\mu$', ha='center', fontsize=12)

ax.plot(x_nu, y_nu, color='#d62728', lw=2)
ax.fill_between(x_nu, 0, y_nu, color='#d62728', alpha=0.3)
ax.text(8, -0.05, r'Target Measure $\nu$', ha='center', fontsize=12)

# Draw the Transport Map arrow
arrow_start = 1.5
arrow_end = 6.5
ax.annotate('', xy=(arrow_end, 0.2), xytext=(arrow_start, 0.2),
            arrowprops=dict(arrowstyle="->", color="black", lw=2))

# Add the text and equation
ax.text(4, 0.22, 'Transport Map', ha='center', va='bottom', fontsize=12)
ax.text(4, 0.28, r'$\nu = T_{\#}\mu$', ha='center', va='bottom', fontsize=16)

ax.set_ylim(-0.1, 0.5)
ax.set_xlim(-5, 13)

plt.tight_layout()
plt.savefig('ot_pushforward_concept.png')
plt.show()