# app_fullbody_pretrained.py β€” Full-Body X-ray Analysis (Pretrained) + PDF + Suspect Boxes # ----------------------------------------------------------------------------- # Run: python -m uvicorn app_fullbody_pretrained:api --host 0.0.0.0 --port 7860 # ----------------------------------------------------------------------------- import os os.system("pip uninstall -y google-generativeai google-api-core googleapis-common-protos grpcio || true") os.system("pip install --upgrade google-generativeai==0.8.5 google-api-core protobuf grpcio --quiet") import os, io, json, tempfile, hashlib, datetime, pathlib from io import BytesIO from typing import Optional, Dict, List, Any, Tuple import numpy as np from PIL import Image import cv2 import unicodedata import torch import torch.nn.functional as F import base64 # FastAPI imports from fastapi import FastAPI, UploadFile, File from fastapi.responses import JSONResponse, Response from fastapi.staticfiles import StaticFiles from transformers import DPTFeatureExtractor, DPTForDepthEstimation import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D # noqa: F401 import os #import openai from fastapi import FastAPI, UploadFile, File from pydantic import BaseModel # Hugging Face Space base URL HF_SPACE_URL = os.getenv("HF_SPACE_URL", "https://abbhy123ghh-x-ray-analysis-2.hf.space").rstrip("/") # βœ… Create app api = FastAPI() os.makedirs("static", exist_ok=True) api.mount("/static", StaticFiles(directory="static"), name="static") from transformers import DPTFeatureExtractor, DPTForDepthEstimation # Optional dependencies try: import pydicom; HAVE_DICOM=True except Exception: HAVE_DICOM=False try: import torchxrayvision as xrv; HAVE_XRV=True except Exception: HAVE_XRV=False try: from transformers import (AutoProcessor, CLIPModel, AutoImageProcessor, AutoModelForImageClassification) HAVE_TRF=True except Exception: HAVE_TRF=False try: from skimage.segmentation import slic; from skimage.util import img_as_float; HAVE_SKIMG=True except Exception: HAVE_SKIMG=False try: from fpdf import FPDF; HAVE_PDF=True except Exception: HAVE_PDF=False import uvicorn DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ----------------------------- # KEEP all your existing utility functions + configs # ----------------------------- # (to_uint8, read_any_xray, chest_disease_probs, analyze_fullbody, pdf_report_bytes, etc.) # I did not touch any model logic. # ----------------------------- # Root # ----------------------------- @api.get("/") async def root(): return {"ok": True, "message": "X-ray API is running"} def to_base64(img: np.ndarray) -> str: pil = Image.fromarray(img).convert("RGB") buf = io.BytesIO() pil.save(buf, format="PNG") return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode() # ======================================================================== # βœ… FastAPI endpoints β€” updated for multi-region analysis # ======================================================================== @api.post("/analyze") async def analyze(file: UploadFile = File(...)): """Analyze an uploaded X-ray and return meta + annotated overlays (JSON).""" try: suffix = (file.filename or "").lower() raw = await file.read() # Read DICOM or normal image if suffix.endswith(".dcm") and HAVE_DICOM: ds = pydicom.dcmread(io.BytesIO(raw)) arr = ds.pixel_array.astype(np.float32) arr -= arr.min(); arr /= (arr.max() - arr.min() + 1e-6) gray_u8 = (arr * 255.0).clip(0, 255).astype(np.uint8) else: gray_u8 = np.array(Image.open(io.BytesIO(raw)).convert("L")) gray_u8 = to_uint8(gray_u8) # Run full-body multi-region analysis res = analyze_fullbody(gray_u8) # πŸ”‘ Flatten meta_pred so frontend can show Age/View/Gender/CTR meta_pred = res.get("meta_pred", {}) res["age"] = meta_pred.get("age", {}).get("label", "N/A") res["view"] = meta_pred.get("view", {}).get("label", "N/A") res["gender"] = meta_pred.get("sex", {}).get("label", "N/A") res["ctr"] = "--" # placeholder # Convert original to base64 res["original_img"] = to_base64(gray_u8) # πŸ”‘ Overlay with boxes + labels # --- Try segmentation overlay --- overlay_img, overlay_meta = findings_overlay(gray_u8, res) # --- βœ… Fallback if overlay is None (always show something) if overlay_img is None: print("⚠️ Overlay not generated β€” using edge-based overlay fallback.") overlay_img = edges_overlay(gray_u8) overlay_meta = {"label": "Edge-based fallback", "boxes": []} res["overlay_img"] = to_base64(overlay_img) res["overlay_meta"] = overlay_meta # βœ… Generate structured AI radiology report doctor_report_html = generate_medical_report(res) return JSONResponse({ "ok": True, "meta": { "age": res["age"], "view": res["view"], "gender": res["gender"], "ctr": res["ctr"] }, "original_img": res["original_img"], "overlay_img": res["overlay_img"], "overlay_meta": res["overlay_meta"], "tasks": res.get("tasks", []), "region": res.get("region", "unknown"), "region_conf": res.get("region_conf", 0.0), "meta_pred": res.get("meta_pred", {}), "doctor_report_html": doctor_report_html # 🩻 formatted HTML report }) except Exception as e: print("❌ Analyze failed:", e) return JSONResponse({"error": str(e)}, status_code=400) from fastapi import UploadFile, File from fastapi.responses import Response, JSONResponse from reportlab.lib.pagesizes import A4 from reportlab.lib import colors from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, Image as RLImage import io, datetime, json from PIL import Image import numpy as np import pydicom # =========================================================== # 🩻 Unified PDF Generator (Drlogy-style Layout for All Regions) # =========================================================== def generate_clinical_pdf(gray_img, overlay_img, res, doctor_report): buffer = io.BytesIO() doc = SimpleDocTemplate( buffer, pagesize=A4, rightMargin=36, leftMargin=36, topMargin=60, bottomMargin=36 ) styles = getSampleStyleSheet() normal = ParagraphStyle('Normal', parent=styles['Normal'], fontSize=10, leading=13) bold = ParagraphStyle('Bold', parent=normal, fontName='Helvetica-Bold') header = ParagraphStyle('Header', parent=normal, alignment=1, fontName='Helvetica-Bold', fontSize=14) footer = ParagraphStyle('Footer', parent=normal, fontSize=8, textColor=colors.gray) story = [] # --- HEADER --- story.append(Paragraph("DRLOGY IMAGING CENTER", header)) story.append(Paragraph("X-Ray | CT-Scan | MRI | USG", bold)) story.append(Paragraph("105-108, SMART VISION COMPLEX, HEALTHCARE ROAD, MUMBAI - 689578", normal)) story.append(Paragraph("πŸ“ž 0123456789 | βœ‰οΈ drlogyimaging@drlogy.com", normal)) story.append(Spacer(1, 10)) # --- DYNAMIC TITLE --- region_title = res.get("region", "General Region").replace("_", " ").title() story.append(Paragraph(f"AI RADIOLOGY REPORT – {region_title.upper()}", header)) story.append(Spacer(1, 10)) # --- PATIENT INFO --- patient_table = [ ["Name:", res.get("patient_name", "N/A"), "Age:", f"{res.get('age', 'N/A')} yrs"], ["Sex:", res.get("gender", "N/A"), "Study Date:", datetime.datetime.now().strftime("%d %b %Y, %I:%M %p")], ["Region:", region_title, "AI Confidence:", f"{res.get('region_conf', 0):.2f}"], ] table = Table(patient_table, colWidths=[70, 160, 90, 150]) table.setStyle(TableStyle([ ("BOX", (0, 0), (-1, -1), 0.5, colors.gray), ("INNERGRID", (0, 0), (-1, -1), 0.25, colors.gray), ("FONTNAME", (0, 0), (-1, -1), "Helvetica"), ("FONTSIZE", (0, 0), (-1, -1), 9), ("BACKGROUND", (0, 0), (-1, 0), colors.whitesmoke) ])) story.append(table) story.append(Spacer(1, 12)) # --- TECHNICAL DETAILS --- story.append(Paragraph("TECHNICAL DETAILS", bold)) story.append(Paragraph( "AI-assisted analysis performed using a multi-region deep learning model. " "Single-view radiograph analyzed for anatomical abnormalities.", normal)) story.append(Spacer(1, 10)) # --- FINDINGS --- story.append(Paragraph("AI FINDINGS", bold)) story.append(Paragraph( "Below are the observed findings as per AI-based analysis of the selected body part.", normal)) story.append(Spacer(1, 5)) story.append(Paragraph(doctor_report.replace("\n", "
"), normal)) story.append(Spacer(1, 10)) # --- IMPRESSION --- story.append(Paragraph("AI IMPRESSION", bold)) story.append(Paragraph( "Based on the AI model’s interpretation, no acute abnormality was confidently detected unless otherwise stated above.", normal)) story.append(Spacer(1, 10)) # --- RECOMMENDATIONS --- story.append(Paragraph("RECOMMENDATIONS", bold)) story.append(Paragraph( "Clinical correlation is advised. If clinical suspicion persists, further radiographic or advanced imaging is recommended.", normal)) story.append(Spacer(1, 10)) # --- IMAGES --- story.append(Paragraph("IMAGES:", bold)) img_buf = io.BytesIO() Image.fromarray(gray_img).save(img_buf, format="PNG") img_buf.seek(0) story.append(RLImage(img_buf, width=200, height=200)) if overlay_img is not None: overlay_buf = io.BytesIO() Image.fromarray(overlay_img).save(overlay_buf, format="PNG") overlay_buf.seek(0) story.append(RLImage(overlay_buf, width=200, height=200)) story.append(Spacer(1, 20)) # --- FOOTER --- story.append(Paragraph("**** End of Report ****", footer)) story.append(Spacer(1, 6)) story.append(Paragraph("Dr. AI Radiologist (MD)", footer)) story.append(Paragraph(f"Generated on: {datetime.datetime.now():%d %b %Y, %I:%M %p}", footer)) story.append(Paragraph("This AI report is supportive, not a substitute for radiologist diagnosis.", footer)) doc.build(story) buffer.seek(0) return buffer.getvalue() # =========================================================== # βœ… Main Endpoint β€” Analyze & Generate PDF # =========================================================== from fastapi import UploadFile, File from fastapi.responses import Response, JSONResponse from reportlab.lib.pagesizes import A4 from reportlab.lib import colors from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle from reportlab.platypus import ( SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, Image as RLImage ) from reportlab.lib.units import mm import io, datetime, json, re from PIL import Image import numpy as np import pydicom @api.post("/analyze_pdf") async def analyze_pdf(file: UploadFile = File(...)): """Generate a professional AI radiology report PDF.""" try: suffix = (file.filename or "").lower() raw = await file.read() # --- Read DICOM / Image --- if suffix.endswith(".dcm") and HAVE_DICOM: ds = pydicom.dcmread(io.BytesIO(raw)) arr = ds.pixel_array.astype(np.float32) arr -= arr.min(); arr /= (arr.max()-arr.min()+1e-6) gray_u8 = (arr*255).clip(0,255).astype(np.uint8) else: gray_u8 = np.array(Image.open(io.BytesIO(raw)).convert("L")) gray_u8 = to_uint8(gray_u8) # --- Model Analysis --- res = analyze_fullbody(gray_u8) overlay_img, _ = findings_overlay(gray_u8, res) if overlay_img is None: overlay_img = edges_overlay(gray_u8) # --- Gemini AI Report --- try: prompt = f""" You are a senior radiologist. Generate a structured radiology report. Data: {json.dumps(res, indent=2)} Include clear prose for: Technical Details, Findings, Impression, Recommendations. Avoid markdown, bullets, or extra headings. """ gemini_response = gemini_model.generate_content(prompt) doctor_report = getattr(gemini_response, "text", "AI report unavailable.").strip() except Exception: doctor_report = "AI medical report could not be generated." clean_report = re.sub(r"[#*_]+", "", doctor_report) clean_report = re.sub(r"\n{2,}", "\n", clean_report.strip()) # --- PDF Template --- buffer = io.BytesIO() doc = SimpleDocTemplate(buffer, pagesize=A4, leftMargin=35, rightMargin=35, topMargin=35, bottomMargin=40) styles = getSampleStyleSheet() normal = ParagraphStyle("NormalCustom", parent=styles["Normal"], fontSize=10, leading=14) heading = ParagraphStyle("Heading", parent=styles["Heading2"], fontSize=13, textColor=colors.HexColor("#004C99")) center_heading = ParagraphStyle("CenterHeading", parent=styles["Heading2"], fontSize=14, textColor=colors.white, alignment=1, spaceAfter=10) footer = ParagraphStyle("Footer", parent=styles["Normal"], fontSize=9, textColor=colors.gray, alignment=1) story = [] # === HEADER BAR === story.append(Table( [[Paragraph("DRLOGY IMAGING CENTER", center_heading)]], colWidths=[500], style=[ ('BACKGROUND', (0,0), (-1,-1), colors.HexColor("#004C99")), ('BOTTOMPADDING', (0,0), (-1,-1), 6), ('TOPPADDING', (0,0), (-1,-1), 6) ] )) story.append(Paragraph("X-Ray | CT-Scan | MRI | USG", normal)) story.append(Paragraph("105-108, SMART VISION COMPLEX, HEALTHCARE ROAD, MUMBAI - 689578", normal)) story.append(Paragraph("πŸ“ž 0123456789 βœ‰ drlogyimaging@drlogy.com", normal)) story.append(Spacer(1, 10)) story.append(Table([[ '', '' ]], colWidths=[480], style=[('LINEBELOW',(0,0),(-1,-1),1,colors.HexColor("#004C99"))])) story.append(Spacer(1, 8)) # === TITLE SECTION === title_table = Table( [[Paragraph("AI RADIOLOGY REPORT", heading), Paragraph(datetime.datetime.now().strftime("%d %b %Y, %I:%M %p"), normal)]], colWidths=[340, 140], style=[ ("VALIGN", (0,0), (-1,-1), "MIDDLE"), ("ALIGN", (1,0), (1,0), "RIGHT") ] ) story.append(title_table) story.append(Spacer(1, 12)) # === PATIENT INFO === story.append(Paragraph("PATIENT INFORMATION", heading)) patient_info = [ ["Name", "N/A"], ["Age", f"{res.get('meta_pred', {}).get('age', {}).get('label', 'N/A')}"], ["Sex", f"{res.get('meta_pred', {}).get('sex', {}).get('label', 'N/A')}"], ["Region", res.get("region", "Unknown").title()], ["AI Confidence", f"{res.get('region_conf', 0):.2f}"], ] table = Table(patient_info, colWidths=[120, 360]) table.setStyle(TableStyle([ ("BOX", (0,0), (-1,-1), 0.5, colors.grey), ("INNERGRID", (0,0), (-1,-1), 0.25, colors.lightgrey), ("BACKGROUND", (0,0), (0,-1), colors.HexColor("#EAF1FB")), ])) story.append(table) story.append(Spacer(1, 10)) # === MAIN BODY === story.append(Paragraph("CLINICAL REPORT SUMMARY", heading)) story.append(Spacer(1, 4)) story.append(Paragraph(clean_report.replace("\n", "
"), normal)) story.append(Spacer(1, 15)) # === IMAGES === story.append(Paragraph("REFERENCE IMAGES", heading)) story.append(Spacer(1, 4)) img_buf = io.BytesIO() Image.fromarray(gray_u8).save(img_buf, format="PNG") img_buf.seek(0) story.append(RLImage(img_buf, width=180, height=180)) story.append(Spacer(1, 6)) overlay_buf = io.BytesIO() Image.fromarray(overlay_img).save(overlay_buf, format="PNG") overlay_buf.seek(0) story.append(RLImage(overlay_buf, width=180, height=180)) story.append(Spacer(1, 20)) # === FOOTER === story.append(Table([[ '', '' ]], colWidths=[480], style=[('LINEBELOW',(0,0),(-1,-1),1,colors.HexColor("#004C99"))])) story.append(Spacer(1, 6)) story.append(Paragraph("*** End of Report ***", footer)) story.append(Spacer(1, 3)) story.append(Paragraph("Generated on: " + datetime.datetime.now().strftime("%d %b %Y, %I:%M %p"), footer)) story.append(Paragraph("This report is AI-generated and intended for research/clinical support only.", footer)) story.append(Paragraph("Β© 2025 Drlogy Imaging Center | 24x7 Services", footer)) doc.build(story) buffer.seek(0) return Response( content=buffer.read(), media_type="application/pdf", headers={"Content-Disposition": "attachment; filename=ai_xray_report.pdf"} ) except Exception as e: import traceback; traceback.print_exc() return JSONResponse({"error": str(e)}, status_code=400) # =============================================== # # βœ… Runway Gen-2 β€” Convert X-ray to 3D Video # # =============================================== # import requests # import io # import base64 # import time # from fastapi import FastAPI, UploadFile, File # from fastapi.responses import JSONResponse # from PIL import Image # app = FastAPI() # PIXVERSE_API_KEY = "sk-4720d9d2ae7bd362bcece9906ffc6848" # @api.post("/xray_to_3dvideo") # async def xray_to_3dvideo(file: UploadFile = File(...)): # try: # raw = await file.read() # img = Image.open(io.BytesIO(raw)).convert("RGB") # buf = io.BytesIO() # img.save(buf, format="PNG") # base64_img = base64.b64encode(buf.getvalue()).decode() # # βœ… FIXED PAYLOAD & HEADER # payload = { # "image": base64_img, # was 'image_base64' # "prompt": "A grayscale rotating 3D visualization of a medical X-ray scan, showing bones with realistic depth and cinematic lighting.", # "duration": 4, # "motion": "rotate", # was 'motion_mode' # "quality": "540p" # } # headers = { # "Authorization": f"Bearer {PIXVERSE_API_KEY}", # was 'x-api-key' # "Content-Type": "application/json" # } # # 1️⃣ Start PixVerse job # start_response = requests.post( # "https://api.segmind.com/v1/pixverse-image2video", # json=payload, # headers=headers # ) # start_data = start_response.json() # print("PixVerse start_data:", start_data) # # 2️⃣ Extract task ID or instant video URL # task_id = start_data.get("task_id") or start_data.get("id") # video_url = ( # start_data.get("video_url") # or start_data.get("video") # or start_data.get("output", {}).get("video_url") # ) # if not task_id and not video_url: # return JSONResponse({ # "error": start_data, # "message": "❌ PixVerse did not return a task ID or video URL (check API key or payload)" # }, status_code=400) # # If instant video available # if video_url: # return JSONResponse({ # "ok": True, # "video_url": video_url, # "message": "βœ… 3D X-ray video generated instantly via PixVerse" # }) # # 3️⃣ Poll task status # poll_url = f"https://api.segmind.com/v1/tasks/{task_id}" # for attempt in range(30): # 2.5 minutes max # time.sleep(5) # status_response = requests.get(poll_url, headers=headers) # status_data = status_response.json() # status = status_data.get("status", "").lower() # if status == "succeeded": # output = status_data.get("output", {}) # video_url = ( # output.get("video_url") # or output.get("video") # or output.get("assets", {}).get("video") # ) # if video_url: # return JSONResponse({ # "ok": True, # "video_url": video_url, # "message": "βœ… 3D X-ray video generated successfully via PixVerse" # }) # else: # return JSONResponse({ # "error": "Video URL missing in succeeded task", # "data": status_data # }, status_code=500) # if status == "failed": # return JSONResponse({ # "error": "❌ Video generation task failed", # "data": status_data # }, status_code=500) # return JSONResponse({ # "error": "⏳ Video generation timed out", # "task_status": status, # "task_id": task_id # }, status_code=202) # except Exception as e: # print("3D video backend error:", str(e)) # return JSONResponse({"error": str(e)}, status_code=500) # =============================================== # βœ… Fully Fixed Version β€” 3D Local Video Generator (with Overlay) # =============================================== from fastapi import UploadFile, File from fastapi.responses import JSONResponse import io, os, torch, numpy as np, imageio, cv2 #from torchvision.models.video import r3d_18 from PIL import Image import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D HF_SPACE_URL = os.getenv("HF_SPACE_URL", "https://abbhy123ghh-x-ray-analysis-2.hf.space").rstrip("/") @api.post("/generate_3d_local") async def generate_3d_local(file: UploadFile = File(...)): """ Generate a pseudo-3D rotating video from a 2D X-ray image. Adds AI overlay and realistic depth-based surface projection. """ try: # 1️⃣ Read and preprocess input image raw = await file.read() img = Image.open(io.BytesIO(raw)).convert("L").resize((224, 224)) gray_u8 = np.array(img, dtype=np.uint8) img_np = np.array(img, dtype=np.float32) / 255.0 # 2️⃣ Try AI overlay from your model try: res = analyze_fullbody(gray_u8) overlay_img, _ = findings_overlay(gray_u8, res) if overlay_img is not None: base_img = overlay_img else: base_img = np.stack([gray_u8] * 3, axis=-1) except Exception as overlay_error: print("⚠️ Overlay generation failed:", overlay_error) base_img = np.stack([gray_u8] * 3, axis=-1) blended = cv2.cvtColor(base_img, cv2.COLOR_BGR2RGB) # 3️⃣ Create pseudo depth volume depth_slices = [img_np * (i / 10) for i in range(10)] volume = np.stack(depth_slices, axis=0) # [D,H,W] volume = torch.tensor(volume).unsqueeze(0).unsqueeze(0) if volume.shape[1] == 1: volume = volume.repeat(1, 3, 1, 1, 1) print("βœ… Final volume shape:", volume.shape) # # 4️⃣ Dummy forward pass for consistency (no inference) # model = r3d_18(weights=None) # model.eval() # with torch.no_grad(): # _ = model(volume) # 5️⃣ Generate 3D surface rotation frames frames = [] H, W = gray_u8.shape x, y = np.meshgrid(np.linspace(0, 1, W), np.linspace(0, 1, H)) depth = cv2.GaussianBlur(gray_u8.astype(np.float32), (9, 9), 0) for angle in range(0, 360, 5): fig = plt.figure(figsize=(3, 3)) ax = fig.add_subplot(111, projection="3d") ax.view_init(30, angle) ax.plot_surface( x, y, depth / 255.0, facecolors=blended / 255.0, rstride=2, cstride=2, linewidth=0, antialiased=False ) ax.axis("off") fig.canvas.draw() frame = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (3,)) frames.append(frame) plt.close(fig) # 6️⃣ Save output video (MP4 preferred, fallback to GIF) os.makedirs("static", exist_ok=True) out_path = "static/xray_3d_local.mp4" try: import imageio_ffmpeg writer = imageio.get_writer( out_path, format="ffmpeg", mode="I", fps=12, codec="libx264", quality=8 ) for frame in frames: writer.append_data(frame) writer.close() video_url = f"{HF_SPACE_URL}/static/xray_3d_local.mp4" msg = "βœ… 3D MP4 video (with overlay) generated successfully" except Exception as e: print("⚠️ FFmpeg failed, saving GIF instead:", e) gif_path = out_path.replace(".mp4", ".gif") imageio.mimsave(gif_path, frames, duration=0.08) video_url = f"{HF_SPACE_URL}/static/xray_3d_local.gif" msg = "βœ… 3D GIF (with overlay) generated successfully" return JSONResponse({ "ok": True, "message": msg, "video_url": video_url }) except Exception as e: import traceback traceback.print_exc() return JSONResponse({"error": str(e)}, status_code=500) # ============================================================ # βœ… MONAI + PyVista 3D Full-Body Visualization (Final HF Safe) # ============================================================ from fastapi import UploadFile, File from fastapi.responses import JSONResponse import io, os, numpy as np, torch from PIL import Image from monai.networks.nets import UNet import torch.nn.functional as F import pyvista as pv from pyvista import start_xvfb import imageio.v3 as iio HF_SPACE_URL = os.getenv("HF_SPACE_URL", "https://abbhy123ghh-x-ray-analysis-2.hf.space").rstrip("/") @api.post("/generate_3d_monai") async def generate_3d_monai(file: UploadFile = File(...)): """ Generate a realistic 3D volumetric visualization for any X-ray region (skull, chest, limb, etc.) using MONAI for enhancement and PyVista for volumetric rendering. Fully CPU-compatible for Hugging Face Spaces. """ try: # --- Step 1: Read and preprocess image --- raw = await file.read() img = Image.open(io.BytesIO(raw)).convert("L").resize((256, 256)) gray = np.array(img, dtype=np.float32) / 255.0 # --- Step 2: Build synthetic 3D volume (simulate multiple slices) --- depth_slices = 16 volume = np.stack([gray * (1 - i / depth_slices) for i in range(depth_slices)], axis=0) volume = np.expand_dims(volume, axis=0) # [1, D, H, W] # --- Step 3: MONAI lightweight 3D UNet (shallow for CPU) --- x = torch.tensor(volume, dtype=torch.float32).unsqueeze(0) # [1, 1, D, H, W] pad_d = (16 - x.shape[2] % 16) % 16 pad_h = (16 - x.shape[3] % 16) % 16 pad_w = (16 - x.shape[4] % 16) % 16 x = F.pad(x, (0, pad_w, 0, pad_h, 0, pad_d)) model = UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(8, 16, 32), strides=(2, 2), ) model.eval() with torch.no_grad(): y = torch.sigmoid(model(x)) seg = y[0, 0].cpu().numpy() seg = (seg - seg.min()) / (seg.max() - seg.min() + 1e-8) seg = np.clip(seg * 255, 0, 255).astype(np.uint8) # --- Step 4: Create PyVista grid (Hugging Face compatible) --- start_xvfb() grid = pv.ImageData(dimensions=seg.shape) grid.spacing = (1, 1, 1) grid.origin = (0, 0, 0) grid.point_data["values"] = seg.flatten(order="F") # --- Step 5: Enhanced 3D visualization with lighting --- plotter = pv.Plotter(off_screen=True, window_size=[512, 512]) plotter.add_volume( grid, cmap="bone", opacity="sigmoid_5", shade=True, diffuse=1.0, specular=0.3, specular_power=15, ) plotter.set_background("black") # Add realistic light light = pv.Light( position=(seg.shape[2]*2, seg.shape[1]*2, seg.shape[0]), color='white', intensity=1.5, ) plotter.add_light(light) frames = [] for angle in range(0, 360, 5): plotter.camera_position = [ (seg.shape[2]*2, 0, seg.shape[0]/2), (seg.shape[2]/2, seg.shape[1]/2, seg.shape[0]/2), (0, 0, 1), ] plotter.camera.azimuth = angle img_frame = plotter.screenshot(return_img=True) frames.append(img_frame) plotter.close() # --- Step 6: Save rotating MP4 video --- os.makedirs("static", exist_ok=True) out_path = "static/xray_monai_3d_fullbody.mp4" iio.imwrite(out_path, frames, fps=12, codec="libx264") return JSONResponse({ "ok": True, "message": "βœ… MONAI + PyVista 3D full-body visualization generated successfully", "video_url": f"{HF_SPACE_URL}/{out_path}" }) except Exception as e: import traceback; traceback.print_exc() return JSONResponse({"error": str(e)}, status_code=500) # ----------------------------- # Config # ----------------------------- CLIP_CANDIDATES = [ "microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224", "openai/clip-vit-base-patch32", ] CLIP_TEMP = float(os.getenv("CLIP_TEMP", "1.0")) CHEST_TEMP = float(os.getenv("CHEST_TEMP", "1.2")) CAL_TAU = float(os.getenv("CAL_TAU", "1.0")) # <-- ADDED: site calibration temperature TB_REPO = os.getenv("TB_REPO", "").strip() TB_LABEL = os.getenv("TB_LABEL", "Tuberculosis") HF_TOKEN = os.getenv("HF_TOKEN", None) DEFAULT_THRESHOLDS: Dict[str, float] = { # chest "Pneumonia": 0.70, "Pulmonary Fibrosis": 0.75, "COPD (proxy)": 0.75, "Lung Cancer (proxy)": 0.70, "Tuberculosis": 0.70, # non-chest "Compression fracture": 0.65, "Scoliosis (proxy)": 0.65, "Hip fracture": 0.65, "OA (proxy)": 0.65, "Fracture (upper_ext)": 0.65, "Fracture (lower_ext)": 0.65, "Implant/Hardware": 0.65, "Skull fracture": 0.65, } ABSTAIN_MARGIN = float(os.getenv("ABSTAIN_MARGIN", "0.08")) REGION_CONF_MIN = float(os.getenv("REGION_CONF_MIN", "0.40")) REGION_NAMES = ["skull","chest","spine","pelvis","upper_ext","lower_ext"] REGION_PROMPTS = { "skull": ["X-ray of skull", "Head/skull radiograph"], "chest": ["Chest X-ray", "Thorax radiograph"], "spine": ["X-ray of spine", "Spine radiograph"], "pelvis": ["Pelvis or hip X-ray", "Pelvic radiograph"], "upper_ext": ["Upper extremity X-ray (shoulder to hand)", "Arm/wrist/hand X-ray"], "lower_ext": ["Lower extremity X-ray (hip to foot)", "Leg/ankle/foot X-ray"], } REGION_DISPLAY = { "chest":"chest", "skull":"skull", "spine":"spine", "pelvis":"pelvis/hip", "upper_ext":"upper extremity", "lower_ext":"lower extremity", "unknown":"unknown", } # Tasks (non-chest) TASKS: Dict[str, List[Dict[str, Any]]] = { "spine": [ {"name": "Compression fracture", "zs": ("X-ray of spine with compression fracture", "X-ray of spine without fracture")}, {"name": "Scoliosis (proxy)", "zs": ("X-ray of spine with scoliosis", "X-ray of a straight spine")}, ], "pelvis": [ {"name": "Hip fracture", "zs": ("Pelvis or hip X-ray with hip fracture", "Pelvis or hip X-ray without fracture")}, {"name": "OA (proxy)", "zs": ("Pelvis or hip X-ray with osteoarthritis", "Pelvis or hip X-ray with normal joint space")}, ], "upper_ext": [ {"name": "Fracture (upper_ext)", "zs": ("Upper extremity X-ray with fracture", "Upper extremity X-ray without fracture")}, {"name": "Implant/Hardware", "zs": ("X-ray with orthopedic implant or hardware", "X-ray without orthopedic implant")}, ], "lower_ext": [ {"name": "Fracture (lower_ext)", "zs": ("Lower extremity X-ray with fracture", "Lower extremity X-ray without fracture")}, {"name": "Implant/Hardware", "zs": ("X-ray with orthopedic implant or hardware", "X-ray without orthopedic implant")}, ], "skull": [ {"name": "Skull fracture", "zs": ("Skull X-ray with fracture", "Skull X-ray without fracture")}, ], } # Sub-label descriptions for every screening item CHEST_DESCRIPTIONS = { "Pneumonia": "Patchy/lobar opacities suggesting infection/inflammation; confirm clinically Β± labs.", "Pulmonary Fibrosis": "Chronic interstitial changes; consider HRCT for characterization.", "COPD (proxy)": "Proxy via emphysema/hyperinflation; COPD diagnosis needs spirometry/history.", "Lung Cancer (proxy)": "Proxy via mass/nodule; suspicious lesions need CT Β± biopsy.", "Tuberculosis": "CXR overlaps with other diseases; definitive Dx needs microbiology/NAAT.", "Asthma": "Typically not detectable on CXR; diagnosis is clinical with spirometry.", } TASK_DESCRIPTIONS = { "Compression fracture": "Height loss of vertebral body; triage to CT/MRI if clinical concern.", "Scoliosis (proxy)": "Curve or rotation; clinical significance depends on Cobb angle (not provided).", "Hip fracture": "Disruption of trabeculae/cortical lines; confirm with dedicated views/CT.", "OA (proxy)": "Joint space narrowing, osteophytes, subchondral changes; clinical correlation.", "Fracture (upper_ext)": "Cortical break or lucent line; immobilize and obtain dedicated views.", "Fracture (lower_ext)": "Cortical break or lucent line; weight-bearing precautions and follow-up.", "Implant/Hardware": "Presence of orthopedic material; assess alignment/loosening if applicable.", "Skull fracture": "Linear or depressed lucency; correlate with neuro status/CT.", } CHEST_TARGETS = { "Pneumonia": ["Pneumonia"], "Pulmonary Fibrosis": ["Fibrosis"], "COPD (proxy)": ["Emphysema"], "Lung Cancer (proxy)": ["Mass", "Nodule"], "Tuberculosis": ["Tuberculosis", "TB"], "Asthma": [], } # ----------------------------- # Utilities # ----------------------------- def to_uint8(img): if isinstance(img, Image.Image): img = np.array(img) if img.dtype == np.uint8: return img img = img.astype(np.float32) img -= img.min(); rng = img.max() - img.min() if rng > 1e-6: img = img / (img.max() + 1e-6) return (img * 255.0).clip(0, 255).astype(np.uint8) def read_any_xray(path_or_buffer) -> np.ndarray: """ Reads a DICOM (.dcm) or image (PNG/JPG) and returns an 8-bit grayscale (uint8) array. DICOM handling: - Applies RescaleSlope/Intercept if present. - Applies Window Center/Width (first value if multiple). - Correctly inverts MONOCHROME1 (after VOI windowing). """ name = getattr(path_or_buffer, "name", str(path_or_buffer)).lower() # Non-DICOM: use PIL and convert to 8-bit grayscale if not name.endswith(".dcm"): return np.array(Image.open(path_or_buffer).convert("L")) if not HAVE_DICOM: raise RuntimeError("Install pydicom to read DICOM files.") ds = pydicom.dcmread(path_or_buffer) # Raw pixel array -> float arr = ds.pixel_array.astype(np.float32) # 1) Modality LUT (Rescale Slope/Intercept) slope = float(getattr(ds, "RescaleSlope", 1.0)) intercept = float(getattr(ds, "RescaleIntercept", 0.0)) if slope != 1.0 or intercept != 0.0: arr = arr * slope + intercept # Helper: coerce WindowCenter/Width to float (first value if multi) def _first_float(x): if x is None: return None try: # pydicom may give MultiValue/Sequence/DSfloat return float(np.atleast_1d(x)[0]) except Exception: try: return float(x) except Exception: return None # 2) VOI LUT: Windowing (if available) wc = _first_float(getattr(ds, "WindowCenter", None)) ww = _first_float(getattr(ds, "WindowWidth", None)) if ww is not None and ww > 0: lo = wc - ww / 2.0 hi = wc + ww / 2.0 arr = np.clip(arr, lo, hi) arr = (arr - lo) / (hi - lo + 1e-6) # -> 0..1 else: # Fallback min-max mn = float(np.min(arr)) mx = float(np.max(arr)) arr = (arr - mn) / (mx - mn + 1e-6) # 3) Photometric interpretation (invert MONOCHROME1) phot = str(getattr(ds, "PhotometricInterpretation", "")).upper() if phot == "MONOCHROME1": arr = 1.0 - arr # Return uint8 image return (arr * 255.0).clip(0, 255).astype(np.uint8) def edges_overlay(gray_u8: np.ndarray) -> np.ndarray: try: edges = cv2.Canny(gray_u8, 50, 150) overlay = np.stack([gray_u8]*3, axis=-1) overlay[edges>0] = [255, 80, 80] return overlay except Exception: return np.stack([gray_u8]*3, axis=-1) def edges_overlay(gray_u8: np.ndarray) -> np.ndarray: try: edges = cv2.Canny(gray_u8, 50, 150) overlay = np.stack([gray_u8]*3, axis=-1) overlay[edges>0] = [255, 80, 80] return overlay except Exception: return np.stack([gray_u8]*3, axis=-1) # βœ… Add your pseudo-3D function here import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D # noqa: F401 def pseudo3d_surface(gray_u8: np.ndarray) -> np.ndarray: """Generate a pseudo-3D surface plot from a 2D X-ray.""" try: H, W = gray_u8.shape X, Y = np.meshgrid(np.arange(W), np.arange(H)) Z = gray_u8.astype(np.float32) fig = plt.figure(figsize=(6, 6)) ax = fig.add_subplot(111, projection='3d') ax.plot_surface(X, Y, Z, cmap="bone", linewidth=0, antialiased=True) ax.set_axis_off() fig.tight_layout(pad=0) buf = io.BytesIO() plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) plt.close(fig) buf.seek(0) img = Image.open(buf).convert("RGB") return np.array(img) except Exception as e: print("Pseudo-3D generation failed:", e) return np.stack([gray_u8]*3, axis=-1) def xray_to_depth_3d(gray_u8: np.ndarray) -> np.ndarray: """Convert 2D X-ray into pseudo-3D surface using Intel/dpt-hybrid-midas.""" if not HAVE_DEPTH: return np.stack([gray_u8]*3, axis=-1) try: pil = Image.fromarray(gray_u8).convert("RGB") inputs = DEPTH_EXTRACTOR(images=pil, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = DEPTH_MODEL(**inputs) predicted_depth = outputs.predicted_depth.squeeze().cpu().numpy() # Normalize depth depth_min, depth_max = predicted_depth.min(), predicted_depth.max() depth_norm = (predicted_depth - depth_min) / (depth_max - depth_min + 1e-6) # Render pseudo-3D H, W = depth_norm.shape X, Y = np.meshgrid(np.arange(W), np.arange(H)) Z = depth_norm * 255 fig = plt.figure(figsize=(6, 6)) ax = fig.add_subplot(111, projection="3d") ax.plot_surface(X, Y, Z, cmap="bone", linewidth=0, antialiased=True) ax.set_axis_off() fig.tight_layout(pad=0) buf = io.BytesIO() plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) plt.close(fig) buf.seek(0) img = Image.open(buf).convert("RGB") return np.array(img) except Exception as e: print("3D depth rendering failed:", e) return np.stack([gray_u8]*3, axis=-1) def risk_level(prob: Optional[float]) -> str: if prob is None: return "N/A" if prob >= 0.70: return "High" if prob >= 0.50: return "Moderate" return "Low" def risk_tag(prob: Optional[float]) -> str: level = risk_level(prob) color = {"High":"#E03131","Moderate":"#F59F00","Low":"#2F9E44","N/A":"#868E96"}[level] pct = "N/A" if prob is None else f"{prob:.0%}" return f"{level} {pct}" def percent_bar(prob: Optional[float]) -> str: if prob is None: return "" width = int(round(prob*100)) return ( "
" f"
" "
" ) # ----------------------------- # Extra detail generator for ALL body regions # ----------------------------- def generate_extra_details(res: dict) -> str: region = res.get("region", "unknown") tasks = res.get("tasks", []) region_texts = { "skull": ( "🧠 **Image Details (Skull)**\n\n" "View type: frontal or lateral skull radiograph.\n\n" "Bones visible: cranial vault, facial bones, mandible.\n\n" "πŸ”Ž **AI / Computer Vision Context**\n" "- Region classifier: skull.\n" "- Tasks: skull fracture detection, implant/hardware.\n" "- AI approach: CLIP region classification + fracture CNN/ViT.\n\n" "🚦 **Limitations**\n" "- Small hairline fractures may be missed.\n" "- CT is gold standard for head trauma." ), "chest": ( "🫁 **Image Details (Chest)**\n\n" "View type: PA/AP/Lateral chest X-ray.\n\n" "Visible: lungs, ribs, diaphragm, clavicles, cardiac silhouette.\n\n" "πŸ”Ž **AI / Computer Vision Context**\n" "- Region classifier: chest.\n" "- Tasks: Pneumonia, fibrosis, COPD, TB, lung cancer proxy.\n" "- Uses TorchXRayVision DenseNet + GradCAM overlays.\n\n" "🚦 **Limitations**\n" "- Cannot replace CT for detailed lung/mediastinal lesions.\n" "- Diagnosis requires clinical + lab correlation." ), "spine": ( "🦴 **Image Details (Spine)**\n\n" "View type: cervical, thoracic, or lumbar.\n\n" "Visible: vertebral bodies, intervertebral spaces, alignment.\n\n" "πŸ”Ž **AI / Computer Vision Context**\n" "- Region classifier: spine.\n" "- Tasks: compression fracture, scoliosis.\n" "- AI approach: zero-shot CLIP + CNNs for fracture lines.\n\n" "🚦 **Limitations**\n" "- MRI needed for discs/cord.\n" "- Subtle wedge fractures may require CT." ), "pelvis": ( "🦴 **Image Details (Pelvis/Hip)**\n\n" "View type: AP pelvis/hip.\n\n" "Visible: iliac bones, acetabulum, femoral heads.\n\n" "πŸ”Ž **AI / Computer Vision Context**\n" "- Region classifier: pelvis/hip.\n" "- Tasks: hip fracture, osteoarthritis.\n" "- AI approach: CLIP region classification + fracture CNN/transformer.\n\n" "🚦 **Limitations**\n" "- Subtle fractures may be missed β†’ CT/MRI often needed." ), "upper_ext": ( "βœ‹ **Image Details (Upper Extremity)**\n\n" "View type: arm, shoulder, elbow, wrist, hand.\n\n" "Visible: humerus, radius, ulna, carpals, metacarpals.\n\n" "πŸ”Ž **AI / Computer Vision Context**\n" "- Region classifier: upper extremity.\n" "- Tasks: fractures, implants/hardware.\n" "- AI approach: CLIP + fracture CNNs (MURA dataset fine-tuning).\n\n" "🚦 **Limitations**\n" "- Soft tissues (ligaments/tendons) invisible.\n" "- US/MRI needed for rotator cuff, ligament injuries." ), "lower_ext": ( "🦢 **Image Details (Lower Extremity)**\n\n" "View type: hip, femur, knee, tibia, ankle, foot.\n\n" "Visible: phalanges, metatarsals, tarsals, tibia, fibula.\n\n" "πŸ”Ž **AI / Computer Vision Context**\n" "- Region classifier: lower extremity.\n" "- Tasks: fractures, osteoarthritis, implants, deformity alignment.\n" "- AI approach: CLIP region classification β†’ fracture CNN/transformer (MURA dataset).\n\n" "🚦 **Limitations**\n" "- Multiple views required to rule out fractures.\n" "- Soft tissue injuries invisible on X-ray." ), "unknown": ( "❓ **Region Unknown**\n\n" "The model could not confidently classify this X-ray.\n" "Please check image quality (exposure, blur, cropping)." ), } text = region_texts.get(region, region_texts["unknown"]) # Append AI task probabilities if available if tasks: text += "\n\n**Screening Results:**\n" for t in tasks: name = t.get("name") p = t.get("prob") prob_str = "N/A" if p is None else f"{p*100:.1f}%" text += f"- {name}: {prob_str}\n" return text # ============================================= # 🧠 AI Radiology Report Generator (fixed version) # ============================================= import os import google.generativeai as genai # βœ… Load Gemini API key from Hugging Face secret genai.configure(api_key=os.getenv("GEMINI_API_KEY")) # βœ… Initialize Gemini model (flash = faster / cheaper) try: # βœ… New API style (v1) gemini_model = genai.GenerativeModel("gemini-2.5-flash") except Exception as e: print("⚠️ Falling back to Gemini v1beta format due to:", e) try: # βœ… Old API style (v1beta) gemini_model = genai.GenerativeModel("models/gemini-1.5-flash-latest") except Exception: # βœ… Fallback to PRO model if FLASH not available gemini_model = genai.GenerativeModel("gemini-1.5-pro") def generate_medical_report(findings: dict) -> str: """ Generate a professional radiology report from AI findings using Gemini, and format the output with Markdown and proper newlines. """ prompt = f""" You are a senior radiologist. Based on the following AI findings, generate a detailed, structured radiology report in clinical style. Findings: {json.dumps(findings, indent=2)} Follow this format: PATIENT INFORMATION TECHNICAL DETAILS RADIOLOGICAL FINDINGS IMPRESSION RECOMMENDATIONS FOLLOW-UP """ try: response = gemini_model.generate_content(prompt) # βœ… Extract text from Gemini response safely if hasattr(response, "text"): report_text = response.text.strip() elif hasattr(response, "candidates") and response.candidates: report_text = response.candidates[0].content.parts[0].text.strip() else: return "AI medical report could not be generated." # βœ… Format report into readable HTML (newlines + Markdown) formatted_report = format_ai_report(report_text) return formatted_report except Exception as e: import traceback print("Gemini AI report generation failed:", e) traceback.print_exc() return "AI medical report could not be generated." import re import markdown def format_ai_report(text: str) -> str: """ Formats Gemini AI report text into readable HTML with proper newlines and Markdown styling. """ if not text: return "

No report generated.

" text = re.sub(r'\n{2,}', '\n\n', text.strip()) # normalize spacing html_text = markdown.markdown(text) # convert Markdown to HTML html_text = html_text.replace('\n', '
') # keep single line breaks styled_html = f"""
{html_text}
""" return styled_html # ---------- Saliency β†’ boxes & overlay ---------- def _colorize_heatmap(gray_u8: np.ndarray, sal: np.ndarray) -> np.ndarray: sal = (sal - sal.min()) / (sal.max() - sal.min() + 1e-6) base = cv2.cvtColor(gray_u8, cv2.COLOR_GRAY2BGR) heat = cv2.applyColorMap((sal*255).astype(np.uint8), cv2.COLORMAP_JET) return cv2.addWeighted(base, 0.5, heat, 0.5, 0) def _topk_boxes_from_saliency(sal: np.ndarray, k: int = 3, min_area_ratio: float = 0.001, q: float = 0.92): """ Returns list of boxes [{'x':..,'y':..,'w':..,'h':..,'score':..}] from saliency (0..1). q is the quantile threshold; increase for tighter boxes. """ sal = (sal - sal.min()) / (sal.max() - sal.min() + 1e-6) H, W = sal.shape thr = float(np.quantile(sal, q)) mask = (sal >= thr).astype(np.uint8) * 255 mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3,3), np.uint8), iterations=1) cnts, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) areas = [(cv2.contourArea(c), c) for c in cnts] areas.sort(key=lambda x: x[0], reverse=True) min_area = H*W*min_area_ratio boxes = [] for area, c in areas: if area < min_area: continue x, y, w, h = cv2.boundingRect(c) roi = sal[y:y+h, x:x+w] score = float(roi.mean()) boxes.append({'x': int(x), 'y': int(y), 'w': int(w), 'h': int(h), 'score': score}) if len(boxes) >= k: break return boxes def _overlay_with_boxes(gray_u8: np.ndarray, sal: np.ndarray, label: str = "Suspect") -> Tuple[np.ndarray, list]: """Color heatmap + draw yellow rectangles with labels.""" out = _colorize_heatmap(gray_u8, sal) boxes = _topk_boxes_from_saliency(sal, k=3, min_area_ratio=0.001, q=0.92) for i, b in enumerate(boxes, start=1): x, y, w, h = b['x'], b['y'], b['w'], b['h'] cv2.rectangle(out, (x, y), (x+w, y+h), (0, 255, 255), 2) text = f"{label} #{i}" cv2.putText(out, text, (x, max(0, y-6)), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0,255,255), 1, cv2.LINE_AA) return out, boxes # ----------------------------- # PDF report (FPDF) β€” Unicode-safe + robust PNG embed on Windows # ----------------------------- def pdf_report_bytes(gray_u8: np.ndarray, res: Dict[str, Any], include_overlay: bool = True, findings_png: Optional[np.ndarray] = None) -> bytes: if not HAVE_PDF: raise RuntimeError("FPDF not installed. Run: pip install fpdf") def _safe(s): if s is None: return "" s = str(s) s = (s.replace("β€”", "-").replace("–", "-").replace("\u2011", "-") .replace("β€’", "*").replace("Β±", "+/-") .replace("β€œ", '"').replace("”", '"').replace("’", "'").replace("…", "...")) s = unicodedata.normalize("NFKD", s).encode("ascii", "ignore").decode("ascii") return s pdf = FPDF(format="A4", unit="mm") pdf.set_auto_page_break(auto=True, margin=12) title = "AI Radiology Report – Research Use Only" region = res.get("region", "unknown") conf = res.get("region_conf", 0.0) meta = res.get("meta_pred", {}) timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M") # ---------------- Page 1 ---------------- pdf.add_page() pdf.set_font("Helvetica", "B", 16) pdf.cell(0, 10, _safe(title), ln=1) pdf.set_font("Helvetica", "I", 10) pdf.cell(0, 6, _safe(f"Generated on: {timestamp}"), ln=1) pdf.cell(0, 6, _safe("Disclaimer: This is an AI-generated report, research use only, not for diagnostic purposes."), ln=1) pdf.ln(4) # Patient / Meta Info pdf.set_font("Helvetica", "B", 12) pdf.cell(0, 8, "Patient / Meta Information", ln=1) pdf.set_font("Helvetica", "", 11) pdf.cell(0, 7, _safe(f"Age (approx): {meta.get('age',{}).get('label','N/A')}"), ln=1) pdf.cell(0, 7, _safe(f"Sex: {meta.get('sex',{}).get('label','N/A')}"), ln=1) disp_region = REGION_DISPLAY.get(region, region) pdf.cell(0, 7, _safe(f"Region (auto-detected): {disp_region} | Confidence: {conf:.2f}"), ln=1) pdf.ln(3) # Clinical Indication pdf.set_font("Helvetica", "B", 12) pdf.cell(0, 8, "Clinical Indication", ln=1) pdf.set_font("Helvetica", "", 11) pdf.multi_cell(0, 6, _safe(f"AI-assisted screening of full-body X-ray. Region auto-detected as {disp_region}.")) pdf.ln(3) # Technique pdf.set_font("Helvetica", "B", 12) pdf.cell(0, 8, "Technique", ln=1) pdf.set_font("Helvetica", "", 11) pdf.multi_cell(0, 6, _safe("Single-view X-ray (projection estimated by model).")) pdf.ln(3) # Findings pdf.set_font("Helvetica", "B", 12) pdf.cell(0, 8, "Findings", ln=1) pdf.set_font("Helvetica", "", 11) details = generate_extra_details(res) if details: for line in details.split("\n"): pdf.multi_cell(0, 6, _safe(line)) else: pdf.multi_cell(0, 6, "No detailed findings available.") pdf.ln(3) # Disease Screening pdf.set_font("Helvetica", "B", 12) pdf.cell(0, 8, "Disease Screening", ln=1) pdf.set_font("Helvetica", "", 11) tasks = res.get("tasks", []) if not tasks: pdf.multi_cell(0, 6, "No tasks configured for this region or models unavailable.") else: for t in tasks: name = t.get("name", "?") p = t.get("prob", None) desc = t.get("desc", "") risk = risk_level(p) p_str = "N/A" if p is None else f"{p*100:.1f}%" pdf.multi_cell(0, 6, _safe(f"{name}: {p_str} | Risk: {risk}")) if desc: pdf.set_text_color(120,120,120) pdf.multi_cell(0, 6, _safe(f" - {desc}")) pdf.set_text_color(0,0,0) pdf.ln(1) pdf.ln(3) # Impression / Conclusion pdf.set_font("Helvetica", "B", 12) pdf.cell(0, 8, "Impression / Conclusion", ln=1) pdf.set_font("Helvetica", "", 11) high_risk = [t for t in tasks if risk_level(t.get("prob")) == "High"] if high_risk: pdf.multi_cell(0, 6, "Abnormal findings detected – recommend further evaluation / correlation with clinical context.") else: pdf.multi_cell(0, 6, "No high-risk abnormalities detected by AI. Findings consistent with low or moderate risk only.") pdf.ln(4) # ---------------- Page 2: Original Image ---------------- fd1, p1 = tempfile.mkstemp(suffix=".png"); os.close(fd1) try: Image.fromarray(gray_u8).convert("RGB").save(p1, format="PNG") pdf.add_page() pdf.set_font("Helvetica", "B", 12) pdf.cell(0, 8, "Original X-ray Image", ln=1) pdf.image(p1, x=10, y=None, w=180) finally: try: os.remove(p1) except Exception: pass # ---------------- Page 3: Overlay ---------------- if include_overlay: fd2, p2 = tempfile.mkstemp(suffix=".png"); os.close(fd2) try: if findings_png is not None: Image.fromarray(findings_png).convert("RGB").save(p2, format="PNG") title2 = "AI Heatmap with Suspect Boxes" else: over = edges_overlay(gray_u8) Image.fromarray(over).convert("RGB").save(p2, format="PNG") title2 = "Edges Overlay" pdf.add_page() pdf.set_font("Helvetica", "B", 12) pdf.cell(0, 8, _safe(title2), ln=1) pdf.image(p2, x=10, y=None, w=180) finally: try: os.remove(p2) except Exception: pass return pdf.output(dest="S").encode("latin-1") # ----------------------------- # Meta: Age / Sex / View (zero-shot via CLIP/BiomedCLIP) # ----------------------------- def _region_word(region: str) -> str: return { "chest": "Chest X-ray", "skull": "Skull X-ray", "spine": "Spine X-ray", "pelvis": "Pelvis/hip X-ray", "upper_ext": "Upper extremity X-ray", "lower_ext": "Lower extremity X-ray", }.get(region, "X-ray image") def predict_view(pil_rgb: Image.Image, region: str) -> dict: if region != "chest": return {"label": "N/A", "probs": {}} texts = ["PA chest X-ray", "AP chest X-ray", "Lateral chest X-ray"] p = clip_probs(pil_rgb, texts, temp=CLIP_TEMP) idx = int(np.argmax(p)) return {"label": ["PA","AP","Lateral"][idx], "probs": {"PA": float(p[0]), "AP": float(p[1]), "Lateral": float(p[2])}} def predict_sex(pil_rgb: Image.Image, region: str) -> dict: base = _region_word(region) texts = [f"{base} of a male patient", f"{base} of a female patient"] p = clip_probs(pil_rgb, texts, temp=CLIP_TEMP) idx = int(np.argmax(p)) return {"label": ["Male","Female"][idx], "probs": {"Male": float(p[0]), "Female": float(p[1])}} def predict_age(pil_rgb: Image.Image, region: str) -> dict: base = _region_word(region) labels = ["Child (<=14)", "Adolescent (15-19)", "Adult (20-64)", "Elderly (65+)"] prompts = [f"{base} of a child", f"{base} of an adolescent", f"{base} of an adult", f"{base} of an elderly person"] anchors = np.array([8.0, 17.0, 40.0, 75.0], dtype=np.float32) p = clip_probs(pil_rgb, prompts, temp=CLIP_TEMP) est = float(np.dot(p, anchors)) return {"label": f"{est:.1f} yrs", "years": est, "probs": {labels[i]: float(p[i]) for i in range(len(labels))}} def predict_meta(pil_rgb: Image.Image, region: str) -> dict: if CLIP_MDL is None or CLIP_PROC is None: return {"age": {"label": "N/A", "years": None, "probs": {}}, "sex": {"label": "N/A", "probs": {}}, "view": {"label": "N/A", "probs": {}}} return {"age": predict_age(pil_rgb, region), "sex": predict_sex(pil_rgb, region), "view": predict_view(pil_rgb, region)} # ----------------------------- # CLIP/BiomedCLIP zero-shot # ----------------------------- def load_clip_like(): if not HAVE_TRF: return None, None for rid in CLIP_CANDIDATES: try: proc = AutoProcessor.from_pretrained(rid, token=HF_TOKEN) mdl = CLIPModel.from_pretrained(rid, token=HF_TOKEN).eval().to(DEVICE) return mdl, proc except Exception: continue return None, None CLIP_MDL, CLIP_PROC = load_clip_like() def clip_probs(pil_rgb: Image.Image, texts: List[str], temp: float = 1.0) -> np.ndarray: if CLIP_MDL is None or CLIP_PROC is None: return np.full((len(texts),), 1.0/len(texts), dtype=np.float32) inputs = CLIP_PROC(text=texts, images=pil_rgb, return_tensors="pt", padding=True).to(DEVICE) with torch.no_grad(): out = CLIP_MDL(**inputs) logits = out.logits_per_image.squeeze(0) / max(1e-6, temp) probs = torch.softmax(logits, dim=0).detach().cpu().numpy() return probs.astype(np.float32) def clip_binary_prob(pil_rgb: Image.Image, pos_text: str, neg_text: str, temp: float = 1.0) -> float: probs = clip_probs(pil_rgb, [pos_text, neg_text], temp=temp) return float(probs[0]) # ----------------------------- # Chest ensemble (TorchXRayVision) # ----------------------------- def load_xrv_models(): if not HAVE_XRV: return None, None import torch import torchxrayvision as xrv # βœ… Fix for PyTorch 2.6+ safe-loading policy # Allow DenseNet class from torchxrayvision to be unpickled safely try: torch.serialization.add_safe_globals([xrv.models.DenseNet]) except Exception as e: print("Safe global registration failed:", e) def _mk(weights_key): try: m = xrv.models.DenseNet(weights=weights_key) return m.eval().to(DEVICE) except Exception as e: print(f"⚠️ Error loading model {weights_key}:", e) return None # Load the CheX model (you can add more if needed) models = [_mk("densenet121-res224-chex")] # Combine all pathology labels union = {} for m in models: if m is None: continue for lbl in m.pathologies: k = lbl.lower() if k not in union: union[k] = lbl return models, union # Initialize global variables XRV_MODELS, LOWER_TO_PRETTY = load_xrv_models() def xrv_logits_dict(gray_u8: np.ndarray, mdl) -> dict: img01 = gray_u8.astype(np.float32) / 255.0 img_hu = img01 * 2048.0 - 1024.0 t = torch.from_numpy(img_hu).unsqueeze(0).unsqueeze(0).to(DEVICE) t = torch.nn.functional.interpolate(t, size=224, mode="bilinear", align_corners=False) with torch.no_grad(): logits = mdl(t).squeeze(0).detach().cpu().numpy() return {lbl.lower(): float(logits[i]) for i, lbl in enumerate(mdl.pathologies)} def chest_disease_probs(gray_u8: np.ndarray): if not HAVE_XRV or XRV_MODELS is None: return {}, [] imgs = [gray_u8, np.ascontiguousarray(np.flip(gray_u8, axis=1))] sum_logits, count = {}, {} for im in imgs: for mdl in XRV_MODELS: d = xrv_logits_dict(im, mdl) for lbl, v in d.items(): sum_logits[lbl] = sum_logits.get(lbl, 0.0) + v count[lbl] = count.get(lbl, 0) + 1 mean_logits = {k: (sum_logits[k]/max(1,count[k]))/max(1e-6, CHEST_TEMP) for k in sum_logits} label_probs = {k: 1.0/(1.0+np.exp(-v)) for k,v in mean_logits.items()} requested = {} for disp, aliases in CHEST_TARGETS.items(): if not aliases: requested[disp] = None continue p=None for a in aliases: p = label_probs.get(a.lower()) if p is not None: break requested[disp] = p ranked = sorted(label_probs.items(), key=lambda x: x[1], reverse=True)[:10] ranked_pretty = [(LOWER_TO_PRETTY.get(lbl, lbl.title()), p) for lbl,p in ranked] return requested, ranked_pretty # Optional HF binary plugin (e.g., TB) def load_hf_binary(repo_id: str, token: Optional[str] = None): if not (HAVE_TRF and repo_id): return None, None, None proc = AutoImageProcessor.from_pretrained(repo_id, token=token) mdl = AutoModelForImageClassification.from_pretrained(repo_id, token=token).eval().to(DEVICE) id2 = getattr(mdl.config, "id2label", None) return mdl, proc, id2 def run_hf_binary(mdl, proc, pil_rgb: Image.Image, target_label_name: str) -> Optional[float]: if mdl is None or proc is None: return None inputs = proc(images=pil_rgb, return_tensors="pt").to(DEVICE) with torch.no_grad(): out = mdl(**inputs) logits = out.logits.squeeze(0) if logits.numel() == 1: return float(torch.sigmoid(logits).item()) id2 = getattr(mdl.config, "id2label", None) or {} name2idx = {v.lower(): int(k) for k,v in id2.items()} idx = name2idx.get(target_label_name.lower()) if idx is None or idx >= logits.numel(): return None return float(torch.sigmoid(logits[idx]).item()) # ----------------------------- # Saliency for chest (Grad-CAM) and non-chest (CLIP occlusion) # ----------------------------- def gradcam_single_xrv(gray_u8: np.ndarray, mdl, class_idx: int, out_size: Tuple[int,int]) -> Optional[np.ndarray]: """ Grad-CAM for a single TorchXRayVision DenseNet model & class index. Returns saliency map in 0..1 resized to out_size. """ try: mdl.eval() feats = {} grads = {} def fwd_hook(module, inp, out): feats["x"] = out.detach() def bwd_hook(module, grad_in, grad_out): grads["x"] = grad_out[0].detach() h1 = mdl.features.register_forward_hook(fwd_hook) h2 = mdl.features.register_full_backward_hook(bwd_hook) img01 = gray_u8.astype(np.float32) / 255.0 img_hu = img01 * 2048.0 - 1024.0 t = torch.from_numpy(img_hu).unsqueeze(0).unsqueeze(0).to(DEVICE) t = torch.nn.functional.interpolate(t, size=224, mode="bilinear", align_corners=False) t.requires_grad_(True) mdl.zero_grad(set_to_none=True) out = mdl(t) # logits if class_idx >= out.shape[1]: h1.remove(); h2.remove(); return None target = out[0, class_idx] target.backward(retain_graph=True) fmap = feats["x"] # (1, C, H, W) grad = grads["x"] # (1, C, H, W) weights = grad.mean(dim=(2,3), keepdim=True) # (1, C, 1, 1) cam = torch.relu((weights * fmap).sum(dim=1)).squeeze(0) # (H, W) cam = cam.detach().cpu().numpy() cam = cv2.resize(cam, out_size[::-1], interpolation=cv2.INTER_LINEAR) cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-6) h1.remove(); h2.remove() return cam except Exception: try: h1.remove(); h2.remove() except Exception: pass return None def chest_findings_heatmap(gray_u8: np.ndarray, res: Dict[str,Any]) -> Optional[Tuple[np.ndarray, str]]: """Return (saliency_map_0..1, target_label_name) for chest via ensemble Grad-CAM.""" if not HAVE_XRV or XRV_MODELS is None: return None chest = res.get("chest", {}) candidates = [(k, v) for k, v in chest.items() if v is not None] if not candidates: return None target_name, _ = max(candidates, key=lambda kv: kv[1]) lbl_lowers = [a.lower() for a in CHEST_TARGETS.get(target_name, [])] if not lbl_lowers: return None h, w = gray_u8.shape cams = [] for mdl in XRV_MODELS: idx = None for a in lbl_lowers: try: idx = [s.lower() for s in mdl.pathologies].index(a); break except ValueError: continue if idx is None: continue cam = gradcam_single_xrv(gray_u8, mdl, idx, (h, w)) if cam is not None: cams.append(cam) if not cams: return None sal = np.mean(cams, axis=0) sal = (sal - sal.min()) / (sal.max() - sal.min() + 1e-6) return sal, target_name def best_nonchest_task_and_prompts(res: Dict[str,Any]) -> Optional[Tuple[str, Tuple[str,str]]]: region = res.get("region", "unknown") if region not in TASKS: return None # map name -> (pos, neg) name2prompts = {} for t in TASKS[region]: if "zs" in t: name2prompts[t["name"]] = t["zs"] # pick highest-prob configured task best_name = None best_p = -1.0 for t in res.get("tasks", []): if t.get("region") != region: continue nm = t.get("name") if nm in name2prompts and t.get("prob") is not None: if float(t["prob"]) > best_p: best_p = float(t["prob"]); best_name = nm if best_name is None: # fallback to first available task config if TASKS[region]: cand = TASKS[region][0] if "zs" in cand: return cand["name"], cand["zs"] return None return best_name, name2prompts[best_name] def clip_occlusion_saliency(pil_rgb: Image.Image, pos_text: str, neg_text: str, n_segments: int = 140) -> Optional[np.ndarray]: """Return saliency map (0..1) via SLIC-occlusion for CLIP zero-shot.""" if not (HAVE_SKIMG and CLIP_MDL is not None): return None rgb = np.array(pil_rgb.convert("RGB")) base_p = clip_binary_prob(pil_rgb, pos_text, neg_text, temp=CLIP_TEMP) h, w, _ = rgb.shape seg = slic(img_as_float(rgb), n_segments=n_segments, compactness=10, start_label=0) blurred = cv2.GaussianBlur(rgb, (21,21), 0) contrib = np.zeros((h, w), dtype=np.float32) for lab in np.unique(seg): mask = (seg == lab) occluded = rgb.copy() occluded[mask] = blurred[mask] p = clip_binary_prob(Image.fromarray(occluded), pos_text, neg_text, temp=CLIP_TEMP) delta = max(0.0, base_p - p) # drop in prob -> importance contrib[mask] = delta sal = (contrib - contrib.min()) / (contrib.max() - contrib.min() + 1e-6) return sal def findings_overlay(gray_u8: np.ndarray, res: Dict[str,Any]): """ Returns (overlay_bgr, meta_dict) meta_dict = {"label": , "boxes": [ {x,y,w,h,score}, ... ]} """ region = res.get("region", "unknown") pil = Image.fromarray(gray_u8).convert("RGB") if region == "chest": out = chest_findings_heatmap(gray_u8, res) if out is None: return None, None sal, label = out overlay, boxes = _overlay_with_boxes(gray_u8, sal, label=label) return overlay, {"label": label, "boxes": boxes} else: pair = best_nonchest_task_and_prompts(res) if pair is None: return None, None name, (pos, neg) = pair sal = clip_occlusion_saliency(pil, pos, neg) if sal is None: return None, None overlay, boxes = _overlay_with_boxes(gray_u8, sal, label=name) return overlay, {"label": name, "boxes": boxes} # ----------------------------- # Gating helpers # ----------------------------- def abstain_flag(p: Optional[float], margin: float = ABSTAIN_MARGIN) -> bool: if p is None: return True return abs(p - 0.5) < margin def triage_flag(p: Optional[float], threshold: float) -> bool: if p is None: return False return p >= threshold # ---- ADDED: Quality metrics & gates, calibration, audit log ---- def quality_metrics(gray_u8): H, W = gray_u8.shape[:2] try: blur_var = float(cv2.Laplacian(gray_u8, cv2.CV_64F).var()) except Exception: blur_var = float('nan') under = float((gray_u8 < 10).mean()) over = float((gray_u8 > 245).mean()) return {"min_side": min(H, W), "blur_var": blur_var, "under_pct": under, "over_pct": over} def gate_decision(metrics, region_conf, min_side=256, blur_var_min=80.0, under_max=0.25, over_max=0.25, region_conf_min=REGION_CONF_MIN): ok, reasons = True, [] if metrics["min_side"] < min_side: ok=False; reasons.append("image too small") if not np.isnan(metrics["blur_var"]) and metrics["blur_var"] < blur_var_min: ok=False; reasons.append("too blurry") if metrics["under_pct"] > under_max: ok=False; reasons.append("under-exposed") if metrics["over_pct"] > over_max: ok=False; reasons.append("over-exposed") if region_conf < region_conf_min: ok=False; reasons.append("low region confidence (OOD)") return ok, reasons def calibrate_prob(p, tau=CAL_TAU): if p is None: return None p = float(np.clip(p, 1e-6, 1-1e-6)) logit = np.log(p/(1-p)) return 1.0/(1.0 + np.exp(-logit/max(1e-6, tau))) LOG_PATH = pathlib.Path("logs"); LOG_PATH.mkdir(exist_ok=True) def audit_log(event, payload, raw_bytes=None): rec = { "ts": datetime.datetime.utcnow().isoformat()+"Z", "event": event, "payload": payload, "input_sha256": hashlib.sha256(raw_bytes).hexdigest() if raw_bytes else None } with open(LOG_PATH / f"audit-{datetime.date.today().isoformat()}.jsonl", "a", encoding="utf-8") as f: f.write(json.dumps(rec) + "\n") # ----------------------------- # Shared inference # ----------------------------- def analyze_fullbody(gray_u8: np.ndarray) -> Dict[str, Any]: result: Dict[str, Any] = { "meta": {"version": 1, "device": DEVICE}, "region": None, "region_conf": None, "region_source": None, "tasks": [], "chest": {} } pil_rgb = Image.fromarray(gray_u8).convert("RGB") # Region router (CLIP) texts = [REGION_PROMPTS[r][0] for r in REGION_NAMES] probs = clip_probs(pil_rgb, texts, temp=CLIP_TEMP) idx = int(np.argmax(probs)) region = REGION_NAMES[idx]; conf = float(probs[idx]) result.update({"region": region, "region_conf": conf, "region_source": "zero-shot"}) # Meta (Age/Sex/View) using region-aware prompts result["meta_pred"] = predict_meta(pil_rgb, region) # Low-confidence region β†’ unknown if conf < REGION_CONF_MIN: result["region"] = "unknown" # Region-specific screening if region == "chest" and HAVE_XRV and XRV_MODELS is not None: chest_probs, _ranked = chest_disease_probs(gray_u8) if TB_REPO and HAVE_TRF: tb_m, tb_p, _ = load_hf_binary(TB_REPO, HF_TOKEN) tb_prob = run_hf_binary(tb_m, tb_p, pil_rgb, TB_LABEL) if tb_prob is not None: chest_probs["Tuberculosis"] = tb_prob result["chest"] = chest_probs for name, p in chest_probs.items(): thr = DEFAULT_THRESHOLDS.get(name, 0.5) result["tasks"].append({ "region": "chest", "name": name, "prob": p, "threshold": thr, "abstain": abstain_flag(p), "triage": triage_flag(p, thr), "method": "xrv-ensemble" if name != TB_LABEL else "hf-binary", "desc": CHEST_DESCRIPTIONS.get(name, ""), }) else: for t in TASKS.get(region, []): name = t["name"]; thr = DEFAULT_THRESHOLDS.get(name, 0.6) prob = None; method = None if "hf_repo" in t and t.get("hf_repo") and HAVE_TRF: mdl, proc, _ = load_hf_binary(t["hf_repo"], HF_TOKEN) label = t.get("label", "positive") prob = run_hf_binary(mdl, proc, pil_rgb, label) method = f"hf:{t['hf_repo']}" elif "zs" in t: pos, neg = t["zs"] prob = clip_binary_prob(pil_rgb, pos, neg, temp=CLIP_TEMP) method = "zero-shot-clip" result["tasks"].append({ "region": region, "name": name, "prob": prob, "threshold": thr, "abstain": abstain_flag(prob), "triage": triage_flag(prob, thr), "method": method, "desc": TASK_DESCRIPTIONS.get(name, ""), }) return result