sanchitshaleen
Initial deployment of RAG with Gemma-3 to Hugging Face Spaces
4aec76b
""" Database Module for LLM System
- Contains the `VectorDB` class to manage a vector database using FAISS and Ollama embeddings.
- Provides methods to initialize the database, retrieve embeddings, and perform similarity searches.
"""
import os
from typing import Tuple, Optional
from langchain_core.documents import Document
from langchain_ollama import OllamaEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.runnables import ConfigurableField
# For type hinting
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from langchain_core.vectorstores import VectorStoreRetriever
# config:
from llm_system.config import VECTOR_DB_PERSIST_DIR, VECTOR_DB_INDEX_NAME
from logger import get_logger
log = get_logger(name="core_database")
class VectorDB:
"""A class to manage the vector database using FAISS and Ollama embeddings.
Args:
embed_model (str): The name of the Ollama embeddings model to use.
retriever_num_docs (int): Number of documents to retrieve for similarity search.
verify_connection (bool): Whether to verify the connection to the embeddings model.
persist_path (str, optional): Path to the persisted FAISS database. If None, a new DB is created.
index_name (str, optional): Name of the FAISS index file. Defaults to "index.faiss".
## Functions:
+ `get_embeddings()`: Returns the Ollama embeddings model.
+ `get_vector_store()`: Returns the FAISS vector store.
+ `get_retriever()`: Returns the retriever configured for similarity search.
"""
def __init__(
self, embed_model: str,
retriever_num_docs: int = 5,
verify_connection: bool = False,
persist_path: Optional[str] = VECTOR_DB_PERSIST_DIR,
index_name: Optional[str] = VECTOR_DB_INDEX_NAME
):
self.persist_path: Optional[str] = persist_path
self.index_name: Optional[str] = index_name
log.info(
f"Initializing VectorDB with embeddings='{embed_model}', path='{persist_path}', k={retriever_num_docs} docs."
)
# Here, I have configured the model to be loaded on CPU completely.
# Reason: Ollama keeps alternately loading and unloading the LLM/Emb model on GPU.
# Solution: Load the LLM on GPU and the Embedding model on CPU 100%.
self.embeddings = OllamaEmbeddings(model=embed_model, num_gpu=0, keep_alive=-1)
if verify_connection:
try:
self.embeddings.embed_documents(['a'])
log.info(f"Embeddings model '{embed_model}' initialized and verified.")
except Exception as e:
log.error(f"Failed to initialize Embeddings: {e}")
raise RuntimeError(f"Couldn't initialize Embeddings model '{embed_model}'") from e
else:
log.warning(f"Embeddings '{embed_model}' initialized without connection verification.")
# Create a dummy document to initialize the FAISS vector store:
dummy_doc = Document(
page_content="Hello World!",
metadata={"user_id": "public", 'source': "test document"}
)
# Load faiss from disk:
if persist_path and index_name:
database_file = os.path.join(persist_path, index_name)
if not os.path.exists(database_file):
self.db = FAISS.from_documents([dummy_doc], embedding=self.embeddings)
self.db.save_local(persist_path)
log.info("Created a new FAISS vector store on disk with a dummy document.")
else:
log.info(f"Found existing FAISS vector store at '{database_file}'.")
self.db = FAISS.load_local(
persist_path, self.embeddings, allow_dangerous_deserialization=True)
# Create one temp, in memory, FAISS vector store:
else:
self.db = FAISS.from_documents([dummy_doc], embedding=self.embeddings)
log.info("Created a new FAISS vector store in memory with a dummy document.")
# self.retriever = self.db.as_retriever(
# search_type="similarity",
# search_kwargs={"k": retriever_num_docs, "filter":{"user_id": "public"}},
# )
# log.info(f"Created retriever with k={retriever_num_docs}.")
# Simple retriever does not have way to pass some filters with rag_chain.invoke()
# Basically no way to pass args at runtime
# Hence, using configurable retriever:
# https://github.com/langchain-ai/langchain/issues/9195#issuecomment-2095196865
retriever = self.db.as_retriever()
configurable_retriever = retriever.configurable_fields(
search_kwargs=ConfigurableField(
id="search_kwargs",
name="Search Kwargs",
description="The search kwargs to use",
)
)
# call it like this:
# configurable_retriever.invoke(
# input="What is the Sun?",
# config={"configurable": {
# "search_kwargs": {
# "k": 5,
# "search_type": "similarity",
# # And here comes the main thing:
# "filter": {
# "$or": [
# {"user_id": "curious_cat"},
# {"user_id": "public"}
# ]
# },
# }
# }}
# )
self.retriever = configurable_retriever
log.info(f"Created configurable retriever.")
def get_embeddings(self) -> Embeddings:
log.info("Returning the Embeddings model instance.")
return self.embeddings
def get_vector_store(self) -> VectorStore:
log.info("Returning the FAISS vector store instance.")
return self.db
def get_retriever(self) -> VectorStoreRetriever:
log.info("Returning the retriever for similarity search.")
return self.retriever # type: ignore[return-value]
def save_db_to_disk(self) -> bool:
"""Saves the current vector store to disk if a persist path is set.
Returns:
bool: True if the vector store was saved successfully, False otherwise.
"""
if self.persist_path and self.index_name:
try:
# Somehow, loading needs 'index.faiss', but saving needs only 'index'.
# index_base_name = self.index_name[:-6] if self.index_name.endswith('.faiss') else self.index_name
if self.index_name.endswith('.faiss'):
index_base_name = self.index_name[:-6]
else:
index_base_name = self.index_name
self.db.save_local(self.persist_path, index_name=index_base_name)
log.info(f"Vector store saved to disk at '{self.persist_path}/{self.index_name}'.")
return True
except Exception as e:
log.error(f"Failed to save vector store to disk: {e}")
return False
else:
log.warning("Skipped saving to disk as no persist path is set.")
return True