# src/walkxr_ai/rag/retrieval_engine.py import os import yaml import logging from typing import Any, Optional import chromadb from llama_index.core import ( Settings, VectorStoreIndex, StorageContext, SimpleDirectoryReader, ) from llama_index.core.node_parser import SentenceSplitter, SemanticSplitterNodeParser from llama_index.embeddings.ollama import OllamaEmbedding from llama_index.llms.ollama import Ollama from llama_index.vector_stores.chroma import ChromaVectorStore # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) class RetrievalEngine: """ A professional, configurable RAG engine for WalkXR-AI. This engine handles document ingestion, indexing, and querying with a focus on modularity and configurability to support custom agent development. It supports multiple chunking strategies and is designed to be used as a 'tool' by agentic systems. """ def __init__(self, config_path: Optional[str] = None): """ Initializes the engine by loading configuration, setting up models, and preparing the vector store. """ if config_path is None: # Default path relative to the project root config_path = os.path.join(os.path.dirname(__file__), "rag_config.yaml") self.config = self._load_config(config_path) self._setup_models() self._setup_vector_store() self._index = None self._query_engine = None logging.info("RetrievalEngine initialized successfully.") def _load_config(self, config_path: str) -> dict: """Loads the RAG configuration from a YAML file.""" logging.info(f"Loading configuration from: {config_path}") if not os.path.exists(config_path): raise FileNotFoundError(f"Configuration file not found at {config_path}") with open(config_path, "r") as f: return yaml.safe_load(f) def _setup_models(self): """Sets up the embedding and language models based on the config.""" logging.info(f"Setting up embedding model: {self.config['embed_model']}") Settings.embed_model = OllamaEmbedding(model_name=self.config["embed_model"]) logging.info(f"Setting up LLM: {self.config['llm_model']}") Settings.llm = Ollama(model=self.config["llm_model"], request_timeout=120.0) def _setup_vector_store(self): """Initializes the ChromaDB client and storage context.""" persist_dir = self.config["storage"]["persist_dir"] if not os.path.exists(persist_dir): logging.info(f"Creating persistence directory: {persist_dir}") os.makedirs(persist_dir) db = chromadb.PersistentClient(path=persist_dir) chroma_collection = db.get_or_create_collection( self.config["storage"]["collection_name"] ) self.vector_store = ChromaVectorStore(chroma_collection=chroma_collection) self.storage_context = StorageContext.from_defaults( vector_store=self.vector_store ) logging.info(f"Vector store configured at: {persist_dir}") def _get_node_parser(self): """ Selects and configures the node parser based on the chunking strategy defined in the config file. """ strategy = self.config["chunking"]["strategy"] logging.info(f"Using chunking strategy: {strategy}") if strategy == "semantic": # Uses the embedding model to find logical breakpoints in text. # More advanced and context-aware. return SemanticSplitterNodeParser( buffer_size=self.config["chunking"]["semantic"]["buffer_size"], breakpoint_percentile_threshold=self.config["chunking"]["semantic"][ "breakpoint_percentile_threshold" ], embed_model=Settings.embed_model, ) elif strategy == "sentence": # Simple, fixed-size chunking. Faster but less context-aware. return SentenceSplitter( chunk_size=self.config["chunking"]["sentence"]["chunk_size"], chunk_overlap=self.config["chunking"]["sentence"]["chunk_overlap"], ) else: raise ValueError(f"Unknown chunking strategy: {strategy}") def ingest_documents(self): """ Loads documents from the data directory, processes them using the configured chunking strategy, and builds the vector index. """ data_dir = self.config["data_dir"] if not os.path.exists(data_dir): raise FileNotFoundError(f"Data directory not found: {data_dir}") logging.info("--- Starting Document Ingestion ---") logging.info(f"Loading documents from: {data_dir}") reader = SimpleDirectoryReader(data_dir, recursive=True) documents = reader.load_data() logging.info(f"Loaded {len(documents)} document(s).") node_parser = self._get_node_parser() nodes = node_parser.get_nodes_from_documents(documents) logging.info( f"Building index with {len(nodes)} nodes. This may take a moment..." ) self._index = VectorStoreIndex(nodes, storage_context=self.storage_context) logging.info("--- Ingestion Complete. Index has been built and persisted. ---") def _load_index(self): """Loads the persisted index from the vector store if it exists.""" if self._index is None: logging.info("Attempting to load existing index from vector store...") try: self._index = VectorStoreIndex.from_vector_store( self.vector_store, ) logging.info("Index loaded successfully from storage.") except Exception as e: # This is not a critical error, it just means we need to ingest. logging.warning( f"Could not load index from storage. It may not exist yet. Error: {e}" ) self._index = None # Ensure index is None if loading fails def get_query_engine(self) -> Any: """ Initializes and returns a query engine. This is the primary method for agents to interact with the RAG pipeline. Returns: A LlamaIndex query engine instance. """ if self._query_engine is None: self._load_index() # Attempt to load the index if self._index is None: # This is a critical error for querying. logging.error( "Index is not loaded. Please run the 'ingest' command first." ) raise ValueError("Index is not loaded. Cannot create query engine.") logging.info("Creating query engine...") self._query_engine = self._index.as_query_engine( similarity_top_k=self.config["retrieval"]["similarity_top_k"] ) logging.info("Query engine created and ready.") return self._query_engine def query(self, query: str) -> Any: """ High-level interface for running a query against the RAG index. Args: query (str): The input query string. Returns: Any: The result from the underlying query engine. """ query_engine = self.get_query_engine() return query_engine.query(query)