|
|
import gradio as gr |
|
|
import cv2 |
|
|
import time |
|
|
import logging |
|
|
import os |
|
|
import re |
|
|
import pickle |
|
|
import json |
|
|
import pandas as pd |
|
|
from symspellpy import SymSpell, Verbosity |
|
|
from rapidocr import RapidOCR, EngineType, LangCls, LangDet, LangRec, ModelType, OCRVersion |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def initialize_database(): |
|
|
|
|
|
|
|
|
data_path = os.path.join(os.path.dirname(__file__), "data/Dataset.csv") |
|
|
try: |
|
|
df = pd.read_csv(data_path, encoding='utf-8') |
|
|
except UnicodeDecodeError: |
|
|
df = pd.read_csv(data_path, encoding='latin1') |
|
|
drug_db = set(df["Combined_Drugs"].astype(str).str.lower().str.strip()) |
|
|
|
|
|
sym_spell = SymSpell(max_dictionary_edit_distance=2, prefix_length=7) |
|
|
|
|
|
for drug in drug_db: |
|
|
d = drug.lower() |
|
|
sym_spell.create_dictionary_entry(d, 100000) |
|
|
parts = d.split() |
|
|
if len(parts) > 1: |
|
|
for p in parts: |
|
|
sym_spell.create_dictionary_entry(p, 100000) |
|
|
|
|
|
drug_token_index = {} |
|
|
for full in drug_db: |
|
|
toks = full.split() |
|
|
for tok in toks: |
|
|
drug_token_index.setdefault(tok, set()).add(full) |
|
|
|
|
|
ANCHOR_PREFIXES = ["tab", "cap"] |
|
|
|
|
|
ANCHORS = [ |
|
|
r"tab\.?", r"cap\.?" |
|
|
] |
|
|
ANCHOR_PATTERN = re.compile(r"\b(" + "|".join(ANCHORS) + r")\b", re.IGNORECASE) |
|
|
|
|
|
return { |
|
|
'drug_db': drug_db, |
|
|
'sym_spell': sym_spell, |
|
|
'drug_token_index': drug_token_index, |
|
|
'ANCHOR_PREFIXES': ANCHOR_PREFIXES, |
|
|
'ANCHOR_PATTERN': ANCHOR_PATTERN |
|
|
} |
|
|
|
|
|
|
|
|
logger.info("Initializing database...") |
|
|
cache_path = os.path.join(os.path.dirname(__file__), "cache/database_cache.pkl") |
|
|
try: |
|
|
with open(cache_path, 'rb') as f: |
|
|
cache = pickle.load(f) |
|
|
drug_db = cache['drug_db'] |
|
|
sym_spell = cache['sym_spell'] |
|
|
drug_token_index = cache['drug_token_index'] |
|
|
ANCHOR_PREFIXES = cache['ANCHOR_PREFIXES'] |
|
|
ANCHOR_PATTERN = cache['ANCHOR_PATTERN'] |
|
|
logger.info("Database loaded from cache.") |
|
|
except FileNotFoundError: |
|
|
logger.info("Cache not found. Initializing from CSV...") |
|
|
cache = initialize_database() |
|
|
drug_db = cache['drug_db'] |
|
|
sym_spell = cache['sym_spell'] |
|
|
drug_token_index = cache['drug_token_index'] |
|
|
ANCHOR_PREFIXES = cache['ANCHOR_PREFIXES'] |
|
|
ANCHOR_PATTERN = cache['ANCHOR_PATTERN'] |
|
|
|
|
|
os.makedirs(os.path.dirname(cache_path), exist_ok=True) |
|
|
with open(cache_path, 'wb') as f: |
|
|
pickle.dump(cache, f) |
|
|
logger.info("Database initialized and cached.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_potential_med_line(text: str, ANCHOR_PATTERN) -> bool: |
|
|
t = text.lower() |
|
|
|
|
|
anchor_match = ANCHOR_PATTERN.search(t) |
|
|
if anchor_match: |
|
|
return True |
|
|
return False |
|
|
|
|
|
def validate_drug_match(term: str, drug_db, drug_token_index): |
|
|
""" |
|
|
Map SymSpell term -> canonical database drug, or None if noise. |
|
|
""" |
|
|
if term in drug_db: |
|
|
return term |
|
|
if term in drug_token_index: |
|
|
|
|
|
return sorted(drug_token_index[term])[0] |
|
|
return None |
|
|
|
|
|
def normalize_anchored_tokens(raw_text: str, ANCHOR_PREFIXES): |
|
|
""" |
|
|
Use TAB/CAP/T. as anchors, not something to delete: |
|
|
- 'TABCLOPITAB75MG TAB' -> ['clopitab'] |
|
|
- 'TAB SOBISISTAB' -> ['sobisistab'] |
|
|
- 'TABSTARPRESSXL25MGTAB' -> ['starpressxl'] |
|
|
""" |
|
|
t = raw_text.lower() |
|
|
|
|
|
t = re.sub(r"\d+\s*(mg|ml|gm|%|u|mcg)", " ", t) |
|
|
t = re.sub(r"\d+", " ", t) |
|
|
|
|
|
|
|
|
t = re.sub(r"[^\w\s]", " ", t) |
|
|
|
|
|
tokens = t.split() |
|
|
|
|
|
normalized = [] |
|
|
skip_next = False |
|
|
|
|
|
for i, tok in enumerate(tokens): |
|
|
if skip_next: |
|
|
skip_next = False |
|
|
continue |
|
|
|
|
|
base = tok |
|
|
|
|
|
|
|
|
for pref in ANCHOR_PREFIXES: |
|
|
if base.startswith(pref) and len(base) > len(pref): |
|
|
base = base[len(pref):] |
|
|
break |
|
|
|
|
|
|
|
|
if base in ["tab", "cap", "t"]: |
|
|
if i + 1 < len(tokens): |
|
|
merged = tokens[i + 1] |
|
|
for pref in ANCHOR_PREFIXES: |
|
|
if merged.startswith(pref) and len(merged) > len(pref): |
|
|
merged = merged[len(pref):] |
|
|
break |
|
|
base = merged |
|
|
skip_next = True |
|
|
else: |
|
|
continue |
|
|
|
|
|
base = base.strip() |
|
|
normalized.append(base) |
|
|
|
|
|
return normalized |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Initializing OCR engine...") |
|
|
ocr_engine = RapidOCR( |
|
|
params={ |
|
|
"Global.max_side_len": 2000, |
|
|
"Det.engine_type": EngineType.ONNXRUNTIME, |
|
|
"Det.lang_type": LangDet.EN, |
|
|
"Det.model_type": ModelType.MOBILE, |
|
|
"Det.ocr_version": OCRVersion.PPOCRV4, |
|
|
"Cls.engine_type": EngineType.ONNXRUNTIME, |
|
|
"Cls.lang_type": LangCls.CH, |
|
|
"Cls.model_type": ModelType.MOBILE, |
|
|
"Cls.ocr_version": OCRVersion.PPOCRV4, |
|
|
"Rec.engine_type": EngineType.ONNXRUNTIME, |
|
|
"Rec.lang_type": LangRec.EN, |
|
|
"Rec.model_type": ModelType.MOBILE, |
|
|
"Rec.ocr_version": OCRVersion.PPOCRV4, |
|
|
} |
|
|
) |
|
|
logger.info("OCR engine initialized.") |
|
|
|
|
|
def process_image_ocr(image_input): |
|
|
|
|
|
if isinstance(image_input, str): |
|
|
img = cv2.imread(image_input) |
|
|
if img is None: |
|
|
raise ValueError(f"Could not load image from {image_input}") |
|
|
else: |
|
|
img = image_input |
|
|
|
|
|
|
|
|
ocr_result = ocr_engine( |
|
|
img, |
|
|
use_det=True, |
|
|
use_cls=True, |
|
|
use_rec=True, |
|
|
text_score=0.3, |
|
|
box_thresh=0.3, |
|
|
unclip_ratio=2.0, |
|
|
return_word_box=False, |
|
|
) |
|
|
|
|
|
ocr_data = ocr_result.txts |
|
|
|
|
|
found_meds_with_originals = {} |
|
|
|
|
|
for item in ocr_data: |
|
|
text_lower = item.lower() |
|
|
|
|
|
|
|
|
if not is_potential_med_line(text_lower, ANCHOR_PATTERN): |
|
|
continue |
|
|
|
|
|
|
|
|
if "dr." in text_lower or "dr " in text_lower: |
|
|
continue |
|
|
|
|
|
|
|
|
candidate_tokens = normalize_anchored_tokens(item, ANCHOR_PREFIXES) |
|
|
|
|
|
|
|
|
normalized_text_str = " ".join(candidate_tokens) |
|
|
|
|
|
|
|
|
if candidate_tokens: |
|
|
segmentation = sym_spell.word_segmentation(" ".join(candidate_tokens)) |
|
|
corrected_string = segmentation.corrected_string |
|
|
candidate_tokens = corrected_string.split() |
|
|
|
|
|
line_matches = [] |
|
|
i = 0 |
|
|
n = len(candidate_tokens) |
|
|
|
|
|
while i < n: |
|
|
match_found = False |
|
|
|
|
|
for length in range(min(5, n - i), 0, -1): |
|
|
phrase_tokens = candidate_tokens[i : i + length] |
|
|
phrase = " ".join(phrase_tokens) |
|
|
|
|
|
|
|
|
if phrase in drug_db: |
|
|
|
|
|
if phrase in normalized_text_str: |
|
|
line_matches.append((phrase, "exact", phrase)) |
|
|
else: |
|
|
line_matches.append((phrase, "fuzzy", phrase)) |
|
|
|
|
|
i += length |
|
|
match_found = True |
|
|
break |
|
|
|
|
|
if match_found: |
|
|
continue |
|
|
|
|
|
|
|
|
word = candidate_tokens[i] |
|
|
i += 1 |
|
|
|
|
|
|
|
|
canonical = validate_drug_match(word, drug_db, drug_token_index) |
|
|
if canonical: |
|
|
|
|
|
if len(word) / len(canonical) < 0.6: |
|
|
continue |
|
|
|
|
|
if word in normalized_text_str: |
|
|
line_matches.append((canonical, "exact", word)) |
|
|
else: |
|
|
line_matches.append((canonical, "fuzzy", word)) |
|
|
continue |
|
|
|
|
|
|
|
|
if len(word) < 3: |
|
|
continue |
|
|
|
|
|
suggestions = sym_spell.lookup( |
|
|
word, Verbosity.CLOSEST, max_edit_distance=1 |
|
|
) |
|
|
if not suggestions: |
|
|
continue |
|
|
|
|
|
cand = suggestions[0].term |
|
|
canonical = validate_drug_match(cand, drug_db, drug_token_index) |
|
|
if canonical: |
|
|
|
|
|
if len(word) / len(canonical) < 0.6: |
|
|
continue |
|
|
line_matches.append((canonical, "fuzzy", word)) |
|
|
|
|
|
|
|
|
exact_matches = [m for m in line_matches if m[1] == "exact"] |
|
|
if exact_matches: |
|
|
final_matches = exact_matches |
|
|
else: |
|
|
final_matches = line_matches |
|
|
|
|
|
for match in final_matches: |
|
|
canonical = match[0] |
|
|
original_text = match[2] |
|
|
|
|
|
if canonical not in found_meds_with_originals: |
|
|
found_meds_with_originals[canonical] = [] |
|
|
if item not in found_meds_with_originals[canonical]: |
|
|
found_meds_with_originals[canonical].append(item) |
|
|
|
|
|
return found_meds_with_originals |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_input(image_input): |
|
|
"""Gradio interface handler.""" |
|
|
if image_input is None: |
|
|
return "Please upload an image.", {} |
|
|
|
|
|
try: |
|
|
|
|
|
image_bgr = cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
found_meds_dict = process_image_ocr(image_bgr) |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
|
|
|
|
drugs_list = sorted(found_meds_dict.keys()) |
|
|
drugs_count = len(drugs_list) |
|
|
|
|
|
|
|
|
summary = f"Found {drugs_count} medication(s) in {elapsed_time:.3f}s" |
|
|
|
|
|
|
|
|
medications_json = { |
|
|
"total_medications": drugs_count, |
|
|
"processing_time": f"{elapsed_time:.3f}s", |
|
|
"medications": [ |
|
|
{ |
|
|
"id": idx + 1, |
|
|
"name": drug.title(), |
|
|
"original_text": found_meds_dict[drug] |
|
|
} |
|
|
for idx, drug in enumerate(drugs_list) |
|
|
] |
|
|
} |
|
|
|
|
|
return summary, medications_json |
|
|
except Exception as e: |
|
|
logger.error(f"Processing error: {e}") |
|
|
return f"Error: {str(e)}", {} |
|
|
|
|
|
logger.info("Starting ...") |
|
|
|
|
|
with gr.Blocks(title="Fast OCR") as demo: |
|
|
gr.Markdown("# Prescription OCR") |
|
|
gr.Markdown("Upload a prescription image to extract medications.") |
|
|
gr.Markdown("**Optimized with MOBILE models + ONNX Runtime for maximum speed**") |
|
|
|
|
|
with gr.Row(): |
|
|
image_input = gr.Image(type="numpy", label="Upload Prescription") |
|
|
output_text = gr.Textbox(label="Summary", lines=2) |
|
|
|
|
|
with gr.Row(): |
|
|
medications_json = gr.JSON(label="Extracted Medications") |
|
|
|
|
|
submit_btn = gr.Button("Extract Medications", variant="primary") |
|
|
submit_btn.click(process_input, inputs=image_input, outputs=[output_text, medications_json]) |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["examples/test1.jpeg"], |
|
|
["examples/test2.jpeg"], |
|
|
["examples/test3.jpeg"], |
|
|
["examples/test4.jpeg"], |
|
|
["examples/test5.jpeg"], |
|
|
["examples/test6.png"], |
|
|
["examples/test7.png"], |
|
|
], |
|
|
inputs=image_input, |
|
|
outputs=[output_text, medications_json], |
|
|
fn=process_input, |
|
|
cache_examples=False, |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue(max_size=10) |
|
|
demo.launch(max_threads=4, show_error=True) |