# # Files used: # conf.files.context # conf.files.index # conf.files.embeddings import logging import csv import io import hydra from omegaconf import DictConfig, OmegaConf import faiss from datasets import load_dataset, IterableDataset from tqdm.auto import tqdm from langchain_community.docstore.in_memory import InMemoryDocstore from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain_core.documents import Document log = logging.getLogger(__name__) def get_hnsw_pq_index(conf: DictConfig) -> faiss.IndexHNSWPQ: index = faiss.index_factory( conf.embeddings.dim, f"HNSW{conf.indexes.hnsw_pq.hnsw_m},PQ{conf.indexes.hnsw_pq.pq_m}x{conf.indexes.hnsw_pq.bits}", ) index.hnsw.efConstruction = conf.indexes.hnsw_pq.ef_construction index.hnsw.efSearch = conf.indexes.hnsw_pq.ef_search return index def get_hnsw_index(conf: DictConfig) -> faiss.IndexHNSW: index = faiss.index_factory( conf.embeddings.dim, f"HNSW{conf.indexes.hnsw_pq.hnsw_m}", ) index.hnsw.efConstruction = conf.indexes.hnsw_pq.ef_construction index.hnsw.efSearch = conf.indexes.hnsw_pq.ef_search return index def get_index(conf: DictConfig) -> faiss.Index: if conf.indexing.index_type == "hnsw_pq": return get_hnsw_pq_index(conf) elif conf.indexing.index_type == "hnsw": return get_hnsw_index(conf) else: raise ValueError(f"Unknown index type: {conf.indexing.index_type}") def get_embedding_model(conf: DictConfig) -> HuggingFaceEmbeddings: return HuggingFaceEmbeddings( model_name=conf.embeddings.model, model_kwargs={"trust_remote_code": True, "device": conf.embeddings.device}, ) def get_faiss_store(conf: DictConfig, embedding_model: HuggingFaceEmbeddings) -> FAISS: index = get_index(conf) return FAISS( embedding_function=embedding_model, index=index, docstore=InMemoryDocstore(), index_to_docstore_id={}, ) def get_dataset( conf, ) -> IterableDataset: with open(conf.files.context, "r") as f: num_lines = sum(1 for _ in f) dataset = load_dataset( "json", data_files={"train": [conf.files.context]}, streaming=True, split="train", ) batched_dataset = dataset.batch(batch_size=conf.indexing.batch_size) return batched_dataset, num_lines def add_batch( embedding_model: HuggingFaceEmbeddings, faiss_store: FAISS, csv_writer: csv.writer, f: io.TextIOWrapper, batch: dict, ) -> None: embeddings = embedding_model.embed_documents(batch["text_content"]) text_and_embeddings = zip(batch["text_content"], embeddings) metadatas = [ { "source_name": source_name, "associated_dates": associated_date, "chunk_id": chunk_id, } for source_name, associated_date, chunk_id in zip( batch["source_name"], batch["associated_dates"], batch["chunk_id"] ) ] faiss_store.add_embeddings( text_embeddings=text_and_embeddings, metadatas=metadatas, ids=batch["chunk_id"] ) csv_writer.writerows([[f"{x:.8f}" for x in row] for row in embeddings]) f.flush() def save_index(conf: DictConfig, faiss_store: FAISS) -> None: faiss_store.save_local(conf.files.index) @hydra.main(version_base=None, config_path="conf", config_name="config") def main(conf: DictConfig) -> None: log.info("Config:\n%s", OmegaConf.to_yaml(conf)) embedding_model = get_embedding_model(conf) faiss_store = get_faiss_store(conf, embedding_model) dataset, num_lines = get_dataset(conf) with open(conf.files.embeddings, "a", newline="") as f: csv_writer = csv.writer(f, delimiter="\t") for batch in tqdm(dataset, total=num_lines // conf.indexing.batch_size): add_batch(embedding_model, faiss_store, csv_writer, f, batch) save_index(conf, faiss_store) if __name__ == "__main__": main()