""" Database Module for LLM System - Contains the `VectorDB` class to manage a vector database using FAISS and Ollama embeddings. - Provides methods to initialize the database, retrieve embeddings, and perform similarity searches. """ import os from typing import Tuple, Optional from langchain_core.documents import Document from langchain_ollama import OllamaEmbeddings from langchain_community.vectorstores import FAISS from langchain_core.runnables import ConfigurableField # For type hinting from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStoreRetriever # config: from llm_system.config import VECTOR_DB_PERSIST_DIR, VECTOR_DB_INDEX_NAME from logger import get_logger log = get_logger(name="core_database") class VectorDB: """A class to manage the vector database using FAISS and Ollama embeddings. Args: embed_model (str): The name of the Ollama embeddings model to use. retriever_num_docs (int): Number of documents to retrieve for similarity search. verify_connection (bool): Whether to verify the connection to the embeddings model. persist_path (str, optional): Path to the persisted FAISS database. If None, a new DB is created. index_name (str, optional): Name of the FAISS index file. Defaults to "index.faiss". ## Functions: + `get_embeddings()`: Returns the Ollama embeddings model. + `get_vector_store()`: Returns the FAISS vector store. + `get_retriever()`: Returns the retriever configured for similarity search. """ def __init__( self, embed_model: str, retriever_num_docs: int = 5, verify_connection: bool = False, persist_path: Optional[str] = VECTOR_DB_PERSIST_DIR, index_name: Optional[str] = VECTOR_DB_INDEX_NAME ): self.persist_path: Optional[str] = persist_path self.index_name: Optional[str] = index_name log.info( f"Initializing VectorDB with embeddings='{embed_model}', path='{persist_path}', k={retriever_num_docs} docs." ) # Here, I have configured the model to be loaded on CPU completely. # Reason: Ollama keeps alternately loading and unloading the LLM/Emb model on GPU. # Solution: Load the LLM on GPU and the Embedding model on CPU 100%. self.embeddings = OllamaEmbeddings(model=embed_model, num_gpu=0, keep_alive=-1) if verify_connection: try: self.embeddings.embed_documents(['a']) log.info(f"Embeddings model '{embed_model}' initialized and verified.") except Exception as e: log.error(f"Failed to initialize Embeddings: {e}") raise RuntimeError(f"Couldn't initialize Embeddings model '{embed_model}'") from e else: log.warning(f"Embeddings '{embed_model}' initialized without connection verification.") # Create a dummy document to initialize the FAISS vector store: dummy_doc = Document( page_content="Hello World!", metadata={"user_id": "public", 'source': "test document"} ) # Load faiss from disk: if persist_path and index_name: database_file = os.path.join(persist_path, index_name) if not os.path.exists(database_file): self.db = FAISS.from_documents([dummy_doc], embedding=self.embeddings) self.db.save_local(persist_path) log.info("Created a new FAISS vector store on disk with a dummy document.") else: log.info(f"Found existing FAISS vector store at '{database_file}'.") self.db = FAISS.load_local( persist_path, self.embeddings, allow_dangerous_deserialization=True) # Create one temp, in memory, FAISS vector store: else: self.db = FAISS.from_documents([dummy_doc], embedding=self.embeddings) log.info("Created a new FAISS vector store in memory with a dummy document.") # self.retriever = self.db.as_retriever( # search_type="similarity", # search_kwargs={"k": retriever_num_docs, "filter":{"user_id": "public"}}, # ) # log.info(f"Created retriever with k={retriever_num_docs}.") # Simple retriever does not have way to pass some filters with rag_chain.invoke() # Basically no way to pass args at runtime # Hence, using configurable retriever: # https://github.com/langchain-ai/langchain/issues/9195#issuecomment-2095196865 retriever = self.db.as_retriever() configurable_retriever = retriever.configurable_fields( search_kwargs=ConfigurableField( id="search_kwargs", name="Search Kwargs", description="The search kwargs to use", ) ) # call it like this: # configurable_retriever.invoke( # input="What is the Sun?", # config={"configurable": { # "search_kwargs": { # "k": 5, # "search_type": "similarity", # # And here comes the main thing: # "filter": { # "$or": [ # {"user_id": "curious_cat"}, # {"user_id": "public"} # ] # }, # } # }} # ) self.retriever = configurable_retriever log.info(f"Created configurable retriever.") def get_embeddings(self) -> Embeddings: log.info("Returning the Embeddings model instance.") return self.embeddings def get_vector_store(self) -> VectorStore: log.info("Returning the FAISS vector store instance.") return self.db def get_retriever(self) -> VectorStoreRetriever: log.info("Returning the retriever for similarity search.") return self.retriever # type: ignore[return-value] def save_db_to_disk(self) -> bool: """Saves the current vector store to disk if a persist path is set. Returns: bool: True if the vector store was saved successfully, False otherwise. """ if self.persist_path and self.index_name: try: # Somehow, loading needs 'index.faiss', but saving needs only 'index'. # index_base_name = self.index_name[:-6] if self.index_name.endswith('.faiss') else self.index_name if self.index_name.endswith('.faiss'): index_base_name = self.index_name[:-6] else: index_base_name = self.index_name self.db.save_local(self.persist_path, index_name=index_base_name) log.info(f"Vector store saved to disk at '{self.persist_path}/{self.index_name}'.") return True except Exception as e: log.error(f"Failed to save vector store to disk: {e}") return False else: log.warning("Skipped saving to disk as no persist path is set.") return True