WalkXR-AI / scripts / run_agent_simulation.py
run_agent_simulation.py
Raw
# scripts/run_agent_simulation.py
import sys
from pathlib import Path

project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

import os
import time
import json
import requests
from datetime import datetime
from uuid import uuid4
from pathlib import Path

from data_sources.sheets_loader import load_scenarios_from_sheet
from schemas.emotional_state import BasicEmotionalState
from rewards.reward_function import compute_reward
from agent_utils.emotion_log import log_emotion, extract_latent_emotion

# --- Config ---
SHEET_URL = "https://docs.google.com/spreadsheets/d/1KJ2gf_FrF8XXWgwan9yWE0T7NGwcBLT7HOdHKnXFcUY/edit?usp=sharing"
TAB_NAME = "All Output"
NUM_TURNS = 5
LOG_DIR = "logs"
os.makedirs(LOG_DIR, exist_ok=True)

print("Loading scenarios from sheet...")
# --- Load Scenario ---
scenarios = load_scenarios_from_sheet(SHEET_URL, tab_name=TAB_NAME)
if not scenarios:
    raise ValueError("No scenarios loaded from sheet.")

scenario = scenarios[0]
print(f"Loaded scenario: {scenario.persona_id}_{scenario.module_id}")

from agent_utils.matching import find_best_matching_agent
from data_sources.agent_loader import load_agents
from schemas.agent_schema import Agent
from schemas.goal_schema import Goal

dummy_goal = Goal(
    name="Default Goal",
    description="Placeholder goal for simulation case lacking context.",
)

agents = load_agents()
agent = find_best_matching_agent(dummy_goal, agents)

# --- Initialize Agent ---
history = []
state = BasicEmotionalState(emotion="neutral", intensity=0.0)

# --- Prepare Logging ---
scenario_uid = f"{scenario.persona_id}_{scenario.module_id}"

print(f"Starting simulation with {NUM_TURNS} turns...")
print(f"Scenario ID: {scenario_uid}")
print(f"Prompt: {scenario.prompt[:100]}...")
print(f"Initial emotion: {scenario.emotion_before}")
print(f"Desired AI tone: {scenario.tone}")

log_path = Path(LOG_DIR) / f"sim_{scenario_uid}_{uuid4().hex[:6]}.jsonl"

with open(log_path, "w") as f:
    for turn in range(NUM_TURNS):
        # Generate contextual user input based on SimulationCase
        if turn == 0:
            # Use the scenario prompt as the initial user input
            user_input = f"I'm experiencing this situation: {scenario.prompt}. I'm feeling {scenario.emotion_before}."
        elif turn == 1:
            user_input = "Can you help me understand what's happening here? I feel anxious about this social interaction."
        elif turn == 2:
            user_input = "What should I do in this kind of situation? Any advice?"
        elif turn == 3:
            user_input = "How can I manage these feelings better next time?"
        else:
            user_input = "Thank you for the help. Any final thoughts?"
            
        print(f"\n[Turn {turn+1}] User: {user_input[:100]}...")

        # Agent response with error handling
        start_time = time.time()
        try:
            res = requests.post(
                "http://localhost:8000/chat",
                json={
                    "user_id": "sim_user",
                    "stage": "demo",
                    "message": user_input,
                    "history": history
                },
                timeout=30  
            )
            res.raise_for_status()  # Raise exception for bad status codes
            latency = time.time() - start_time
            response_data = res.json()
            agent_response = response_data["response_text"]
            api_success = True
            
        except Exception as e:
            latency = time.time() - start_time
            agent_response = f"API_ERROR: {str(e)}"
            api_success = False
            print(f"API Error: {e}")

        print(f"[Turn {turn+1}] Agent: {agent_response[:100]}...")
        print(f"[Turn {turn+1}] Latency: {latency:.2f}s")

        history.append((user_input, agent_response))

        # Extract emotion with error handling
        try:
            latent_emotion = extract_latent_emotion(agent_response)
        except Exception as e:
            latent_emotion = {"emotion": "unknown", "confidence": 0.0}
            print(f"Emotion extraction error: {e}")

        # Placeholder QA score - improve based on response quality
        qa_score = 1.0 if "error" in agent_response.lower() else 4.0

        # Compute reward with error handling
        try:
            reward = compute_reward(
                qa_score=qa_score,
                agent_emotion=state.emotion,
                mood_arc_target="neutral",
                latency=latency
            )
            reward_dict = reward.__dict__ if hasattr(reward, '__dict__') else {"total": reward}
        except Exception as e:
            reward_dict = {"total": 0.0, "error": str(e)}
            print(f"Reward computation error: {e}")

        # --- Create full log entry ---
        log_entry = {
            "turn": turn,
            "timestamp": datetime.utcnow().isoformat(),
            "scenario_id": scenario_uid,
            "persona_id": scenario.persona_id,
            "module_id": scenario.module_id,
            "scenario_prompt": scenario.prompt,
            "initial_emotion": scenario.emotion_before,
            "desired_ai_behavior": scenario.ai_behavior,
            "desired_tone": scenario.tone,
            "user_input": user_input,
            "agent_response": agent_response,
            "emotional_state": state.model_dump(),
            "latent_emotion": latent_emotion,
            "latency": latency,
            "qa_score": qa_score,
            "reward": reward_dict,
            "api_success": api_success
        }

        try:
            log_emotion({
                "agent_response": agent_response,
                "latent_emotion": latent_emotion,
                "turn": turn,
                "timestamp": log_entry["timestamp"]
            })
        except Exception as e:
            print(f"Emotion logging error: {e}")

        f.write(json.dumps(log_entry) + "\n")
        f.flush()  # Ensure data is written immediately
        print(f"[Turn {turn+1}] Logged.")

print(f"\nโœ… Simulation complete! Log saved to: {log_path}")
print(f"๐Ÿ“Š Final history has {len(history)} turns")
print(f"๐ŸŽฏ Scenario: {scenario.prompt[:100]}...")
print(f"๐ŸŽญ Expected AI behavior: {scenario.ai_behavior[:100]}...")