# In src/walkxr_ai/agents/small_moments_roleplay_agent.py
import os
from typing import Dict, Any, List, Optional, Tuple, TypedDict
from walkxr_ai.models.ollama_model import OllamaModel
from walkxr_ai.models.base_model import BaseModel
from walkxr_ai.rag.retrieval_engine import RetrievalEngine
from walkxr_ai.api.models import ChatResponse
from langsmith import traceable
from langsmith.run_helpers import tracing_context
# Placeholder for a future factory to handle different model providers.
# from ..models import LanguageModelFactory
# --- LangGraph State Preparation ---
# By defining a state object, we make the agent compatible with graph-based
# orchestration frameworks like LangGraph. The graph manages the state,
# and this agent acts as a node that modifies it.
class AgentState(TypedDict):
"""Represents the state of our agent at any point in time."""
user_input: str
retrieved_context: List[str]
conversation_history: List[Tuple[str, str]]
response: str
# Future fields can include emotional analysis, pacing signals, etc.
class SmallMomentsRoleplayAgent:
"""
An agent designed to act as an emotional companion for the Small Moments Walk.
This agent is architected for use within a LangGraph flow. It uses a
Retrieval-Augmented Generation (RAG) pattern to ground its responses in
relevant user history and walk-specific content.
"""
def __init__(self, config: Optional[dict[str, Any]] = None):
"""
Initializes the agent with a specific configuration.
Args:
config (dict, optional): Configuration for persona, model, etc.
"""
if config is None:
config = {}
self.persona = config.get(
"persona",
"You are a warm, empathetic companion. Your role is to listen, ask gentle questions, and share relatable anecdotes. You are here to help the user feel more comfortable and connected in small, everyday moments.",
)
# -- Modular Components --
# Initialize the actual retrieval engine.
# In a real app, the path might come from a config file.
self.retrieval_engine = RetrievalEngine()
# embedding_model_name="all-MiniLM-L6-v2"
# The language model can be swapped (e.g., local Ollama vs. OpenAI API).
# A future LanguageModelFactory would handle this based on config.
model_name = os.environ.get("AGENT_MODEL", "ollama")
if model_name == "ollama":
self.language_model: BaseModel = OllamaModel(
model_name=os.environ.get("OLLAMA_MODEL", "llama3")
)
else:
# Placeholder for future backends (e.g., OpenAIModel, MockModel)
raise ValueError(f"[Init Error] Unknown model backend: {model_name}")
# self.language_model = LanguageModelFactory.create(model_name)
# For now, we'll keep a mock-like structure for the LLM part.
print(f"[Info] Language model backend set to: {model_name}")
def _format_prompt(self, state: AgentState) -> str:
"""
Constructs the full prompt from the agent's state.
"""
history_str = "\n".join(
[f"Human: {h}\nAI: {a}" for h, a in state["conversation_history"]]
)
context_str = "\n".join([f"- {item}" for item in state["retrieved_context"]])
prompt = f"""**System Persona:**
{self.persona}
**Retrieved Context:**
{context_str}
**Conversation History:**
{history_str}
**Current User Input:**
Human: {state['user_input']}
**Your Turn:**
AI:"""
return prompt
@traceable(
name="SmallMomentsAgent.get_response",
tags=["walkxr", "agent"],
metadata={"agent_version": "v1.0", "llm": "llama3"}
)
def get_response(self, user_input: str, history: List[Tuple[str, str]]) -> Dict[str, Any]:
"""
Generates a response using the RAG pattern. This method can be called
directly or wrapped as a node in a LangGraph.
Args:
user_input: The latest input from the user.
history: The full conversation history.
Returns:
The AI's generated response.
"""
with tracing_context() as ctx:
trace_info = {}
if ctx is not None:
trace_info = {
"trace_id": ctx.trace_id,
"run_id": ctx.run_id
}
# 1. Retrieve context.
retrieved_context = self.retrieve_context(user_input)
# 2. Assemble the state for this turn.
current_state: AgentState = {
"user_input": user_input,
"retrieved_context": retrieved_context,
"conversation_history": history,
"response": "", # To be filled.
}
# 3. Format the prompt.
prompt = self._format_prompt(current_state)
# 4. Generate a response (mocked for now).
# In a real implementation, this would be:
# ai_response = self.language_model.generate_response(prompt)
print(f"[Debug] Generating response for prompt:\n---\n{prompt}\n---")
#ai_response = "That's a really interesting thought. It reminds me of a time when... [mocked response]"
ai_response =self.language_model.generate_response(
prompt=prompt,
history=history
)
return ChatResponse(
response_text=ai_response,
source_chunks=retrieved_context,
debug={
"prompt": prompt,
"langsmith_trace": trace_info
}
).dict()
@traceable(name="SmallMomentsAgent.retrieve_context", tags=["retrieval"])
def retrieve_context(self, query: str) -> List[str]:
try:
result = self.retrieval_engine.query(query)
return [result.response]
except Exception as e:
return [f"[Retrieval error] {str(e)}"]