import os
import sys
import typer
from rich.console import Console
from rich.panel import Panel
# Add the project root to the Python path to allow for absolute imports
# This is a robust way to handle imports in scripts, regardless of where the script is run from
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PROJECT_ROOT)
from src.walkxr_ai.rag.retrieval_engine import RetrievalEngine # noqa: E402
app = typer.Typer(
help="A command-line tool to manage the WalkXR-AI RAG knowledge base."
)
console = Console()
@app.command()
def ingest():
"""
Loads documents from the data directory, processes them, and builds the vector store index.
This command should be run whenever the knowledge base documents are updated.
"""
console.print(
Panel(
"[bold cyan]Starting RAG Ingestion Process[/bold cyan]",
title="WalkXR-AI RAG Management",
expand=False,
)
)
try:
engine = RetrievalEngine()
engine.ingest_documents()
console.print(
"\n[bold green]✅ Ingestion complete. The knowledge base is ready.[/bold green]"
)
except Exception as e:
console.print(
f"\n[bold red]❌ An error occurred during ingestion: {e}[/bold red]"
)
raise typer.Exit(code=1)
@app.command()
def query(
query_text: str = typer.Argument(..., help="The question to ask the RAG system."),
):
"""
Queries the existing RAG index and prints the retrieved context.
This is used to test the retrieval quality of your knowledge base.
"""
console.print(
Panel(
"[bold cyan]Querying RAG System[/bold cyan]",
title="WalkXR-AI RAG Management",
expand=False,
)
)
console.print(f"[bold]Query:[/bold] {query_text}")
try:
engine = RetrievalEngine()
query_engine = engine.get_query_engine()
response = query_engine.query(query_text)
console.print(
"\n[bold green]✅ Query successful. Retrieved context:[/bold green]"
)
console.print(str(response))
except Exception as e:
console.print(f"\n[bold red]❌ An error occurred during query: {e}[/bold red]")
console.print(
"[bold yellow]Hint: Did you run the 'ingest' command first?[/bold yellow]"
)
raise typer.Exit(code=1)
if __name__ == "__main__":
app()