from dataclasses import dataclass
import os
import sys
from omegaconf import DictConfig
from langchain_core.documents import Document
from .rag_pipeline.rag_validation import (
get_embedding_model,
load_faiss_store,
attach_reranker,
get_hyde_model,
get_reader_model,
get_prompt,
get_rag_chain,
retrieve_docs_batched,
parse_regex,
)
@dataclass
class Source:
name: str
text: str
index_id: int
class RagQA:
def __init__(self, conf: DictConfig):
self.rag_chain = None
self.hyde_pipeline = None
self.retriever = None
self.conf = conf
def load(self):
embedding_model = get_embedding_model(self.conf)
faiss_store = load_faiss_store(self.conf, embedding_model)
self.retriever = faiss_store.as_retriever()
if self.conf.rag.reranking.enabled:
self.retriever = attach_reranker(self.conf, self.retriever)
self.hyde_pipeline = None
if self.conf.rag.hyde.enabled or self.conf.rag.summary.enabled:
self.hyde_pipeline = get_hyde_model(self.conf)
reader_model = get_reader_model(self.conf)
prompt = get_prompt(self.conf)
self.rag_chain = get_rag_chain(self.conf, reader_model, prompt)
@staticmethod
def _docs_to_sources(docs: list[Document]) -> list[Source]:
return [
Source(
name=doc.metadata["source_name"],
text=doc.metadata["original_page_content"],
index_id=doc.metadata["chunk_id"],
)
for doc in docs
]
def answer(self, question: str) -> tuple[str, list[Source]]:
docs = retrieve_docs_batched(
self.conf,
self.retriever,
None, # Not using the sparse index.
self.hyde_pipeline,
self.hyde_pipeline, # Use the hyde model for summarization as well.
[question],
)
sources = self._docs_to_sources(docs[0]["docs"])
chain_output = self.rag_chain.batch(docs)
batch_answers = [
parse_regex(row["raw_output"])["answer"] for row in chain_output
]
answer = " ".join(batch_answers[0].strip().split())
return answer, sources