Spaces:
Sleeping
Sleeping
File size: 4,770 Bytes
6a97566 5b9ecc2 6a97566 5b9ecc2 6a97566 998ba81 6a97566 5b9ecc2 6a97566 5b9ecc2 6a97566 ddf4755 6a97566 cbbf969 998ba81 6a97566 998ba81 6a97566 ddf4755 6a97566 ddf4755 998ba81 6a97566 cbbf969 6a97566 ddf4755 998ba81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import os
import shutil
import tempfile
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from sqlalchemy import text
from sqlmodel.ext.asyncio.session import AsyncSession
from src.auth.utils import get_current_user
from src.core.database import get_async_session
from .schemas import ManualTextRequest
from .service import store_manual_text
from .embedding import embedding_model
from .schemas import (
SemanticSearchRequest,
SemanticSearchResult,
TokenizeRequest,
TokenizeResponse,
UploadKBResponse,
)
from .service import process_pdf_and_store
router = APIRouter(prefix="/chatbot", tags=["chatbot"])
@router.post("/tokenize", response_model=TokenizeResponse)
async def tokenize_text(payload: TokenizeRequest,user_id: UUID = Depends(get_current_user)):
try:
encoded = embedding_model.tokenizer(
payload.text,
return_tensors="np",
truncation=True,
padding="longest",
max_length=512,
)
return TokenizeResponse(
input_ids=encoded["input_ids"][0].tolist(),
attention_mask=encoded["attention_mask"][0].tolist(),
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/semantic-search", response_model=list[SemanticSearchResult])
async def semantic_search(
payload: SemanticSearchRequest, session: AsyncSession = Depends(get_async_session), user_id: UUID = Depends(get_current_user)
):
if len(payload.embedding) == 0:
raise HTTPException(status_code=400, detail="Embedding cannot be empty.")
q_vector = payload.embedding
top_k = payload.top_k or 3
q_vector_str = "[" + ",".join(str(x) for x in q_vector) + "]"
sql = text(
"""
SELECT id, kb_id, chunk_text,image_url,
embedding <#> :query_vec AS score
FROM knowledge_chunk
ORDER BY embedding <#> :query_vec ASC
LIMIT :top_k
"""
)
result = await session.execute(
sql, {"query_vec": q_vector_str, "top_k": top_k}
)
rows = result.fetchall()
return [
SemanticSearchResult(
chunk_id=str(r.id),
kb_id=str(r.kb_id),
text=r.chunk_text,
image_url=r.image_url,
score=float(r.score),
)
for r in rows
]
# before hitting this endpoint make sure the model.data & model.onnx_data is available on the asset/onnx folder
# @router.post("/upload-pdf", response_model=UploadKBResponse)
# async def upload_pdf(
# file: UploadFile = File(...),
# name: str = Form(...),
# description: Optional[str] = Form(None),
# session: AsyncSession = Depends(get_async_session),
# ):
# if not file.filename.endswith(".pdf"):
# raise HTTPException(
# status_code=400, detail="Only PDF files are supported for now."
# )
# tmp_dir = tempfile.mkdtemp()
# tmp_path = os.path.join(tmp_dir, file.filename)
# try:
# with open(tmp_path, "wb") as out_f:
# shutil.copyfileobj(file.file, out_f)
# with open(tmp_path, "rb") as fobj:
# result = await process_pdf_and_store(fobj, name, description, session)
# return UploadKBResponse(
# kb_id=result["kb_id"],
# name=result["name"],
# chunks_stored=result["chunks_stored"],
# )
# finally:
# try:
# os.remove(tmp_path)
# os.rmdir(tmp_dir)
# except Exception:
# pass
# @router.post("/manual-add-chunk")
# async def manual_add_chunk(
# payload: ManualTextRequest,
# session: AsyncSession = Depends(get_async_session)
# ):
# return await store_manual_text(
# kb_id=payload.kb_id,
# text=payload.text,
# session=session
# )
# @router.post("/test-semantic", response_model=list[SemanticSearchResult])
# async def test_semantic(
# query: str,
# top_k: int = 3,
# session: AsyncSession = Depends(get_async_session)
# ):
# embedding = await embedding_model.embed_text(query)
# q_vec = "[" + ",".join(map(str, embedding)) + "]"
# sql = text("""
# SELECT id, kb_id, chunk_text,
# embedding <#> :vec AS score
# FROM knowledge_chunk
# ORDER BY embedding <#> :vec ASC
# LIMIT :k
# """)
# result = await session.execute(sql, {"vec": q_vec, "k": top_k})
# rows = result.fetchall()
# return [
# SemanticSearchResult(
# chunk_id=str(r.id),
# kb_id=str(r.kb_id),
# text=r.chunk_text,
# score=float(r.score),
# )
# for r in rows
# ]
|