ai-rag / rag_engine.py
feersdilaa's picture
Rename rag-engine.py to rag_engine.py
bea0431 verified
import os
import time
from pathlib import Path
from typing import List, Optional
from sentence_transformers import SentenceTransformer
from langchain_core.documents import Document
from langchain_chroma import Chroma
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_google_genai import ChatGoogleGenerativeAI
# ---------------------------
# QWEN EMBEDDINGS WRAPPER
# ---------------------------
class QwenHFEmbeddings:
def __init__(self, model: str = "Qwen/Qwen3-Embedding-0.6B", batch_size: int = 8):
print(f"[INIT] Loading embedding model: {model}")
self.model = SentenceTransformer(model)
self.batch_size = batch_size
def _encode(self, texts, prompt_name=None):
if isinstance(texts, str):
texts = [texts]
outputs = []
for i in range(0, len(texts), self.batch_size):
batch = texts[i:i+self.batch_size]
emb = self.model.encode(
batch,
prompt_name=prompt_name,
convert_to_numpy=True
).tolist()
outputs.extend(emb)
return outputs
def embed_documents(self, texts):
return self._encode(texts)
def embed_query(self, text):
return self._encode(text, prompt_name="query")[0]
# ============================
# RAG ENGINE
# ============================
class NewsLegalAnalyzer:
def __init__(self, db_path: str = "db_hukum_Qwen"):
self.db_path = Path(db_path)
self.embeddings = None
self.vectordb = None
self.retriever = None
self.llm = None
self.chain = None
# ---------------------------
# LOAD MODELS
# ---------------------------
def load_embeddings(self):
self.embeddings = QwenHFEmbeddings()
return True
def load_vector_db(self):
if not self.db_path.exists():
raise FileNotFoundError("Folder database tidak ditemukan.")
self.vectordb = Chroma(
persist_directory=str(self.db_path),
embedding_function=self.embeddings
)
total = len(self.vectordb.get()["ids"])
print(f"[DB] Loaded {total} documents.")
return True
def load_llm(self, model="gemini-2.5-flash-lite"):
if "GOOGLE_API_KEY" not in os.environ:
raise EnvironmentError("GOOGLE_API_KEY belum diset di Hugging Face Secrets.")
self.llm = ChatGoogleGenerativeAI(
model=model,
temperature=0.4
)
return True
# ---------------------------
# RETRIEVER
# ---------------------------
def setup_retriever(self, k=15, fetch_k=50):
self.retriever = self.vectordb.as_retriever(
search_type="mmr",
search_kwargs={"k": k, "fetch_k": fetch_k}
)
return True
# ---------------------------
# CREATE CHAIN
# ---------------------------
def create_chain(self):
template = """
Anda adalah Asisten Editor Berita Kriminal.
Tugas Anda adalah memberikan pasal yang relevan terhadap kronologi kejadian.
REFERENSI:
{context}
BERITA:
{question}
Jawaban:
"""
prompt = PromptTemplate(
template=template,
input_variables=["context", "question"]
)
def format_docs(docs: List[Document]) -> str:
if not docs:
return "Tidak ada referensi hukum ditemukan."
return "\n\n".join(
f"[{i+1}] {d.metadata.get('sumber_uu')}:\n{d.page_content}"
for i, d in enumerate(docs)
)
self.chain = (
{"context": self.retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| self.llm
| StrOutputParser()
)
# ---------------------------
# RUN ANALYSIS
# ---------------------------
def analyze(self, text: str) -> str:
if not self.chain:
raise RuntimeError("Chain belum dibuat.")
return self.chain.invoke(text)
# ---------------------------
# INIT ALL
# ---------------------------
def initialize(self):
self.load_embeddings()
self.load_vector_db()
self.setup_retriever()
self.load_llm()
self.create_chain()
print("[INIT] Semua komponen siap.")