retrieval-augmented-generation / src / rag_pipeline / embed_and_build_dense_index.py
embed_and_build_dense_index.py
Raw
# # Files used:
# conf.files.context
# conf.files.index
# conf.files.embeddings

import logging
import csv
import io

import hydra
from omegaconf import DictConfig, OmegaConf
import faiss
from datasets import load_dataset, IterableDataset
from tqdm.auto import tqdm

from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document

log = logging.getLogger(__name__)


def get_hnsw_pq_index(conf: DictConfig) -> faiss.IndexHNSWPQ:
    index = faiss.index_factory(
        conf.embeddings.dim,
        f"HNSW{conf.indexes.hnsw_pq.hnsw_m},PQ{conf.indexes.hnsw_pq.pq_m}x{conf.indexes.hnsw_pq.bits}",
    )
    index.hnsw.efConstruction = conf.indexes.hnsw_pq.ef_construction
    index.hnsw.efSearch = conf.indexes.hnsw_pq.ef_search
    return index


def get_hnsw_index(conf: DictConfig) -> faiss.IndexHNSW:
    index = faiss.index_factory(
        conf.embeddings.dim,
        f"HNSW{conf.indexes.hnsw_pq.hnsw_m}",
    )
    index.hnsw.efConstruction = conf.indexes.hnsw_pq.ef_construction
    index.hnsw.efSearch = conf.indexes.hnsw_pq.ef_search
    return index


def get_index(conf: DictConfig) -> faiss.Index:
    if conf.indexing.index_type == "hnsw_pq":
        return get_hnsw_pq_index(conf)
    elif conf.indexing.index_type == "hnsw":
        return get_hnsw_index(conf)
    else:
        raise ValueError(f"Unknown index type: {conf.indexing.index_type}")


def get_embedding_model(conf: DictConfig) -> HuggingFaceEmbeddings:
    return HuggingFaceEmbeddings(
        model_name=conf.embeddings.model,
        model_kwargs={"trust_remote_code": True, "device": conf.embeddings.device},
    )


def get_faiss_store(conf: DictConfig, embedding_model: HuggingFaceEmbeddings) -> FAISS:
    index = get_index(conf)
    return FAISS(
        embedding_function=embedding_model,
        index=index,
        docstore=InMemoryDocstore(),
        index_to_docstore_id={},
    )


def get_dataset(
    conf,
) -> IterableDataset:

    with open(conf.files.context, "r") as f:
        num_lines = sum(1 for _ in f)

    dataset = load_dataset(
        "json",
        data_files={"train": [conf.files.context]},
        streaming=True,
        split="train",
    )

    batched_dataset = dataset.batch(batch_size=conf.indexing.batch_size)
    return batched_dataset, num_lines


def add_batch(
    embedding_model: HuggingFaceEmbeddings,
    faiss_store: FAISS,
    csv_writer: csv.writer,
    f: io.TextIOWrapper,
    batch: dict,
) -> None:
    embeddings = embedding_model.embed_documents(batch["text_content"])
    text_and_embeddings = zip(batch["text_content"], embeddings)

    metadatas = [
        {
            "source_name": source_name,
            "associated_dates": associated_date,
            "chunk_id": chunk_id,
        }
        for source_name, associated_date, chunk_id in zip(
            batch["source_name"], batch["associated_dates"], batch["chunk_id"]
        )
    ]

    faiss_store.add_embeddings(
        text_embeddings=text_and_embeddings, metadatas=metadatas, ids=batch["chunk_id"]
    )
    csv_writer.writerows([[f"{x:.8f}" for x in row] for row in embeddings])
    f.flush()


def save_index(conf: DictConfig, faiss_store: FAISS) -> None:
    faiss_store.save_local(conf.files.index)


@hydra.main(version_base=None, config_path="conf", config_name="config")
def main(conf: DictConfig) -> None:
    log.info("Config:\n%s", OmegaConf.to_yaml(conf))

    embedding_model = get_embedding_model(conf)
    faiss_store = get_faiss_store(conf, embedding_model)
    dataset, num_lines = get_dataset(conf)

    with open(conf.files.embeddings, "a", newline="") as f:
        csv_writer = csv.writer(f, delimiter="\t")

        for batch in tqdm(dataset, total=num_lines // conf.indexing.batch_size):
            add_batch(embedding_model, faiss_store, csv_writer, f, batch)

    save_index(conf, faiss_store)


if __name__ == "__main__":
    main()