# src/walkxr_ai/rag/retrieval_engine.py
import os
import yaml
import logging
from typing import Any, Optional
import chromadb
from llama_index.core import (
Settings,
VectorStoreIndex,
StorageContext,
SimpleDirectoryReader,
)
from llama_index.core.node_parser import SentenceSplitter, SemanticSplitterNodeParser
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.llms.ollama import Ollama
from llama_index.vector_stores.chroma import ChromaVectorStore
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
class RetrievalEngine:
"""
A professional, configurable RAG engine for WalkXR-AI.
This engine handles document ingestion, indexing, and querying with a focus
on modularity and configurability to support custom agent development. It
supports multiple chunking strategies and is designed to be used as a
'tool' by agentic systems.
"""
def __init__(self, config_path: Optional[str] = None):
"""
Initializes the engine by loading configuration, setting up models,
and preparing the vector store.
"""
if config_path is None:
# Default path relative to the project root
config_path = os.path.join(os.path.dirname(__file__), "rag_config.yaml")
self.config = self._load_config(config_path)
self._setup_models()
self._setup_vector_store()
self._index = None
self._query_engine = None
logging.info("RetrievalEngine initialized successfully.")
def _load_config(self, config_path: str) -> dict:
"""Loads the RAG configuration from a YAML file."""
logging.info(f"Loading configuration from: {config_path}")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Configuration file not found at {config_path}")
with open(config_path, "r") as f:
return yaml.safe_load(f)
def _setup_models(self):
"""Sets up the embedding and language models based on the config."""
logging.info(f"Setting up embedding model: {self.config['embed_model']}")
Settings.embed_model = OllamaEmbedding(model_name=self.config["embed_model"])
logging.info(f"Setting up LLM: {self.config['llm_model']}")
Settings.llm = Ollama(model=self.config["llm_model"], request_timeout=120.0)
def _setup_vector_store(self):
"""Initializes the ChromaDB client and storage context."""
persist_dir = self.config["storage"]["persist_dir"]
if not os.path.exists(persist_dir):
logging.info(f"Creating persistence directory: {persist_dir}")
os.makedirs(persist_dir)
db = chromadb.PersistentClient(path=persist_dir)
chroma_collection = db.get_or_create_collection(
self.config["storage"]["collection_name"]
)
self.vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
self.storage_context = StorageContext.from_defaults(
vector_store=self.vector_store
)
logging.info(f"Vector store configured at: {persist_dir}")
def _get_node_parser(self):
"""
Selects and configures the node parser based on the chunking strategy
defined in the config file.
"""
strategy = self.config["chunking"]["strategy"]
logging.info(f"Using chunking strategy: {strategy}")
if strategy == "semantic":
# Uses the embedding model to find logical breakpoints in text.
# More advanced and context-aware.
return SemanticSplitterNodeParser(
buffer_size=self.config["chunking"]["semantic"]["buffer_size"],
breakpoint_percentile_threshold=self.config["chunking"]["semantic"][
"breakpoint_percentile_threshold"
],
embed_model=Settings.embed_model,
)
elif strategy == "sentence":
# Simple, fixed-size chunking. Faster but less context-aware.
return SentenceSplitter(
chunk_size=self.config["chunking"]["sentence"]["chunk_size"],
chunk_overlap=self.config["chunking"]["sentence"]["chunk_overlap"],
)
else:
raise ValueError(f"Unknown chunking strategy: {strategy}")
def ingest_documents(self):
"""
Loads documents from the data directory, processes them using the
configured chunking strategy, and builds the vector index.
"""
data_dir = self.config["data_dir"]
if not os.path.exists(data_dir):
raise FileNotFoundError(f"Data directory not found: {data_dir}")
logging.info("--- Starting Document Ingestion ---")
logging.info(f"Loading documents from: {data_dir}")
reader = SimpleDirectoryReader(data_dir, recursive=True)
documents = reader.load_data()
logging.info(f"Loaded {len(documents)} document(s).")
node_parser = self._get_node_parser()
nodes = node_parser.get_nodes_from_documents(documents)
logging.info(
f"Building index with {len(nodes)} nodes. This may take a moment..."
)
self._index = VectorStoreIndex(nodes, storage_context=self.storage_context)
logging.info("--- Ingestion Complete. Index has been built and persisted. ---")
def _load_index(self):
"""Loads the persisted index from the vector store if it exists."""
if self._index is None:
logging.info("Attempting to load existing index from vector store...")
try:
self._index = VectorStoreIndex.from_vector_store(
self.vector_store,
)
logging.info("Index loaded successfully from storage.")
except Exception as e:
# This is not a critical error, it just means we need to ingest.
logging.warning(
f"Could not load index from storage. It may not exist yet. Error: {e}"
)
self._index = None # Ensure index is None if loading fails
def get_query_engine(self) -> Any:
"""
Initializes and returns a query engine. This is the primary method
for agents to interact with the RAG pipeline.
Returns:
A LlamaIndex query engine instance.
"""
if self._query_engine is None:
self._load_index() # Attempt to load the index
if self._index is None:
# This is a critical error for querying.
logging.error(
"Index is not loaded. Please run the 'ingest' command first."
)
raise ValueError("Index is not loaded. Cannot create query engine.")
logging.info("Creating query engine...")
self._query_engine = self._index.as_query_engine(
similarity_top_k=self.config["retrieval"]["similarity_top_k"]
)
logging.info("Query engine created and ready.")
return self._query_engine
def query(self, query: str) -> Any:
"""
High-level interface for running a query against the RAG index.
Args:
query (str): The input query string.
Returns:
Any: The result from the underlying query engine.
"""
query_engine = self.get_query_engine()
return query_engine.query(query)