tensor-group-sym / lean / StarG / Equivariance.lean
Equivariance.lean
Raw
/-
  StarG/Equivariance.lean, Lean 4.29 / current Mathlib compatible
  0 sorry.
  fourier_power_invariant is proved via change of variable, is_hom,
  and orthogonal invariance of the sum of squares (from unitary).
-/
import StarG.Basic
import StarG.Algebra
import StarG.SVD

set_option linter.unusedSectionVars false
set_option linter.unusedSimpArgs false
set_option linter.unusedVariables false

open Finset BigOperators Matrix

noncomputable section

namespace GroupTensor

variable {G : Type*} [Group G] [Fintype G] [DecidableEq G]
variable {ℓ m p : ℕ}

def leftAction (g : G) (A : GroupTensor G ℓ m) : GroupTensor G ℓ m :=
  fun h => A (g⁻¹ * h)

theorem leftAction_one (A : GroupTensor G ℓ m) :
    leftAction (1 : G) A = A := by
  funext h i j; simp [leftAction]

theorem leftAction_mul (g₁ g₂ : G) (A : GroupTensor G ℓ m) :
    leftAction g₁ (leftAction g₂ A) = leftAction (g₁ * g₂) A := by
  funext h i j; simp only [leftAction]; congr 1; group

theorem starG_equivariant_left
    (g : G) (A : GroupTensor G ℓ m) (B : GroupTensor G m p) :
    leftAction g A ⋆ B = leftAction g (A ⋆ B) := by
  funext c i j; simp only [leftAction, starG]
  apply Finset.sum_congr rfl; intro k _
  let e : G ≃ G := ⟨(g⁻¹ * ·), (g * ·), fun a => by simp, fun a => by simp⟩
  exact Fintype.sum_equiv e _ _ (fun a => by dsimp [e]; congr 2; group)

end GroupTensor

-- Stated outside the section to avoid [Group G] / [CommGroup G] diamond
theorem GroupTensor.starG_equivariant_right
    {G : Type*} [CommGroup G] [Fintype G] [DecidableEq G]
    {ℓ m p : ℕ}
    (g : G) (A : GroupTensor G ℓ m) (B : GroupTensor G m p) :
    A ⋆ GroupTensor.leftAction g B = GroupTensor.leftAction g (A ⋆ B) := by
  funext c i j; simp only [GroupTensor.leftAction, GroupTensor.starG]
  apply Finset.sum_congr rfl; intro k _
  apply Finset.sum_congr rfl; intro a _
  have h : g⁻¹ * (a⁻¹ * c) = a⁻¹ * (g⁻¹ * c) := by
    rw [← mul_assoc, ← mul_assoc, mul_comm g⁻¹ a⁻¹]
  simp [h]

namespace GroupTensor

variable {G : Type*} [Group G] [Fintype G] [DecidableEq G]
variable {ℓ m p : ℕ}


theorem frobNormSq_invariant (g : G) (A : GroupTensor G ℓ m) :
    frobNormSq (leftAction g A) = frobNormSq A := by
  simp only [frobNormSq, leftAction]
  let e : G ≃ G := ⟨(g⁻¹ * ·), (g * ·), fun a => by simp, fun a => by simp⟩
  exact Fintype.sum_equiv e _ _ (fun h => rfl)

theorem dc_component_invariant (g : G) (A : GroupTensor G ℓ m)
    (i : Fin ℓ) (j : Fin m) :
    ∑ h : G, leftAction g A h i j = ∑ h : G, A h i j := by
  simp only [leftAction]
  let e : G ≃ G := ⟨(g⁻¹ * ·), (g * ·), fun a => by simp, fun a => by simp⟩
  exact Fintype.sum_equiv e _ _ (fun h => rfl)

/-- Left-multiplication by an orthogonal matrix preserves the sum of squares.
    This is the core lemma for Fourier power invariance: if Rᵀ R = I, then
    ∑ s, (∑ u, R s u * v u)² = ∑ u, (v u)². -/
private theorem orthog_preserves_sq_sum {d : ℕ} (R : Matrix (Fin d) (Fin d) ℝ)
    (hR : Rᵀ * R = 1) (v : Fin d → ℝ) :
    ∑ s : Fin d, (∑ u : Fin d, R s u * v u) ^ 2 = ∑ u : Fin d, (v u) ^ 2 := by
  -- Extract the entry-wise orthogonality condition: ∑ s, R s u * R s w = δ(u,w)
  have orth : ∀ u w : Fin d, ∑ s : Fin d, R s u * R s w =
      if u = w then 1 else 0 := by
    intro u w
    have := congr_fun (congr_fun hR u) w
    simp only [Matrix.mul_apply, Matrix.transpose_apply, Matrix.one_apply] at this
    exact this
  -- Expand ² = · * · , then distribute both sums
  calc ∑ s, (∑ u, R s u * v u) ^ 2
      = ∑ s, (∑ u, R s u * v u) * (∑ w, R s w * v w) := by
        simp_rw [sq]
    _ = ∑ s, ∑ u, ∑ w, (R s u * R s w) * (v u * v w) := by
        congr 1; ext s; rw [Finset.sum_mul_sum]
        congr 1; ext u; congr 1; ext w; ring
    -- Exchange summation order: bring ∑ s innermost
    _ = ∑ u, ∑ w, (∑ s, R s u * R s w) * (v u * v w) := by
        rw [Finset.sum_comm]; congr 1; ext u
        rw [Finset.sum_comm]; congr 1; ext w
        rw [← Finset.sum_mul]
    -- Apply orthogonality
    _ = ∑ u, ∑ w, (if u = w then 1 else 0) * (v u * v w) := by
        congr 1; ext u; congr 1; ext w; rw [orth]
    -- Collapse: only the u = w term survives
    _ = ∑ u, (v u) ^ 2 := by
        congr 1; ext u
        rw [Fintype.sum_eq_single u]
        · simp [sq]
        · intro w hw; simp [Ne.symm hw]

/-- Per-irrep Fourier power is invariant under the group action.
    The proof proceeds by: (1) change of variable h → g*h in the Fourier sum,
    (2) applying is_hom to factor ρ(g*h) = ρ(g)*ρ(h), and
    (3) using orthogonality of ρ(g) (from unitary) to show the squared
    Frobenius norm is preserved. -/
theorem fourier_power_invariant
    (g : G) (A : GroupTensor G ℓ m) (ρ : Irrep G) :
    fourierBlockNormSq (leftAction g A) ρ = fourierBlockNormSq A ρ := by
  simp only [fourierBlockNormSq, matFrobSq]
  -- Step 1: Establish the key per-entry relation:
  --   F_g(i,s)(j,t) = ∑ u, ρ(g)(s,u) * F(i,u)(j,t)
  -- where F = fourierBlock A ρ and F_g = fourierBlock (leftAction g A) ρ.
  have fb_rel : ∀ (i : Fin ℓ) (s : Fin ρ.dim) (j : Fin m) (t : Fin ρ.dim),
      fourierBlock (leftAction g A) ρ ⟨i, s⟩ ⟨j, t⟩ =
      ∑ u : Fin ρ.dim, ρ.ρ g s u * fourierBlock A ρ ⟨i, u⟩ ⟨j, t⟩ := by
    intro i s j t
    simp only [fourierBlock, leftAction]
    -- Change of variable h → g * h' (equiv: h' = g⁻¹ * h)
    rw [show (∑ h : G, A (g⁻¹ * h) i j * ρ.ρ h s t) =
        ∑ h' : G, A h' i j * ρ.ρ (g * h') s t from
      Fintype.sum_equiv
        ⟨(g⁻¹ * ·), (g * ·), fun a => by simp, fun a => by simp⟩
        _ _ (fun h => by dsimp; congr 1; group)]
    -- Apply is_hom: ρ(g * h') = ρ(g) * ρ(h')
    simp_rw [ρ.is_hom, Matrix.mul_apply, Finset.mul_sum]
    rw [Finset.sum_comm]
    congr 1; ext u; congr 1; ext h'; ring
  -- Step 2: Rewrite the LHS using the per-entry relation
  simp_rw [fb_rel]
  -- Goal: ∑ (i,s), ∑ (j,t), (∑ u, ρ.ρ g s u * F(i,u)(j,t))² =
  --       ∑ (i,s), ∑ (j,t), (F(i,s)(j,t))²
  -- Step 3: Convert sums over product types to nested sums
  simp_rw [Fintype.sum_prod_type]
  -- Goal: ∑ i, ∑ s, ∑ j, ∑ t, (∑ u, ρ.ρ g s u * F(i,u)(j,t))² =
  --       ∑ i, ∑ s, ∑ j, ∑ t, (F(i,s)(j,t))²
  -- Step 4: For each fixed (i, j, t), apply orthogonal invariance over s.
  -- In Lean 4.30 `congr 1; ext x` does not reliably descend into `Finset.sum`
  -- bodies; we use `Finset.sum_comm` to push `∑ s` to the innermost position
  -- and then `simp_rw` with the orthogonality lemma to discharge the goal
  -- pointwise.
  rw [show (∑ i : Fin ℓ, ∑ s : Fin ρ.dim, ∑ j : Fin m, ∑ t : Fin ρ.dim,
            (∑ u : Fin ρ.dim, ρ.ρ g s u * fourierBlock A ρ ⟨i, u⟩ ⟨j, t⟩)^2) =
         (∑ i : Fin ℓ, ∑ j : Fin m, ∑ t : Fin ρ.dim, ∑ s : Fin ρ.dim,
            (∑ u : Fin ρ.dim, ρ.ρ g s u * fourierBlock A ρ ⟨i, u⟩ ⟨j, t⟩)^2) from by
    refine Finset.sum_congr rfl ?_
    intro i _
    rw [Finset.sum_comm]
    refine Finset.sum_congr rfl ?_
    intro j _
    rw [Finset.sum_comm]]
  rw [show (∑ i : Fin ℓ, ∑ s : Fin ρ.dim, ∑ j : Fin m, ∑ t : Fin ρ.dim,
            (fourierBlock A ρ ⟨i, s⟩ ⟨j, t⟩)^2) =
         (∑ i : Fin ℓ, ∑ j : Fin m, ∑ t : Fin ρ.dim, ∑ s : Fin ρ.dim,
            (fourierBlock A ρ ⟨i, s⟩ ⟨j, t⟩)^2) from by
    refine Finset.sum_congr rfl ?_
    intro i _
    rw [Finset.sum_comm]
    refine Finset.sum_congr rfl ?_
    intro j _
    rw [Finset.sum_comm]]
  -- Both sides now have shape ∑ i, ∑ j, ∑ t, ∑ s, F²; the inner ∑ s claim
  -- is exactly orthog_preserves_sq_sum applied pointwise.
  refine Finset.sum_congr rfl ?_; intro i _
  refine Finset.sum_congr rfl ?_; intro j _
  refine Finset.sum_congr rfl ?_; intro t _
  exact orthog_preserves_sq_sum (ρ.ρ g) (ρ.unitary g)
    (fun u => fourierBlock A ρ ⟨i, u⟩ ⟨j, t⟩)

end GroupTensor
end