File size: 7,251 Bytes
4aec76b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
""" Database Module for LLM System
- Contains the `VectorDB` class to manage a vector database using FAISS and Ollama embeddings.
- Provides methods to initialize the database, retrieve embeddings, and perform similarity searches.
"""

import os
from typing import Tuple, Optional
from langchain_core.documents import Document
from langchain_ollama import OllamaEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.runnables import ConfigurableField

# For type hinting
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from langchain_core.vectorstores import VectorStoreRetriever

# config:
from llm_system.config import VECTOR_DB_PERSIST_DIR, VECTOR_DB_INDEX_NAME

from logger import get_logger
log = get_logger(name="core_database")


class VectorDB:
    """A class to manage the vector database using FAISS and Ollama embeddings.

    Args:
        embed_model (str): The name of the Ollama embeddings model to use.
        retriever_num_docs (int): Number of documents to retrieve for similarity search.
        verify_connection (bool): Whether to verify the connection to the embeddings model.
        persist_path (str, optional): Path to the persisted FAISS database. If None, a new DB is created.
        index_name (str, optional): Name of the FAISS index file. Defaults to "index.faiss".

    ## Functions:
        + `get_embeddings()`: Returns the Ollama embeddings model.
        + `get_vector_store()`: Returns the FAISS vector store.
        + `get_retriever()`: Returns the retriever configured for similarity search.
    """
    def __init__(
        self, embed_model: str,
        retriever_num_docs: int = 5,
        verify_connection: bool = False,
        persist_path: Optional[str] = VECTOR_DB_PERSIST_DIR,
        index_name: Optional[str] = VECTOR_DB_INDEX_NAME
    ):
        self.persist_path: Optional[str] = persist_path
        self.index_name: Optional[str] = index_name

        log.info(
            f"Initializing VectorDB with embeddings='{embed_model}', path='{persist_path}', k={retriever_num_docs} docs."
        )

        # Here, I have configured the model to be loaded on CPU completely.
        # Reason: Ollama keeps alternately loading and unloading the LLM/Emb model on GPU.
        # Solution: Load the LLM on GPU and the Embedding model on CPU 100%.
        self.embeddings = OllamaEmbeddings(model=embed_model, num_gpu=0, keep_alive=-1)

        if verify_connection:
            try:
                self.embeddings.embed_documents(['a'])
                log.info(f"Embeddings model '{embed_model}' initialized and verified.")

            except Exception as e:
                log.error(f"Failed to initialize Embeddings: {e}")
                raise RuntimeError(f"Couldn't initialize Embeddings model '{embed_model}'") from e
        else:
            log.warning(f"Embeddings '{embed_model}' initialized without connection verification.")

        # Create a dummy document to initialize the FAISS vector store:
        dummy_doc = Document(
            page_content="Hello World!",
            metadata={"user_id": "public", 'source': "test document"}
        )

        # Load faiss from disk:
        if persist_path and index_name:
            database_file = os.path.join(persist_path, index_name)

            if not os.path.exists(database_file):
                self.db = FAISS.from_documents([dummy_doc], embedding=self.embeddings)
                self.db.save_local(persist_path)
                log.info("Created a new FAISS vector store on disk with a dummy document.")
            else:
                log.info(f"Found existing FAISS vector store at '{database_file}'.")
                self.db = FAISS.load_local(
                    persist_path, self.embeddings, allow_dangerous_deserialization=True)

        # Create one temp, in memory, FAISS vector store:
        else:
            self.db = FAISS.from_documents([dummy_doc], embedding=self.embeddings)
            log.info("Created a new FAISS vector store in memory with a dummy document.")

        # self.retriever = self.db.as_retriever(
        #     search_type="similarity",
        #     search_kwargs={"k": retriever_num_docs, "filter":{"user_id": "public"}},
        # )
        # log.info(f"Created retriever with k={retriever_num_docs}.")

        # Simple retriever does not have way to pass some filters with rag_chain.invoke()
        # Basically no way to pass args at runtime
        # Hence, using configurable retriever:
        # https://github.com/langchain-ai/langchain/issues/9195#issuecomment-2095196865
        retriever = self.db.as_retriever()
        configurable_retriever = retriever.configurable_fields(
            search_kwargs=ConfigurableField(
                id="search_kwargs",
                name="Search Kwargs",
                description="The search kwargs to use",
            )
        )

        # call it like this:
        # configurable_retriever.invoke(
        #     input="What is the Sun?",
        #     config={"configurable": {
        #         "search_kwargs": {
        #             "k": 5,
        #             "search_type": "similarity",
        #             # And here comes the main thing:
        #             "filter": {
        #                 "$or": [
        #                     {"user_id": "curious_cat"},
        #                     {"user_id": "public"}
        #                 ]
        #             },
        #         }
        #     }}
        # )

        self.retriever = configurable_retriever
        log.info(f"Created configurable retriever.")

    def get_embeddings(self) -> Embeddings:
        log.info("Returning the Embeddings model instance.")
        return self.embeddings

    def get_vector_store(self) -> VectorStore:
        log.info("Returning the FAISS vector store instance.")
        return self.db

    def get_retriever(self) -> VectorStoreRetriever:
        log.info("Returning the retriever for similarity search.")
        return self.retriever # type: ignore[return-value]

    def save_db_to_disk(self) -> bool:
        """Saves the current vector store to disk if a persist path is set.
        Returns:
            bool: True if the vector store was saved successfully, False otherwise.
        """

        if self.persist_path and self.index_name:
            try:
                # Somehow, loading needs 'index.faiss', but saving needs only 'index'.
                # index_base_name = self.index_name[:-6] if self.index_name.endswith('.faiss') else self.index_name
                if self.index_name.endswith('.faiss'):
                    index_base_name = self.index_name[:-6]
                else:
                    index_base_name = self.index_name

                self.db.save_local(self.persist_path, index_name=index_base_name)
                log.info(f"Vector store saved to disk at '{self.persist_path}/{self.index_name}'.")
                return True
            except Exception as e:
                log.error(f"Failed to save vector store to disk: {e}")
                return False
        else:
            log.warning("Skipped saving to disk as no persist path is set.")
            return True