|
|
""" Database Module for LLM System |
|
|
- Contains the `VectorDB` class to manage a vector database using FAISS and Ollama embeddings. |
|
|
- Provides methods to initialize the database, retrieve embeddings, and perform similarity searches. |
|
|
""" |
|
|
|
|
|
import os |
|
|
from typing import Tuple, Optional |
|
|
from langchain_core.documents import Document |
|
|
from langchain_ollama import OllamaEmbeddings |
|
|
from langchain_community.vectorstores import FAISS |
|
|
from langchain_core.runnables import ConfigurableField |
|
|
|
|
|
|
|
|
from langchain_core.embeddings import Embeddings |
|
|
from langchain_core.vectorstores import VectorStore |
|
|
from langchain_core.vectorstores import VectorStoreRetriever |
|
|
|
|
|
|
|
|
from llm_system.config import VECTOR_DB_PERSIST_DIR, VECTOR_DB_INDEX_NAME |
|
|
|
|
|
from logger import get_logger |
|
|
log = get_logger(name="core_database") |
|
|
|
|
|
|
|
|
class VectorDB: |
|
|
"""A class to manage the vector database using FAISS and Ollama embeddings. |
|
|
|
|
|
Args: |
|
|
embed_model (str): The name of the Ollama embeddings model to use. |
|
|
retriever_num_docs (int): Number of documents to retrieve for similarity search. |
|
|
verify_connection (bool): Whether to verify the connection to the embeddings model. |
|
|
persist_path (str, optional): Path to the persisted FAISS database. If None, a new DB is created. |
|
|
index_name (str, optional): Name of the FAISS index file. Defaults to "index.faiss". |
|
|
|
|
|
## Functions: |
|
|
+ `get_embeddings()`: Returns the Ollama embeddings model. |
|
|
+ `get_vector_store()`: Returns the FAISS vector store. |
|
|
+ `get_retriever()`: Returns the retriever configured for similarity search. |
|
|
""" |
|
|
def __init__( |
|
|
self, embed_model: str, |
|
|
retriever_num_docs: int = 5, |
|
|
verify_connection: bool = False, |
|
|
persist_path: Optional[str] = VECTOR_DB_PERSIST_DIR, |
|
|
index_name: Optional[str] = VECTOR_DB_INDEX_NAME |
|
|
): |
|
|
self.persist_path: Optional[str] = persist_path |
|
|
self.index_name: Optional[str] = index_name |
|
|
|
|
|
log.info( |
|
|
f"Initializing VectorDB with embeddings='{embed_model}', path='{persist_path}', k={retriever_num_docs} docs." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.embeddings = OllamaEmbeddings(model=embed_model, num_gpu=0, keep_alive=-1) |
|
|
|
|
|
if verify_connection: |
|
|
try: |
|
|
self.embeddings.embed_documents(['a']) |
|
|
log.info(f"Embeddings model '{embed_model}' initialized and verified.") |
|
|
|
|
|
except Exception as e: |
|
|
log.error(f"Failed to initialize Embeddings: {e}") |
|
|
raise RuntimeError(f"Couldn't initialize Embeddings model '{embed_model}'") from e |
|
|
else: |
|
|
log.warning(f"Embeddings '{embed_model}' initialized without connection verification.") |
|
|
|
|
|
|
|
|
dummy_doc = Document( |
|
|
page_content="Hello World!", |
|
|
metadata={"user_id": "public", 'source': "test document"} |
|
|
) |
|
|
|
|
|
|
|
|
if persist_path and index_name: |
|
|
database_file = os.path.join(persist_path, index_name) |
|
|
|
|
|
if not os.path.exists(database_file): |
|
|
self.db = FAISS.from_documents([dummy_doc], embedding=self.embeddings) |
|
|
self.db.save_local(persist_path) |
|
|
log.info("Created a new FAISS vector store on disk with a dummy document.") |
|
|
else: |
|
|
log.info(f"Found existing FAISS vector store at '{database_file}'.") |
|
|
self.db = FAISS.load_local( |
|
|
persist_path, self.embeddings, allow_dangerous_deserialization=True) |
|
|
|
|
|
|
|
|
else: |
|
|
self.db = FAISS.from_documents([dummy_doc], embedding=self.embeddings) |
|
|
log.info("Created a new FAISS vector store in memory with a dummy document.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retriever = self.db.as_retriever() |
|
|
configurable_retriever = retriever.configurable_fields( |
|
|
search_kwargs=ConfigurableField( |
|
|
id="search_kwargs", |
|
|
name="Search Kwargs", |
|
|
description="The search kwargs to use", |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.retriever = configurable_retriever |
|
|
log.info(f"Created configurable retriever.") |
|
|
|
|
|
def get_embeddings(self) -> Embeddings: |
|
|
log.info("Returning the Embeddings model instance.") |
|
|
return self.embeddings |
|
|
|
|
|
def get_vector_store(self) -> VectorStore: |
|
|
log.info("Returning the FAISS vector store instance.") |
|
|
return self.db |
|
|
|
|
|
def get_retriever(self) -> VectorStoreRetriever: |
|
|
log.info("Returning the retriever for similarity search.") |
|
|
return self.retriever |
|
|
|
|
|
def save_db_to_disk(self) -> bool: |
|
|
"""Saves the current vector store to disk if a persist path is set. |
|
|
Returns: |
|
|
bool: True if the vector store was saved successfully, False otherwise. |
|
|
""" |
|
|
|
|
|
if self.persist_path and self.index_name: |
|
|
try: |
|
|
|
|
|
|
|
|
if self.index_name.endswith('.faiss'): |
|
|
index_base_name = self.index_name[:-6] |
|
|
else: |
|
|
index_base_name = self.index_name |
|
|
|
|
|
self.db.save_local(self.persist_path, index_name=index_base_name) |
|
|
log.info(f"Vector store saved to disk at '{self.persist_path}/{self.index_name}'.") |
|
|
return True |
|
|
except Exception as e: |
|
|
log.error(f"Failed to save vector store to disk: {e}") |
|
|
return False |
|
|
else: |
|
|
log.warning("Skipped saving to disk as no persist path is set.") |
|
|
return True |
|
|
|