WalkXR-AI / src / walkxr_ai / rag / retrieval_engine.py
retrieval_engine.py
Raw
# src/walkxr_ai/rag/retrieval_engine.py

import os
import yaml
import logging
from typing import Any, Optional

import chromadb
from llama_index.core import (
    Settings,
    VectorStoreIndex,
    StorageContext,
    SimpleDirectoryReader,
)
from llama_index.core.node_parser import SentenceSplitter, SemanticSplitterNodeParser
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.llms.ollama import Ollama
from llama_index.vector_stores.chroma import ChromaVectorStore

# Configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)


class RetrievalEngine:
    """
    A professional, configurable RAG engine for WalkXR-AI.

    This engine handles document ingestion, indexing, and querying with a focus
    on modularity and configurability to support custom agent development. It
    supports multiple chunking strategies and is designed to be used as a
    'tool' by agentic systems.
    """

    def __init__(self, config_path: Optional[str] = None):
        """
        Initializes the engine by loading configuration, setting up models,
        and preparing the vector store.
        """
        if config_path is None:
            # Default path relative to the project root
            config_path = os.path.join(os.path.dirname(__file__), "rag_config.yaml")

        self.config = self._load_config(config_path)
        self._setup_models()
        self._setup_vector_store()

        self._index = None
        self._query_engine = None
        logging.info("RetrievalEngine initialized successfully.")

    def _load_config(self, config_path: str) -> dict:
        """Loads the RAG configuration from a YAML file."""
        logging.info(f"Loading configuration from: {config_path}")
        if not os.path.exists(config_path):
            raise FileNotFoundError(f"Configuration file not found at {config_path}")
        with open(config_path, "r") as f:
            return yaml.safe_load(f)

    def _setup_models(self):
        """Sets up the embedding and language models based on the config."""
        logging.info(f"Setting up embedding model: {self.config['embed_model']}")
        Settings.embed_model = OllamaEmbedding(model_name=self.config["embed_model"])
        logging.info(f"Setting up LLM: {self.config['llm_model']}")
        Settings.llm = Ollama(model=self.config["llm_model"], request_timeout=120.0)

    def _setup_vector_store(self):
        """Initializes the ChromaDB client and storage context."""
        persist_dir = self.config["storage"]["persist_dir"]
        if not os.path.exists(persist_dir):
            logging.info(f"Creating persistence directory: {persist_dir}")
            os.makedirs(persist_dir)

        db = chromadb.PersistentClient(path=persist_dir)
        chroma_collection = db.get_or_create_collection(
            self.config["storage"]["collection_name"]
        )
        self.vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
        self.storage_context = StorageContext.from_defaults(
            vector_store=self.vector_store
        )
        logging.info(f"Vector store configured at: {persist_dir}")

    def _get_node_parser(self):
        """
        Selects and configures the node parser based on the chunking strategy
        defined in the config file.
        """
        strategy = self.config["chunking"]["strategy"]
        logging.info(f"Using chunking strategy: {strategy}")

        if strategy == "semantic":
            # Uses the embedding model to find logical breakpoints in text.
            # More advanced and context-aware.
            return SemanticSplitterNodeParser(
                buffer_size=self.config["chunking"]["semantic"]["buffer_size"],
                breakpoint_percentile_threshold=self.config["chunking"]["semantic"][
                    "breakpoint_percentile_threshold"
                ],
                embed_model=Settings.embed_model,
            )
        elif strategy == "sentence":
            # Simple, fixed-size chunking. Faster but less context-aware.
            return SentenceSplitter(
                chunk_size=self.config["chunking"]["sentence"]["chunk_size"],
                chunk_overlap=self.config["chunking"]["sentence"]["chunk_overlap"],
            )
        else:
            raise ValueError(f"Unknown chunking strategy: {strategy}")

    def ingest_documents(self):
        """
        Loads documents from the data directory, processes them using the
        configured chunking strategy, and builds the vector index.
        """
        data_dir = self.config["data_dir"]
        if not os.path.exists(data_dir):
            raise FileNotFoundError(f"Data directory not found: {data_dir}")

        logging.info("--- Starting Document Ingestion ---")
        logging.info(f"Loading documents from: {data_dir}")
        reader = SimpleDirectoryReader(data_dir, recursive=True)
        documents = reader.load_data()
        logging.info(f"Loaded {len(documents)} document(s).")

        node_parser = self._get_node_parser()
        nodes = node_parser.get_nodes_from_documents(documents)

        logging.info(
            f"Building index with {len(nodes)} nodes. This may take a moment..."
        )
        self._index = VectorStoreIndex(nodes, storage_context=self.storage_context)
        logging.info("--- Ingestion Complete. Index has been built and persisted. ---")

    def _load_index(self):
        """Loads the persisted index from the vector store if it exists."""
        if self._index is None:
            logging.info("Attempting to load existing index from vector store...")
            try:
                self._index = VectorStoreIndex.from_vector_store(
                    self.vector_store,
                )
                logging.info("Index loaded successfully from storage.")
            except Exception as e:
                # This is not a critical error, it just means we need to ingest.
                logging.warning(
                    f"Could not load index from storage. It may not exist yet. Error: {e}"
                )
                self._index = None  # Ensure index is None if loading fails

    def get_query_engine(self) -> Any:
        """
        Initializes and returns a query engine. This is the primary method
        for agents to interact with the RAG pipeline.

        Returns:
            A LlamaIndex query engine instance.
        """
        if self._query_engine is None:
            self._load_index()  # Attempt to load the index
            if self._index is None:
                # This is a critical error for querying.
                logging.error(
                    "Index is not loaded. Please run the 'ingest' command first."
                )
                raise ValueError("Index is not loaded. Cannot create query engine.")

            logging.info("Creating query engine...")
            self._query_engine = self._index.as_query_engine(
                similarity_top_k=self.config["retrieval"]["similarity_top_k"]
            )
            logging.info("Query engine created and ready.")
        return self._query_engine

    def query(self, query: str) -> Any:
        """
        High-level interface for running a query against the RAG index.

        Args:
            query (str): The input query string.

        Returns:
            Any: The result from the underlying query engine.
        """
        query_engine = self.get_query_engine()
        return query_engine.query(query)