Shri commited on
Commit
5038bcb
·
1 Parent(s): ac6fa79

feat: tokenization sematic search endpoint

Browse files
alembic/versions/dd61202db14f_add_knowledgebase_chunk.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """add: knowledgebase,chunk
2
+
3
+ Revision ID: dd61202db14f
4
+ Revises: b33e3b5b7af9
5
+ Create Date: 2025-11-17 23:28:11.537932
6
+
7
+ """
8
+ from typing import Sequence, Union
9
+
10
+ from alembic import op
11
+ import sqlalchemy as sa
12
+ import sqlmodel.sql.sqltypes
13
+
14
+
15
+ # revision identifiers, used by Alembic.
16
+ revision: str = 'dd61202db14f'
17
+ down_revision: Union[str, Sequence[str], None] = 'b33e3b5b7af9'
18
+ branch_labels: Union[str, Sequence[str], None] = None
19
+ depends_on: Union[str, Sequence[str], None] = None
20
+
21
+
22
+ def upgrade() -> None:
23
+ """Upgrade schema."""
24
+ # ### commands auto generated by Alembic - please adjust! ###
25
+ pass
26
+ # ### end Alembic commands ###
27
+
28
+
29
+ def downgrade() -> None:
30
+ """Downgrade schema."""
31
+ # ### commands auto generated by Alembic - please adjust! ###
32
+ pass
33
+ # ### end Alembic commands ###
src/chatbot/embedding.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # to run this file you need model.onnx_data on the assets/onnx folder or you can obtain it from here.: https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX/tree/main/onnx
2
+
3
+ import asyncio
4
+ import os
5
+ from typing import List
6
+
7
+ import numpy as np
8
+ import onnxruntime as ort
9
+ from transformers import AutoTokenizer
10
+
11
+ BASE_DIR = os.path.dirname(__file__)
12
+
13
+ TOKENIZER_DIR = os.path.abspath(os.path.join(BASE_DIR, "..", "assets", "tokenizer"))
14
+
15
+ MODEL_DIR = os.path.abspath(
16
+ os.path.join(BASE_DIR, "..", "assets", "onnx", "model.onnx")
17
+ )
18
+
19
+
20
+ class EmbeddingModel:
21
+ def __init__(self):
22
+ print(TOKENIZER_DIR)
23
+ self.tokenizer = AutoTokenizer.from_pretrained(
24
+ TOKENIZER_DIR, local_files_only=True
25
+ )
26
+
27
+ sess_options = ort.SessionOptions()
28
+ providers = ["CPUExecutionProvider"]
29
+
30
+ self.session = ort.InferenceSession(
31
+ MODEL_DIR, sess_options, providers=providers
32
+ )
33
+
34
+ self.input_names = [inp.name for inp in self.session.get_inputs()]
35
+ self.output_names = [out.name for out in self.session.get_outputs()]
36
+
37
+ def _run_sync(
38
+ self, input_ids: np.ndarray, attention_mask: np.ndarray
39
+ ) -> List[float]:
40
+ inputs = {}
41
+
42
+ if "input_ids" in self.input_names:
43
+ inputs["input_ids"] = input_ids
44
+ else:
45
+ inputs[self.input_names[0]] = input_ids
46
+
47
+ if "attention_mask" in self.input_names:
48
+ inputs["attention_mask"] = attention_mask
49
+ elif len(self.input_names) > 1:
50
+ inputs[self.input_names[1]] = attention_mask
51
+
52
+ outputs = self.session.run(self.output_names, inputs)
53
+ emb = outputs[0]
54
+
55
+ if emb.ndim == 3:
56
+ emb_vector = emb.mean(axis=1)[0]
57
+ elif emb.ndim == 2:
58
+ emb_vector = emb[0]
59
+ else:
60
+ emb_vector = np.asarray(emb).flatten()
61
+
62
+ return emb_vector.astype(float).tolist()
63
+
64
+ async def embed_text(self, text: str, max_length: int = 512) -> List[float]:
65
+
66
+ encoded = self.tokenizer(
67
+ text,
68
+ return_tensors="np",
69
+ truncation=True,
70
+ padding="longest",
71
+ max_length=max_length,
72
+ )
73
+
74
+ input_ids = encoded["input_ids"].astype(np.int64)
75
+ attention_mask = encoded.get("attention_mask", np.ones_like(input_ids)).astype(
76
+ np.int64
77
+ )
78
+
79
+ loop = asyncio.get_event_loop()
80
+ vector = await loop.run_in_executor(
81
+ None, self._run_sync, input_ids, attention_mask
82
+ )
83
+
84
+ return vector
85
+
86
+
87
+ embedding_model = EmbeddingModel()
src/chatbot/models.py CHANGED
@@ -1,2 +1,28 @@
1
  import uuid
2
- import sqlmodel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import uuid
2
+ from datetime import datetime
3
+ from typing import List
4
+
5
+ from pgvector.sqlalchemy import Vector
6
+ from sqlalchemy import Column
7
+ from sqlmodel import Field, Relationship, SQLModel
8
+
9
+
10
+ class KnowledgeBase(SQLModel, table=True):
11
+ __tablename__ = "knowledge_base"
12
+ id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
13
+ name: str = Field(nullable=False)
14
+ description: str | None = None
15
+ created_at: datetime = Field(default_factory=datetime.now)
16
+ knowledge_chunk: List["KnowledgeChunk"] = Relationship(
17
+ back_populates="knowledge_base"
18
+ )
19
+
20
+
21
+ class KnowledgeChunk(SQLModel, table=True):
22
+ __tablename__ = "knowledge_chunk"
23
+ id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
24
+ kb_id: uuid.UUID = Field(foreign_key="knowledge_base.id", nullable=False)
25
+ chunk_index: int
26
+ chunk_text: str
27
+ embedding: List[float] = Field(sa_column=Column(Vector(768)))
28
+ knowledge_base: "KnowledgeBase" = Relationship(back_populates="knowledge_chunk")
src/chatbot/router.py CHANGED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import tempfile
4
+ from typing import Optional
5
+
6
+ from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
7
+ from sqlalchemy import text
8
+ from sqlmodel.ext.asyncio.session import AsyncSession
9
+
10
+ from src.core.database import get_async_session
11
+
12
+ from .embedding import embedding_model
13
+ from .schemas import (
14
+ SemanticSearchRequest,
15
+ SemanticSearchResult,
16
+ TokenizeRequest,
17
+ TokenizeResponse,
18
+ UploadKBResponse,
19
+ )
20
+ from .service import process_pdf_and_store
21
+
22
+ router = APIRouter(prefix="/chatbot", tags=["chatbot"])
23
+
24
+
25
+ # before hitting this endpoint make sure the model.data & model.onnx_data is available on the asset/onnx folder
26
+ @router.post("/upload-pdf", response_model=UploadKBResponse)
27
+ async def upload_pdf(
28
+ file: UploadFile = File(...),
29
+ name: str = Form(...),
30
+ description: Optional[str] = Form(None),
31
+ session: AsyncSession = Depends(get_async_session),
32
+ ):
33
+ if not file.filename.endswith(".pdf"):
34
+ raise HTTPException(
35
+ status_code=400, detail="Only PDF files are supported for now."
36
+ )
37
+
38
+ tmp_dir = tempfile.mkdtemp()
39
+ tmp_path = os.path.join(tmp_dir, file.filename)
40
+ try:
41
+ with open(tmp_path, "wb") as out_f:
42
+ shutil.copyfileobj(file.file, out_f)
43
+
44
+ with open(tmp_path, "rb") as fobj:
45
+ result = await process_pdf_and_store(fobj, name, description, session)
46
+
47
+ return UploadKBResponse(
48
+ kb_id=result["kb_id"],
49
+ name=result["name"],
50
+ chunks_stored=result["chunks_stored"],
51
+ )
52
+ finally:
53
+ try:
54
+ os.remove(tmp_path)
55
+ os.rmdir(tmp_dir)
56
+ except Exception:
57
+ pass
58
+
59
+
60
+ @router.post("/tokenize", response_model=TokenizeResponse)
61
+ async def tokenize_text(payload: TokenizeRequest):
62
+ try:
63
+ encoded = embedding_model.tokenizer(
64
+ payload.text,
65
+ return_tensors="np",
66
+ truncation=True,
67
+ padding="longest",
68
+ max_length=512,
69
+ )
70
+
71
+ return TokenizeResponse(
72
+ input_ids=encoded["input_ids"][0].tolist(),
73
+ attention_mask=encoded["attention_mask"][0].tolist(),
74
+ )
75
+
76
+ except Exception as e:
77
+ raise HTTPException(status_code=500, detail=str(e))
78
+
79
+
80
+ @router.post("/semantic-search", response_model=list[SemanticSearchResult])
81
+ async def semantic_search(
82
+ payload: SemanticSearchRequest, session: AsyncSession = Depends(get_async_session)
83
+ ):
84
+
85
+ if len(payload.embedding) == 0:
86
+ raise HTTPException(status_code=400, detail="Embedding cannot be empty.")
87
+
88
+ q_vector = payload.embedding
89
+ top_k = payload.top_k or 3
90
+
91
+ sql = text(
92
+ """
93
+ SELECT id, kb_id, chunk_text, embedding <=> :query_vec AS score
94
+ FROM knowledge_chunk
95
+ ORDER BY embedding <=> :query_vec
96
+ LIMIT :top_k
97
+ """
98
+ )
99
+
100
+ rows = await session.exec(sql, {"query_vec": q_vector, "top_k": top_k})
101
+ rows = rows.fetchall()
102
+
103
+ return [
104
+ SemanticSearchResult(
105
+ chunk_id=str(r.id),
106
+ kb_id=str(r.kb_id),
107
+ text=r.chunk_text,
108
+ score=float(r.score),
109
+ )
110
+ for r in rows
111
+ ]
src/chatbot/schemas.py CHANGED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from typing import List, Optional
3
+
4
+ from pydantic import BaseModel
5
+
6
+
7
+ class UploadKBResponse(BaseModel):
8
+ kb_id: uuid.UUID
9
+ name: str
10
+ chunks_stored: int
11
+
12
+
13
+ class UploadKBRequest(BaseModel):
14
+ name: str
15
+ description: Optional[str] = None
16
+
17
+
18
+ class TokenizeRequest(BaseModel):
19
+ text: str
20
+
21
+
22
+ class TokenizeResponse(BaseModel):
23
+ input_ids: List[int]
24
+ attention_mask: List[int]
25
+
26
+
27
+ class SemanticSearchRequest(BaseModel):
28
+ embedding: List[float]
29
+ top_k: Optional[int] = 3
30
+
31
+
32
+ class SemanticSearchResult(BaseModel):
33
+ chunk_id: str
34
+ kb_id: str
35
+ text: str
36
+ score: float
src/chatbot/service.py CHANGED
@@ -1,2 +1,45 @@
1
- from typing import List
2
- from uuid import UUID
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from sqlmodel.ext.asyncio.session import AsyncSession
4
+
5
+ from .embedding import embedding_model
6
+ from .models import KnowledgeBase, KnowledgeChunk
7
+ from .utils import (
8
+ chunk_sentences_with_overlap,
9
+ extract_text_from_pdf_fileobj,
10
+ split_into_sentences,
11
+ )
12
+
13
+ DEFAULT_MAX_WORDS = int(os.getenv("CHUNK_MAX_WORDS", "200"))
14
+ DEFAULT_OVERLAP = int(os.getenv("CHUNK_OVERLAP_WORDS", "40"))
15
+
16
+
17
+ async def process_pdf_and_store(
18
+ fileobj, kb_name: str, kb_description: str | None, session: AsyncSession
19
+ ):
20
+ raw_text = extract_text_from_pdf_fileobj(fileobj)
21
+
22
+ sentences = split_into_sentences(raw_text)
23
+
24
+ chunks = chunk_sentences_with_overlap(
25
+ sentences, max_words=DEFAULT_MAX_WORDS, overlap_words=DEFAULT_OVERLAP
26
+ )
27
+
28
+ kb = KnowledgeBase(name=kb_name, description=kb_description)
29
+ session.add(kb)
30
+ await session.commit()
31
+ await session.refresh(kb)
32
+
33
+ chunk_objs = []
34
+ for idx, chunk_text in enumerate(chunks):
35
+ emb = await embedding_model.embed_text(chunk_text)
36
+
37
+ chunk = KnowledgeChunk(
38
+ kb_id=kb.id, chunk_index=idx, chunk_text=chunk_text, embedding=emb
39
+ )
40
+ session.add(chunk)
41
+ chunk_objs.append(chunk)
42
+
43
+ await session.commit()
44
+
45
+ return {"kb_id": kb.id, "name": kb_name, "chunks_stored": len(chunk_objs)}
src/chatbot/utils.py CHANGED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List
3
+ import PyPDF2
4
+
5
+
6
+ def clean_text(text: str) -> str:
7
+ text = re.sub(r'\s+', ' ', text)
8
+ text = re.sub(r'\s+([,.!?;:])', r'\1', text)
9
+ text = re.sub(r'[_\-]{2,}', ' ', text)
10
+ text = re.sub(r'\.{2,}', '.', text)
11
+ text = re.sub(r'\s{2,}', ' ', text)
12
+ return text.strip()
13
+
14
+
15
+ def extract_text_from_pdf_fileobj(fileobj) -> str:
16
+ reader = PyPDF2.PdfReader(fileobj)
17
+ all_text = []
18
+ for page in reader.pages:
19
+ page_text = page.extract_text()
20
+ if page_text:
21
+ all_text.append(page_text)
22
+ return clean_text(" ".join(all_text))
23
+
24
+
25
+ def split_into_sentences(text: str) -> List[str]:
26
+ sentence_endings = re.compile(r'(?<=[.!?])\s+')
27
+ sentences = sentence_endings.split(text)
28
+ return [s.strip() for s in sentences if s.strip()]
29
+
30
+
31
+ def chunk_sentences_with_overlap(sentences: List[str], max_words: int = 200, overlap_words: int = 40) -> List[str]:
32
+ chunks = []
33
+ current = []
34
+ current_len = 0
35
+
36
+ for sentence in sentences:
37
+ words = sentence.split()
38
+ wc = len(words)
39
+
40
+ if current_len + wc > max_words and current:
41
+ chunks.append(" ".join(current))
42
+
43
+ if overlap_words > 0:
44
+ last_words = " ".join(" ".join(current).split()[-overlap_words:])
45
+ current = [last_words] if last_words else []
46
+ current_len = len(last_words.split())
47
+ else:
48
+ current = []
49
+ current_len = 0
50
+
51
+ current.append(sentence)
52
+ current_len += wc
53
+
54
+ if current:
55
+ chunks.append(" ".join(current))
56
+
57
+ return chunks
src/main.py CHANGED
@@ -1,12 +1,12 @@
1
- from src.assets.router import router as assets
2
- from src.profile.router import router as profile
3
  from fastapi import FastAPI
4
 
 
5
  from src.auth.router import router as auth_router
 
6
  from src.core.database import init_db
7
  from src.home.router import router as home_router
8
- from src.auth.router import router as auth_router
9
  from src.leave.router import router as leave
 
10
 
11
  app = FastAPI(title="Yuvabe App API")
12
 
@@ -22,6 +22,8 @@ app.include_router(assets)
22
 
23
  app.include_router(leave)
24
 
 
 
25
 
26
  @app.get("/")
27
  def root():
 
 
 
1
  from fastapi import FastAPI
2
 
3
+ from src.assets.router import router as assets
4
  from src.auth.router import router as auth_router
5
+ from src.chatbot.router import router as chatbot
6
  from src.core.database import init_db
7
  from src.home.router import router as home_router
 
8
  from src.leave.router import router as leave
9
+ from src.profile.router import router as profile
10
 
11
  app = FastAPI(title="Yuvabe App API")
12
 
 
22
 
23
  app.include_router(leave)
24
 
25
+ app.include_router(chatbot)
26
+
27
 
28
  @app.get("/")
29
  def root():