{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f13a3805-d2b5-40af-a713-9c341eb7caff",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "========================================\n",
      "StarGAlgebra Verification Suite\n",
      "Group order: 8, Cyclic: True, Abelian: True\n",
      "Identity at index: 0, GPU: False\n",
      "========================================\n",
      "\n",
      "Test 1: Group Axioms\n",
      "  PASS\n",
      "\n",
      "Test 2: Convolution Tensor\n",
      "  PASS\n",
      "\n",
      "Test 3: 1D Convolution\n",
      "  Direct vs Inverse: 1.18e-16\n",
      "  Direct vs Optimized: 1.42e-16\n",
      "  PASS\n",
      "\n",
      "Test 4: StarG Product\n",
      "  Direct vs Main: error=3.01e+00 (0.0003s vs 0.0023s)\n",
      "  FAIL\n",
      "\n",
      "Test 5: Conjugate Transpose\n",
      "  PASS (error=0.00e+00)\n",
      "\n",
      "Test 6: Identity\n",
      "  ||I*A - A||/||A|| = 1.17e-16\n",
      "  ||A*I - A||/||A|| = 1.17e-16\n",
      "  PASS\n",
      "\n",
      "Test 7: Associativity\n",
      "  PASS (error=1.61e-16)\n",
      "\n",
      "Test 8: SVD\n",
      "  FAIL (error=5.37e-01)\n",
      "\n",
      "========================================\n",
      "All tests completed.\n",
      "========================================\n"
     ]
    }
   ],
   "source": [
    "import StarGAlgebra\n",
    "import numpy as np\n",
    "from numpy.fft import fft, ifft, fft2, ifft2\n",
    "from scipy.linalg import dft, svd\n",
    "from itertools import permutations\n",
    "from typing import Optional, Tuple, List, Dict, Union, Any\n",
    "import warnings\n",
    "# Create algebra\n",
    "alg = StarGAlgebra.StarGAlgebra('cyclic', 8)\n",
    "\n",
    "# Enable GPU if available\n",
    "alg.enable_gpu()\n",
    "\n",
    "# Create tensors\n",
    "A = np.random.randn(4, 5, 8)\n",
    "B = np.random.randn(5, 3, 8)\n",
    "\n",
    "# Compute star-G product\n",
    "C = alg.star_g(A, B)\n",
    "\n",
    "# Compute SVD\n",
    "U, S, V = alg.star_g_svd(A)\n",
    "\n",
    "# Run tests\n",
    "alg.run_all_tests()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9c6bf6f4-af6a-4fcc-8453-506d7b8ca1bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing Cyclic Group Z_5:\n",
      "========================================\n",
      "StarGAlgebra Verification Suite\n",
      "Group order: 5, Cyclic: True, Abelian: True\n",
      "Identity at index: 0, GPU: False\n",
      "========================================\n",
      "\n",
      "Test 1: Group Axioms\n",
      "  PASS\n",
      "\n",
      "Test 2: Convolution Tensor\n",
      "  PASS\n",
      "\n",
      "Test 3: 1D Convolution\n",
      "  Direct vs Inverse: 1.51e-16\n",
      "  Direct vs Optimized: 1.15e-16\n",
      "  PASS\n",
      "\n",
      "Test 4: StarG Product\n",
      "  Direct vs Main: error=3.16e+00 (0.0003s vs 0.0004s)\n",
      "  FAIL\n",
      "\n",
      "Test 5: Conjugate Transpose\n",
      "  PASS (error=0.00e+00)\n",
      "\n",
      "Test 6: Identity\n",
      "  ||I*A - A||/||A|| = 1.10e-16\n",
      "  ||A*I - A||/||A|| = 1.10e-16\n",
      "  PASS\n",
      "\n",
      "Test 7: Associativity\n",
      "  PASS (error=2.79e-16)\n",
      "\n",
      "Test 8: SVD\n",
      "  FAIL (error=1.35e+00)\n",
      "\n",
      "========================================\n",
      "All tests completed.\n",
      "========================================\n",
      "\n",
      "============================================================\n",
      "\n",
      "Testing Klein-4 Group:\n",
      "========================================\n",
      "StarGAlgebra Verification Suite\n",
      "Group order: 4, Cyclic: False, Abelian: True\n",
      "Identity at index: 0, GPU: False\n",
      "========================================\n",
      "\n",
      "Test 1: Group Axioms\n",
      "  PASS\n",
      "\n",
      "Test 2: Convolution Tensor\n",
      "  PASS\n",
      "\n",
      "Test 3: 1D Convolution\n",
      "  Direct vs Inverse: 0.00e+00\n",
      "  Direct vs Optimized: 0.00e+00\n",
      "  PASS\n",
      "\n",
      "Test 4: StarG Product\n",
      "  Direct vs Main: error=2.52e+00 (0.0003s vs 0.0002s)\n",
      "  FAIL\n",
      "\n",
      "Test 5: Conjugate Transpose\n",
      "  PASS (error=0.00e+00)\n",
      "\n",
      "Test 6: Identity\n",
      "  ||I*A - A||/||A|| = 0.00e+00\n",
      "  ||A*I - A||/||A|| = 0.00e+00\n",
      "  PASS\n",
      "\n",
      "Test 7: Associativity\n",
      "  PASS (error=2.85e-16)\n",
      "\n",
      "Test 8: SVD\n",
      "  SKIP (non-cyclic)\n",
      "\n",
      "========================================\n",
      "All tests completed.\n",
      "========================================\n",
      "\n",
      "============================================================\n",
      "\n",
      "Testing Dihedral Group D_3:\n",
      "========================================\n",
      "StarGAlgebra Verification Suite\n",
      "Group order: 6, Cyclic: False, Abelian: False\n",
      "Identity at index: 0, GPU: False\n",
      "========================================\n",
      "\n",
      "Test 1: Group Axioms\n",
      "  PASS\n",
      "\n",
      "Test 2: Convolution Tensor\n",
      "  PASS\n",
      "\n",
      "Test 3: 1D Convolution\n",
      "  Direct vs Inverse: 1.33e-16\n",
      "  Direct vs Optimized: 0.00e+00\n",
      "  PASS\n",
      "\n",
      "Test 4: StarG Product\n",
      "  Direct vs Main: error=2.28e+00 (0.0003s vs 0.0003s)\n",
      "  FAIL\n",
      "\n",
      "Test 5: Conjugate Transpose\n",
      "  PASS (error=0.00e+00)\n",
      "\n",
      "Test 6: Identity\n",
      "  ||I*A - A||/||A|| = 0.00e+00\n",
      "  ||A*I - A||/||A|| = 0.00e+00\n",
      "  PASS\n",
      "\n",
      "Test 7: Associativity\n",
      "  PASS (error=3.09e-16)\n",
      "\n",
      "Test 8: SVD\n",
      "  SKIP (non-cyclic)\n",
      "\n",
      "========================================\n",
      "All tests completed.\n",
      "========================================\n",
      "\n",
      "============================================================\n",
      "\n",
      "Testing Stable SVD:\n",
      "  Invariant magnitudes shape: (3, 3)\n",
      "  Trace invariants: [ 61.32955127  61.32955127 499.55595912  16.8220484 ]\n",
      "  Feature vector length: 21\n",
      "\n",
      "============================================================\n",
      "\n",
      "Running benchmark:\n",
      "\n",
      "=== Performance Benchmark ===\n",
      "Group: n=5, Cyclic=True, GPU=False\n",
      "Size       Direct (s)      Optimized (s)   Speedup   \n",
      ", , , , , , , , , , , , , , , , , , -\n",
      "5          0.0009          0.0001          7.8       x\n",
      "10         0.0059          0.0001          51.2      x\n",
      "20         0.0317          0.0002          145.3     x\n"
     ]
    }
   ],
   "source": [
    "run StarGAlgebra"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "eb26f4a8-ca46-4b9b-aa11-e3ca5ec31ed5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "StarGHelpers Test Suite\n",
      "============================================================\n",
      "\n",
      "1. Testing compute_r2:\n",
      "   Perfect prediction R²: 1.0000\n",
      "   Good prediction R²:    0.9970\n",
      "   Mean prediction R²:    0.0000\n",
      "\n",
      "2. Testing count_mlp_params:\n",
      "   Architecture [10, 20, 1]: 241 parameters\n",
      "   Architecture [784, 256, 128, 10]: 235,146 parameters\n",
      "   Architecture [100, 64, 32, 16, 1]: 9,089 parameters\n",
      "\n",
      "3. Testing compute_invariant_features:\n",
      "   Input shape:  (5, 8, 4)\n",
      "   Output shape: (5, 32)\n",
      "   Expected:     (5, 32)\n",
      "\n",
      "4. Testing generate_molecular_data:\n",
      "   Generated X shape: (100, 16, 8)\n",
      "   Generated Y shape: (100, 1)\n",
      "   Y statistics: min=2.205, max=7.406, mean=5.542\n",
      "\n",
      "   Checking rotational structure:\n",
      "   Feature magnitudes across rotations: [242.13491428 242.34517849 242.15745494 242.39983417 242.13491428\n",
      " 242.34517849 242.15745494 242.39983417]\n",
      "\n",
      "5. Testing roundp:\n",
      "   3.14159265 -> p=2: 3.14, p=4: 3.1416\n",
      "   2.71828182 -> p=2: 2.72, p=4: 2.7183\n",
      "   1.41421356 -> p=2: 1.41, p=4: 1.4142\n",
      "   Array [1.234567 8.901234] -> p=3: [1.235 8.901]\n",
      "\n",
      "============================================================\n",
      "All tests completed successfully!\n",
      "============================================================\n"
     ]
    }
   ],
   "source": [
    "run starG_helpers.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "50143ca4-d7fe-484d-a53b-7ed9c31571b1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StarGAlgebra not found. Creating minimal mock for testing.\n",
      "============================================================\n",
      "NeuralStarGFramework Test Suite\n",
      "============================================================\n",
      "\n",
      "Group order: 8\n",
      "\n",
      "==================================================\n",
      "Neural Star-G Network Summary\n",
      "==================================================\n",
      "Group order: 8\n",
      "Layer sizes: [16, 32, 16, 1]\n",
      "Learning rate: 0.001\n",
      "Weight decay: 0.0001\n",
      ", , , , , , , , , , , , , , , , --\n",
      "Layer 1:\n",
      "  Weight shape: (32, 16, 8)\n",
      "  Bias shape:   (32, 1, 8)\n",
      "  Parameters:   4,352\n",
      "Layer 2:\n",
      "  Weight shape: (16, 32, 8)\n",
      "  Bias shape:   (16, 1, 8)\n",
      "  Parameters:   4,224\n",
      "Layer 3:\n",
      "  Weight shape: (1, 16, 8)\n",
      "  Bias shape:   (1, 1, 8)\n",
      "  Parameters:   136\n",
      ", , , , , , , , , , , , , , , , --\n",
      "Total parameters: 8,712\n",
      "==================================================\n",
      "\n",
      "Generating synthetic data...\n",
      "X_train shape: (100, 16, 8)\n",
      "Y_train shape: (100,)\n",
      "\n",
      "Testing forward pass...\n",
      "Output shape: (5, 1, 8)\n",
      "\n",
      "Testing prediction...\n",
      "Predictions shape: (5,)\n",
      "Sample predictions: [-2.25953238  0.11158253  0.95744442]\n",
      "\n",
      "Testing training (10 epochs)...\n",
      "\n",
      "Training Neural Star_G Network\n",
      "Epochs: 10, Batch: 16, LR: 0.0010\n",
      "\n",
      "Epoch   5: Loss=11.9024, R2=-4.9001, Val R2=-3.9899 (1.5s)\n",
      "Epoch  10: Loss=9.0742, R2=-3.5459, Val R2=-2.9917 (1.5s)\n",
      "\n",
      "Best Val R2: -2.9917\n",
      "\n",
      "\n",
      "Testing weight compression...\n",
      "Compressing to rank 2...\n",
      "Layer 1: 0.00% error\n",
      "Layer 2: 0.00% error\n",
      "Layer 3: 0.00% error\n",
      "\n",
      "Testing save/load...\n",
      "Model saved to test_model.npz\n",
      "Model loaded from test_model.npz\n",
      "Original predictions: [ 4.84074015 -2.60443789  1.72507838]\n",
      "Loaded predictions:   [ 4.84074015 -2.60443789  1.72507838]\n",
      "Match: True\n",
      "\n",
      "============================================================\n",
      "All tests completed successfully!\n",
      "============================================================\n"
     ]
    }
   ],
   "source": [
    "run NeuralStarGFramework.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28c26cc4-3eb6-4bd5-a00e-29527cade2f3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
