import argparse import subprocess import sys from pathlib import Path PROJECT_ROOT = Path(__file__).resolve().parents[2] def _resolve_root(path_str: str) -> Path: path = Path(path_str) if not path.is_absolute(): path = PROJECT_ROOT / path return path.resolve() def _list_sessions(root: Path) -> list[str]: if not root.exists(): return [] # Find all directories containing a server model (recursive) found_dirs = set() for p in root.glob("**/server_head_round_*.pth"): # Get path relative to root try: rel_dir = p.parent.relative_to(root) if str(rel_dir) == ".": # If weights are in root (flat layout), we don't treat root as a session name # unless we want to support it specifically. # For this project, sessions are subdirs. pass else: found_dirs.add(str(rel_dir)) except ValueError: continue if not found_dirs: # Fallback to legacy behavior for backward compatibility found_dirs = [ p.name for p in root.iterdir() if p.is_dir() and (p.name.startswith("20") or "seed" in p.name) ] return sorted(list(found_dirs)) def main() -> int: parser = argparse.ArgumentParser( description="Batch wrapper for src.data.run_evaluation across many sessions." ) parser.add_argument( "--sessions-root", default="bestweights", help="Directory that contains session folders (default: bestweights).", ) parser.add_argument( "--only", default="", help="Comma-separated session IDs to evaluate.", ) parser.add_argument( "--limit", type=int, default=0, help="Max number of sessions to run (0 means all).", ) parser.add_argument( "--device", default="cpu", help="Evaluation device passed through to run_evaluation (cpu/mps/cuda).", ) parser.add_argument( "--force-prob-threshold", type=float, default=None, help="Optional fixed probability threshold in [0,1] for all sessions.", ) parser.add_argument( "--report-tag", default="", help="Optional report tag suffix passed through to run_evaluation.", ) parser.add_argument( "--continue-on-error", action="store_true", help="Continue remaining sessions even if one evaluation fails.", ) parser.add_argument( "--dry-run", action="store_true", help="Print commands without executing.", ) args = parser.parse_args() if args.force_prob_threshold is not None and not (0.0 <= float(args.force_prob_threshold) <= 1.0): raise ValueError("--force-prob-threshold must be in [0, 1].") sessions_root = _resolve_root(args.sessions_root) sessions = _list_sessions(sessions_root) only_ids = {s.strip() for s in args.only.split(",") if s.strip()} if only_ids: sessions = [s for s in sessions if s in only_ids] if args.limit > 0: sessions = sessions[: args.limit] if not sessions: raise FileNotFoundError(f"No sessions found under {sessions_root}") print(f"[BATCH-EVAL] sessions_root={sessions_root}") print(f"[BATCH-EVAL] total_sessions={len(sessions)}") bw_dir = PROJECT_ROOT / "bestweights" try: rel_prefix = sessions_root.relative_to(bw_dir) except ValueError: rel_prefix = Path("") failures: list[str] = [] for idx, session_id in enumerate(sessions, start=1): # e.g., if rel_prefix is '2026-04-09_08-11-48', combine it with '01_seed42' full_session_path = str(rel_prefix / session_id) if str(rel_prefix) != "." else session_id cmd = [ sys.executable, "-m", "src.data.run_evaluation", "--device", str(args.device), "--session", full_session_path, ] if args.force_prob_threshold is not None: cmd.extend(["--force-prob-threshold", str(args.force_prob_threshold)]) if args.report_tag: cmd.extend(["--report-tag", str(args.report_tag)]) printable = " ".join(cmd) print(f"\n[BATCH-EVAL] [{idx}/{len(sessions)}] {session_id}") print(f"[BATCH-EVAL][CMD] {printable}") if args.dry_run: continue result = subprocess.run(cmd, cwd=PROJECT_ROOT, check=False) if result.returncode != 0: failures.append(session_id) print(f"[BATCH-EVAL][ERROR] session={session_id} returncode={result.returncode}") if not args.continue_on_error: break if failures: print(f"\n[BATCH-EVAL] completed_with_failures={len(failures)} failed_sessions={failures}") return 1 print("\n[BATCH-EVAL] all sessions completed successfully.") return 0 if __name__ == "__main__": raise SystemExit(main())