tensor-group-sym / python / large_scale / starg_torch / octahedral.py
octahedral.py
Raw
"""Chiral octahedral group O (order 24) and its 5 irreps.

This is used in the Wigner-Eckart experiment to attribute predictive power
to angular momentum channels (l = 0, 0, 2, 1, 2 for A_1, A_2, E, T_1, T_2).
The 24 rotation matrices are built explicitly: identity, 6 face rotations
(±90°, 180° about coordinate axes), 8 vertex rotations (±120° about body
diagonals), 6 edge rotations (180° about edge midpoints).
"""

from __future__ import annotations

from typing import List, Tuple

import numpy as np


def _axis_angle(axis: np.ndarray, angle: float) -> np.ndarray:
    """Rodrigues rotation matrix."""
    axis = axis / np.linalg.norm(axis)
    K = np.array([[0, -axis[2], axis[1]],
                  [axis[2], 0, -axis[0]],
                  [-axis[1], axis[0], 0]])
    return np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * (K @ K)


def octahedral_rotations() -> List[np.ndarray]:
    R = [np.eye(3)]
    # Face rotations: ±90°, 180° about each coordinate axis (9 total)
    for axis in [np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1])]:
        for angle in [np.pi / 2, np.pi, -np.pi / 2]:
            R.append(_axis_angle(axis.astype(float), angle))
    # Vertex rotations: ±120° about each body diagonal (8 total)
    diags = [np.array([1, 1, 1]), np.array([1, 1, -1]),
             np.array([1, -1, 1]), np.array([-1, 1, 1])]
    for axis in diags:
        for angle in [2 * np.pi / 3, -2 * np.pi / 3]:
            R.append(_axis_angle(axis.astype(float), angle))
    # Edge rotations: 180° about each edge midpoint (6 total)
    edges = [np.array([1, 1, 0]), np.array([1, -1, 0]),
             np.array([1, 0, 1]), np.array([1, 0, -1]),
             np.array([0, 1, 1]), np.array([0, 1, -1])]
    for axis in edges:
        R.append(_axis_angle(axis.astype(float), np.pi))
    assert len(R) == 24, f"expected 24 rotations, got {len(R)}"
    return R


def _round_to_int(M: np.ndarray, tol: float = 1e-6) -> np.ndarray:
    Mr = np.round(M)
    if np.max(np.abs(M - Mr)) < tol:
        return Mr.astype(int)
    return M


def octahedral_group() -> Tuple[np.ndarray, List[np.ndarray]]:
    """Return (multiplication_table, list_of_24_rotation_matrices)."""
    R = octahedral_rotations()
    R_int = [_round_to_int(M) for M in R]
    n = 24
    T = np.zeros((n, n), dtype=int)
    for i in range(n):
        for j in range(n):
            P = R[i] @ R[j]
            # find which element this is
            best = -1
            for k in range(n):
                if np.allclose(P, R[k], atol=1e-6):
                    best = k
                    break
            if best < 0:
                raise RuntimeError(f"product R[{i}] R[{j}] not in group")
            T[i, j] = best
    return T, R


def octahedral_irreps():
    """Construct the 5 irreps of O and assemble F_G.

    A_1: trivial (1-d, l=0)
    A_2: sign of permutation / determinant of orientation (1-d, l=0)
    E:   2-d (l=2 component)
    T_1: 3-d (l=1, the rotation matrices themselves)
    T_2: 3-d (l=2, traceless symmetric tensor part)

    Returns (F_G, [d_rho]) with F_G of shape (24, 24): each row g is the
    concatenated row-vectorization of [ρ_1(g), ρ_2(g), ..., ρ_5(g)].
    """
    R = octahedral_rotations()
    n = 24
    # A_1: trivial
    A1 = np.ones((n,))
    # A_2: distinguishes proper rotations conjugate to S_4 vs the rest
    # In chiral octahedral, A_2 ≡ A_1 because all elements have det +1.
    # To get a true second 1-d irrep we use the sign of the permutation
    # representation on the four body diagonals.
    diag_action = []
    diags = np.array([[1, 1, 1], [1, 1, -1], [1, -1, 1], [-1, 1, 1]])
    for Rg in R:
        permuted = (Rg @ diags.T).T
        idx = np.zeros(4, dtype=int)
        for i, p in enumerate(permuted):
            for j, d in enumerate(diags):
                if np.allclose(p, d, atol=1e-6) or np.allclose(p, -d, atol=1e-6):
                    idx[i] = j
                    break
        diag_action.append(idx)
    A2 = np.array([_perm_sign(p) for p in diag_action], dtype=float)

    # T_1: the rotation matrices (3-d)
    T1 = [R[g] for g in range(n)]

    # E and T_2: from the 5-d symmetric traceless rank-2 representation
    # acting on Sym^2(R^3) / trace, decomposed into 2-d (E) + 3-d (T_2)
    # We build the 5-d representation, then orthogonally split via projection.
    Sym2 = [_sym2_rep(Rg) for Rg in R]  # each (5, 5)
    # Use eigenstructure of the operator (1/|G|) Σ ρ(g) ρ_test(g)^T to
    # find the 2-d invariant subspace (E), for the symmetric traceless
    # representation, the standard split gives: E spanned by
    # diag(1, -1, 0)/√2 and (2 zz - xx - yy)/√6; T_2 spanned by xy, xz, yz.
    E_basis = np.array([
        [1.0, -1.0, 0.0, 0.0, 0.0],          # xx - yy
        [-1.0, -1.0, 2.0, 0.0, 0.0] / np.sqrt(3.0),  # 2 zz - xx - yy (normalized)
    ]) / np.sqrt(2.0)
    T2_basis = np.array([
        [0.0, 0.0, 0.0, 1.0, 0.0],   # xy
        [0.0, 0.0, 0.0, 0.0, 1.0],   # xz
        # yz is implicit; for the chiral octahedral group the basis is
        # (xy, xz, yz). We'll add yz below if available.
    ])
    # Sym^2 ordering: (xx, yy, zz, xy, xz, yz), but we used 5 because trace-free.
    # For simplicity we use the explicit 5-d basis (xx-yy, 2zz-xx-yy, xy, xz, yz).
    # Reconstruct projections.
    full_basis_5d = np.array([
        [1, -1, 0, 0, 0, 0],                       # xx-yy
        [-1, -1, 2, 0, 0, 0],                      # 2zz-xx-yy
        [0, 0, 0, 1, 0, 0],                        # xy
        [0, 0, 0, 0, 1, 0],                        # xz
        [0, 0, 0, 0, 0, 1],                        # yz
    ], dtype=float)
    # Normalize
    norms = np.linalg.norm(full_basis_5d, axis=1, keepdims=True)
    full_basis_5d = full_basis_5d / norms

    # Build 6-d symmetric tensor representation, then project to 5-d traceless
    Sym2_6 = [_sym2_rep_full(Rg) for Rg in R]  # each (6, 6)
    Sym2_5 = [full_basis_5d @ M @ full_basis_5d.T for M in Sym2_6]

    # E corresponds to first two basis elements; T_2 to the last three
    E = [M[:2, :2] for M in Sym2_5]
    T2 = [M[2:, 2:] for M in Sym2_5]

    # Assemble F_G row-by-row
    rows = []
    for g in range(n):
        row = []
        row.append(A1[g])
        row.append(A2[g])
        row.extend(E[g].flatten().tolist())
        row.extend(T1[g].flatten().tolist())
        row.extend(T2[g].flatten().tolist())
        rows.append(row)
    F = np.array(rows, dtype=complex) / np.sqrt(n)
    irrep_dims = [1, 1, 2, 3, 3]
    assert F.shape == (24, 24)
    return F, irrep_dims


def _perm_sign(perm: np.ndarray) -> int:
    """Sign of a permutation given as an array of target indices."""
    inversions = 0
    for i in range(len(perm)):
        for j in range(i + 1, len(perm)):
            if perm[i] > perm[j]:
                inversions += 1
    return 1 if inversions % 2 == 0 else -1


def _sym2_rep(R: np.ndarray) -> np.ndarray:
    """5-d symmetric traceless representation acting on Sym^2(R^3) / trace."""
    Mfull = _sym2_rep_full(R)
    full_basis_5d = np.array([
        [1, -1, 0, 0, 0, 0],
        [-1, -1, 2, 0, 0, 0],
        [0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 1],
    ], dtype=float)
    full_basis_5d = full_basis_5d / np.linalg.norm(full_basis_5d, axis=1, keepdims=True)
    return full_basis_5d @ Mfull @ full_basis_5d.T


def _sym2_rep_full(R: np.ndarray) -> np.ndarray:
    """6-d symmetric tensor representation (xx, yy, zz, xy, xz, yz)."""
    M = np.zeros((6, 6))
    pairs = [(0, 0), (1, 1), (2, 2), (0, 1), (0, 2), (1, 2)]
    for i, (a, b) in enumerate(pairs):
        for j, (c, d) in enumerate(pairs):
            if c == d:
                M[i, j] = R[a, c] * R[b, d]
            else:
                M[i, j] = R[a, c] * R[b, d] + R[a, d] * R[b, c]
    return M