File size: 4,146 Bytes
5ce7f98
 
 
 
 
 
 
 
 
caa8115
 
 
 
5ce7f98
 
 
 
 
 
 
 
dda1c90
5ce7f98
 
dda1c90
 
 
 
 
 
 
 
caa8115
dda1c90
 
 
5ce7f98
dda1c90
5ce7f98
 
caa8115
a57e772
5ce7f98
 
a57e772
5ce7f98
a57e772
 
5ce7f98
 
 
 
 
a57e772
5ce7f98
 
a57e772
5ce7f98
a57e772
 
 
5ce7f98
 
a57e772
5ce7f98
 
a57e772
caa8115
 
 
5ce7f98
 
 
 
 
 
 
 
 
 
 
 
 
 
caa8115
5ce7f98
 
 
 
 
a57e772
caa8115
 
a57e772
 
 
caa8115
 
5ce7f98
 
caa8115
a57e772
caa8115
 
5ce7f98
 
 
 
 
 
 
caa8115
5ce7f98
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# app.py
import gradio as gr
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import faiss
import numpy as np
import nltk

# ------------------------------
# Step 0. NLTK setup
# ------------------------------
nltk.download("punkt", quiet=True)

# ------------------------------
# Step 1. Load dataset
# ------------------------------
print("πŸ“š Loading PubMedQA dataset...")
dataset = load_dataset("pubmed_qa", "pqa_labeled")

def extract_docs(ds):
    """Extract clean text documents safely from the PubMedQA dataset."""
    docs = []
    for e in ds:
        if isinstance(e, dict):
            ctx = e.get("context", "")
            if isinstance(ctx, dict):
                text = ctx.get("contexts", [""])
                if isinstance(text, list):
                    docs.append(" ".join(map(str, text)))
                else:
                    docs.append(str(text))
            else:
                docs.append(str(ctx))
        elif isinstance(e, str):
            docs.append(e)
        else:
            docs.append(str(e))
    return docs

documents = extract_docs(dataset["train"][:500])
print(f"βœ… Loaded {len(documents)} biomedical documents.")

# ------------------------------
# Step 2. Build embeddings (Biomedical)
# ------------------------------
print("πŸ” Building biomedical embeddings...")
embed_model = SentenceTransformer("pritamdeka/S-PubMedBert-MS-MARCO")
embeddings = embed_model.encode(documents, show_progress_bar=True)
embeddings = np.array(embeddings).astype("float32")

index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
print("βœ… FAISS index built with biomedical embeddings.")

# ------------------------------
# Step 3. Load biomedical generation model
# ------------------------------
print("βš™οΈ Loading biomedical text generation model...")
tokenizer = AutoTokenizer.from_pretrained("allenai/biomed-flan-t5-base")
gen_model = AutoModelForSeq2SeqLM.from_pretrained("allenai/biomed-flan-t5-base")

# ------------------------------
# Step 4. Define RAG function
# ------------------------------
def rag_answer(question, k=3, max_new_tokens=256):
    """Retrieve top-k relevant biomedical passages and generate an answer."""
    if not question.strip():
        return "Please enter a question.", ""

    query_vec = embed_model.encode([question])
    scores, indices = index.search(query_vec.astype("float32"), k)
    retrieved = [documents[i] for i in indices[0]]

    context = "\n".join(retrieved)
    prompt = f"Question: {question}\n\nContext:\n{context}\n\nAnswer:"

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
    outputs = gen_model.generate(**inputs, max_new_tokens=max_new_tokens)
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return answer, "\n\n---\n".join(retrieved)

# ------------------------------
# Step 5. Gradio Interface
# ------------------------------
def ask(question, k, max_tokens):
    answer, sources = rag_answer(question, k, max_tokens)
    return answer, sources

with gr.Blocks(title="πŸ₯ MedQuery AI β€” Biomedical RAG Assistant") as demo:
    gr.Markdown(
        """
        # πŸ₯ MedQuery AI β€” Biomedical Knowledge Assistant  
        This app retrieves relevant PubMed-style passages and generates concise,  
        **evidence-based biomedical answers** using Retrieval-Augmented Generation (RAG).  
        """
    )

    with gr.Row():
        question = gr.Textbox(
            label="Ask a biomedical or clinical question",
            placeholder="e.g. What are the diagnostic criteria for hypertension?"
        )
    with gr.Row():
        k = gr.Slider(1, 8, step=1, value=3, label="Top-K passages to retrieve")
        max_tokens = gr.Slider(64, 512, step=32, value=256, label="Max tokens for answer")

    with gr.Row():
        submit = gr.Button("Get Answer")

    answer = gr.Textbox(label="AI Answer", lines=4)
    sources = gr.Textbox(label="Retrieved Context", lines=10)

    submit.click(ask, inputs=[question, k, max_tokens], outputs=[answer, sources])

demo.launch()