IFMedTechdemo's picture
Update app.py
b88d7ad verified
import gradio as gr
import cv2
import time
import logging
import os
import re
import pickle
import json
import pandas as pd
from symspellpy import SymSpell, Verbosity
from rapidocr import RapidOCR, EngineType, LangCls, LangDet, LangRec, ModelType, OCRVersion
import numpy as np
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ============================================================================
# Database Initialization (from src/database/init.py)
# ============================================================================
def initialize_database():
# Assuming data/Dataset.csv is relative to the current script or fixed path
# Adjust path if necessary. app.py is in root, data is in ./data
data_path = os.path.join(os.path.dirname(__file__), "data/Dataset.csv")
try:
df = pd.read_csv(data_path, encoding='utf-8')
except UnicodeDecodeError:
df = pd.read_csv(data_path, encoding='latin1')
drug_db = set(df["Combined_Drugs"].astype(str).str.lower().str.strip())
sym_spell = SymSpell(max_dictionary_edit_distance=2, prefix_length=7)
for drug in drug_db:
d = drug.lower()
sym_spell.create_dictionary_entry(d, 100000)
parts = d.split()
if len(parts) > 1:
for p in parts:
sym_spell.create_dictionary_entry(p, 100000)
drug_token_index = {}
for full in drug_db:
toks = full.split()
for tok in toks:
drug_token_index.setdefault(tok, set()).add(full)
ANCHOR_PREFIXES = ["tab", "cap"]
ANCHORS = [
r"tab\.?", r"cap\.?"
]
ANCHOR_PATTERN = re.compile(r"\b(" + "|".join(ANCHORS) + r")\b", re.IGNORECASE)
return {
'drug_db': drug_db,
'sym_spell': sym_spell,
'drug_token_index': drug_token_index,
'ANCHOR_PREFIXES': ANCHOR_PREFIXES,
'ANCHOR_PATTERN': ANCHOR_PATTERN
}
# Initialize Database Globally
logger.info("Initializing database...")
cache_path = os.path.join(os.path.dirname(__file__), "cache/database_cache.pkl")
try:
with open(cache_path, 'rb') as f:
cache = pickle.load(f)
drug_db = cache['drug_db']
sym_spell = cache['sym_spell']
drug_token_index = cache['drug_token_index']
ANCHOR_PREFIXES = cache['ANCHOR_PREFIXES']
ANCHOR_PATTERN = cache['ANCHOR_PATTERN']
logger.info("Database loaded from cache.")
except FileNotFoundError:
logger.info("Cache not found. Initializing from CSV...")
cache = initialize_database()
drug_db = cache['drug_db']
sym_spell = cache['sym_spell']
drug_token_index = cache['drug_token_index']
ANCHOR_PREFIXES = cache['ANCHOR_PREFIXES']
ANCHOR_PATTERN = cache['ANCHOR_PATTERN']
# Save cache
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
with open(cache_path, 'wb') as f:
pickle.dump(cache, f)
logger.info("Database initialized and cached.")
# ============================================================================
# Helper Functions (from src/utils/drug_matching.py)
# ============================================================================
def is_potential_med_line(text: str, ANCHOR_PATTERN) -> bool:
t = text.lower()
# User requested ONLY anchor check
anchor_match = ANCHOR_PATTERN.search(t)
if anchor_match:
return True
return False
def validate_drug_match(term: str, drug_db, drug_token_index):
"""
Map SymSpell term -> canonical database drug, or None if noise.
"""
if term in drug_db:
return term
if term in drug_token_index:
# pick one canonical name; you can change selection logic if needed
return sorted(drug_token_index[term])[0]
return None
def normalize_anchored_tokens(raw_text: str, ANCHOR_PREFIXES):
"""
Use TAB/CAP/T. as anchors, not something to delete:
- 'TABCLOPITAB75MG TAB' -> ['clopitab']
- 'TAB SOBISISTAB' -> ['sobisistab']
- 'TABSTARPRESSXL25MGTAB' -> ['starpressxl']
"""
t = raw_text.lower()
# Remove dosage and numbers but keep anchor letters
t = re.sub(r"\d+\s*(mg|ml|gm|%|u|mcg)", " ", t)
t = re.sub(r"\d+", " ", t)
# Remove punctuation including full-width parentheses
t = re.sub(r"[^\w\s]", " ", t)
tokens = t.split()
normalized = []
skip_next = False
for i, tok in enumerate(tokens):
if skip_next:
skip_next = False
continue
base = tok
# Case 1: token starts with anchor as prefix (no space)
for pref in ANCHOR_PREFIXES:
if base.startswith(pref) and len(base) > len(pref):
base = base[len(pref):]
break
# Case 2: token is pure anchor and should attach to next token
if base in ["tab", "cap", "t"]:
if i + 1 < len(tokens):
merged = tokens[i + 1]
for pref in ANCHOR_PREFIXES:
if merged.startswith(pref) and len(merged) > len(pref):
merged = merged[len(pref):]
break
base = merged
skip_next = True
else:
continue
base = base.strip()
normalized.append(base)
return normalized
# ============================================================================
# OCR Processor (from src/ocr/processor.py)
# ============================================================================
# Initialize OCR Engine Globally
logger.info("Initializing OCR engine...")
ocr_engine = RapidOCR(
params={
"Global.max_side_len": 2000,
"Det.engine_type": EngineType.ONNXRUNTIME,
"Det.lang_type": LangDet.EN,
"Det.model_type": ModelType.MOBILE,
"Det.ocr_version": OCRVersion.PPOCRV4,
"Cls.engine_type": EngineType.ONNXRUNTIME,
"Cls.lang_type": LangCls.CH,
"Cls.model_type": ModelType.MOBILE,
"Cls.ocr_version": OCRVersion.PPOCRV4,
"Rec.engine_type": EngineType.ONNXRUNTIME,
"Rec.lang_type": LangRec.EN,
"Rec.model_type": ModelType.MOBILE,
"Rec.ocr_version": OCRVersion.PPOCRV4,
}
)
logger.info("OCR engine initialized.")
def process_image_ocr(image_input):
# Load image using cv2 if it's a path, otherwise use the array directly
if isinstance(image_input, str):
img = cv2.imread(image_input)
if img is None:
raise ValueError(f"Could not load image from {image_input}")
else:
img = image_input
# Run OCR
ocr_result = ocr_engine(
img,
use_det=True,
use_cls=True,
use_rec=True,
text_score=0.3,
box_thresh=0.3,
unclip_ratio=2.0,
return_word_box=False,
)
ocr_data = ocr_result.txts
found_meds_with_originals = {}
for item in ocr_data:
text_lower = item.lower()
# Simplified line-level gate: ONLY check anchors
if not is_potential_med_line(text_lower, ANCHOR_PATTERN):
continue
# Skip doctor name lines
if "dr." in text_lower or "dr " in text_lower:
continue
# Anchor-aware tokens
candidate_tokens = normalize_anchored_tokens(item, ANCHOR_PREFIXES)
# Save original normalized text for exact match checking
normalized_text_str = " ".join(candidate_tokens)
# Optional SymSpell segmentation on normalized tokens
if candidate_tokens:
segmentation = sym_spell.word_segmentation(" ".join(candidate_tokens))
corrected_string = segmentation.corrected_string
candidate_tokens = corrected_string.split()
line_matches = []
i = 0
n = len(candidate_tokens)
while i < n:
match_found = False
# Greedy longest match: try phrases of length 5 down to 1
for length in range(min(5, n - i), 0, -1):
phrase_tokens = candidate_tokens[i : i + length]
phrase = " ".join(phrase_tokens)
# Check exact phrase in DB
if phrase in drug_db:
# Found a multi-word (or single-word) drug match!
if phrase in normalized_text_str:
line_matches.append((phrase, "exact", phrase))
else:
line_matches.append((phrase, "fuzzy", phrase))
i += length
match_found = True
break
if match_found:
continue
# Fallback: Single token processing (Fuzzy / Partial)
word = candidate_tokens[i]
i += 1
# Check for exact match first (as a single token)
canonical = validate_drug_match(word, drug_db, drug_token_index)
if canonical:
# Coverage check: detected word must cover a significant portion of the canonical name
if len(word) / len(canonical) < 0.6:
continue
if word in normalized_text_str:
line_matches.append((canonical, "exact", word))
else:
line_matches.append((canonical, "fuzzy", word))
continue
# Fuzzy matching
if len(word) < 3:
continue
suggestions = sym_spell.lookup(
word, Verbosity.CLOSEST, max_edit_distance=1
)
if not suggestions:
continue
cand = suggestions[0].term
canonical = validate_drug_match(cand, drug_db, drug_token_index)
if canonical:
# Coverage check for fuzzy match too
if len(word) / len(canonical) < 0.6:
continue
line_matches.append((canonical, "fuzzy", word))
# Filter matches for this line:
exact_matches = [m for m in line_matches if m[1] == "exact"]
if exact_matches:
final_matches = exact_matches
else:
final_matches = line_matches
for match in final_matches:
canonical = match[0]
original_text = match[2]
if canonical not in found_meds_with_originals:
found_meds_with_originals[canonical] = []
if item not in found_meds_with_originals[canonical]:
found_meds_with_originals[canonical].append(item)
return found_meds_with_originals
# ============================================================================
# Gradio Interface
# ============================================================================
def process_input(image_input):
"""Gradio interface handler."""
if image_input is None:
return "Please upload an image.", {}
try:
# Convert RGB (Gradio) to BGR (OpenCV)
image_bgr = cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR)
start_time = time.time()
# Use the robust processor with in-memory image
found_meds_dict = process_image_ocr(image_bgr)
elapsed_time = time.time() - start_time
drugs_list = sorted(found_meds_dict.keys())
drugs_count = len(drugs_list)
# Summary text
summary = f"Found {drugs_count} medication(s) in {elapsed_time:.3f}s"
# JSON output with all medications
medications_json = {
"total_medications": drugs_count,
"processing_time": f"{elapsed_time:.3f}s",
"medications": [
{
"id": idx + 1,
"name": drug.title(),
"original_text": found_meds_dict[drug]
}
for idx, drug in enumerate(drugs_list)
]
}
return summary, medications_json
except Exception as e:
logger.error(f"Processing error: {e}")
return f"Error: {str(e)}", {}
logger.info("Starting ...")
with gr.Blocks(title="Fast OCR") as demo:
gr.Markdown("# Prescription OCR")
gr.Markdown("Upload a prescription image to extract medications.")
gr.Markdown("**Optimized with MOBILE models + ONNX Runtime for maximum speed**")
with gr.Row():
image_input = gr.Image(type="numpy", label="Upload Prescription")
output_text = gr.Textbox(label="Summary", lines=2)
with gr.Row():
medications_json = gr.JSON(label="Extracted Medications")
submit_btn = gr.Button("Extract Medications", variant="primary")
submit_btn.click(process_input, inputs=image_input, outputs=[output_text, medications_json])
gr.Examples(
examples=[
["examples/test1.jpeg"],
["examples/test2.jpeg"],
["examples/test3.jpeg"],
["examples/test4.jpeg"],
["examples/test5.jpeg"],
["examples/test6.png"],
["examples/test7.png"],
],
inputs=image_input,
outputs=[output_text, medications_json],
fn=process_input,
cache_examples=False,
)
if __name__ == "__main__":
demo.queue(max_size=10)
demo.launch(max_threads=4, show_error=True)