#!/usr/bin/env python3
"""
Scenario loop entry-point for fsl-server and fsl-client nodes.
Automatically runs multiple scenarios from matrix.yaml.
"""
import argparse
import copy
import os
import signal
import subprocess
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Any
import yaml
PROJECT_ROOT = Path(__file__).resolve().parents[2]
def _deep_merge(base: dict[str, Any], override: dict[str, Any]) -> None:
"""In-place deep merge of override into base."""
for key, value in override.items():
if isinstance(value, dict) and isinstance(base.get(key), dict):
_deep_merge(base[key], value)
else:
base[key] = copy.deepcopy(value)
def main() -> None:
parser = argparse.ArgumentParser(description="FSL Scenario Loop Runner")
parser.add_argument(
"--role",
choices=["server", "client"],
required=True,
help="Which node role this container plays.",
)
args = parser.parse_args()
scenario_id = os.environ.get("SCENARIO_ID", "").strip()
node_module = (
"src.nodes.server_node" if args.role == "server" else "src.nodes.client_node"
)
# ── Single-scenario mode: SCENARIO_ID already provided ──
if scenario_id:
print(f"[SCENARIO LOOP] Single-scenario mode: role={args.role} SCENARIO_ID={scenario_id}")
os.execv(sys.executable, [sys.executable, "-u", "-m", node_module])
return
# ── Multi-scenario mode: read matrix ──
config_path = Path(os.environ.get("FSL_CONFIG_PATH", str(PROJECT_ROOT / "config.yaml")))
matrix_path = Path(os.environ.get("FSL_MATRIX_CONFIG_PATH", str(PROJECT_ROOT / "matrix.yaml")))
with config_path.open("r", encoding="utf-8") as fh:
root_cfg = yaml.safe_load(fh) or {}
with matrix_path.open("r", encoding="utf-8") as fh:
matrix_raw = yaml.safe_load(fh) or {}
matrix_cfg = matrix_raw.get("experiment_matrix", {})
scenarios = matrix_cfg.get("scenarios", [])
raw_seeds = matrix_cfg.get("seeds", [root_cfg.get("training", {}).get("seed", 42)])
seeds = [int(s) for s in raw_seeds]
session_id = os.environ.get("SESSION_ID", "").strip() or datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
tmp_dir = PROJECT_ROOT / "results" / "matrix_configs"
tmp_dir.mkdir(parents=True, exist_ok=True)
failed: list[str] = []
run_count = 0
total_runs = len(scenarios) * len(seeds)
for seed in seeds:
for scenario in scenarios:
if run_count > 0:
print("\n" + "─"*60)
print(f"[SCENARIO LOOP] Waiting 5 seconds for system cleanup...")
time.sleep(5)
sid = str(scenario.get("id", "unknown")).strip()
run_id = f"{sid}_seed{seed}"
run_count += 1
print(f"\n[SCENARIO LOOP] ── {run_count}/{total_runs}: Starting {run_id} ──")
# Prepare merged config
run_cfg = copy.deepcopy(root_cfg)
_deep_merge(run_cfg, scenario.get("overrides", {}))
run_cfg.setdefault("training", {})["seed"] = seed
tmp_config = tmp_dir / f"{run_id}_cfg.yaml"
with tmp_config.open("w", encoding="utf-8") as fh:
yaml.safe_dump(run_cfg, fh, sort_keys=False)
env = os.environ.copy()
env["SESSION_ID"] = session_id
env["SCENARIO_ID"] = run_id
env["FSL_CONFIG_PATH"] = str(tmp_config)
# --- Synchronisation Barrier (Client Only) ---
if args.role == "client":
# Prioritize environment variable over config file for distributed runs
server_host = os.environ.get("FSL_SERVER_HOST") or run_cfg.get("grpc", {}).get("server_host", "localhost")
server_port = run_cfg.get("grpc", {}).get("server_port", 50051)
print(f"[SCENARIO LOOP] Client waiting for Server ({server_host}:{server_port}) to switch to {run_id}...")
# Use dynamic loading to find files exactly where Server/Client nodes expect them
import importlib.util
def load_module(name, path):
spec = importlib.util.spec_from_file_location(name, path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
try:
fsl_pb2 = load_module("fsl_pb2", str(PROJECT_ROOT / "proto" / "fsl_pb2.py"))
fsl_pb2_grpc = load_module("fsl_pb2_grpc", str(PROJECT_ROOT / "proto" / "fsl_pb2_grpc.py"))
import grpc
except Exception as e:
print(f"[SCENARIO LOOP] Critical: Could not find gRPC generated files in proto/: {e}")
sys.exit(1)
while True:
try:
channel = grpc.insecure_channel(f"{server_host}:{server_port}")
stub = fsl_pb2_grpc.FSLServiceStub(channel)
# Use the Heartbeat/GetInfo endpoint to ask the server its current ID
empty_msg = getattr(fsl_pb2, 'Empty')()
response = stub.GetInfo(empty_msg, timeout=2.0)
current_server_sid = getattr(response, 'scenario_id', 'unknown')
if current_server_sid == run_id:
print(f"[SCENARIO LOOP] Server is READY for {run_id}. Proceeding.")
break
else:
print(f"[SCENARIO LOOP] Server is in mismatching scenario: {current_server_sid} != {run_id}. Waiting 3s...")
except Exception:
# Server might be offline or starting up
print(f"[SCENARIO LOOP] Server ({server_host}:{server_port}) is offline or restarting. Waiting 3s...")
time.sleep(3)
# --- Start Subprocess with Logging ---
log_dir = PROJECT_ROOT / "results" / "logs" / session_id
log_dir.mkdir(parents=True, exist_ok=True)
log_file_path = log_dir / f"{run_id}.log"
print(f"[SCENARIO LOOP] Logging to: {log_file_path}")
try:
# We use pipe to capture output so we can save it to file AND show it on console
proc = subprocess.Popen(
[sys.executable, "-u", "-m", node_module],
cwd=str(PROJECT_ROOT),
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1
)
# Setup signal handling
def handle_signal(sig, frame):
print(f"\n[SCENARIO LOOP] Signal received. Terminating {run_id}...")
proc.terminate()
sys.exit(0)
signal.signal(signal.SIGINT, handle_signal)
signal.signal(signal.SIGTERM, handle_signal)
# Use a file to save the log
with open(log_file_path, "w", encoding="utf-8") as log_fh:
# Stream output line by line
for line in proc.stdout:
sys.stdout.write(line)
sys.stdout.flush()
log_fh.write(line)
log_fh.flush()
rc = proc.wait()
except Exception as e:
print(f"[SCENARIO LOOP] Error during subprocess execution: {e}")
rc = -1
if rc != 0:
print(f"[SCENARIO LOOP] {run_id} failed (code {rc}). Continuing...")
failed.append(run_id)
else:
print(f"[SCENARIO LOOP] {run_id} completed successfully.")
print(f"\n[SCENARIO LOOP] All done: {total_runs - len(failed)}/{total_runs} OK")
if failed:
sys.exit(1)
if __name__ == "__main__":
main()