Spaces:
Runtime error
Runtime error
File size: 4,146 Bytes
5ce7f98 caa8115 5ce7f98 dda1c90 5ce7f98 dda1c90 caa8115 dda1c90 5ce7f98 dda1c90 5ce7f98 caa8115 a57e772 5ce7f98 a57e772 5ce7f98 a57e772 5ce7f98 a57e772 5ce7f98 a57e772 5ce7f98 a57e772 5ce7f98 a57e772 5ce7f98 a57e772 caa8115 5ce7f98 caa8115 5ce7f98 a57e772 caa8115 a57e772 caa8115 5ce7f98 caa8115 a57e772 caa8115 5ce7f98 caa8115 5ce7f98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
# app.py
import gradio as gr
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import faiss
import numpy as np
import nltk
# ------------------------------
# Step 0. NLTK setup
# ------------------------------
nltk.download("punkt", quiet=True)
# ------------------------------
# Step 1. Load dataset
# ------------------------------
print("π Loading PubMedQA dataset...")
dataset = load_dataset("pubmed_qa", "pqa_labeled")
def extract_docs(ds):
"""Extract clean text documents safely from the PubMedQA dataset."""
docs = []
for e in ds:
if isinstance(e, dict):
ctx = e.get("context", "")
if isinstance(ctx, dict):
text = ctx.get("contexts", [""])
if isinstance(text, list):
docs.append(" ".join(map(str, text)))
else:
docs.append(str(text))
else:
docs.append(str(ctx))
elif isinstance(e, str):
docs.append(e)
else:
docs.append(str(e))
return docs
documents = extract_docs(dataset["train"][:500])
print(f"β
Loaded {len(documents)} biomedical documents.")
# ------------------------------
# Step 2. Build embeddings (Biomedical)
# ------------------------------
print("π Building biomedical embeddings...")
embed_model = SentenceTransformer("pritamdeka/S-PubMedBert-MS-MARCO")
embeddings = embed_model.encode(documents, show_progress_bar=True)
embeddings = np.array(embeddings).astype("float32")
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
print("β
FAISS index built with biomedical embeddings.")
# ------------------------------
# Step 3. Load biomedical generation model
# ------------------------------
print("βοΈ Loading biomedical text generation model...")
tokenizer = AutoTokenizer.from_pretrained("allenai/biomed-flan-t5-base")
gen_model = AutoModelForSeq2SeqLM.from_pretrained("allenai/biomed-flan-t5-base")
# ------------------------------
# Step 4. Define RAG function
# ------------------------------
def rag_answer(question, k=3, max_new_tokens=256):
"""Retrieve top-k relevant biomedical passages and generate an answer."""
if not question.strip():
return "Please enter a question.", ""
query_vec = embed_model.encode([question])
scores, indices = index.search(query_vec.astype("float32"), k)
retrieved = [documents[i] for i in indices[0]]
context = "\n".join(retrieved)
prompt = f"Question: {question}\n\nContext:\n{context}\n\nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
outputs = gen_model.generate(**inputs, max_new_tokens=max_new_tokens)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
return answer, "\n\n---\n".join(retrieved)
# ------------------------------
# Step 5. Gradio Interface
# ------------------------------
def ask(question, k, max_tokens):
answer, sources = rag_answer(question, k, max_tokens)
return answer, sources
with gr.Blocks(title="π₯ MedQuery AI β Biomedical RAG Assistant") as demo:
gr.Markdown(
"""
# π₯ MedQuery AI β Biomedical Knowledge Assistant
This app retrieves relevant PubMed-style passages and generates concise,
**evidence-based biomedical answers** using Retrieval-Augmented Generation (RAG).
"""
)
with gr.Row():
question = gr.Textbox(
label="Ask a biomedical or clinical question",
placeholder="e.g. What are the diagnostic criteria for hypertension?"
)
with gr.Row():
k = gr.Slider(1, 8, step=1, value=3, label="Top-K passages to retrieve")
max_tokens = gr.Slider(64, 512, step=32, value=256, label="Max tokens for answer")
with gr.Row():
submit = gr.Button("Get Answer")
answer = gr.Textbox(label="AI Answer", lines=4)
sources = gr.Textbox(label="Retrieved Context", lines=10)
submit.click(ask, inputs=[question, k, max_tokens], outputs=[answer, sources])
demo.launch()
|