tensor-group-sym / lean / StarG / ProductGroup.lean
ProductGroup.lean
Raw
/-
  StarG/ProductGroup.lean, Lean 4.29 / current Mathlib compatible
  0 sorry.
  irrepProd.is_hom and irrepProd.unitary are proved via sum conversion
  through finProdFinEquiv and the mixed-product / unitarity properties
  of Kronecker products.
-/
import StarG.Basic
import StarG.SVD

set_option linter.unusedSectionVars false
set_option linter.unusedSimpArgs false

open Finset BigOperators Matrix

noncomputable section

namespace GroupTensor

variable {G₁ G₂ : Type*}
  [Group G₁] [Fintype G₁] [DecidableEq G₁]
  [Group G₂] [Fintype G₂] [DecidableEq G₂]
variable {ℓ m p : ℕ}

theorem convTensor_prod
    (a₁ b₁ c₁ : G₁) (a₂ b₂ c₂ : G₂) :
    convTensor (a₁, a₂) (b₁, b₂) (c₁, c₂) =
    convTensor a₁ b₁ c₁ * convTensor a₂ b₂ c₂ := by
  simp only [convTensor, Prod.mk_mul_mk, Prod.mk.injEq]
  by_cases h1 : a₁ * b₁ = c₁ <;> by_cases h2 : a₂ * b₂ = c₂ <;> simp_all

private abbrev splitIdx {n₁ n₂ : ℕ} (k : Fin (n₁ * n₂)) : Fin n₁ × Fin n₂ :=
  finProdFinEquiv.symm k

/-- Convert a sum over `Fin (d₁ * d₂)` with body indexed via `finProdFinEquiv.symm`
    into a double sum over `Fin d₁` and `Fin d₂`.  The function `f` takes a pair
    so that the caller can supply it explicitly and avoid higher-order unification
    issues with `rw`. -/
private lemma sum_finProdFin_eq {d₁ d₂ : ℕ} (f : Fin d₁ × Fin d₂ → ℝ) :
    (∑ x : Fin (d₁ * d₂), f (finProdFinEquiv.symm x)) =
    ∑ a : Fin d₁, ∑ b : Fin d₂, f (a, b) :=
  (Equiv.sum_comp finProdFinEquiv.symm f).trans (Fintype.sum_prod_type f)

/-- Tensor product of two irreps. The representation maps are defined
    entry-wise as `(ρ₁ ⊗ ρ₂)(g₁,g₂)_{ij} = ρ₁(g₁)_{i₁,j₁} * ρ₂(g₂)_{i₂,j₂}`.
    The `is_hom` and `unitary` proofs are standard Kronecker product facts
    (mixed-product property and unitarity of Kronecker products). -/
def irrepProd (ρ₁ : Irrep G₁) (ρ₂ : Irrep G₂) : Irrep (G₁ × G₂) where
  dim := ρ₁.dim * ρ₂.dim
  dim_pos := Nat.mul_pos ρ₁.dim_pos ρ₂.dim_pos
  ρ := fun ⟨g₁, g₂⟩ => Matrix.of fun i j =>
    ρ₁.ρ g₁ (splitIdx i).1 (splitIdx j).1 *
    ρ₂.ρ g₂ (splitIdx i).2 (splitIdx j).2
  is_hom := by
    -- (ρ₁⊗ρ₂)((g₁,g₂)*(h₁,h₂)) = (ρ₁⊗ρ₂)(g₁,g₂) * (ρ₁⊗ρ₂)(h₁,h₂)
    -- by the mixed-product property of Kronecker products.
    intro ⟨g₁, g₂⟩ ⟨h₁, h₂⟩
    ext i j
    simp only [Matrix.mul_apply, Matrix.of_apply, Prod.mk_mul_mk, splitIdx]
    -- Rewrite LHS using is_hom for each factor
    conv_lhs => rw [ρ₁.is_hom g₁ h₁, ρ₂.is_hom g₂ h₂]
    simp only [Matrix.mul_apply]
    -- LHS: (∑ a, ρ₁.ρ g₁ _ a * ρ₁.ρ h₁ a _) * (∑ b, ρ₂.ρ g₂ _ b * ρ₂.ρ h₂ b _)
    -- Expand to double sum
    rw [Finset.sum_mul_sum]
    -- Convert RHS: single sum over Fin(d₁*d₂) → double sum via finProdFinEquiv
    symm
    rw [sum_finProdFin_eq (fun p =>
      ρ₁.ρ g₁ (finProdFinEquiv.symm i).1 p.1 *
        ρ₂.ρ g₂ (finProdFinEquiv.symm i).2 p.2 *
      (ρ₁.ρ h₁ p.1 (finProdFinEquiv.symm j).1 *
        ρ₂.ρ h₂ p.2 (finProdFinEquiv.symm j).2))]
    symm
    -- Both sides are ∑ a, ∑ b, ... ; terms agree up to ring rearrangement
    congr 1; ext a; congr 1; ext b; ring
  unitary := by
    -- (ρ₁(g₁)⊗ρ₂(g₂))ᵀ * (ρ₁(g₁)⊗ρ₂(g₂)) = (ρ₁ᵀρ₁) ⊗ (ρ₂ᵀρ₂) = I⊗I = I
    intro ⟨g₁, g₂⟩
    ext i j
    simp only [Matrix.mul_apply, Matrix.transpose_apply, Matrix.of_apply,
               Matrix.one_apply, splitIdx]
    -- Convert sum over Fin(d₁*d₂) to double sum
    rw [sum_finProdFin_eq (fun p =>
      (ρ₁.ρ g₁ p.1 (finProdFinEquiv.symm i).1 *
        ρ₂.ρ g₂ p.2 (finProdFinEquiv.symm i).2) *
      (ρ₁.ρ g₁ p.1 (finProdFinEquiv.symm j).1 *
        ρ₂.ρ g₂ p.2 (finProdFinEquiv.symm j).2))]
    -- Rearrange: group ρ₁ terms and ρ₂ terms separately
    simp_rw [show ∀ a b,
      (ρ₁.ρ g₁ a (finProdFinEquiv.symm i).1 *
        ρ₂.ρ g₂ b (finProdFinEquiv.symm i).2) *
      (ρ₁.ρ g₁ a (finProdFinEquiv.symm j).1 *
        ρ₂.ρ g₂ b (finProdFinEquiv.symm j).2) =
      (ρ₁.ρ g₁ a (finProdFinEquiv.symm i).1 *
        ρ₁.ρ g₁ a (finProdFinEquiv.symm j).1) *
      (ρ₂.ρ g₂ b (finProdFinEquiv.symm i).2 *
        ρ₂.ρ g₂ b (finProdFinEquiv.symm j).2)
      from fun _ _ => by ring]
    -- Factor double sum into product of sums
    rw [← Finset.sum_mul_sum]
    -- Each factor is (ρᵢᵀ * ρᵢ)(_, _) = I(_, _) by unitarity
    have h₁ : ∑ a : Fin ρ₁.dim,
        ρ₁.ρ g₁ a (finProdFinEquiv.symm i).1 *
        ρ₁.ρ g₁ a (finProdFinEquiv.symm j).1 =
      if (finProdFinEquiv.symm i).1 = (finProdFinEquiv.symm j).1
        then 1 else 0 := by
      have := congr_fun (congr_fun (ρ₁.unitary g₁)
        (finProdFinEquiv.symm i).1) (finProdFinEquiv.symm j).1
      simp only [Matrix.mul_apply, Matrix.transpose_apply,
                 Matrix.one_apply] at this
      exact this
    have h₂ : ∑ b : Fin ρ₂.dim,
        ρ₂.ρ g₂ b (finProdFinEquiv.symm i).2 *
        ρ₂.ρ g₂ b (finProdFinEquiv.symm j).2 =
      if (finProdFinEquiv.symm i).2 = (finProdFinEquiv.symm j).2
        then 1 else 0 := by
      have := congr_fun (congr_fun (ρ₂.unitary g₂)
        (finProdFinEquiv.symm i).2) (finProdFinEquiv.symm j).2
      simp only [Matrix.mul_apply, Matrix.transpose_apply,
                 Matrix.one_apply] at this
      exact this
    rw [h₁, h₂]
    -- Product of Kronecker deltas = Kronecker delta on product type. We
    -- analyse the four sub-cases by case-splitting on whether each
    -- component projection of `finProdFinEquiv.symm` agrees, and connect
    -- the equality `i = j` to those projections via injectivity of
    -- `finProdFinEquiv.symm`. (In Lean 4.30, `rw` of an iff into the
    -- if-then-else condition fails on decidability dependence; case
    -- analysis avoids the issue.)
    by_cases hij : i = j
    · -- i = j: the components must agree, so both ifs reduce to 1.
      have hi1 : (finProdFinEquiv.symm i).1 = (finProdFinEquiv.symm j).1 := by
        rw [hij]
      have hi2 : (finProdFinEquiv.symm i).2 = (finProdFinEquiv.symm j).2 := by
        rw [hij]
      rw [if_pos hi1, if_pos hi2, mul_one, if_pos hij]
    · -- i ≠ j: at least one component must differ, so the LHS product is 0.
      rw [if_neg hij]
      by_cases hi1 :
        (finProdFinEquiv.symm i).1 = (finProdFinEquiv.symm j).1
      · by_cases hi2 :
          (finProdFinEquiv.symm i).2 = (finProdFinEquiv.symm j).2
        · -- both components agree: would force i = j, contradicting hij.
          exact absurd
            (finProdFinEquiv.symm.injective (Prod.ext hi1 hi2)) hij
        · rw [if_pos hi1, if_neg hi2, mul_zero]
      · rw [if_neg hi1, zero_mul]

theorem starG_prod_fourier
    (A : GroupTensor (G₁ × G₂) ℓ m)
    (B : GroupTensor (G₁ × G₂) m p)
    (ρ₁ : Irrep G₁) (ρ₂ : Irrep G₂) :
    fourierBlock (A ⋆ B) (irrepProd ρ₁ ρ₂) =
    fourierBlock A (irrepProd ρ₁ ρ₂) * fourierBlock B (irrepProd ρ₁ ρ₂) :=
  fourier_multiplicative A B (irrepProd ρ₁ ρ₂)

variable {G₃ : Type*} [Group G₃] [Fintype G₃] [DecidableEq G₃]

theorem convTensor_prod3
    (a₁ b₁ c₁ : G₁) (a₂ b₂ c₂ : G₂) (a₃ b₃ c₃ : G₃) :
    convTensor (a₁, a₂, a₃) (b₁, b₂, b₃) (c₁, c₂, c₃) =
    convTensor a₁ b₁ c₁ * convTensor a₂ b₂ c₂ * convTensor a₃ b₃ c₃ := by
  rw [convTensor_prod a₁ b₁ c₁ (a₂, a₃) (b₂, b₃) (c₂, c₃)]
  rw [convTensor_prod a₂ b₂ c₂ a₃ b₃ c₃]; ring

end GroupTensor
end