WalkXR-AI / src / walkxr_ai / agents / small_moments_roleplay_agent.py
small_moments_roleplay_agent.py
Raw
# In src/walkxr_ai/agents/small_moments_roleplay_agent.py


import os
from typing import Dict, Any, List, Optional, Tuple, TypedDict
from walkxr_ai.models.ollama_model import OllamaModel
from walkxr_ai.models.base_model import BaseModel
from walkxr_ai.rag.retrieval_engine import RetrievalEngine
from walkxr_ai.api.models import ChatResponse
from langsmith import traceable
from langsmith.run_helpers import tracing_context


# Placeholder for a future factory to handle different model providers.
# from ..models import LanguageModelFactory


# --- LangGraph State Preparation ---
# By defining a state object, we make the agent compatible with graph-based
# orchestration frameworks like LangGraph. The graph manages the state,
# and this agent acts as a node that modifies it.
class AgentState(TypedDict):
    """Represents the state of our agent at any point in time."""

    user_input: str
    retrieved_context: List[str]
    conversation_history: List[Tuple[str, str]]
    response: str
    # Future fields can include emotional analysis, pacing signals, etc.


class SmallMomentsRoleplayAgent:
    """
    An agent designed to act as an emotional companion for the Small Moments Walk.

    This agent is architected for use within a LangGraph flow. It uses a
    Retrieval-Augmented Generation (RAG) pattern to ground its responses in
    relevant user history and walk-specific content.
    """

    def __init__(self, config: Optional[dict[str, Any]] = None):
        """
        Initializes the agent with a specific configuration.

        Args:
            config (dict, optional): Configuration for persona, model, etc.
        """
        if config is None:
            config = {}

        self.persona = config.get(
            "persona",
            "You are a warm, empathetic companion. Your role is to listen, ask gentle questions, and share relatable anecdotes. You are here to help the user feel more comfortable and connected in small, everyday moments.",
        )

        # -- Modular Components --
        # Initialize the actual retrieval engine.
        # In a real app, the path might come from a config file.
        self.retrieval_engine = RetrievalEngine()
        # embedding_model_name="all-MiniLM-L6-v2"


        # The language model can be swapped (e.g., local Ollama vs. OpenAI API).
        # A future LanguageModelFactory would handle this based on config.
        model_name = os.environ.get("AGENT_MODEL", "ollama")
        if model_name == "ollama":
            self.language_model: BaseModel = OllamaModel(
                model_name=os.environ.get("OLLAMA_MODEL", "llama3")
            )
        else:
            # Placeholder for future backends (e.g., OpenAIModel, MockModel)
            raise ValueError(f"[Init Error] Unknown model backend: {model_name}")
        # self.language_model = LanguageModelFactory.create(model_name)
        # For now, we'll keep a mock-like structure for the LLM part.
        print(f"[Info] Language model backend set to: {model_name}")

    def _format_prompt(self, state: AgentState) -> str:
        """
        Constructs the full prompt from the agent's state.
        """
        history_str = "\n".join(
            [f"Human: {h}\nAI: {a}" for h, a in state["conversation_history"]]
        )
        context_str = "\n".join([f"- {item}" for item in state["retrieved_context"]])

        prompt = f"""**System Persona:**
{self.persona}

**Retrieved Context:**
{context_str}

**Conversation History:**
{history_str}

**Current User Input:**
Human: {state['user_input']}

**Your Turn:**
AI:"""
        return prompt

    @traceable(
        name="SmallMomentsAgent.get_response",
        tags=["walkxr", "agent"],
        metadata={"agent_version": "v1.0", "llm": "llama3"}
    )
    def get_response(self, user_input: str, history: List[Tuple[str, str]]) -> Dict[str, Any]:
        """
        Generates a response using the RAG pattern. This method can be called
        directly or wrapped as a node in a LangGraph.

        Args:
            user_input: The latest input from the user.
            history: The full conversation history.

        Returns:
            The AI's generated response.
        """
        with tracing_context() as ctx:
            trace_info = {}
            if ctx is not None:
                trace_info = {
                    "trace_id": ctx.trace_id,
                    "run_id": ctx.run_id
                }

            # 1. Retrieve context.
            retrieved_context = self.retrieve_context(user_input)

            # 2. Assemble the state for this turn.
            current_state: AgentState = {
                "user_input": user_input,
                "retrieved_context": retrieved_context,
                "conversation_history": history,
                "response": "",  # To be filled.
            }

            # 3. Format the prompt.
            prompt = self._format_prompt(current_state)

            # 4. Generate a response (mocked for now).
            # In a real implementation, this would be:
            # ai_response = self.language_model.generate_response(prompt)
            print(f"[Debug] Generating response for prompt:\n---\n{prompt}\n---")
            #ai_response = "That's a really interesting thought. It reminds me of a time when... [mocked response]"
            ai_response =self.language_model.generate_response(
                prompt=prompt,
                history=history
            )

            return ChatResponse(
                response_text=ai_response,
                source_chunks=retrieved_context,
                debug={
                    "prompt": prompt,
                    "langsmith_trace": trace_info
                }
            ).dict()

    @traceable(name="SmallMomentsAgent.retrieve_context", tags=["retrieval"])
    def retrieve_context(self, query: str) -> List[str]:
        try:
            result = self.retrieval_engine.query(query)
            return [result.response]
        except Exception as e:
            return [f"[Retrieval error] {str(e)}"]