aegisai / backend / agents / vision_agent.py
vision_agent.py
Raw
import base64
import numpy as np
from collections import deque
from typing import Dict, Any, Optional
from google.genai import types
from google.genai.types import GenerateContentConfig, Content, Part
from agents.base_agent import BaseAgent
from config.settings import settings, VISION_AGENT_PROMPT

class VisionAgent(BaseAgent):
    """Performs real-time security analysis on video frames using Gemini.

    This agent maintains a temporal history of previous detections to improve
    contextual awareness and reduce false positives across a video stream.
    """

    def __init__(self, **kwargs):
        """Initializes the VisionAgent with a 10-frame sliding window history."""
        super().__init__(**kwargs)
        self.max_history = 10
        self.frame_history = deque(maxlen=self.max_history)

    async def process(self, frame: np.ndarray = None, base64_image: str = None, frame_number: int = 0) -> Dict[str, Any]:
        """Analyzes visual input for security threats and returns a structured report.

        Args:
            frame: Optional numpy array (OpenCV format) of the video frame.
            base64_image: Optional base64 encoded string of the image.
            frame_number: Current frame index used for temporal tracking.

        Returns:
            Dict[str, Any]: Structured analysis containing incident status, 
                severity, confidence, and reasoning.
        """
        try:
            image_bytes = self._prepare_image_bytes(frame, base64_image)
            context = self._build_context()
            
            user_prompt = "Analyze the input based on the security protocol."
            if context:
                user_prompt += f"\n\nTEMPORAL CONTEXT:\n{context}"
    
            # WRAP THIS IN TRY/EXCEPT
            response = await self.client.aio.models.generate_content(
                model=self.model_name,
                contents=[
                    types.Content(role="user", parts=[
                        types.Part.from_text(text=user_prompt),
                        types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg")
                    ])
                ],
                config=GenerateContentConfig(
                    system_instruction=VISION_AGENT_PROMPT,
                    temperature=settings.TEMPERATURE,
                    response_mime_type="application/json"
                ),
            )
            result = self._parse_json_response(response.text)
        except Exception as e:
            self.logger.error(f"Vision API Error: {str(e)}")
            return self._default_result(f"API Error: {str(e)}")
    
        if not result:
            return self._default_result("JSON parsing failed")
    
        validated = self._validate_result(result)
        self._update_history(frame_number, validated)
        return validated

    def _prepare_image_bytes(self, frame: Optional[np.ndarray], base64_str: Optional[str]) -> bytes:
        """Normalizes image input into bytes, handling OpenCV frames and Base64 strings.

        Args:
            frame: Raw image array.
            base64_str: Base64 string (with or without data URI prefix).

        Returns:
            bytes: JPEG encoded image data.
        """
        if frame is not None:
            import cv2
            success, buffer = cv2.imencode(".jpg", frame)
            if not success:
                raise ValueError("Could not encode frame to JPEG")
            return buffer.tobytes()
        
        if base64_str:
            if "," in base64_str:
                base64_str = base64_str.split(",")[-1]
            
            # Ensure proper padding for the base64 decoder
            missing_padding = len(base64_str) % 4
            if missing_padding:
                base64_str += "=" * (4 - missing_padding)
                
            return base64.b64decode(base64_str)
            
        raise ValueError("No valid image source provided")

    def _build_context(self) -> str:
        """Returns the newline-separated history of recent detections."""
        return "\n".join(self.frame_history)

    def _update_history(self, frame_num: int, result: Dict):
        """Adds the current detection to the sliding window history."""
        summary = f"Frame {frame_num}: {result.get('type', 'normal')} ({result.get('severity', 'low')})"
        self.frame_history.append(summary)

    def _validate_result(self, result: Dict) -> Dict:
        """Clamps confidence scores and ensures all required fields are present."""
        raw_conf = result.get("confidence", 0)
        confidence = max(0, min(100, int(raw_conf)))
        
        # Incident is only True if confidence exceeds the system threshold
        is_incident = result.get("incident", False)
        if confidence < settings.CONFIDENCE_THRESHOLD:
            is_incident = False

        return {
            "incident": is_incident,
            "type": result.get("type", "unknown"),
            "severity": str(result.get("severity", "low")).lower(),
            "confidence": confidence,
            "reasoning": result.get("reasoning", "No explanation"),
            "subjects": result.get("subjects", []),
            "recommended_actions": result.get("recommended_actions", [])
        }

    def _default_result(self, error_msg: str) -> Dict[str, Any]:
        """Returns a safe, non-incident result in case of processing errors."""
        return {
            "incident": False, 
            "type": "error", 
            "severity": "low", 
            "confidence": 0, 
            "reasoning": error_msg
        }