csc8114 / code / src / nodes / run_scenario_loop.py
run_scenario_loop.py
Raw
#!/usr/bin/env python3
"""
Scenario loop entry-point for fsl-server and fsl-client nodes.
Automatically runs multiple scenarios from matrix.yaml.
"""
import argparse
import copy
import os
import signal
import subprocess
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Any

import yaml

PROJECT_ROOT = Path(__file__).resolve().parents[2]

def _deep_merge(base: dict[str, Any], override: dict[str, Any]) -> None:
    """In-place deep merge of override into base."""
    for key, value in override.items():
        if isinstance(value, dict) and isinstance(base.get(key), dict):
            _deep_merge(base[key], value)
        else:
            base[key] = copy.deepcopy(value)

def main() -> None:
    parser = argparse.ArgumentParser(description="FSL Scenario Loop Runner")
    parser.add_argument(
        "--role",
        choices=["server", "client"],
        required=True,
        help="Which node role this container plays.",
    )
    args = parser.parse_args()

    scenario_id = os.environ.get("SCENARIO_ID", "").strip()
    node_module = (
        "src.nodes.server_node" if args.role == "server" else "src.nodes.client_node"
    )

    # ── Single-scenario mode: SCENARIO_ID already provided ──
    if scenario_id:
        print(f"[SCENARIO LOOP] Single-scenario mode: role={args.role} SCENARIO_ID={scenario_id}")
        os.execv(sys.executable, [sys.executable, "-u", "-m", node_module])
        return

    # ── Multi-scenario mode: read matrix ──
    config_path = Path(os.environ.get("FSL_CONFIG_PATH", str(PROJECT_ROOT / "config.yaml")))
    matrix_path = Path(os.environ.get("FSL_MATRIX_CONFIG_PATH", str(PROJECT_ROOT / "matrix.yaml")))
    
    with config_path.open("r", encoding="utf-8") as fh:
        root_cfg = yaml.safe_load(fh) or {}
    with matrix_path.open("r", encoding="utf-8") as fh:
        matrix_raw = yaml.safe_load(fh) or {}

    matrix_cfg = matrix_raw.get("experiment_matrix", {})
    scenarios = matrix_cfg.get("scenarios", [])
    raw_seeds = matrix_cfg.get("seeds", [root_cfg.get("training", {}).get("seed", 42)])
    seeds = [int(s) for s in raw_seeds]
    session_id = os.environ.get("SESSION_ID", "").strip() or datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    tmp_dir = PROJECT_ROOT / "results" / "matrix_configs"
    tmp_dir.mkdir(parents=True, exist_ok=True)

    failed: list[str] = []
    run_count = 0
    total_runs = len(scenarios) * len(seeds)

    for seed in seeds:
        for scenario in scenarios:
            if run_count > 0:
                print("\n" + ""*60)
                print(f"[SCENARIO LOOP] Waiting 5 seconds for system cleanup...")
                time.sleep(5)

            sid = str(scenario.get("id", "unknown")).strip()
            run_id = f"{sid}_seed{seed}"
            run_count += 1
            print(f"\n[SCENARIO LOOP] ── {run_count}/{total_runs}: Starting {run_id} ──")

            # Prepare merged config
            run_cfg = copy.deepcopy(root_cfg)
            _deep_merge(run_cfg, scenario.get("overrides", {}))
            run_cfg.setdefault("training", {})["seed"] = seed
            tmp_config = tmp_dir / f"{run_id}_cfg.yaml"
            with tmp_config.open("w", encoding="utf-8") as fh:
                yaml.safe_dump(run_cfg, fh, sort_keys=False)


            env = os.environ.copy()
            env["SESSION_ID"] = session_id
            env["SCENARIO_ID"] = run_id
            env["FSL_CONFIG_PATH"] = str(tmp_config)

            # --- Synchronisation Barrier (Client Only) ---
            if args.role == "client":
                # Prioritize environment variable over config file for distributed runs
                server_host = os.environ.get("FSL_SERVER_HOST") or run_cfg.get("grpc", {}).get("server_host", "localhost")
                server_port = run_cfg.get("grpc", {}).get("server_port", 50051)
                
                print(f"[SCENARIO LOOP] Client waiting for Server ({server_host}:{server_port}) to switch to {run_id}...")
                
                # Use dynamic loading to find files exactly where Server/Client nodes expect them
                import importlib.util
                def load_module(name, path):
                    spec = importlib.util.spec_from_file_location(name, path)
                    mod = importlib.util.module_from_spec(spec)
                    spec.loader.exec_module(mod)
                    return mod

                try:
                    fsl_pb2 = load_module("fsl_pb2", str(PROJECT_ROOT / "proto" / "fsl_pb2.py"))
                    fsl_pb2_grpc = load_module("fsl_pb2_grpc", str(PROJECT_ROOT / "proto" / "fsl_pb2_grpc.py"))
                    import grpc
                except Exception as e:
                    print(f"[SCENARIO LOOP] Critical: Could not find gRPC generated files in proto/: {e}")
                    sys.exit(1)
                
                while True:
                    try:
                        channel = grpc.insecure_channel(f"{server_host}:{server_port}")
                        stub = fsl_pb2_grpc.FSLServiceStub(channel)
                        # Use the Heartbeat/GetInfo endpoint to ask the server its current ID
                        empty_msg = getattr(fsl_pb2, 'Empty')()
                        response = stub.GetInfo(empty_msg, timeout=2.0)
                        current_server_sid = getattr(response, 'scenario_id', 'unknown')
                        
                        if current_server_sid == run_id:
                            print(f"[SCENARIO LOOP] Server is READY for {run_id}. Proceeding.")
                            break
                        else:
                            print(f"[SCENARIO LOOP] Server is in mismatching scenario: {current_server_sid} != {run_id}. Waiting 3s...")
                    except Exception:
                        # Server might be offline or starting up
                        print(f"[SCENARIO LOOP] Server ({server_host}:{server_port}) is offline or restarting. Waiting 3s...")
                    
                    time.sleep(3)

            # --- Start Subprocess with Logging ---
            log_dir = PROJECT_ROOT / "results" / "logs" / session_id
            log_dir.mkdir(parents=True, exist_ok=True)
            log_file_path = log_dir / f"{run_id}.log"
            
            print(f"[SCENARIO LOOP] Logging to: {log_file_path}")
            
            try:
                # We use pipe to capture output so we can save it to file AND show it on console
                proc = subprocess.Popen(
                    [sys.executable, "-u", "-m", node_module],
                    cwd=str(PROJECT_ROOT),
                    env=env,
                    stdout=subprocess.PIPE,
                    stderr=subprocess.STDOUT,
                    text=True,
                    bufsize=1
                )
                
                # Setup signal handling
                def handle_signal(sig, frame):
                    print(f"\n[SCENARIO LOOP] Signal received. Terminating {run_id}...")
                    proc.terminate()
                    sys.exit(0)
                
                signal.signal(signal.SIGINT, handle_signal)
                signal.signal(signal.SIGTERM, handle_signal)
                
                # Use a file to save the log
                with open(log_file_path, "w", encoding="utf-8") as log_fh:
                    # Stream output line by line
                    for line in proc.stdout:
                        sys.stdout.write(line)
                        sys.stdout.flush()
                        log_fh.write(line)
                        log_fh.flush()
                
                rc = proc.wait()
            except Exception as e:
                print(f"[SCENARIO LOOP] Error during subprocess execution: {e}")
                rc = -1

            if rc != 0:
                print(f"[SCENARIO LOOP] {run_id} failed (code {rc}). Continuing...")
                failed.append(run_id)
            else:
                print(f"[SCENARIO LOOP] {run_id} completed successfully.")

    print(f"\n[SCENARIO LOOP] All done: {total_runs - len(failed)}/{total_runs} OK")
    if failed:
        sys.exit(1)

if __name__ == "__main__":
    main()