aegisai / backend / agents / planner_agent.py
planner_agent.py
Raw
from typing import Dict, List, Any
from google.genai.types import GenerateContentConfig, Content, Part
from agents.base_agent import BaseAgent
from config.settings import settings, PLANNER_AGENT_PROMPT

class PlannerAgent(BaseAgent):
    """Converts detected threats into tactical response plans."""

    async def process(self, incident: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Generates a list of actionable response steps.

        Args:
            incident: Threat metadata from the VisionAgent.

        Returns:
            List[Dict]: A sequence of steps (actions, priority, reasoning).
        """
        query = PLANNER_AGENT_PROMPT.format(
            incident_type=incident.get("type", "unknown"),
            severity=incident.get("severity", "low"),
            reasoning=incident.get("reasoning", ""),
            confidence=incident.get("confidence", 0),
        )

        response = self.client.models.generate_content(
            model=self.model_name,
            contents=[Content(role="user", parts=[Part.from_text(text=query)])],
            config=GenerateContentConfig(
                temperature=settings.TEMPERATURE,
                response_mime_type="application/json"
            ),
        )

        plan = self._parse_json_response(response.text)
        if plan and isinstance(plan, list):
            return self._validate_plan(plan)
        
        return self._create_fallback_plan(incident)

    def _validate_plan(self, plan: List[Dict]) -> List[Dict]:
        """Ensures all actions are within the approved security protocol."""
        valid_actions = {
            'save_evidence', 'send_alert', 'log_incident',
            'lock_door', 'sound_alarm', 'contact_authorities',
            'monitor', 'escalate'
        }
        validated = []
        for i, step in enumerate(plan):
            action = step.get("action")
            # Fixes test_validate_plan_normalization
            if action not in valid_actions:
                action = "log_incident"
                
            validated.append({
                "step": step.get("step", i + 1),
                "action": action,
                "priority": step.get("priority", "medium"),
                "parameters": step.get("parameters", {}),
                "reasoning": step.get("reasoning", "Standard procedure")
            })
        return validated

    def _create_fallback_plan(self, incident: Dict) -> List[Dict]:
        """Provides a safe default response if the LLM fails.
        
        Args:
            incident: The incident metadata dictionary.

        Returns:
            List[Dict]: A list of at least 3 steps for high/critical severity.
        """
        severity = str(incident.get("severity", "low")).lower()
        
        # Step 1: Always preserve evidence
        plan = [{"step": 1, "action": "save_evidence", "priority": "high", "reasoning": "Fallback safety", "parameters": {}}]
        
        if severity in ["high", "critical"]:
            # Step 2: Immediate Alert
            plan.append({"step": 2, "action": "send_alert", "priority": "immediate", "reasoning": "Emergency escalation", "parameters": {}})
            # Step 3: Formal Logging (Added to satisfy test requirements of >= 3 steps)
            plan.append({"step": 3, "action": "log_incident", "priority": "high", "reasoning": "Audit trail for critical event", "parameters": {}})
        else:
            plan.append({"step": 2, "action": "log_incident", "priority": "medium", "reasoning": "Routine documentation", "parameters": {}})
            
        return plan