Spaces:
Paused
Paused
| # app_fullbody_pretrained.py β Full-Body X-ray Analysis (Pretrained) + PDF + Suspect Boxes | |
| # ----------------------------------------------------------------------------- | |
| # Run: python -m uvicorn app_fullbody_pretrained:api --host 0.0.0.0 --port 7860 | |
| # ----------------------------------------------------------------------------- | |
| import os | |
| os.system("pip uninstall -y google-generativeai google-api-core googleapis-common-protos grpcio || true") | |
| os.system("pip install --upgrade google-generativeai==0.8.5 google-api-core protobuf grpcio --quiet") | |
| import os, io, json, tempfile, hashlib, datetime, pathlib | |
| from io import BytesIO | |
| from typing import Optional, Dict, List, Any, Tuple | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| import unicodedata | |
| import torch | |
| import torch.nn.functional as F | |
| import base64 | |
| # FastAPI imports | |
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.responses import JSONResponse, Response | |
| from fastapi.staticfiles import StaticFiles | |
| from transformers import DPTFeatureExtractor, DPTForDepthEstimation | |
| import matplotlib.pyplot as plt | |
| from mpl_toolkits.mplot3d import Axes3D # noqa: F401 | |
| import os | |
| #import openai | |
| from fastapi import FastAPI, UploadFile, File | |
| from pydantic import BaseModel | |
| # Hugging Face Space base URL | |
| HF_SPACE_URL = os.getenv("HF_SPACE_URL", "https://abbhy123ghh-x-ray-analysis.hf.space").rstrip("/") | |
| # β Create app | |
| api = FastAPI() | |
| os.makedirs("static", exist_ok=True) | |
| api.mount("/static", StaticFiles(directory="static"), name="static") | |
| from transformers import DPTFeatureExtractor, DPTForDepthEstimation | |
| # Optional dependencies | |
| try: import pydicom; HAVE_DICOM=True | |
| except Exception: HAVE_DICOM=False | |
| try: import torchxrayvision as xrv; HAVE_XRV=True | |
| except Exception: HAVE_XRV=False | |
| try: | |
| from transformers import (AutoProcessor, CLIPModel, AutoImageProcessor, AutoModelForImageClassification) | |
| HAVE_TRF=True | |
| except Exception: HAVE_TRF=False | |
| try: from skimage.segmentation import slic; from skimage.util import img_as_float; HAVE_SKIMG=True | |
| except Exception: HAVE_SKIMG=False | |
| try: from fpdf import FPDF; HAVE_PDF=True | |
| except Exception: HAVE_PDF=False | |
| import uvicorn | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ----------------------------- | |
| # KEEP all your existing utility functions + configs | |
| # ----------------------------- | |
| # (to_uint8, read_any_xray, chest_disease_probs, analyze_fullbody, pdf_report_bytes, etc.) | |
| # I did not touch any model logic. | |
| # ----------------------------- | |
| # Root | |
| # ----------------------------- | |
| async def root(): | |
| return {"ok": True, "message": "X-ray API is running"} | |
| def to_base64(img: np.ndarray) -> str: | |
| pil = Image.fromarray(img).convert("RGB") | |
| buf = io.BytesIO() | |
| pil.save(buf, format="PNG") | |
| return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode() | |
| # ======================================================================== | |
| # β FastAPI endpoints β updated for multi-region analysis | |
| # ======================================================================== | |
| async def analyze(file: UploadFile = File(...)): | |
| """Analyze an uploaded X-ray and return meta + annotated overlays (JSON).""" | |
| try: | |
| suffix = (file.filename or "").lower() | |
| raw = await file.read() | |
| # Read DICOM or normal image | |
| if suffix.endswith(".dcm") and HAVE_DICOM: | |
| ds = pydicom.dcmread(io.BytesIO(raw)) | |
| arr = ds.pixel_array.astype(np.float32) | |
| arr -= arr.min(); arr /= (arr.max() - arr.min() + 1e-6) | |
| gray_u8 = (arr * 255.0).clip(0, 255).astype(np.uint8) | |
| else: | |
| gray_u8 = np.array(Image.open(io.BytesIO(raw)).convert("L")) | |
| gray_u8 = to_uint8(gray_u8) | |
| # Run full-body multi-region analysis | |
| res = analyze_fullbody(gray_u8) | |
| # π Flatten meta_pred so frontend can show Age/View/Gender/CTR | |
| meta_pred = res.get("meta_pred", {}) | |
| res["age"] = meta_pred.get("age", {}).get("label", "N/A") | |
| res["view"] = meta_pred.get("view", {}).get("label", "N/A") | |
| res["gender"] = meta_pred.get("sex", {}).get("label", "N/A") | |
| res["ctr"] = "--" # placeholder | |
| # Convert original to base64 | |
| res["original_img"] = to_base64(gray_u8) | |
| # π Overlay with boxes + labels | |
| # --- Try segmentation overlay --- | |
| overlay_img, overlay_meta = findings_overlay(gray_u8, res) | |
| # --- β Fallback if overlay is None (always show something) | |
| if overlay_img is None: | |
| print("β οΈ Overlay not generated β using edge-based overlay fallback.") | |
| overlay_img = edges_overlay(gray_u8) | |
| overlay_meta = {"label": "Edge-based fallback", "boxes": []} | |
| res["overlay_img"] = to_base64(overlay_img) | |
| res["overlay_meta"] = overlay_meta | |
| # β Generate structured AI radiology report | |
| doctor_report_html = generate_medical_report(res) | |
| return JSONResponse({ | |
| "ok": True, | |
| "meta": { | |
| "age": res["age"], | |
| "view": res["view"], | |
| "gender": res["gender"], | |
| "ctr": res["ctr"] | |
| }, | |
| "original_img": res["original_img"], | |
| "overlay_img": res["overlay_img"], | |
| "overlay_meta": res["overlay_meta"], | |
| "tasks": res.get("tasks", []), | |
| "region": res.get("region", "unknown"), | |
| "region_conf": res.get("region_conf", 0.0), | |
| "meta_pred": res.get("meta_pred", {}), | |
| "doctor_report_html": doctor_report_html # π©» formatted HTML report | |
| }) | |
| except Exception as e: | |
| print("β Analyze failed:", e) | |
| return JSONResponse({"error": str(e)}, status_code=400) | |
| from fastapi import UploadFile, File | |
| from fastapi.responses import Response, JSONResponse | |
| from reportlab.lib.pagesizes import A4 | |
| from reportlab.lib import colors | |
| from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle | |
| from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, Image as RLImage | |
| import io, datetime, json | |
| from PIL import Image | |
| import numpy as np | |
| import pydicom | |
| # =========================================================== | |
| # π©» Unified PDF Generator (Drlogy-style Layout for All Regions) | |
| # =========================================================== | |
| def generate_clinical_pdf(gray_img, overlay_img, res, doctor_report): | |
| buffer = io.BytesIO() | |
| doc = SimpleDocTemplate( | |
| buffer, pagesize=A4, rightMargin=36, leftMargin=36, topMargin=60, bottomMargin=36 | |
| ) | |
| styles = getSampleStyleSheet() | |
| normal = ParagraphStyle('Normal', parent=styles['Normal'], fontSize=10, leading=13) | |
| bold = ParagraphStyle('Bold', parent=normal, fontName='Helvetica-Bold') | |
| header = ParagraphStyle('Header', parent=normal, alignment=1, fontName='Helvetica-Bold', fontSize=14) | |
| footer = ParagraphStyle('Footer', parent=normal, fontSize=8, textColor=colors.gray) | |
| story = [] | |
| # --- HEADER --- | |
| story.append(Paragraph("<b>DRLOGY IMAGING CENTER</b>", header)) | |
| story.append(Paragraph("X-Ray | CT-Scan | MRI | USG", bold)) | |
| story.append(Paragraph("105-108, SMART VISION COMPLEX, HEALTHCARE ROAD, MUMBAI - 689578", normal)) | |
| story.append(Paragraph("π 0123456789 | βοΈ [email protected]", normal)) | |
| story.append(Spacer(1, 10)) | |
| # --- DYNAMIC TITLE --- | |
| region_title = res.get("region", "General Region").replace("_", " ").title() | |
| story.append(Paragraph(f"<b>AI RADIOLOGY REPORT β {region_title.upper()}</b>", header)) | |
| story.append(Spacer(1, 10)) | |
| # --- PATIENT INFO --- | |
| patient_table = [ | |
| ["Name:", res.get("patient_name", "N/A"), "Age:", f"{res.get('age', 'N/A')} yrs"], | |
| ["Sex:", res.get("gender", "N/A"), "Study Date:", datetime.datetime.now().strftime("%d %b %Y, %I:%M %p")], | |
| ["Region:", region_title, "AI Confidence:", f"{res.get('region_conf', 0):.2f}"], | |
| ] | |
| table = Table(patient_table, colWidths=[70, 160, 90, 150]) | |
| table.setStyle(TableStyle([ | |
| ("BOX", (0, 0), (-1, -1), 0.5, colors.gray), | |
| ("INNERGRID", (0, 0), (-1, -1), 0.25, colors.gray), | |
| ("FONTNAME", (0, 0), (-1, -1), "Helvetica"), | |
| ("FONTSIZE", (0, 0), (-1, -1), 9), | |
| ("BACKGROUND", (0, 0), (-1, 0), colors.whitesmoke) | |
| ])) | |
| story.append(table) | |
| story.append(Spacer(1, 12)) | |
| # --- TECHNICAL DETAILS --- | |
| story.append(Paragraph("<b>TECHNICAL DETAILS</b>", bold)) | |
| story.append(Paragraph( | |
| "AI-assisted analysis performed using a multi-region deep learning model. " | |
| "Single-view radiograph analyzed for anatomical abnormalities.", normal)) | |
| story.append(Spacer(1, 10)) | |
| # --- FINDINGS --- | |
| story.append(Paragraph("<b>AI FINDINGS</b>", bold)) | |
| story.append(Paragraph( | |
| "Below are the observed findings as per AI-based analysis of the selected body part.", normal)) | |
| story.append(Spacer(1, 5)) | |
| story.append(Paragraph(doctor_report.replace("\n", "<br/>"), normal)) | |
| story.append(Spacer(1, 10)) | |
| # --- IMPRESSION --- | |
| story.append(Paragraph("<b>AI IMPRESSION</b>", bold)) | |
| story.append(Paragraph( | |
| "Based on the AI modelβs interpretation, no acute abnormality was confidently detected unless otherwise stated above.", | |
| normal)) | |
| story.append(Spacer(1, 10)) | |
| # --- RECOMMENDATIONS --- | |
| story.append(Paragraph("<b>RECOMMENDATIONS</b>", bold)) | |
| story.append(Paragraph( | |
| "Clinical correlation is advised. If clinical suspicion persists, further radiographic or advanced imaging is recommended.", | |
| normal)) | |
| story.append(Spacer(1, 10)) | |
| # --- IMAGES --- | |
| story.append(Paragraph("<b>IMAGES:</b>", bold)) | |
| img_buf = io.BytesIO() | |
| Image.fromarray(gray_img).save(img_buf, format="PNG") | |
| img_buf.seek(0) | |
| story.append(RLImage(img_buf, width=200, height=200)) | |
| if overlay_img is not None: | |
| overlay_buf = io.BytesIO() | |
| Image.fromarray(overlay_img).save(overlay_buf, format="PNG") | |
| overlay_buf.seek(0) | |
| story.append(RLImage(overlay_buf, width=200, height=200)) | |
| story.append(Spacer(1, 20)) | |
| # --- FOOTER --- | |
| story.append(Paragraph("**** End of Report ****", footer)) | |
| story.append(Spacer(1, 6)) | |
| story.append(Paragraph("Dr. AI Radiologist (MD)", footer)) | |
| story.append(Paragraph(f"Generated on: {datetime.datetime.now():%d %b %Y, %I:%M %p}", footer)) | |
| story.append(Paragraph("This AI report is supportive, not a substitute for radiologist diagnosis.", footer)) | |
| doc.build(story) | |
| buffer.seek(0) | |
| return buffer.getvalue() | |
| # =========================================================== | |
| # β Main Endpoint β Analyze & Generate PDF | |
| # =========================================================== | |
| from fastapi import UploadFile, File | |
| from fastapi.responses import Response, JSONResponse | |
| from reportlab.lib.pagesizes import A4 | |
| from reportlab.lib import colors | |
| from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle | |
| from reportlab.platypus import ( | |
| SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, Image as RLImage | |
| ) | |
| from reportlab.lib.units import mm | |
| import io, datetime, json, re | |
| from PIL import Image | |
| import numpy as np | |
| import pydicom | |
| async def analyze_pdf(file: UploadFile = File(...)): | |
| """Generate a professional AI radiology report PDF.""" | |
| try: | |
| suffix = (file.filename or "").lower() | |
| raw = await file.read() | |
| # --- Read DICOM / Image --- | |
| if suffix.endswith(".dcm") and HAVE_DICOM: | |
| ds = pydicom.dcmread(io.BytesIO(raw)) | |
| arr = ds.pixel_array.astype(np.float32) | |
| arr -= arr.min(); arr /= (arr.max()-arr.min()+1e-6) | |
| gray_u8 = (arr*255).clip(0,255).astype(np.uint8) | |
| else: | |
| gray_u8 = np.array(Image.open(io.BytesIO(raw)).convert("L")) | |
| gray_u8 = to_uint8(gray_u8) | |
| # --- Model Analysis --- | |
| res = analyze_fullbody(gray_u8) | |
| overlay_img, _ = findings_overlay(gray_u8, res) | |
| if overlay_img is None: | |
| overlay_img = edges_overlay(gray_u8) | |
| # --- Gemini AI Report --- | |
| try: | |
| prompt = f""" | |
| You are a senior radiologist. Generate a structured radiology report. | |
| Data: {json.dumps(res, indent=2)} | |
| Include clear prose for: Technical Details, Findings, Impression, Recommendations. | |
| Avoid markdown, bullets, or extra headings. | |
| """ | |
| gemini_response = gemini_model.generate_content(prompt) | |
| doctor_report = getattr(gemini_response, "text", "AI report unavailable.").strip() | |
| except Exception: | |
| doctor_report = "AI medical report could not be generated." | |
| clean_report = re.sub(r"[#*_]+", "", doctor_report) | |
| clean_report = re.sub(r"\n{2,}", "\n", clean_report.strip()) | |
| # --- PDF Template --- | |
| buffer = io.BytesIO() | |
| doc = SimpleDocTemplate(buffer, pagesize=A4, leftMargin=35, rightMargin=35, topMargin=35, bottomMargin=40) | |
| styles = getSampleStyleSheet() | |
| normal = ParagraphStyle("NormalCustom", parent=styles["Normal"], fontSize=10, leading=14) | |
| heading = ParagraphStyle("Heading", parent=styles["Heading2"], fontSize=13, textColor=colors.HexColor("#004C99")) | |
| center_heading = ParagraphStyle("CenterHeading", parent=styles["Heading2"], fontSize=14, | |
| textColor=colors.white, alignment=1, spaceAfter=10) | |
| footer = ParagraphStyle("Footer", parent=styles["Normal"], fontSize=9, textColor=colors.gray, alignment=1) | |
| story = [] | |
| # === HEADER BAR === | |
| story.append(Table( | |
| [[Paragraph("<b>DRLOGY IMAGING CENTER</b>", center_heading)]], | |
| colWidths=[500], | |
| style=[ | |
| ('BACKGROUND', (0,0), (-1,-1), colors.HexColor("#004C99")), | |
| ('BOTTOMPADDING', (0,0), (-1,-1), 6), | |
| ('TOPPADDING', (0,0), (-1,-1), 6) | |
| ] | |
| )) | |
| story.append(Paragraph("<font color='grey'>X-Ray | CT-Scan | MRI | USG</font>", normal)) | |
| story.append(Paragraph("105-108, SMART VISION COMPLEX, HEALTHCARE ROAD, MUMBAI - 689578", normal)) | |
| story.append(Paragraph("π 0123456789 β [email protected]", normal)) | |
| story.append(Spacer(1, 10)) | |
| story.append(Table([[ '', '' ]], colWidths=[480], | |
| style=[('LINEBELOW',(0,0),(-1,-1),1,colors.HexColor("#004C99"))])) | |
| story.append(Spacer(1, 8)) | |
| # === TITLE SECTION === | |
| title_table = Table( | |
| [[Paragraph("<b><font size=13 color='#004C99'>AI RADIOLOGY REPORT</font></b>", heading), | |
| Paragraph(datetime.datetime.now().strftime("%d %b %Y, %I:%M %p"), normal)]], | |
| colWidths=[340, 140], | |
| style=[ | |
| ("VALIGN", (0,0), (-1,-1), "MIDDLE"), | |
| ("ALIGN", (1,0), (1,0), "RIGHT") | |
| ] | |
| ) | |
| story.append(title_table) | |
| story.append(Spacer(1, 12)) | |
| # === PATIENT INFO === | |
| story.append(Paragraph("<b>PATIENT INFORMATION</b>", heading)) | |
| patient_info = [ | |
| ["Name", "N/A"], | |
| ["Age", f"{res.get('meta_pred', {}).get('age', {}).get('label', 'N/A')}"], | |
| ["Sex", f"{res.get('meta_pred', {}).get('sex', {}).get('label', 'N/A')}"], | |
| ["Region", res.get("region", "Unknown").title()], | |
| ["AI Confidence", f"{res.get('region_conf', 0):.2f}"], | |
| ] | |
| table = Table(patient_info, colWidths=[120, 360]) | |
| table.setStyle(TableStyle([ | |
| ("BOX", (0,0), (-1,-1), 0.5, colors.grey), | |
| ("INNERGRID", (0,0), (-1,-1), 0.25, colors.lightgrey), | |
| ("BACKGROUND", (0,0), (0,-1), colors.HexColor("#EAF1FB")), | |
| ])) | |
| story.append(table) | |
| story.append(Spacer(1, 10)) | |
| # === MAIN BODY === | |
| story.append(Paragraph("<b>CLINICAL REPORT SUMMARY</b>", heading)) | |
| story.append(Spacer(1, 4)) | |
| story.append(Paragraph(clean_report.replace("\n", "<br/>"), normal)) | |
| story.append(Spacer(1, 15)) | |
| # === IMAGES === | |
| story.append(Paragraph("<b>REFERENCE IMAGES</b>", heading)) | |
| story.append(Spacer(1, 4)) | |
| img_buf = io.BytesIO() | |
| Image.fromarray(gray_u8).save(img_buf, format="PNG") | |
| img_buf.seek(0) | |
| story.append(RLImage(img_buf, width=180, height=180)) | |
| story.append(Spacer(1, 6)) | |
| overlay_buf = io.BytesIO() | |
| Image.fromarray(overlay_img).save(overlay_buf, format="PNG") | |
| overlay_buf.seek(0) | |
| story.append(RLImage(overlay_buf, width=180, height=180)) | |
| story.append(Spacer(1, 20)) | |
| # === FOOTER === | |
| story.append(Table([[ '', '' ]], colWidths=[480], | |
| style=[('LINEBELOW',(0,0),(-1,-1),1,colors.HexColor("#004C99"))])) | |
| story.append(Spacer(1, 6)) | |
| story.append(Paragraph("<b>*** End of Report ***</b>", footer)) | |
| story.append(Spacer(1, 3)) | |
| story.append(Paragraph("Generated on: " + datetime.datetime.now().strftime("%d %b %Y, %I:%M %p"), footer)) | |
| story.append(Paragraph("This report is AI-generated and intended for research/clinical support only.", footer)) | |
| story.append(Paragraph("<font color='#004C99'>Β© 2025 Drlogy Imaging Center | 24x7 Services</font>", footer)) | |
| doc.build(story) | |
| buffer.seek(0) | |
| return Response( | |
| content=buffer.read(), | |
| media_type="application/pdf", | |
| headers={"Content-Disposition": "attachment; filename=ai_xray_report.pdf"} | |
| ) | |
| except Exception as e: | |
| import traceback; traceback.print_exc() | |
| return JSONResponse({"error": str(e)}, status_code=400) | |
| # =============================================== | |
| # # β Runway Gen-2 β Convert X-ray to 3D Video | |
| # # =============================================== | |
| # import requests | |
| # import io | |
| # import base64 | |
| # import time | |
| # from fastapi import FastAPI, UploadFile, File | |
| # from fastapi.responses import JSONResponse | |
| # from PIL import Image | |
| # app = FastAPI() | |
| # PIXVERSE_API_KEY = "sk-4720d9d2ae7bd362bcece9906ffc6848" | |
| # @api.post("/xray_to_3dvideo") | |
| # async def xray_to_3dvideo(file: UploadFile = File(...)): | |
| # try: | |
| # raw = await file.read() | |
| # img = Image.open(io.BytesIO(raw)).convert("RGB") | |
| # buf = io.BytesIO() | |
| # img.save(buf, format="PNG") | |
| # base64_img = base64.b64encode(buf.getvalue()).decode() | |
| # # β FIXED PAYLOAD & HEADER | |
| # payload = { | |
| # "image": base64_img, # was 'image_base64' | |
| # "prompt": "A grayscale rotating 3D visualization of a medical X-ray scan, showing bones with realistic depth and cinematic lighting.", | |
| # "duration": 4, | |
| # "motion": "rotate", # was 'motion_mode' | |
| # "quality": "540p" | |
| # } | |
| # headers = { | |
| # "Authorization": f"Bearer {PIXVERSE_API_KEY}", # was 'x-api-key' | |
| # "Content-Type": "application/json" | |
| # } | |
| # # 1οΈβ£ Start PixVerse job | |
| # start_response = requests.post( | |
| # "https://api.segmind.com/v1/pixverse-image2video", | |
| # json=payload, | |
| # headers=headers | |
| # ) | |
| # start_data = start_response.json() | |
| # print("PixVerse start_data:", start_data) | |
| # # 2οΈβ£ Extract task ID or instant video URL | |
| # task_id = start_data.get("task_id") or start_data.get("id") | |
| # video_url = ( | |
| # start_data.get("video_url") | |
| # or start_data.get("video") | |
| # or start_data.get("output", {}).get("video_url") | |
| # ) | |
| # if not task_id and not video_url: | |
| # return JSONResponse({ | |
| # "error": start_data, | |
| # "message": "β PixVerse did not return a task ID or video URL (check API key or payload)" | |
| # }, status_code=400) | |
| # # If instant video available | |
| # if video_url: | |
| # return JSONResponse({ | |
| # "ok": True, | |
| # "video_url": video_url, | |
| # "message": "β 3D X-ray video generated instantly via PixVerse" | |
| # }) | |
| # # 3οΈβ£ Poll task status | |
| # poll_url = f"https://api.segmind.com/v1/tasks/{task_id}" | |
| # for attempt in range(30): # 2.5 minutes max | |
| # time.sleep(5) | |
| # status_response = requests.get(poll_url, headers=headers) | |
| # status_data = status_response.json() | |
| # status = status_data.get("status", "").lower() | |
| # if status == "succeeded": | |
| # output = status_data.get("output", {}) | |
| # video_url = ( | |
| # output.get("video_url") | |
| # or output.get("video") | |
| # or output.get("assets", {}).get("video") | |
| # ) | |
| # if video_url: | |
| # return JSONResponse({ | |
| # "ok": True, | |
| # "video_url": video_url, | |
| # "message": "β 3D X-ray video generated successfully via PixVerse" | |
| # }) | |
| # else: | |
| # return JSONResponse({ | |
| # "error": "Video URL missing in succeeded task", | |
| # "data": status_data | |
| # }, status_code=500) | |
| # if status == "failed": | |
| # return JSONResponse({ | |
| # "error": "β Video generation task failed", | |
| # "data": status_data | |
| # }, status_code=500) | |
| # return JSONResponse({ | |
| # "error": "β³ Video generation timed out", | |
| # "task_status": status, | |
| # "task_id": task_id | |
| # }, status_code=202) | |
| # except Exception as e: | |
| # print("3D video backend error:", str(e)) | |
| # return JSONResponse({"error": str(e)}, status_code=500) | |
| # =============================================== | |
| # β Fully Fixed Version β 3D Local Video Generator (with Overlay) | |
| # =============================================== | |
| from fastapi import UploadFile, File | |
| from fastapi.responses import JSONResponse | |
| import io, os, torch, numpy as np, imageio, cv2 | |
| #from torchvision.models.video import r3d_18 | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from mpl_toolkits.mplot3d import Axes3D | |
| HF_SPACE_URL = os.getenv("HF_SPACE_URL", "https://abbhy123ghh-x-ray-analysis.hf.space").rstrip("/") | |
| async def generate_3d_local(file: UploadFile = File(...)): | |
| """ | |
| Generate a pseudo-3D rotating video from a 2D X-ray image. | |
| Adds AI overlay and realistic depth-based surface projection. | |
| """ | |
| try: | |
| # 1οΈβ£ Read and preprocess input image | |
| raw = await file.read() | |
| img = Image.open(io.BytesIO(raw)).convert("L").resize((224, 224)) | |
| gray_u8 = np.array(img, dtype=np.uint8) | |
| img_np = np.array(img, dtype=np.float32) / 255.0 | |
| # 2οΈβ£ Try AI overlay from your model | |
| try: | |
| res = analyze_fullbody(gray_u8) | |
| overlay_img, _ = findings_overlay(gray_u8, res) | |
| if overlay_img is not None: | |
| base_img = overlay_img | |
| else: | |
| base_img = np.stack([gray_u8] * 3, axis=-1) | |
| except Exception as overlay_error: | |
| print("β οΈ Overlay generation failed:", overlay_error) | |
| base_img = np.stack([gray_u8] * 3, axis=-1) | |
| blended = cv2.cvtColor(base_img, cv2.COLOR_BGR2RGB) | |
| # 3οΈβ£ Create pseudo depth volume | |
| depth_slices = [img_np * (i / 10) for i in range(10)] | |
| volume = np.stack(depth_slices, axis=0) # [D,H,W] | |
| volume = torch.tensor(volume).unsqueeze(0).unsqueeze(0) | |
| if volume.shape[1] == 1: | |
| volume = volume.repeat(1, 3, 1, 1, 1) | |
| print("β Final volume shape:", volume.shape) | |
| # # 4οΈβ£ Dummy forward pass for consistency (no inference) | |
| # model = r3d_18(weights=None) | |
| # model.eval() | |
| # with torch.no_grad(): | |
| # _ = model(volume) | |
| # 5οΈβ£ Generate 3D surface rotation frames | |
| frames = [] | |
| H, W = gray_u8.shape | |
| x, y = np.meshgrid(np.linspace(0, 1, W), np.linspace(0, 1, H)) | |
| depth = cv2.GaussianBlur(gray_u8.astype(np.float32), (9, 9), 0) | |
| for angle in range(0, 360, 5): | |
| fig = plt.figure(figsize=(3, 3)) | |
| ax = fig.add_subplot(111, projection="3d") | |
| ax.view_init(30, angle) | |
| ax.plot_surface( | |
| x, y, depth / 255.0, | |
| facecolors=blended / 255.0, | |
| rstride=2, cstride=2, | |
| linewidth=0, antialiased=False | |
| ) | |
| ax.axis("off") | |
| fig.canvas.draw() | |
| frame = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
| frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
| frames.append(frame) | |
| plt.close(fig) | |
| # 6οΈβ£ Save output video (MP4 preferred, fallback to GIF) | |
| os.makedirs("static", exist_ok=True) | |
| out_path = "static/xray_3d_local.mp4" | |
| try: | |
| import imageio_ffmpeg | |
| writer = imageio.get_writer( | |
| out_path, format="ffmpeg", mode="I", | |
| fps=12, codec="libx264", quality=8 | |
| ) | |
| for frame in frames: | |
| writer.append_data(frame) | |
| writer.close() | |
| video_url = f"{HF_SPACE_URL}/static/xray_3d_local.mp4" | |
| msg = "β 3D MP4 video (with overlay) generated successfully" | |
| except Exception as e: | |
| print("β οΈ FFmpeg failed, saving GIF instead:", e) | |
| gif_path = out_path.replace(".mp4", ".gif") | |
| imageio.mimsave(gif_path, frames, duration=0.08) | |
| video_url = f"{HF_SPACE_URL}/static/xray_3d_local.gif" | |
| msg = "β 3D GIF (with overlay) generated successfully" | |
| return JSONResponse({ | |
| "ok": True, | |
| "message": msg, | |
| "video_url": video_url | |
| }) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return JSONResponse({"error": str(e)}, status_code=500) | |
| # ============================================================ | |
| # β MONAI + PyVista 3D Full-Body Visualization (Final HF Safe) | |
| # ============================================================ | |
| from fastapi import UploadFile, File | |
| from fastapi.responses import JSONResponse | |
| import io, os, numpy as np, torch | |
| from PIL import Image | |
| from monai.networks.nets import UNet | |
| import torch.nn.functional as F | |
| import pyvista as pv | |
| from pyvista import start_xvfb | |
| import imageio.v3 as iio | |
| HF_SPACE_URL = os.getenv("HF_SPACE_URL", "https://abbhy123ghh-x-ray-analysis.hf.space").rstrip("/") | |
| async def generate_3d_monai(file: UploadFile = File(...)): | |
| """ | |
| Generate a realistic 3D volumetric visualization for any X-ray region | |
| (skull, chest, limb, etc.) using MONAI for enhancement and PyVista for | |
| volumetric rendering. Fully CPU-compatible for Hugging Face Spaces. | |
| """ | |
| try: | |
| # --- Step 1: Read and preprocess image --- | |
| raw = await file.read() | |
| img = Image.open(io.BytesIO(raw)).convert("L").resize((256, 256)) | |
| gray = np.array(img, dtype=np.float32) / 255.0 | |
| # --- Step 2: Build synthetic 3D volume (simulate multiple slices) --- | |
| depth_slices = 16 | |
| volume = np.stack([gray * (1 - i / depth_slices) for i in range(depth_slices)], axis=0) | |
| volume = np.expand_dims(volume, axis=0) # [1, D, H, W] | |
| # --- Step 3: MONAI lightweight 3D UNet (shallow for CPU) --- | |
| x = torch.tensor(volume, dtype=torch.float32).unsqueeze(0) # [1, 1, D, H, W] | |
| pad_d = (16 - x.shape[2] % 16) % 16 | |
| pad_h = (16 - x.shape[3] % 16) % 16 | |
| pad_w = (16 - x.shape[4] % 16) % 16 | |
| x = F.pad(x, (0, pad_w, 0, pad_h, 0, pad_d)) | |
| model = UNet( | |
| spatial_dims=3, | |
| in_channels=1, | |
| out_channels=1, | |
| channels=(8, 16, 32), | |
| strides=(2, 2), | |
| ) | |
| model.eval() | |
| with torch.no_grad(): | |
| y = torch.sigmoid(model(x)) | |
| seg = y[0, 0].cpu().numpy() | |
| seg = (seg - seg.min()) / (seg.max() - seg.min() + 1e-8) | |
| seg = np.clip(seg * 255, 0, 255).astype(np.uint8) | |
| # --- Step 4: Create PyVista grid (Hugging Face compatible) --- | |
| start_xvfb() | |
| grid = pv.ImageData(dimensions=seg.shape) | |
| grid.spacing = (1, 1, 1) | |
| grid.origin = (0, 0, 0) | |
| grid.point_data["values"] = seg.flatten(order="F") | |
| # --- Step 5: Enhanced 3D visualization with lighting --- | |
| plotter = pv.Plotter(off_screen=True, window_size=[512, 512]) | |
| plotter.add_volume( | |
| grid, | |
| cmap="bone", | |
| opacity="sigmoid_5", | |
| shade=True, | |
| diffuse=1.0, | |
| specular=0.3, | |
| specular_power=15, | |
| ) | |
| plotter.set_background("black") | |
| # Add realistic light | |
| light = pv.Light( | |
| position=(seg.shape[2]*2, seg.shape[1]*2, seg.shape[0]), | |
| color='white', | |
| intensity=1.5, | |
| ) | |
| plotter.add_light(light) | |
| frames = [] | |
| for angle in range(0, 360, 5): | |
| plotter.camera_position = [ | |
| (seg.shape[2]*2, 0, seg.shape[0]/2), | |
| (seg.shape[2]/2, seg.shape[1]/2, seg.shape[0]/2), | |
| (0, 0, 1), | |
| ] | |
| plotter.camera.azimuth = angle | |
| img_frame = plotter.screenshot(return_img=True) | |
| frames.append(img_frame) | |
| plotter.close() | |
| # --- Step 6: Save rotating MP4 video --- | |
| os.makedirs("static", exist_ok=True) | |
| out_path = "static/xray_monai_3d_fullbody.mp4" | |
| iio.imwrite(out_path, frames, fps=12, codec="libx264") | |
| return JSONResponse({ | |
| "ok": True, | |
| "message": "β MONAI + PyVista 3D full-body visualization generated successfully", | |
| "video_url": f"{HF_SPACE_URL}/{out_path}" | |
| }) | |
| except Exception as e: | |
| import traceback; traceback.print_exc() | |
| return JSONResponse({"error": str(e)}, status_code=500) | |
| # ----------------------------- | |
| # Config | |
| # ----------------------------- | |
| CLIP_CANDIDATES = [ | |
| "microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224", | |
| "openai/clip-vit-base-patch32", | |
| ] | |
| CLIP_TEMP = float(os.getenv("CLIP_TEMP", "1.0")) | |
| CHEST_TEMP = float(os.getenv("CHEST_TEMP", "1.2")) | |
| CAL_TAU = float(os.getenv("CAL_TAU", "1.0")) # <-- ADDED: site calibration temperature | |
| TB_REPO = os.getenv("TB_REPO", "").strip() | |
| TB_LABEL = os.getenv("TB_LABEL", "Tuberculosis") | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) | |
| DEFAULT_THRESHOLDS: Dict[str, float] = { | |
| # chest | |
| "Pneumonia": 0.70, | |
| "Pulmonary Fibrosis": 0.75, | |
| "COPD (proxy)": 0.75, | |
| "Lung Cancer (proxy)": 0.70, | |
| "Tuberculosis": 0.70, | |
| # non-chest | |
| "Compression fracture": 0.65, | |
| "Scoliosis (proxy)": 0.65, | |
| "Hip fracture": 0.65, | |
| "OA (proxy)": 0.65, | |
| "Fracture (upper_ext)": 0.65, | |
| "Fracture (lower_ext)": 0.65, | |
| "Implant/Hardware": 0.65, | |
| "Skull fracture": 0.65, | |
| } | |
| ABSTAIN_MARGIN = float(os.getenv("ABSTAIN_MARGIN", "0.08")) | |
| REGION_CONF_MIN = float(os.getenv("REGION_CONF_MIN", "0.40")) | |
| REGION_NAMES = ["skull","chest","spine","pelvis","upper_ext","lower_ext"] | |
| REGION_PROMPTS = { | |
| "skull": ["X-ray of skull", "Head/skull radiograph"], | |
| "chest": ["Chest X-ray", "Thorax radiograph"], | |
| "spine": ["X-ray of spine", "Spine radiograph"], | |
| "pelvis": ["Pelvis or hip X-ray", "Pelvic radiograph"], | |
| "upper_ext": ["Upper extremity X-ray (shoulder to hand)", "Arm/wrist/hand X-ray"], | |
| "lower_ext": ["Lower extremity X-ray (hip to foot)", "Leg/ankle/foot X-ray"], | |
| } | |
| REGION_DISPLAY = { | |
| "chest":"chest", | |
| "skull":"skull", | |
| "spine":"spine", | |
| "pelvis":"pelvis/hip", | |
| "upper_ext":"upper extremity", | |
| "lower_ext":"lower extremity", | |
| "unknown":"unknown", | |
| } | |
| # Tasks (non-chest) | |
| TASKS: Dict[str, List[Dict[str, Any]]] = { | |
| "spine": [ | |
| {"name": "Compression fracture", "zs": ("X-ray of spine with compression fracture", "X-ray of spine without fracture")}, | |
| {"name": "Scoliosis (proxy)", "zs": ("X-ray of spine with scoliosis", "X-ray of a straight spine")}, | |
| ], | |
| "pelvis": [ | |
| {"name": "Hip fracture", "zs": ("Pelvis or hip X-ray with hip fracture", "Pelvis or hip X-ray without fracture")}, | |
| {"name": "OA (proxy)", "zs": ("Pelvis or hip X-ray with osteoarthritis", "Pelvis or hip X-ray with normal joint space")}, | |
| ], | |
| "upper_ext": [ | |
| {"name": "Fracture (upper_ext)", "zs": ("Upper extremity X-ray with fracture", "Upper extremity X-ray without fracture")}, | |
| {"name": "Implant/Hardware", "zs": ("X-ray with orthopedic implant or hardware", "X-ray without orthopedic implant")}, | |
| ], | |
| "lower_ext": [ | |
| {"name": "Fracture (lower_ext)", "zs": ("Lower extremity X-ray with fracture", "Lower extremity X-ray without fracture")}, | |
| {"name": "Implant/Hardware", "zs": ("X-ray with orthopedic implant or hardware", "X-ray without orthopedic implant")}, | |
| ], | |
| "skull": [ | |
| {"name": "Skull fracture", "zs": ("Skull X-ray with fracture", "Skull X-ray without fracture")}, | |
| ], | |
| } | |
| # Sub-label descriptions for every screening item | |
| CHEST_DESCRIPTIONS = { | |
| "Pneumonia": "Patchy/lobar opacities suggesting infection/inflammation; confirm clinically Β± labs.", | |
| "Pulmonary Fibrosis": "Chronic interstitial changes; consider HRCT for characterization.", | |
| "COPD (proxy)": "Proxy via emphysema/hyperinflation; COPD diagnosis needs spirometry/history.", | |
| "Lung Cancer (proxy)": "Proxy via mass/nodule; suspicious lesions need CT Β± biopsy.", | |
| "Tuberculosis": "CXR overlaps with other diseases; definitive Dx needs microbiology/NAAT.", | |
| "Asthma": "Typically not detectable on CXR; diagnosis is clinical with spirometry.", | |
| } | |
| TASK_DESCRIPTIONS = { | |
| "Compression fracture": "Height loss of vertebral body; triage to CT/MRI if clinical concern.", | |
| "Scoliosis (proxy)": "Curve or rotation; clinical significance depends on Cobb angle (not provided).", | |
| "Hip fracture": "Disruption of trabeculae/cortical lines; confirm with dedicated views/CT.", | |
| "OA (proxy)": "Joint space narrowing, osteophytes, subchondral changes; clinical correlation.", | |
| "Fracture (upper_ext)": "Cortical break or lucent line; immobilize and obtain dedicated views.", | |
| "Fracture (lower_ext)": "Cortical break or lucent line; weight-bearing precautions and follow-up.", | |
| "Implant/Hardware": "Presence of orthopedic material; assess alignment/loosening if applicable.", | |
| "Skull fracture": "Linear or depressed lucency; correlate with neuro status/CT.", | |
| } | |
| CHEST_TARGETS = { | |
| "Pneumonia": ["Pneumonia"], | |
| "Pulmonary Fibrosis": ["Fibrosis"], | |
| "COPD (proxy)": ["Emphysema"], | |
| "Lung Cancer (proxy)": ["Mass", "Nodule"], | |
| "Tuberculosis": ["Tuberculosis", "TB"], | |
| "Asthma": [], | |
| } | |
| # ----------------------------- | |
| # Utilities | |
| # ----------------------------- | |
| def to_uint8(img): | |
| if isinstance(img, Image.Image): | |
| img = np.array(img) | |
| if img.dtype == np.uint8: | |
| return img | |
| img = img.astype(np.float32) | |
| img -= img.min(); rng = img.max() - img.min() | |
| if rng > 1e-6: | |
| img = img / (img.max() + 1e-6) | |
| return (img * 255.0).clip(0, 255).astype(np.uint8) | |
| def read_any_xray(path_or_buffer) -> np.ndarray: | |
| """ | |
| Reads a DICOM (.dcm) or image (PNG/JPG) and returns an 8-bit grayscale (uint8) array. | |
| DICOM handling: | |
| - Applies RescaleSlope/Intercept if present. | |
| - Applies Window Center/Width (first value if multiple). | |
| - Correctly inverts MONOCHROME1 (after VOI windowing). | |
| """ | |
| name = getattr(path_or_buffer, "name", str(path_or_buffer)).lower() | |
| # Non-DICOM: use PIL and convert to 8-bit grayscale | |
| if not name.endswith(".dcm"): | |
| return np.array(Image.open(path_or_buffer).convert("L")) | |
| if not HAVE_DICOM: | |
| raise RuntimeError("Install pydicom to read DICOM files.") | |
| ds = pydicom.dcmread(path_or_buffer) | |
| # Raw pixel array -> float | |
| arr = ds.pixel_array.astype(np.float32) | |
| # 1) Modality LUT (Rescale Slope/Intercept) | |
| slope = float(getattr(ds, "RescaleSlope", 1.0)) | |
| intercept = float(getattr(ds, "RescaleIntercept", 0.0)) | |
| if slope != 1.0 or intercept != 0.0: | |
| arr = arr * slope + intercept | |
| # Helper: coerce WindowCenter/Width to float (first value if multi) | |
| def _first_float(x): | |
| if x is None: | |
| return None | |
| try: | |
| # pydicom may give MultiValue/Sequence/DSfloat | |
| return float(np.atleast_1d(x)[0]) | |
| except Exception: | |
| try: | |
| return float(x) | |
| except Exception: | |
| return None | |
| # 2) VOI LUT: Windowing (if available) | |
| wc = _first_float(getattr(ds, "WindowCenter", None)) | |
| ww = _first_float(getattr(ds, "WindowWidth", None)) | |
| if ww is not None and ww > 0: | |
| lo = wc - ww / 2.0 | |
| hi = wc + ww / 2.0 | |
| arr = np.clip(arr, lo, hi) | |
| arr = (arr - lo) / (hi - lo + 1e-6) # -> 0..1 | |
| else: | |
| # Fallback min-max | |
| mn = float(np.min(arr)) | |
| mx = float(np.max(arr)) | |
| arr = (arr - mn) / (mx - mn + 1e-6) | |
| # 3) Photometric interpretation (invert MONOCHROME1) | |
| phot = str(getattr(ds, "PhotometricInterpretation", "")).upper() | |
| if phot == "MONOCHROME1": | |
| arr = 1.0 - arr | |
| # Return uint8 image | |
| return (arr * 255.0).clip(0, 255).astype(np.uint8) | |
| def edges_overlay(gray_u8: np.ndarray) -> np.ndarray: | |
| try: | |
| edges = cv2.Canny(gray_u8, 50, 150) | |
| overlay = np.stack([gray_u8]*3, axis=-1) | |
| overlay[edges>0] = [255, 80, 80] | |
| return overlay | |
| except Exception: | |
| return np.stack([gray_u8]*3, axis=-1) | |
| def edges_overlay(gray_u8: np.ndarray) -> np.ndarray: | |
| try: | |
| edges = cv2.Canny(gray_u8, 50, 150) | |
| overlay = np.stack([gray_u8]*3, axis=-1) | |
| overlay[edges>0] = [255, 80, 80] | |
| return overlay | |
| except Exception: | |
| return np.stack([gray_u8]*3, axis=-1) | |
| # β Add your pseudo-3D function here | |
| import matplotlib.pyplot as plt | |
| from mpl_toolkits.mplot3d import Axes3D # noqa: F401 | |
| def pseudo3d_surface(gray_u8: np.ndarray) -> np.ndarray: | |
| """Generate a pseudo-3D surface plot from a 2D X-ray.""" | |
| try: | |
| H, W = gray_u8.shape | |
| X, Y = np.meshgrid(np.arange(W), np.arange(H)) | |
| Z = gray_u8.astype(np.float32) | |
| fig = plt.figure(figsize=(6, 6)) | |
| ax = fig.add_subplot(111, projection='3d') | |
| ax.plot_surface(X, Y, Z, cmap="bone", linewidth=0, antialiased=True) | |
| ax.set_axis_off() | |
| fig.tight_layout(pad=0) | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) | |
| plt.close(fig) | |
| buf.seek(0) | |
| img = Image.open(buf).convert("RGB") | |
| return np.array(img) | |
| except Exception as e: | |
| print("Pseudo-3D generation failed:", e) | |
| return np.stack([gray_u8]*3, axis=-1) | |
| def xray_to_depth_3d(gray_u8: np.ndarray) -> np.ndarray: | |
| """Convert 2D X-ray into pseudo-3D surface using Intel/dpt-hybrid-midas.""" | |
| if not HAVE_DEPTH: | |
| return np.stack([gray_u8]*3, axis=-1) | |
| try: | |
| pil = Image.fromarray(gray_u8).convert("RGB") | |
| inputs = DEPTH_EXTRACTOR(images=pil, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = DEPTH_MODEL(**inputs) | |
| predicted_depth = outputs.predicted_depth.squeeze().cpu().numpy() | |
| # Normalize depth | |
| depth_min, depth_max = predicted_depth.min(), predicted_depth.max() | |
| depth_norm = (predicted_depth - depth_min) / (depth_max - depth_min + 1e-6) | |
| # Render pseudo-3D | |
| H, W = depth_norm.shape | |
| X, Y = np.meshgrid(np.arange(W), np.arange(H)) | |
| Z = depth_norm * 255 | |
| fig = plt.figure(figsize=(6, 6)) | |
| ax = fig.add_subplot(111, projection="3d") | |
| ax.plot_surface(X, Y, Z, cmap="bone", linewidth=0, antialiased=True) | |
| ax.set_axis_off() | |
| fig.tight_layout(pad=0) | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) | |
| plt.close(fig) | |
| buf.seek(0) | |
| img = Image.open(buf).convert("RGB") | |
| return np.array(img) | |
| except Exception as e: | |
| print("3D depth rendering failed:", e) | |
| return np.stack([gray_u8]*3, axis=-1) | |
| def risk_level(prob: Optional[float]) -> str: | |
| if prob is None: return "N/A" | |
| if prob >= 0.70: return "High" | |
| if prob >= 0.50: return "Moderate" | |
| return "Low" | |
| def risk_tag(prob: Optional[float]) -> str: | |
| level = risk_level(prob) | |
| color = {"High":"#E03131","Moderate":"#F59F00","Low":"#2F9E44","N/A":"#868E96"}[level] | |
| pct = "N/A" if prob is None else f"{prob:.0%}" | |
| return f"<span style='background:{color};color:#fff;border-radius:8px;padding:2px 8px;font-size:0.85rem'>{level} {pct}</span>" | |
| def percent_bar(prob: Optional[float]) -> str: | |
| if prob is None: return "" | |
| width = int(round(prob*100)) | |
| return ( | |
| "<div style=\"width:100%;height:8px;background:#2b2b2b;border-radius:6px;overflow:hidden;margin:4px 0 8px\">" | |
| f"<div style=\"width:{width}%;height:8px;background:#4c6ef5\"></div>" | |
| "</div>" | |
| ) | |
| # ----------------------------- | |
| # Extra detail generator for ALL body regions | |
| # ----------------------------- | |
| def generate_extra_details(res: dict) -> str: | |
| region = res.get("region", "unknown") | |
| tasks = res.get("tasks", []) | |
| region_texts = { | |
| "skull": ( | |
| "π§ **Image Details (Skull)**\n\n" | |
| "View type: frontal or lateral skull radiograph.\n\n" | |
| "Bones visible: cranial vault, facial bones, mandible.\n\n" | |
| "π **AI / Computer Vision Context**\n" | |
| "- Region classifier: skull.\n" | |
| "- Tasks: skull fracture detection, implant/hardware.\n" | |
| "- AI approach: CLIP region classification + fracture CNN/ViT.\n\n" | |
| "π¦ **Limitations**\n" | |
| "- Small hairline fractures may be missed.\n" | |
| "- CT is gold standard for head trauma." | |
| ), | |
| "chest": ( | |
| "π« **Image Details (Chest)**\n\n" | |
| "View type: PA/AP/Lateral chest X-ray.\n\n" | |
| "Visible: lungs, ribs, diaphragm, clavicles, cardiac silhouette.\n\n" | |
| "π **AI / Computer Vision Context**\n" | |
| "- Region classifier: chest.\n" | |
| "- Tasks: Pneumonia, fibrosis, COPD, TB, lung cancer proxy.\n" | |
| "- Uses TorchXRayVision DenseNet + GradCAM overlays.\n\n" | |
| "π¦ **Limitations**\n" | |
| "- Cannot replace CT for detailed lung/mediastinal lesions.\n" | |
| "- Diagnosis requires clinical + lab correlation." | |
| ), | |
| "spine": ( | |
| "𦴠**Image Details (Spine)**\n\n" | |
| "View type: cervical, thoracic, or lumbar.\n\n" | |
| "Visible: vertebral bodies, intervertebral spaces, alignment.\n\n" | |
| "π **AI / Computer Vision Context**\n" | |
| "- Region classifier: spine.\n" | |
| "- Tasks: compression fracture, scoliosis.\n" | |
| "- AI approach: zero-shot CLIP + CNNs for fracture lines.\n\n" | |
| "π¦ **Limitations**\n" | |
| "- MRI needed for discs/cord.\n" | |
| "- Subtle wedge fractures may require CT." | |
| ), | |
| "pelvis": ( | |
| "𦴠**Image Details (Pelvis/Hip)**\n\n" | |
| "View type: AP pelvis/hip.\n\n" | |
| "Visible: iliac bones, acetabulum, femoral heads.\n\n" | |
| "π **AI / Computer Vision Context**\n" | |
| "- Region classifier: pelvis/hip.\n" | |
| "- Tasks: hip fracture, osteoarthritis.\n" | |
| "- AI approach: CLIP region classification + fracture CNN/transformer.\n\n" | |
| "π¦ **Limitations**\n" | |
| "- Subtle fractures may be missed β CT/MRI often needed." | |
| ), | |
| "upper_ext": ( | |
| "β **Image Details (Upper Extremity)**\n\n" | |
| "View type: arm, shoulder, elbow, wrist, hand.\n\n" | |
| "Visible: humerus, radius, ulna, carpals, metacarpals.\n\n" | |
| "π **AI / Computer Vision Context**\n" | |
| "- Region classifier: upper extremity.\n" | |
| "- Tasks: fractures, implants/hardware.\n" | |
| "- AI approach: CLIP + fracture CNNs (MURA dataset fine-tuning).\n\n" | |
| "π¦ **Limitations**\n" | |
| "- Soft tissues (ligaments/tendons) invisible.\n" | |
| "- US/MRI needed for rotator cuff, ligament injuries." | |
| ), | |
| "lower_ext": ( | |
| "π¦Ά **Image Details (Lower Extremity)**\n\n" | |
| "View type: hip, femur, knee, tibia, ankle, foot.\n\n" | |
| "Visible: phalanges, metatarsals, tarsals, tibia, fibula.\n\n" | |
| "π **AI / Computer Vision Context**\n" | |
| "- Region classifier: lower extremity.\n" | |
| "- Tasks: fractures, osteoarthritis, implants, deformity alignment.\n" | |
| "- AI approach: CLIP region classification β fracture CNN/transformer (MURA dataset).\n\n" | |
| "π¦ **Limitations**\n" | |
| "- Multiple views required to rule out fractures.\n" | |
| "- Soft tissue injuries invisible on X-ray." | |
| ), | |
| "unknown": ( | |
| "β **Region Unknown**\n\n" | |
| "The model could not confidently classify this X-ray.\n" | |
| "Please check image quality (exposure, blur, cropping)." | |
| ), | |
| } | |
| text = region_texts.get(region, region_texts["unknown"]) | |
| # Append AI task probabilities if available | |
| if tasks: | |
| text += "\n\n**Screening Results:**\n" | |
| for t in tasks: | |
| name = t.get("name") | |
| p = t.get("prob") | |
| prob_str = "N/A" if p is None else f"{p*100:.1f}%" | |
| text += f"- {name}: {prob_str}\n" | |
| return text | |
| # ============================================= | |
| # π§ AI Radiology Report Generator (fixed version) | |
| # ============================================= | |
| import os | |
| import google.generativeai as genai | |
| # β Load Gemini API key from Hugging Face secret | |
| genai.configure(api_key=os.getenv("GEMINI_API_KEY")) | |
| # β Initialize Gemini model (flash = faster / cheaper) | |
| try: | |
| # β New API style (v1) | |
| gemini_model = genai.GenerativeModel("gemini-2.5-flash") | |
| except Exception as e: | |
| print("β οΈ Falling back to Gemini v1beta format due to:", e) | |
| try: | |
| # β Old API style (v1beta) | |
| gemini_model = genai.GenerativeModel("models/gemini-1.5-flash-latest") | |
| except Exception: | |
| # β Fallback to PRO model if FLASH not available | |
| gemini_model = genai.GenerativeModel("gemini-1.5-pro") | |
| def generate_medical_report(findings: dict) -> str: | |
| """ | |
| Generate a professional radiology report from AI findings using Gemini, | |
| and format the output with Markdown and proper newlines. | |
| """ | |
| prompt = f""" | |
| You are a senior radiologist. Based on the following AI findings, | |
| generate a detailed, structured radiology report in clinical style. | |
| Findings: | |
| {json.dumps(findings, indent=2)} | |
| Follow this format: | |
| PATIENT INFORMATION | |
| TECHNICAL DETAILS | |
| RADIOLOGICAL FINDINGS | |
| IMPRESSION | |
| RECOMMENDATIONS | |
| FOLLOW-UP | |
| """ | |
| try: | |
| response = gemini_model.generate_content(prompt) | |
| # β Extract text from Gemini response safely | |
| if hasattr(response, "text"): | |
| report_text = response.text.strip() | |
| elif hasattr(response, "candidates") and response.candidates: | |
| report_text = response.candidates[0].content.parts[0].text.strip() | |
| else: | |
| return "AI medical report could not be generated." | |
| # β Format report into readable HTML (newlines + Markdown) | |
| formatted_report = format_ai_report(report_text) | |
| return formatted_report | |
| except Exception as e: | |
| import traceback | |
| print("Gemini AI report generation failed:", e) | |
| traceback.print_exc() | |
| return "AI medical report could not be generated." | |
| import re | |
| import markdown | |
| def format_ai_report(text: str) -> str: | |
| """ | |
| Formats Gemini AI report text into readable HTML with proper newlines and Markdown styling. | |
| """ | |
| if not text: | |
| return "<p><i>No report generated.</i></p>" | |
| text = re.sub(r'\n{2,}', '\n\n', text.strip()) # normalize spacing | |
| html_text = markdown.markdown(text) # convert Markdown to HTML | |
| html_text = html_text.replace('\n', '<br>') # keep single line breaks | |
| styled_html = f""" | |
| <div style="font-family:'Segoe UI',sans-serif;line-height:1.6;color:#e0e0e0;"> | |
| {html_text} | |
| </div> | |
| """ | |
| return styled_html | |
| # ---------- Saliency β boxes & overlay ---------- | |
| def _colorize_heatmap(gray_u8: np.ndarray, sal: np.ndarray) -> np.ndarray: | |
| sal = (sal - sal.min()) / (sal.max() - sal.min() + 1e-6) | |
| base = cv2.cvtColor(gray_u8, cv2.COLOR_GRAY2BGR) | |
| heat = cv2.applyColorMap((sal*255).astype(np.uint8), cv2.COLORMAP_JET) | |
| return cv2.addWeighted(base, 0.5, heat, 0.5, 0) | |
| def _topk_boxes_from_saliency(sal: np.ndarray, k: int = 3, | |
| min_area_ratio: float = 0.001, | |
| q: float = 0.92): | |
| """ | |
| Returns list of boxes [{'x':..,'y':..,'w':..,'h':..,'score':..}] from saliency (0..1). | |
| q is the quantile threshold; increase for tighter boxes. | |
| """ | |
| sal = (sal - sal.min()) / (sal.max() - sal.min() + 1e-6) | |
| H, W = sal.shape | |
| thr = float(np.quantile(sal, q)) | |
| mask = (sal >= thr).astype(np.uint8) * 255 | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3,3), np.uint8), iterations=1) | |
| cnts, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| areas = [(cv2.contourArea(c), c) for c in cnts] | |
| areas.sort(key=lambda x: x[0], reverse=True) | |
| min_area = H*W*min_area_ratio | |
| boxes = [] | |
| for area, c in areas: | |
| if area < min_area: | |
| continue | |
| x, y, w, h = cv2.boundingRect(c) | |
| roi = sal[y:y+h, x:x+w] | |
| score = float(roi.mean()) | |
| boxes.append({'x': int(x), 'y': int(y), 'w': int(w), 'h': int(h), 'score': score}) | |
| if len(boxes) >= k: | |
| break | |
| return boxes | |
| def _overlay_with_boxes(gray_u8: np.ndarray, sal: np.ndarray, label: str = "Suspect") -> Tuple[np.ndarray, list]: | |
| """Color heatmap + draw yellow rectangles with labels.""" | |
| out = _colorize_heatmap(gray_u8, sal) | |
| boxes = _topk_boxes_from_saliency(sal, k=3, min_area_ratio=0.001, q=0.92) | |
| for i, b in enumerate(boxes, start=1): | |
| x, y, w, h = b['x'], b['y'], b['w'], b['h'] | |
| cv2.rectangle(out, (x, y), (x+w, y+h), (0, 255, 255), 2) | |
| text = f"{label} #{i}" | |
| cv2.putText(out, text, (x, max(0, y-6)), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0,255,255), 1, cv2.LINE_AA) | |
| return out, boxes | |
| # ----------------------------- | |
| # PDF report (FPDF) β Unicode-safe + robust PNG embed on Windows | |
| # ----------------------------- | |
| def pdf_report_bytes(gray_u8: np.ndarray, res: Dict[str, Any], | |
| include_overlay: bool = True, | |
| findings_png: Optional[np.ndarray] = None) -> bytes: | |
| if not HAVE_PDF: | |
| raise RuntimeError("FPDF not installed. Run: pip install fpdf") | |
| def _safe(s): | |
| if s is None: return "" | |
| s = str(s) | |
| s = (s.replace("β", "-").replace("β", "-").replace("\u2011", "-") | |
| .replace("β’", "*").replace("Β±", "+/-") | |
| .replace("β", '"').replace("β", '"').replace("β", "'").replace("β¦", "...")) | |
| s = unicodedata.normalize("NFKD", s).encode("ascii", "ignore").decode("ascii") | |
| return s | |
| pdf = FPDF(format="A4", unit="mm") | |
| pdf.set_auto_page_break(auto=True, margin=12) | |
| title = "AI Radiology Report β Research Use Only" | |
| region = res.get("region", "unknown") | |
| conf = res.get("region_conf", 0.0) | |
| meta = res.get("meta_pred", {}) | |
| timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M") | |
| # ---------------- Page 1 ---------------- | |
| pdf.add_page() | |
| pdf.set_font("Helvetica", "B", 16) | |
| pdf.cell(0, 10, _safe(title), ln=1) | |
| pdf.set_font("Helvetica", "I", 10) | |
| pdf.cell(0, 6, _safe(f"Generated on: {timestamp}"), ln=1) | |
| pdf.cell(0, 6, _safe("Disclaimer: This is an AI-generated report, research use only, not for diagnostic purposes."), ln=1) | |
| pdf.ln(4) | |
| # Patient / Meta Info | |
| pdf.set_font("Helvetica", "B", 12) | |
| pdf.cell(0, 8, "Patient / Meta Information", ln=1) | |
| pdf.set_font("Helvetica", "", 11) | |
| pdf.cell(0, 7, _safe(f"Age (approx): {meta.get('age',{}).get('label','N/A')}"), ln=1) | |
| pdf.cell(0, 7, _safe(f"Sex: {meta.get('sex',{}).get('label','N/A')}"), ln=1) | |
| disp_region = REGION_DISPLAY.get(region, region) | |
| pdf.cell(0, 7, _safe(f"Region (auto-detected): {disp_region} | Confidence: {conf:.2f}"), ln=1) | |
| pdf.ln(3) | |
| # Clinical Indication | |
| pdf.set_font("Helvetica", "B", 12) | |
| pdf.cell(0, 8, "Clinical Indication", ln=1) | |
| pdf.set_font("Helvetica", "", 11) | |
| pdf.multi_cell(0, 6, _safe(f"AI-assisted screening of full-body X-ray. Region auto-detected as {disp_region}.")) | |
| pdf.ln(3) | |
| # Technique | |
| pdf.set_font("Helvetica", "B", 12) | |
| pdf.cell(0, 8, "Technique", ln=1) | |
| pdf.set_font("Helvetica", "", 11) | |
| pdf.multi_cell(0, 6, _safe("Single-view X-ray (projection estimated by model).")) | |
| pdf.ln(3) | |
| # Findings | |
| pdf.set_font("Helvetica", "B", 12) | |
| pdf.cell(0, 8, "Findings", ln=1) | |
| pdf.set_font("Helvetica", "", 11) | |
| details = generate_extra_details(res) | |
| if details: | |
| for line in details.split("\n"): | |
| pdf.multi_cell(0, 6, _safe(line)) | |
| else: | |
| pdf.multi_cell(0, 6, "No detailed findings available.") | |
| pdf.ln(3) | |
| # Disease Screening | |
| pdf.set_font("Helvetica", "B", 12) | |
| pdf.cell(0, 8, "Disease Screening", ln=1) | |
| pdf.set_font("Helvetica", "", 11) | |
| tasks = res.get("tasks", []) | |
| if not tasks: | |
| pdf.multi_cell(0, 6, "No tasks configured for this region or models unavailable.") | |
| else: | |
| for t in tasks: | |
| name = t.get("name", "?") | |
| p = t.get("prob", None) | |
| desc = t.get("desc", "") | |
| risk = risk_level(p) | |
| p_str = "N/A" if p is None else f"{p*100:.1f}%" | |
| pdf.multi_cell(0, 6, _safe(f"{name}: {p_str} | Risk: {risk}")) | |
| if desc: | |
| pdf.set_text_color(120,120,120) | |
| pdf.multi_cell(0, 6, _safe(f" - {desc}")) | |
| pdf.set_text_color(0,0,0) | |
| pdf.ln(1) | |
| pdf.ln(3) | |
| # Impression / Conclusion | |
| pdf.set_font("Helvetica", "B", 12) | |
| pdf.cell(0, 8, "Impression / Conclusion", ln=1) | |
| pdf.set_font("Helvetica", "", 11) | |
| high_risk = [t for t in tasks if risk_level(t.get("prob")) == "High"] | |
| if high_risk: | |
| pdf.multi_cell(0, 6, "Abnormal findings detected β recommend further evaluation / correlation with clinical context.") | |
| else: | |
| pdf.multi_cell(0, 6, "No high-risk abnormalities detected by AI. Findings consistent with low or moderate risk only.") | |
| pdf.ln(4) | |
| # ---------------- Page 2: Original Image ---------------- | |
| fd1, p1 = tempfile.mkstemp(suffix=".png"); os.close(fd1) | |
| try: | |
| Image.fromarray(gray_u8).convert("RGB").save(p1, format="PNG") | |
| pdf.add_page() | |
| pdf.set_font("Helvetica", "B", 12) | |
| pdf.cell(0, 8, "Original X-ray Image", ln=1) | |
| pdf.image(p1, x=10, y=None, w=180) | |
| finally: | |
| try: os.remove(p1) | |
| except Exception: pass | |
| # ---------------- Page 3: Overlay ---------------- | |
| if include_overlay: | |
| fd2, p2 = tempfile.mkstemp(suffix=".png"); os.close(fd2) | |
| try: | |
| if findings_png is not None: | |
| Image.fromarray(findings_png).convert("RGB").save(p2, format="PNG") | |
| title2 = "AI Heatmap with Suspect Boxes" | |
| else: | |
| over = edges_overlay(gray_u8) | |
| Image.fromarray(over).convert("RGB").save(p2, format="PNG") | |
| title2 = "Edges Overlay" | |
| pdf.add_page() | |
| pdf.set_font("Helvetica", "B", 12) | |
| pdf.cell(0, 8, _safe(title2), ln=1) | |
| pdf.image(p2, x=10, y=None, w=180) | |
| finally: | |
| try: os.remove(p2) | |
| except Exception: pass | |
| return pdf.output(dest="S").encode("latin-1") | |
| # ----------------------------- | |
| # Meta: Age / Sex / View (zero-shot via CLIP/BiomedCLIP) | |
| # ----------------------------- | |
| def _region_word(region: str) -> str: | |
| return { | |
| "chest": "Chest X-ray", | |
| "skull": "Skull X-ray", | |
| "spine": "Spine X-ray", | |
| "pelvis": "Pelvis/hip X-ray", | |
| "upper_ext": "Upper extremity X-ray", | |
| "lower_ext": "Lower extremity X-ray", | |
| }.get(region, "X-ray image") | |
| def predict_view(pil_rgb: Image.Image, region: str) -> dict: | |
| if region != "chest": | |
| return {"label": "N/A", "probs": {}} | |
| texts = ["PA chest X-ray", "AP chest X-ray", "Lateral chest X-ray"] | |
| p = clip_probs(pil_rgb, texts, temp=CLIP_TEMP) | |
| idx = int(np.argmax(p)) | |
| return {"label": ["PA","AP","Lateral"][idx], | |
| "probs": {"PA": float(p[0]), "AP": float(p[1]), "Lateral": float(p[2])}} | |
| def predict_sex(pil_rgb: Image.Image, region: str) -> dict: | |
| base = _region_word(region) | |
| texts = [f"{base} of a male patient", f"{base} of a female patient"] | |
| p = clip_probs(pil_rgb, texts, temp=CLIP_TEMP) | |
| idx = int(np.argmax(p)) | |
| return {"label": ["Male","Female"][idx], | |
| "probs": {"Male": float(p[0]), "Female": float(p[1])}} | |
| def predict_age(pil_rgb: Image.Image, region: str) -> dict: | |
| base = _region_word(region) | |
| labels = ["Child (<=14)", "Adolescent (15-19)", "Adult (20-64)", "Elderly (65+)"] | |
| prompts = [f"{base} of a child", f"{base} of an adolescent", f"{base} of an adult", f"{base} of an elderly person"] | |
| anchors = np.array([8.0, 17.0, 40.0, 75.0], dtype=np.float32) | |
| p = clip_probs(pil_rgb, prompts, temp=CLIP_TEMP) | |
| est = float(np.dot(p, anchors)) | |
| return {"label": f"{est:.1f} yrs", "years": est, "probs": {labels[i]: float(p[i]) for i in range(len(labels))}} | |
| def predict_meta(pil_rgb: Image.Image, region: str) -> dict: | |
| if CLIP_MDL is None or CLIP_PROC is None: | |
| return {"age": {"label": "N/A", "years": None, "probs": {}}, | |
| "sex": {"label": "N/A", "probs": {}}, | |
| "view": {"label": "N/A", "probs": {}}} | |
| return {"age": predict_age(pil_rgb, region), "sex": predict_sex(pil_rgb, region), "view": predict_view(pil_rgb, region)} | |
| # ----------------------------- | |
| # CLIP/BiomedCLIP zero-shot | |
| # ----------------------------- | |
| def load_clip_like(): | |
| if not HAVE_TRF: return None, None | |
| for rid in CLIP_CANDIDATES: | |
| try: | |
| proc = AutoProcessor.from_pretrained(rid, token=HF_TOKEN) | |
| mdl = CLIPModel.from_pretrained(rid, token=HF_TOKEN).eval().to(DEVICE) | |
| return mdl, proc | |
| except Exception: | |
| continue | |
| return None, None | |
| CLIP_MDL, CLIP_PROC = load_clip_like() | |
| def clip_probs(pil_rgb: Image.Image, texts: List[str], temp: float = 1.0) -> np.ndarray: | |
| if CLIP_MDL is None or CLIP_PROC is None: | |
| return np.full((len(texts),), 1.0/len(texts), dtype=np.float32) | |
| inputs = CLIP_PROC(text=texts, images=pil_rgb, return_tensors="pt", padding=True).to(DEVICE) | |
| with torch.no_grad(): | |
| out = CLIP_MDL(**inputs) | |
| logits = out.logits_per_image.squeeze(0) / max(1e-6, temp) | |
| probs = torch.softmax(logits, dim=0).detach().cpu().numpy() | |
| return probs.astype(np.float32) | |
| def clip_binary_prob(pil_rgb: Image.Image, pos_text: str, neg_text: str, temp: float = 1.0) -> float: | |
| probs = clip_probs(pil_rgb, [pos_text, neg_text], temp=temp) | |
| return float(probs[0]) | |
| # ----------------------------- | |
| # Chest ensemble (TorchXRayVision) | |
| # ----------------------------- | |
| def load_xrv_models(): | |
| if not HAVE_XRV: | |
| return None, None | |
| import torch | |
| import torchxrayvision as xrv | |
| # β Fix for PyTorch 2.6+ safe-loading policy | |
| # Allow DenseNet class from torchxrayvision to be unpickled safely | |
| try: | |
| torch.serialization.add_safe_globals([xrv.models.DenseNet]) | |
| except Exception as e: | |
| print("Safe global registration failed:", e) | |
| def _mk(weights_key): | |
| try: | |
| m = xrv.models.DenseNet(weights=weights_key) | |
| return m.eval().to(DEVICE) | |
| except Exception as e: | |
| print(f"β οΈ Error loading model {weights_key}:", e) | |
| return None | |
| # Load the CheX model (you can add more if needed) | |
| models = [_mk("densenet121-res224-chex")] | |
| # Combine all pathology labels | |
| union = {} | |
| for m in models: | |
| if m is None: | |
| continue | |
| for lbl in m.pathologies: | |
| k = lbl.lower() | |
| if k not in union: | |
| union[k] = lbl | |
| return models, union | |
| # Initialize global variables | |
| XRV_MODELS, LOWER_TO_PRETTY = load_xrv_models() | |
| def xrv_logits_dict(gray_u8: np.ndarray, mdl) -> dict: | |
| img01 = gray_u8.astype(np.float32) / 255.0 | |
| img_hu = img01 * 2048.0 - 1024.0 | |
| t = torch.from_numpy(img_hu).unsqueeze(0).unsqueeze(0).to(DEVICE) | |
| t = torch.nn.functional.interpolate(t, size=224, mode="bilinear", align_corners=False) | |
| with torch.no_grad(): | |
| logits = mdl(t).squeeze(0).detach().cpu().numpy() | |
| return {lbl.lower(): float(logits[i]) for i, lbl in enumerate(mdl.pathologies)} | |
| def chest_disease_probs(gray_u8: np.ndarray): | |
| if not HAVE_XRV or XRV_MODELS is None: | |
| return {}, [] | |
| imgs = [gray_u8, np.ascontiguousarray(np.flip(gray_u8, axis=1))] | |
| sum_logits, count = {}, {} | |
| for im in imgs: | |
| for mdl in XRV_MODELS: | |
| d = xrv_logits_dict(im, mdl) | |
| for lbl, v in d.items(): | |
| sum_logits[lbl] = sum_logits.get(lbl, 0.0) + v | |
| count[lbl] = count.get(lbl, 0) + 1 | |
| mean_logits = {k: (sum_logits[k]/max(1,count[k]))/max(1e-6, CHEST_TEMP) for k in sum_logits} | |
| label_probs = {k: 1.0/(1.0+np.exp(-v)) for k,v in mean_logits.items()} | |
| requested = {} | |
| for disp, aliases in CHEST_TARGETS.items(): | |
| if not aliases: | |
| requested[disp] = None | |
| continue | |
| p=None | |
| for a in aliases: | |
| p = label_probs.get(a.lower()) | |
| if p is not None: | |
| break | |
| requested[disp] = p | |
| ranked = sorted(label_probs.items(), key=lambda x: x[1], reverse=True)[:10] | |
| ranked_pretty = [(LOWER_TO_PRETTY.get(lbl, lbl.title()), p) for lbl,p in ranked] | |
| return requested, ranked_pretty | |
| # Optional HF binary plugin (e.g., TB) | |
| def load_hf_binary(repo_id: str, token: Optional[str] = None): | |
| if not (HAVE_TRF and repo_id): return None, None, None | |
| proc = AutoImageProcessor.from_pretrained(repo_id, token=token) | |
| mdl = AutoModelForImageClassification.from_pretrained(repo_id, token=token).eval().to(DEVICE) | |
| id2 = getattr(mdl.config, "id2label", None) | |
| return mdl, proc, id2 | |
| def run_hf_binary(mdl, proc, pil_rgb: Image.Image, target_label_name: str) -> Optional[float]: | |
| if mdl is None or proc is None: return None | |
| inputs = proc(images=pil_rgb, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| out = mdl(**inputs) | |
| logits = out.logits.squeeze(0) | |
| if logits.numel() == 1: | |
| return float(torch.sigmoid(logits).item()) | |
| id2 = getattr(mdl.config, "id2label", None) or {} | |
| name2idx = {v.lower(): int(k) for k,v in id2.items()} | |
| idx = name2idx.get(target_label_name.lower()) | |
| if idx is None or idx >= logits.numel(): return None | |
| return float(torch.sigmoid(logits[idx]).item()) | |
| # ----------------------------- | |
| # Saliency for chest (Grad-CAM) and non-chest (CLIP occlusion) | |
| # ----------------------------- | |
| def gradcam_single_xrv(gray_u8: np.ndarray, mdl, class_idx: int, out_size: Tuple[int,int]) -> Optional[np.ndarray]: | |
| """ | |
| Grad-CAM for a single TorchXRayVision DenseNet model & class index. | |
| Returns saliency map in 0..1 resized to out_size. | |
| """ | |
| try: | |
| mdl.eval() | |
| feats = {} | |
| grads = {} | |
| def fwd_hook(module, inp, out): | |
| feats["x"] = out.detach() | |
| def bwd_hook(module, grad_in, grad_out): | |
| grads["x"] = grad_out[0].detach() | |
| h1 = mdl.features.register_forward_hook(fwd_hook) | |
| h2 = mdl.features.register_full_backward_hook(bwd_hook) | |
| img01 = gray_u8.astype(np.float32) / 255.0 | |
| img_hu = img01 * 2048.0 - 1024.0 | |
| t = torch.from_numpy(img_hu).unsqueeze(0).unsqueeze(0).to(DEVICE) | |
| t = torch.nn.functional.interpolate(t, size=224, mode="bilinear", align_corners=False) | |
| t.requires_grad_(True) | |
| mdl.zero_grad(set_to_none=True) | |
| out = mdl(t) # logits | |
| if class_idx >= out.shape[1]: | |
| h1.remove(); h2.remove(); return None | |
| target = out[0, class_idx] | |
| target.backward(retain_graph=True) | |
| fmap = feats["x"] # (1, C, H, W) | |
| grad = grads["x"] # (1, C, H, W) | |
| weights = grad.mean(dim=(2,3), keepdim=True) # (1, C, 1, 1) | |
| cam = torch.relu((weights * fmap).sum(dim=1)).squeeze(0) # (H, W) | |
| cam = cam.detach().cpu().numpy() | |
| cam = cv2.resize(cam, out_size[::-1], interpolation=cv2.INTER_LINEAR) | |
| cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-6) | |
| h1.remove(); h2.remove() | |
| return cam | |
| except Exception: | |
| try: | |
| h1.remove(); h2.remove() | |
| except Exception: | |
| pass | |
| return None | |
| def chest_findings_heatmap(gray_u8: np.ndarray, res: Dict[str,Any]) -> Optional[Tuple[np.ndarray, str]]: | |
| """Return (saliency_map_0..1, target_label_name) for chest via ensemble Grad-CAM.""" | |
| if not HAVE_XRV or XRV_MODELS is None: | |
| return None | |
| chest = res.get("chest", {}) | |
| candidates = [(k, v) for k, v in chest.items() if v is not None] | |
| if not candidates: | |
| return None | |
| target_name, _ = max(candidates, key=lambda kv: kv[1]) | |
| lbl_lowers = [a.lower() for a in CHEST_TARGETS.get(target_name, [])] | |
| if not lbl_lowers: | |
| return None | |
| h, w = gray_u8.shape | |
| cams = [] | |
| for mdl in XRV_MODELS: | |
| idx = None | |
| for a in lbl_lowers: | |
| try: | |
| idx = [s.lower() for s in mdl.pathologies].index(a); break | |
| except ValueError: | |
| continue | |
| if idx is None: | |
| continue | |
| cam = gradcam_single_xrv(gray_u8, mdl, idx, (h, w)) | |
| if cam is not None: | |
| cams.append(cam) | |
| if not cams: | |
| return None | |
| sal = np.mean(cams, axis=0) | |
| sal = (sal - sal.min()) / (sal.max() - sal.min() + 1e-6) | |
| return sal, target_name | |
| def best_nonchest_task_and_prompts(res: Dict[str,Any]) -> Optional[Tuple[str, Tuple[str,str]]]: | |
| region = res.get("region", "unknown") | |
| if region not in TASKS: | |
| return None | |
| # map name -> (pos, neg) | |
| name2prompts = {} | |
| for t in TASKS[region]: | |
| if "zs" in t: | |
| name2prompts[t["name"]] = t["zs"] | |
| # pick highest-prob configured task | |
| best_name = None | |
| best_p = -1.0 | |
| for t in res.get("tasks", []): | |
| if t.get("region") != region: | |
| continue | |
| nm = t.get("name") | |
| if nm in name2prompts and t.get("prob") is not None: | |
| if float(t["prob"]) > best_p: | |
| best_p = float(t["prob"]); best_name = nm | |
| if best_name is None: | |
| # fallback to first available task config | |
| if TASKS[region]: | |
| cand = TASKS[region][0] | |
| if "zs" in cand: | |
| return cand["name"], cand["zs"] | |
| return None | |
| return best_name, name2prompts[best_name] | |
| def clip_occlusion_saliency(pil_rgb: Image.Image, pos_text: str, neg_text: str, n_segments: int = 140) -> Optional[np.ndarray]: | |
| """Return saliency map (0..1) via SLIC-occlusion for CLIP zero-shot.""" | |
| if not (HAVE_SKIMG and CLIP_MDL is not None): | |
| return None | |
| rgb = np.array(pil_rgb.convert("RGB")) | |
| base_p = clip_binary_prob(pil_rgb, pos_text, neg_text, temp=CLIP_TEMP) | |
| h, w, _ = rgb.shape | |
| seg = slic(img_as_float(rgb), n_segments=n_segments, compactness=10, start_label=0) | |
| blurred = cv2.GaussianBlur(rgb, (21,21), 0) | |
| contrib = np.zeros((h, w), dtype=np.float32) | |
| for lab in np.unique(seg): | |
| mask = (seg == lab) | |
| occluded = rgb.copy() | |
| occluded[mask] = blurred[mask] | |
| p = clip_binary_prob(Image.fromarray(occluded), pos_text, neg_text, temp=CLIP_TEMP) | |
| delta = max(0.0, base_p - p) # drop in prob -> importance | |
| contrib[mask] = delta | |
| sal = (contrib - contrib.min()) / (contrib.max() - contrib.min() + 1e-6) | |
| return sal | |
| def findings_overlay(gray_u8: np.ndarray, res: Dict[str,Any]): | |
| """ | |
| Returns (overlay_bgr, meta_dict) | |
| meta_dict = {"label": <task name>, "boxes": [ {x,y,w,h,score}, ... ]} | |
| """ | |
| region = res.get("region", "unknown") | |
| pil = Image.fromarray(gray_u8).convert("RGB") | |
| if region == "chest": | |
| out = chest_findings_heatmap(gray_u8, res) | |
| if out is None: | |
| return None, None | |
| sal, label = out | |
| overlay, boxes = _overlay_with_boxes(gray_u8, sal, label=label) | |
| return overlay, {"label": label, "boxes": boxes} | |
| else: | |
| pair = best_nonchest_task_and_prompts(res) | |
| if pair is None: | |
| return None, None | |
| name, (pos, neg) = pair | |
| sal = clip_occlusion_saliency(pil, pos, neg) | |
| if sal is None: | |
| return None, None | |
| overlay, boxes = _overlay_with_boxes(gray_u8, sal, label=name) | |
| return overlay, {"label": name, "boxes": boxes} | |
| # ----------------------------- | |
| # Gating helpers | |
| # ----------------------------- | |
| def abstain_flag(p: Optional[float], margin: float = ABSTAIN_MARGIN) -> bool: | |
| if p is None: return True | |
| return abs(p - 0.5) < margin | |
| def triage_flag(p: Optional[float], threshold: float) -> bool: | |
| if p is None: return False | |
| return p >= threshold | |
| # ---- ADDED: Quality metrics & gates, calibration, audit log ---- | |
| def quality_metrics(gray_u8): | |
| H, W = gray_u8.shape[:2] | |
| try: | |
| blur_var = float(cv2.Laplacian(gray_u8, cv2.CV_64F).var()) | |
| except Exception: | |
| blur_var = float('nan') | |
| under = float((gray_u8 < 10).mean()) | |
| over = float((gray_u8 > 245).mean()) | |
| return {"min_side": min(H, W), "blur_var": blur_var, "under_pct": under, "over_pct": over} | |
| def gate_decision(metrics, region_conf, | |
| min_side=256, blur_var_min=80.0, | |
| under_max=0.25, over_max=0.25, | |
| region_conf_min=REGION_CONF_MIN): | |
| ok, reasons = True, [] | |
| if metrics["min_side"] < min_side: ok=False; reasons.append("image too small") | |
| if not np.isnan(metrics["blur_var"]) and metrics["blur_var"] < blur_var_min: ok=False; reasons.append("too blurry") | |
| if metrics["under_pct"] > under_max: ok=False; reasons.append("under-exposed") | |
| if metrics["over_pct"] > over_max: ok=False; reasons.append("over-exposed") | |
| if region_conf < region_conf_min: ok=False; reasons.append("low region confidence (OOD)") | |
| return ok, reasons | |
| def calibrate_prob(p, tau=CAL_TAU): | |
| if p is None: return None | |
| p = float(np.clip(p, 1e-6, 1-1e-6)) | |
| logit = np.log(p/(1-p)) | |
| return 1.0/(1.0 + np.exp(-logit/max(1e-6, tau))) | |
| LOG_PATH = pathlib.Path("logs"); LOG_PATH.mkdir(exist_ok=True) | |
| def audit_log(event, payload, raw_bytes=None): | |
| rec = { | |
| "ts": datetime.datetime.utcnow().isoformat()+"Z", | |
| "event": event, | |
| "payload": payload, | |
| "input_sha256": hashlib.sha256(raw_bytes).hexdigest() if raw_bytes else None | |
| } | |
| with open(LOG_PATH / f"audit-{datetime.date.today().isoformat()}.jsonl", "a", encoding="utf-8") as f: | |
| f.write(json.dumps(rec) + "\n") | |
| # ----------------------------- | |
| # Shared inference | |
| # ----------------------------- | |
| def analyze_fullbody(gray_u8: np.ndarray) -> Dict[str, Any]: | |
| result: Dict[str, Any] = { | |
| "meta": {"version": 1, "device": DEVICE}, | |
| "region": None, "region_conf": None, "region_source": None, | |
| "tasks": [], "chest": {} | |
| } | |
| pil_rgb = Image.fromarray(gray_u8).convert("RGB") | |
| # Region router (CLIP) | |
| texts = [REGION_PROMPTS[r][0] for r in REGION_NAMES] | |
| probs = clip_probs(pil_rgb, texts, temp=CLIP_TEMP) | |
| idx = int(np.argmax(probs)) | |
| region = REGION_NAMES[idx]; conf = float(probs[idx]) | |
| result.update({"region": region, "region_conf": conf, "region_source": "zero-shot"}) | |
| # Meta (Age/Sex/View) using region-aware prompts | |
| result["meta_pred"] = predict_meta(pil_rgb, region) | |
| # Low-confidence region β unknown | |
| if conf < REGION_CONF_MIN: | |
| result["region"] = "unknown" | |
| # Region-specific screening | |
| if region == "chest" and HAVE_XRV and XRV_MODELS is not None: | |
| chest_probs, _ranked = chest_disease_probs(gray_u8) | |
| if TB_REPO and HAVE_TRF: | |
| tb_m, tb_p, _ = load_hf_binary(TB_REPO, HF_TOKEN) | |
| tb_prob = run_hf_binary(tb_m, tb_p, pil_rgb, TB_LABEL) | |
| if tb_prob is not None: | |
| chest_probs["Tuberculosis"] = tb_prob | |
| result["chest"] = chest_probs | |
| for name, p in chest_probs.items(): | |
| thr = DEFAULT_THRESHOLDS.get(name, 0.5) | |
| result["tasks"].append({ | |
| "region": "chest", "name": name, "prob": p, | |
| "threshold": thr, "abstain": abstain_flag(p), | |
| "triage": triage_flag(p, thr), | |
| "method": "xrv-ensemble" if name != TB_LABEL else "hf-binary", | |
| "desc": CHEST_DESCRIPTIONS.get(name, ""), | |
| }) | |
| else: | |
| for t in TASKS.get(region, []): | |
| name = t["name"]; thr = DEFAULT_THRESHOLDS.get(name, 0.6) | |
| prob = None; method = None | |
| if "hf_repo" in t and t.get("hf_repo") and HAVE_TRF: | |
| mdl, proc, _ = load_hf_binary(t["hf_repo"], HF_TOKEN) | |
| label = t.get("label", "positive") | |
| prob = run_hf_binary(mdl, proc, pil_rgb, label) | |
| method = f"hf:{t['hf_repo']}" | |
| elif "zs" in t: | |
| pos, neg = t["zs"] | |
| prob = clip_binary_prob(pil_rgb, pos, neg, temp=CLIP_TEMP) | |
| method = "zero-shot-clip" | |
| result["tasks"].append({ | |
| "region": region, "name": name, "prob": prob, | |
| "threshold": thr, "abstain": abstain_flag(prob), | |
| "triage": triage_flag(prob, thr), "method": method, | |
| "desc": TASK_DESCRIPTIONS.get(name, ""), | |
| }) | |
| return result | |