Ahmad-01 commited on
Commit
a57e772
Β·
verified Β·
1 Parent(s): dda1c90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -20
app.py CHANGED
@@ -22,11 +22,9 @@ def extract_docs(ds):
22
  """Extract clean text documents safely from the PubMedQA dataset."""
23
  docs = []
24
  for e in ds:
25
- # Case 1: if entry is a dictionary
26
  if isinstance(e, dict):
27
  ctx = e.get("context", "")
28
  if isinstance(ctx, dict):
29
- # Nested dict with list of contexts
30
  text = ctx.get("contexts", [""])
31
  if isinstance(text, list):
32
  docs.append(" ".join(map(str, text)))
@@ -34,41 +32,39 @@ def extract_docs(ds):
34
  docs.append(str(text))
35
  else:
36
  docs.append(str(ctx))
37
- # Case 2: if entry is already a string
38
  elif isinstance(e, str):
39
  docs.append(e)
40
  else:
41
  docs.append(str(e))
42
  return docs
43
 
44
- # Extract a small subset for demo (fast loading)
45
  documents = extract_docs(dataset["train"][:500])
46
- print(f"βœ… Loaded {len(documents)} documents.")
47
 
48
  # ------------------------------
49
- # Step 2. Build embeddings
50
  # ------------------------------
51
- print("πŸ” Building embeddings...")
52
- embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
53
  embeddings = embed_model.encode(documents, show_progress_bar=True)
54
  embeddings = np.array(embeddings).astype("float32")
55
 
56
  index = faiss.IndexFlatL2(embeddings.shape[1])
57
  index.add(embeddings)
58
- print("βœ… FAISS index built.")
59
 
60
  # ------------------------------
61
- # Step 3. Load generation model
62
  # ------------------------------
63
- print("βš™οΈ Loading text generation model...")
64
- tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
65
- gen_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
66
 
67
  # ------------------------------
68
- # Step 4. Define RAG answer function
69
  # ------------------------------
70
  def rag_answer(question, k=3, max_new_tokens=256):
71
- """Retrieve top-k relevant chunks and generate an answer."""
72
  if not question.strip():
73
  return "Please enter a question.", ""
74
 
@@ -92,18 +88,18 @@ def ask(question, k, max_tokens):
92
  answer, sources = rag_answer(question, k, max_tokens)
93
  return answer, sources
94
 
95
- with gr.Blocks(title="πŸ₯ MedQuery AI β€” Healthcare Knowledge Assistant") as demo:
96
  gr.Markdown(
97
  """
98
- # πŸ₯ MedQuery AI β€” Healthcare Knowledge Assistant
99
- Ask any **clinical or biomedical question**, and the app retrieves relevant PubMed data
100
- and generates concise, evidence-based answers using Retrieval-Augmented Generation (RAG).
101
  """
102
  )
103
 
104
  with gr.Row():
105
  question = gr.Textbox(
106
- label="Ask a medical question",
107
  placeholder="e.g. What are the diagnostic criteria for hypertension?"
108
  )
109
  with gr.Row():
 
22
  """Extract clean text documents safely from the PubMedQA dataset."""
23
  docs = []
24
  for e in ds:
 
25
  if isinstance(e, dict):
26
  ctx = e.get("context", "")
27
  if isinstance(ctx, dict):
 
28
  text = ctx.get("contexts", [""])
29
  if isinstance(text, list):
30
  docs.append(" ".join(map(str, text)))
 
32
  docs.append(str(text))
33
  else:
34
  docs.append(str(ctx))
 
35
  elif isinstance(e, str):
36
  docs.append(e)
37
  else:
38
  docs.append(str(e))
39
  return docs
40
 
 
41
  documents = extract_docs(dataset["train"][:500])
42
+ print(f"βœ… Loaded {len(documents)} biomedical documents.")
43
 
44
  # ------------------------------
45
+ # Step 2. Build embeddings (Biomedical)
46
  # ------------------------------
47
+ print("πŸ” Building biomedical embeddings...")
48
+ embed_model = SentenceTransformer("pritamdeka/S-PubMedBert-MS-MARCO")
49
  embeddings = embed_model.encode(documents, show_progress_bar=True)
50
  embeddings = np.array(embeddings).astype("float32")
51
 
52
  index = faiss.IndexFlatL2(embeddings.shape[1])
53
  index.add(embeddings)
54
+ print("βœ… FAISS index built with biomedical embeddings.")
55
 
56
  # ------------------------------
57
+ # Step 3. Load biomedical generation model
58
  # ------------------------------
59
+ print("βš™οΈ Loading biomedical text generation model...")
60
+ tokenizer = AutoTokenizer.from_pretrained("allenai/biomed-flan-t5-base")
61
+ gen_model = AutoModelForSeq2SeqLM.from_pretrained("allenai/biomed-flan-t5-base")
62
 
63
  # ------------------------------
64
+ # Step 4. Define RAG function
65
  # ------------------------------
66
  def rag_answer(question, k=3, max_new_tokens=256):
67
+ """Retrieve top-k relevant biomedical passages and generate an answer."""
68
  if not question.strip():
69
  return "Please enter a question.", ""
70
 
 
88
  answer, sources = rag_answer(question, k, max_tokens)
89
  return answer, sources
90
 
91
+ with gr.Blocks(title="πŸ₯ MedQuery AI β€” Biomedical RAG Assistant") as demo:
92
  gr.Markdown(
93
  """
94
+ # πŸ₯ MedQuery AI β€” Biomedical Knowledge Assistant
95
+ This app retrieves relevant PubMed-style passages and generates concise,
96
+ **evidence-based biomedical answers** using Retrieval-Augmented Generation (RAG).
97
  """
98
  )
99
 
100
  with gr.Row():
101
  question = gr.Textbox(
102
+ label="Ask a biomedical or clinical question",
103
  placeholder="e.g. What are the diagnostic criteria for hypertension?"
104
  )
105
  with gr.Row():