retrieval-augmented-generation / src / rag_pipeline / build_sparse_index.py
build_sparse_index.py
Raw
# # Files used:
# conf.files.context
# conf.files.sparse_index

import json
import logging
from datetime import datetime
import os

import hydra
from omegaconf import DictConfig, OmegaConf
import numpy as np
import pandas as pd
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from scipy.sparse import csr_matrix, save_npz

log = logging.getLogger(__name__)


def save_index(
    conf: DictConfig, idf: np.ndarray, tfidf_matrix: csr_matrix, tfidf_df: pd.DataFrame
) -> None:
    os.makedirs(conf.files.sparse_index, exist_ok=True)
    save_npz(
        os.path.join(conf.files.sparse_index, "tfidf_matrix.npz"),
        tfidf_matrix,
        compressed=True,
    )

    metadata = {
        "idf": idf.tolist(),
        "chunk_ids": tfidf_df.index.tolist(),
        "features": tfidf_df.columns.tolist(),
    }

    with open(os.path.join(conf.files.sparse_index, "metadata.json"), "w") as f:
        json.dump(metadata, f)


def load_documents(conf: DictConfig) -> tuple[list[int], list[str]]:
    document_ids = []
    documents = []
    with open(conf.files.context, "r") as f:
        for line in f:
            data = json.loads(line)
            document_ids.append(int(data["chunk_id"]))

            extracted_entities = [x["entity"] for x in data["extracted_entities"]]
            extracted_dates = [
                datetime.fromisoformat(x).strftime("%Y-%m-%d")
                for x in data["extracted_dates"]
            ]
            associated_dates = data["associated_dates"]

            all_entities = extracted_entities + extracted_dates + associated_dates
            all_entities = [x.strip().replace("\t", " ") for x in all_entities]
            all_entities = [x for x in all_entities if len(x) > 0]

            documents.append("\t".join(all_entities))

    return document_ids, documents


def build_sparse_index(
    document_ids: list[int], documents: list[str]
) -> tuple[np.ndarray, csr_matrix, pd.DataFrame]:
    vectorizer = TfidfVectorizer(
        lowercase=True,
        tokenizer=lambda x: x.strip().split("\t"),
        use_idf=True,
        smooth_idf=True,
    )
    tfidf_matrix = vectorizer.fit_transform(documents)
    tfidf_df = pd.DataFrame(
        tfidf_matrix.toarray(),
        index=document_ids,
        columns=vectorizer.get_feature_names_out(),
    )

    print("TF-IDF matrix shape:", tfidf_df.shape)
    return vectorizer.idf_, tfidf_matrix, tfidf_df


@hydra.main(version_base=None, config_path="conf", config_name="sparse_index")
def main(conf: DictConfig) -> None:
    log.info("Config:\n%s", OmegaConf.to_yaml(conf))

    document_ids, documents = load_documents(conf)
    idf, tfidf_matrix, tfidf_df = build_sparse_index(document_ids, documents)
    save_index(conf, idf, tfidf_matrix, tfidf_df)


if __name__ == "__main__":
    main()