#!/bin/bash
# submit_e3nn.bsub ,  e3nn-based SE(3)-equivariant baseline
#
# e3nn handles scalars, vectors, and rank-2 tensors via output irreps.
# Array index 1..18 = (target × seed) for 6 targets × 3 seeds.
# The output irrep is selected per-target inside this script.

#BSUB -J starg_e3nn[1-18]
#BSUB -o logs/e3nn_%I_%J.out
#BSUB -e logs/e3nn_%I_%J.err
#BSUB -q normal
#BSUB -n 8
#BSUB -gpu "num=1:mode=exclusive_process"
#BSUB -W 12:00
#BSUB -M 64GB

set -uo pipefail
# torch 2.6+ defaults weights_only=True in torch.load, which refuses to
# unpickle e3nn 0.4.4's constants.pt (it contains a `slice` global).
# The env var below restores the pre-2.6 weights_only=False behaviour
# for this workload. Safe because we're loading a known package file
# bundled with e3nn, not arbitrary user input.
export TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1
mkdir -p logs results

TARGETS=(gap alpha mu zpve mu_vector alpha_tensor)
SEEDS=(0 1 2)
TARGET="${TARGETS[$(( (LSB_JOBINDEX - 1) / 3 ))]}"
SEED="${SEEDS[$(( (LSB_JOBINDEX - 1) % 3 ))]}"

case "$TARGET" in
  mu_vector)    OUTPUT_IRREPS="1x1o" ;;
  alpha_tensor) OUTPUT_IRREPS="1x2e + 1x0e" ;;
  *)            OUTPUT_IRREPS="1x0e" ;;
esac

WORKDIR=$HOME/starg/python/large_scale
QM9_DIR=${QM9_DIR:-/u/$USER/data/qm9/dsgdb9nsd}

cd "$WORKDIR"
export PYTHONPATH=".:${PYTHONPATH:-}"

echo "[$(date)] host=$(hostname) array=$LSB_JOBINDEX target=$TARGET seed=$SEED irreps=\"$OUTPUT_IRREPS\""
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader || true

python3 train_baseline_e3nn.py \
    --target        "$TARGET" \
    --output_irreps "$OUTPUT_IRREPS" \
    --qm9_dir       "$QM9_DIR" \
    --seed          "$SEED" \
    --out_dir       results/ \
    --device        cuda

echo "[$(date)] done"
