x-ray_analysis / app_fullbody_pretrained.py
abbhy123ghh's picture
Update app_fullbody_pretrained.py
61a263b verified
# app_fullbody_pretrained.py β€” Full-Body X-ray Analysis (Pretrained) + PDF + Suspect Boxes
# -----------------------------------------------------------------------------
# Run: python -m uvicorn app_fullbody_pretrained:api --host 0.0.0.0 --port 7860
# -----------------------------------------------------------------------------
import os
os.system("pip uninstall -y google-generativeai google-api-core googleapis-common-protos grpcio || true")
os.system("pip install --upgrade google-generativeai==0.8.5 google-api-core protobuf grpcio --quiet")
import os, io, json, tempfile, hashlib, datetime, pathlib
from io import BytesIO
from typing import Optional, Dict, List, Any, Tuple
import numpy as np
from PIL import Image
import cv2
import unicodedata
import torch
import torch.nn.functional as F
import base64
# FastAPI imports
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse, Response
from fastapi.staticfiles import StaticFiles
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # noqa: F401
import os
#import openai
from fastapi import FastAPI, UploadFile, File
from pydantic import BaseModel
# Hugging Face Space base URL
HF_SPACE_URL = os.getenv("HF_SPACE_URL", "https://abbhy123ghh-x-ray-analysis.hf.space").rstrip("/")
# βœ… Create app
api = FastAPI()
os.makedirs("static", exist_ok=True)
api.mount("/static", StaticFiles(directory="static"), name="static")
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
# Optional dependencies
try: import pydicom; HAVE_DICOM=True
except Exception: HAVE_DICOM=False
try: import torchxrayvision as xrv; HAVE_XRV=True
except Exception: HAVE_XRV=False
try:
from transformers import (AutoProcessor, CLIPModel, AutoImageProcessor, AutoModelForImageClassification)
HAVE_TRF=True
except Exception: HAVE_TRF=False
try: from skimage.segmentation import slic; from skimage.util import img_as_float; HAVE_SKIMG=True
except Exception: HAVE_SKIMG=False
try: from fpdf import FPDF; HAVE_PDF=True
except Exception: HAVE_PDF=False
import uvicorn
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# -----------------------------
# KEEP all your existing utility functions + configs
# -----------------------------
# (to_uint8, read_any_xray, chest_disease_probs, analyze_fullbody, pdf_report_bytes, etc.)
# I did not touch any model logic.
# -----------------------------
# Root
# -----------------------------
@api.get("/")
async def root():
return {"ok": True, "message": "X-ray API is running"}
def to_base64(img: np.ndarray) -> str:
pil = Image.fromarray(img).convert("RGB")
buf = io.BytesIO()
pil.save(buf, format="PNG")
return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode()
# ========================================================================
# βœ… FastAPI endpoints β€” updated for multi-region analysis
# ========================================================================
@api.post("/analyze")
async def analyze(file: UploadFile = File(...)):
"""Analyze an uploaded X-ray and return meta + annotated overlays (JSON)."""
try:
suffix = (file.filename or "").lower()
raw = await file.read()
# Read DICOM or normal image
if suffix.endswith(".dcm") and HAVE_DICOM:
ds = pydicom.dcmread(io.BytesIO(raw))
arr = ds.pixel_array.astype(np.float32)
arr -= arr.min(); arr /= (arr.max() - arr.min() + 1e-6)
gray_u8 = (arr * 255.0).clip(0, 255).astype(np.uint8)
else:
gray_u8 = np.array(Image.open(io.BytesIO(raw)).convert("L"))
gray_u8 = to_uint8(gray_u8)
# Run full-body multi-region analysis
res = analyze_fullbody(gray_u8)
# πŸ”‘ Flatten meta_pred so frontend can show Age/View/Gender/CTR
meta_pred = res.get("meta_pred", {})
res["age"] = meta_pred.get("age", {}).get("label", "N/A")
res["view"] = meta_pred.get("view", {}).get("label", "N/A")
res["gender"] = meta_pred.get("sex", {}).get("label", "N/A")
res["ctr"] = "--" # placeholder
# Convert original to base64
res["original_img"] = to_base64(gray_u8)
# πŸ”‘ Overlay with boxes + labels
# --- Try segmentation overlay ---
overlay_img, overlay_meta = findings_overlay(gray_u8, res)
# --- βœ… Fallback if overlay is None (always show something)
if overlay_img is None:
print("⚠️ Overlay not generated β€” using edge-based overlay fallback.")
overlay_img = edges_overlay(gray_u8)
overlay_meta = {"label": "Edge-based fallback", "boxes": []}
res["overlay_img"] = to_base64(overlay_img)
res["overlay_meta"] = overlay_meta
# βœ… Generate structured AI radiology report
doctor_report_html = generate_medical_report(res)
return JSONResponse({
"ok": True,
"meta": {
"age": res["age"],
"view": res["view"],
"gender": res["gender"],
"ctr": res["ctr"]
},
"original_img": res["original_img"],
"overlay_img": res["overlay_img"],
"overlay_meta": res["overlay_meta"],
"tasks": res.get("tasks", []),
"region": res.get("region", "unknown"),
"region_conf": res.get("region_conf", 0.0),
"meta_pred": res.get("meta_pred", {}),
"doctor_report_html": doctor_report_html # 🩻 formatted HTML report
})
except Exception as e:
print("❌ Analyze failed:", e)
return JSONResponse({"error": str(e)}, status_code=400)
from fastapi import UploadFile, File
from fastapi.responses import Response, JSONResponse
from reportlab.lib.pagesizes import A4
from reportlab.lib import colors
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, Image as RLImage
import io, datetime, json
from PIL import Image
import numpy as np
import pydicom
# ===========================================================
# 🩻 Unified PDF Generator (Drlogy-style Layout for All Regions)
# ===========================================================
def generate_clinical_pdf(gray_img, overlay_img, res, doctor_report):
buffer = io.BytesIO()
doc = SimpleDocTemplate(
buffer, pagesize=A4, rightMargin=36, leftMargin=36, topMargin=60, bottomMargin=36
)
styles = getSampleStyleSheet()
normal = ParagraphStyle('Normal', parent=styles['Normal'], fontSize=10, leading=13)
bold = ParagraphStyle('Bold', parent=normal, fontName='Helvetica-Bold')
header = ParagraphStyle('Header', parent=normal, alignment=1, fontName='Helvetica-Bold', fontSize=14)
footer = ParagraphStyle('Footer', parent=normal, fontSize=8, textColor=colors.gray)
story = []
# --- HEADER ---
story.append(Paragraph("<b>DRLOGY IMAGING CENTER</b>", header))
story.append(Paragraph("X-Ray | CT-Scan | MRI | USG", bold))
story.append(Paragraph("105-108, SMART VISION COMPLEX, HEALTHCARE ROAD, MUMBAI - 689578", normal))
story.append(Paragraph("πŸ“ž 0123456789 | βœ‰οΈ [email protected]", normal))
story.append(Spacer(1, 10))
# --- DYNAMIC TITLE ---
region_title = res.get("region", "General Region").replace("_", " ").title()
story.append(Paragraph(f"<b>AI RADIOLOGY REPORT – {region_title.upper()}</b>", header))
story.append(Spacer(1, 10))
# --- PATIENT INFO ---
patient_table = [
["Name:", res.get("patient_name", "N/A"), "Age:", f"{res.get('age', 'N/A')} yrs"],
["Sex:", res.get("gender", "N/A"), "Study Date:", datetime.datetime.now().strftime("%d %b %Y, %I:%M %p")],
["Region:", region_title, "AI Confidence:", f"{res.get('region_conf', 0):.2f}"],
]
table = Table(patient_table, colWidths=[70, 160, 90, 150])
table.setStyle(TableStyle([
("BOX", (0, 0), (-1, -1), 0.5, colors.gray),
("INNERGRID", (0, 0), (-1, -1), 0.25, colors.gray),
("FONTNAME", (0, 0), (-1, -1), "Helvetica"),
("FONTSIZE", (0, 0), (-1, -1), 9),
("BACKGROUND", (0, 0), (-1, 0), colors.whitesmoke)
]))
story.append(table)
story.append(Spacer(1, 12))
# --- TECHNICAL DETAILS ---
story.append(Paragraph("<b>TECHNICAL DETAILS</b>", bold))
story.append(Paragraph(
"AI-assisted analysis performed using a multi-region deep learning model. "
"Single-view radiograph analyzed for anatomical abnormalities.", normal))
story.append(Spacer(1, 10))
# --- FINDINGS ---
story.append(Paragraph("<b>AI FINDINGS</b>", bold))
story.append(Paragraph(
"Below are the observed findings as per AI-based analysis of the selected body part.", normal))
story.append(Spacer(1, 5))
story.append(Paragraph(doctor_report.replace("\n", "<br/>"), normal))
story.append(Spacer(1, 10))
# --- IMPRESSION ---
story.append(Paragraph("<b>AI IMPRESSION</b>", bold))
story.append(Paragraph(
"Based on the AI model’s interpretation, no acute abnormality was confidently detected unless otherwise stated above.",
normal))
story.append(Spacer(1, 10))
# --- RECOMMENDATIONS ---
story.append(Paragraph("<b>RECOMMENDATIONS</b>", bold))
story.append(Paragraph(
"Clinical correlation is advised. If clinical suspicion persists, further radiographic or advanced imaging is recommended.",
normal))
story.append(Spacer(1, 10))
# --- IMAGES ---
story.append(Paragraph("<b>IMAGES:</b>", bold))
img_buf = io.BytesIO()
Image.fromarray(gray_img).save(img_buf, format="PNG")
img_buf.seek(0)
story.append(RLImage(img_buf, width=200, height=200))
if overlay_img is not None:
overlay_buf = io.BytesIO()
Image.fromarray(overlay_img).save(overlay_buf, format="PNG")
overlay_buf.seek(0)
story.append(RLImage(overlay_buf, width=200, height=200))
story.append(Spacer(1, 20))
# --- FOOTER ---
story.append(Paragraph("**** End of Report ****", footer))
story.append(Spacer(1, 6))
story.append(Paragraph("Dr. AI Radiologist (MD)", footer))
story.append(Paragraph(f"Generated on: {datetime.datetime.now():%d %b %Y, %I:%M %p}", footer))
story.append(Paragraph("This AI report is supportive, not a substitute for radiologist diagnosis.", footer))
doc.build(story)
buffer.seek(0)
return buffer.getvalue()
# ===========================================================
# βœ… Main Endpoint β€” Analyze & Generate PDF
# ===========================================================
from fastapi import UploadFile, File
from fastapi.responses import Response, JSONResponse
from reportlab.lib.pagesizes import A4
from reportlab.lib import colors
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.platypus import (
SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, Image as RLImage
)
from reportlab.lib.units import mm
import io, datetime, json, re
from PIL import Image
import numpy as np
import pydicom
@api.post("/analyze_pdf")
async def analyze_pdf(file: UploadFile = File(...)):
"""Generate a professional AI radiology report PDF."""
try:
suffix = (file.filename or "").lower()
raw = await file.read()
# --- Read DICOM / Image ---
if suffix.endswith(".dcm") and HAVE_DICOM:
ds = pydicom.dcmread(io.BytesIO(raw))
arr = ds.pixel_array.astype(np.float32)
arr -= arr.min(); arr /= (arr.max()-arr.min()+1e-6)
gray_u8 = (arr*255).clip(0,255).astype(np.uint8)
else:
gray_u8 = np.array(Image.open(io.BytesIO(raw)).convert("L"))
gray_u8 = to_uint8(gray_u8)
# --- Model Analysis ---
res = analyze_fullbody(gray_u8)
overlay_img, _ = findings_overlay(gray_u8, res)
if overlay_img is None:
overlay_img = edges_overlay(gray_u8)
# --- Gemini AI Report ---
try:
prompt = f"""
You are a senior radiologist. Generate a structured radiology report.
Data: {json.dumps(res, indent=2)}
Include clear prose for: Technical Details, Findings, Impression, Recommendations.
Avoid markdown, bullets, or extra headings.
"""
gemini_response = gemini_model.generate_content(prompt)
doctor_report = getattr(gemini_response, "text", "AI report unavailable.").strip()
except Exception:
doctor_report = "AI medical report could not be generated."
clean_report = re.sub(r"[#*_]+", "", doctor_report)
clean_report = re.sub(r"\n{2,}", "\n", clean_report.strip())
# --- PDF Template ---
buffer = io.BytesIO()
doc = SimpleDocTemplate(buffer, pagesize=A4, leftMargin=35, rightMargin=35, topMargin=35, bottomMargin=40)
styles = getSampleStyleSheet()
normal = ParagraphStyle("NormalCustom", parent=styles["Normal"], fontSize=10, leading=14)
heading = ParagraphStyle("Heading", parent=styles["Heading2"], fontSize=13, textColor=colors.HexColor("#004C99"))
center_heading = ParagraphStyle("CenterHeading", parent=styles["Heading2"], fontSize=14,
textColor=colors.white, alignment=1, spaceAfter=10)
footer = ParagraphStyle("Footer", parent=styles["Normal"], fontSize=9, textColor=colors.gray, alignment=1)
story = []
# === HEADER BAR ===
story.append(Table(
[[Paragraph("<b>DRLOGY IMAGING CENTER</b>", center_heading)]],
colWidths=[500],
style=[
('BACKGROUND', (0,0), (-1,-1), colors.HexColor("#004C99")),
('BOTTOMPADDING', (0,0), (-1,-1), 6),
('TOPPADDING', (0,0), (-1,-1), 6)
]
))
story.append(Paragraph("<font color='grey'>X-Ray | CT-Scan | MRI | USG</font>", normal))
story.append(Paragraph("105-108, SMART VISION COMPLEX, HEALTHCARE ROAD, MUMBAI - 689578", normal))
story.append(Paragraph("πŸ“ž 0123456789 βœ‰ [email protected]", normal))
story.append(Spacer(1, 10))
story.append(Table([[ '', '' ]], colWidths=[480],
style=[('LINEBELOW',(0,0),(-1,-1),1,colors.HexColor("#004C99"))]))
story.append(Spacer(1, 8))
# === TITLE SECTION ===
title_table = Table(
[[Paragraph("<b><font size=13 color='#004C99'>AI RADIOLOGY REPORT</font></b>", heading),
Paragraph(datetime.datetime.now().strftime("%d %b %Y, %I:%M %p"), normal)]],
colWidths=[340, 140],
style=[
("VALIGN", (0,0), (-1,-1), "MIDDLE"),
("ALIGN", (1,0), (1,0), "RIGHT")
]
)
story.append(title_table)
story.append(Spacer(1, 12))
# === PATIENT INFO ===
story.append(Paragraph("<b>PATIENT INFORMATION</b>", heading))
patient_info = [
["Name", "N/A"],
["Age", f"{res.get('meta_pred', {}).get('age', {}).get('label', 'N/A')}"],
["Sex", f"{res.get('meta_pred', {}).get('sex', {}).get('label', 'N/A')}"],
["Region", res.get("region", "Unknown").title()],
["AI Confidence", f"{res.get('region_conf', 0):.2f}"],
]
table = Table(patient_info, colWidths=[120, 360])
table.setStyle(TableStyle([
("BOX", (0,0), (-1,-1), 0.5, colors.grey),
("INNERGRID", (0,0), (-1,-1), 0.25, colors.lightgrey),
("BACKGROUND", (0,0), (0,-1), colors.HexColor("#EAF1FB")),
]))
story.append(table)
story.append(Spacer(1, 10))
# === MAIN BODY ===
story.append(Paragraph("<b>CLINICAL REPORT SUMMARY</b>", heading))
story.append(Spacer(1, 4))
story.append(Paragraph(clean_report.replace("\n", "<br/>"), normal))
story.append(Spacer(1, 15))
# === IMAGES ===
story.append(Paragraph("<b>REFERENCE IMAGES</b>", heading))
story.append(Spacer(1, 4))
img_buf = io.BytesIO()
Image.fromarray(gray_u8).save(img_buf, format="PNG")
img_buf.seek(0)
story.append(RLImage(img_buf, width=180, height=180))
story.append(Spacer(1, 6))
overlay_buf = io.BytesIO()
Image.fromarray(overlay_img).save(overlay_buf, format="PNG")
overlay_buf.seek(0)
story.append(RLImage(overlay_buf, width=180, height=180))
story.append(Spacer(1, 20))
# === FOOTER ===
story.append(Table([[ '', '' ]], colWidths=[480],
style=[('LINEBELOW',(0,0),(-1,-1),1,colors.HexColor("#004C99"))]))
story.append(Spacer(1, 6))
story.append(Paragraph("<b>*** End of Report ***</b>", footer))
story.append(Spacer(1, 3))
story.append(Paragraph("Generated on: " + datetime.datetime.now().strftime("%d %b %Y, %I:%M %p"), footer))
story.append(Paragraph("This report is AI-generated and intended for research/clinical support only.", footer))
story.append(Paragraph("<font color='#004C99'>Β© 2025 Drlogy Imaging Center | 24x7 Services</font>", footer))
doc.build(story)
buffer.seek(0)
return Response(
content=buffer.read(),
media_type="application/pdf",
headers={"Content-Disposition": "attachment; filename=ai_xray_report.pdf"}
)
except Exception as e:
import traceback; traceback.print_exc()
return JSONResponse({"error": str(e)}, status_code=400)
# ===============================================
# # βœ… Runway Gen-2 β€” Convert X-ray to 3D Video
# # ===============================================
# import requests
# import io
# import base64
# import time
# from fastapi import FastAPI, UploadFile, File
# from fastapi.responses import JSONResponse
# from PIL import Image
# app = FastAPI()
# PIXVERSE_API_KEY = "sk-4720d9d2ae7bd362bcece9906ffc6848"
# @api.post("/xray_to_3dvideo")
# async def xray_to_3dvideo(file: UploadFile = File(...)):
# try:
# raw = await file.read()
# img = Image.open(io.BytesIO(raw)).convert("RGB")
# buf = io.BytesIO()
# img.save(buf, format="PNG")
# base64_img = base64.b64encode(buf.getvalue()).decode()
# # βœ… FIXED PAYLOAD & HEADER
# payload = {
# "image": base64_img, # was 'image_base64'
# "prompt": "A grayscale rotating 3D visualization of a medical X-ray scan, showing bones with realistic depth and cinematic lighting.",
# "duration": 4,
# "motion": "rotate", # was 'motion_mode'
# "quality": "540p"
# }
# headers = {
# "Authorization": f"Bearer {PIXVERSE_API_KEY}", # was 'x-api-key'
# "Content-Type": "application/json"
# }
# # 1️⃣ Start PixVerse job
# start_response = requests.post(
# "https://api.segmind.com/v1/pixverse-image2video",
# json=payload,
# headers=headers
# )
# start_data = start_response.json()
# print("PixVerse start_data:", start_data)
# # 2️⃣ Extract task ID or instant video URL
# task_id = start_data.get("task_id") or start_data.get("id")
# video_url = (
# start_data.get("video_url")
# or start_data.get("video")
# or start_data.get("output", {}).get("video_url")
# )
# if not task_id and not video_url:
# return JSONResponse({
# "error": start_data,
# "message": "❌ PixVerse did not return a task ID or video URL (check API key or payload)"
# }, status_code=400)
# # If instant video available
# if video_url:
# return JSONResponse({
# "ok": True,
# "video_url": video_url,
# "message": "βœ… 3D X-ray video generated instantly via PixVerse"
# })
# # 3️⃣ Poll task status
# poll_url = f"https://api.segmind.com/v1/tasks/{task_id}"
# for attempt in range(30): # 2.5 minutes max
# time.sleep(5)
# status_response = requests.get(poll_url, headers=headers)
# status_data = status_response.json()
# status = status_data.get("status", "").lower()
# if status == "succeeded":
# output = status_data.get("output", {})
# video_url = (
# output.get("video_url")
# or output.get("video")
# or output.get("assets", {}).get("video")
# )
# if video_url:
# return JSONResponse({
# "ok": True,
# "video_url": video_url,
# "message": "βœ… 3D X-ray video generated successfully via PixVerse"
# })
# else:
# return JSONResponse({
# "error": "Video URL missing in succeeded task",
# "data": status_data
# }, status_code=500)
# if status == "failed":
# return JSONResponse({
# "error": "❌ Video generation task failed",
# "data": status_data
# }, status_code=500)
# return JSONResponse({
# "error": "⏳ Video generation timed out",
# "task_status": status,
# "task_id": task_id
# }, status_code=202)
# except Exception as e:
# print("3D video backend error:", str(e))
# return JSONResponse({"error": str(e)}, status_code=500)
# ===============================================
# βœ… Fully Fixed Version β€” 3D Local Video Generator (with Overlay)
# ===============================================
from fastapi import UploadFile, File
from fastapi.responses import JSONResponse
import io, os, torch, numpy as np, imageio, cv2
#from torchvision.models.video import r3d_18
from PIL import Image
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
HF_SPACE_URL = os.getenv("HF_SPACE_URL", "https://abbhy123ghh-x-ray-analysis.hf.space").rstrip("/")
@api.post("/generate_3d_local")
async def generate_3d_local(file: UploadFile = File(...)):
"""
Generate a pseudo-3D rotating video from a 2D X-ray image.
Adds AI overlay and realistic depth-based surface projection.
"""
try:
# 1️⃣ Read and preprocess input image
raw = await file.read()
img = Image.open(io.BytesIO(raw)).convert("L").resize((224, 224))
gray_u8 = np.array(img, dtype=np.uint8)
img_np = np.array(img, dtype=np.float32) / 255.0
# 2️⃣ Try AI overlay from your model
try:
res = analyze_fullbody(gray_u8)
overlay_img, _ = findings_overlay(gray_u8, res)
if overlay_img is not None:
base_img = overlay_img
else:
base_img = np.stack([gray_u8] * 3, axis=-1)
except Exception as overlay_error:
print("⚠️ Overlay generation failed:", overlay_error)
base_img = np.stack([gray_u8] * 3, axis=-1)
blended = cv2.cvtColor(base_img, cv2.COLOR_BGR2RGB)
# 3️⃣ Create pseudo depth volume
depth_slices = [img_np * (i / 10) for i in range(10)]
volume = np.stack(depth_slices, axis=0) # [D,H,W]
volume = torch.tensor(volume).unsqueeze(0).unsqueeze(0)
if volume.shape[1] == 1:
volume = volume.repeat(1, 3, 1, 1, 1)
print("βœ… Final volume shape:", volume.shape)
# # 4️⃣ Dummy forward pass for consistency (no inference)
# model = r3d_18(weights=None)
# model.eval()
# with torch.no_grad():
# _ = model(volume)
# 5️⃣ Generate 3D surface rotation frames
frames = []
H, W = gray_u8.shape
x, y = np.meshgrid(np.linspace(0, 1, W), np.linspace(0, 1, H))
depth = cv2.GaussianBlur(gray_u8.astype(np.float32), (9, 9), 0)
for angle in range(0, 360, 5):
fig = plt.figure(figsize=(3, 3))
ax = fig.add_subplot(111, projection="3d")
ax.view_init(30, angle)
ax.plot_surface(
x, y, depth / 255.0,
facecolors=blended / 255.0,
rstride=2, cstride=2,
linewidth=0, antialiased=False
)
ax.axis("off")
fig.canvas.draw()
frame = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (3,))
frames.append(frame)
plt.close(fig)
# 6️⃣ Save output video (MP4 preferred, fallback to GIF)
os.makedirs("static", exist_ok=True)
out_path = "static/xray_3d_local.mp4"
try:
import imageio_ffmpeg
writer = imageio.get_writer(
out_path, format="ffmpeg", mode="I",
fps=12, codec="libx264", quality=8
)
for frame in frames:
writer.append_data(frame)
writer.close()
video_url = f"{HF_SPACE_URL}/static/xray_3d_local.mp4"
msg = "βœ… 3D MP4 video (with overlay) generated successfully"
except Exception as e:
print("⚠️ FFmpeg failed, saving GIF instead:", e)
gif_path = out_path.replace(".mp4", ".gif")
imageio.mimsave(gif_path, frames, duration=0.08)
video_url = f"{HF_SPACE_URL}/static/xray_3d_local.gif"
msg = "βœ… 3D GIF (with overlay) generated successfully"
return JSONResponse({
"ok": True,
"message": msg,
"video_url": video_url
})
except Exception as e:
import traceback
traceback.print_exc()
return JSONResponse({"error": str(e)}, status_code=500)
# ============================================================
# βœ… MONAI + PyVista 3D Full-Body Visualization (Final HF Safe)
# ============================================================
from fastapi import UploadFile, File
from fastapi.responses import JSONResponse
import io, os, numpy as np, torch
from PIL import Image
from monai.networks.nets import UNet
import torch.nn.functional as F
import pyvista as pv
from pyvista import start_xvfb
import imageio.v3 as iio
HF_SPACE_URL = os.getenv("HF_SPACE_URL", "https://abbhy123ghh-x-ray-analysis.hf.space").rstrip("/")
@api.post("/generate_3d_monai")
async def generate_3d_monai(file: UploadFile = File(...)):
"""
Generate a realistic 3D volumetric visualization for any X-ray region
(skull, chest, limb, etc.) using MONAI for enhancement and PyVista for
volumetric rendering. Fully CPU-compatible for Hugging Face Spaces.
"""
try:
# --- Step 1: Read and preprocess image ---
raw = await file.read()
img = Image.open(io.BytesIO(raw)).convert("L").resize((256, 256))
gray = np.array(img, dtype=np.float32) / 255.0
# --- Step 2: Build synthetic 3D volume (simulate multiple slices) ---
depth_slices = 16
volume = np.stack([gray * (1 - i / depth_slices) for i in range(depth_slices)], axis=0)
volume = np.expand_dims(volume, axis=0) # [1, D, H, W]
# --- Step 3: MONAI lightweight 3D UNet (shallow for CPU) ---
x = torch.tensor(volume, dtype=torch.float32).unsqueeze(0) # [1, 1, D, H, W]
pad_d = (16 - x.shape[2] % 16) % 16
pad_h = (16 - x.shape[3] % 16) % 16
pad_w = (16 - x.shape[4] % 16) % 16
x = F.pad(x, (0, pad_w, 0, pad_h, 0, pad_d))
model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=1,
channels=(8, 16, 32),
strides=(2, 2),
)
model.eval()
with torch.no_grad():
y = torch.sigmoid(model(x))
seg = y[0, 0].cpu().numpy()
seg = (seg - seg.min()) / (seg.max() - seg.min() + 1e-8)
seg = np.clip(seg * 255, 0, 255).astype(np.uint8)
# --- Step 4: Create PyVista grid (Hugging Face compatible) ---
start_xvfb()
grid = pv.ImageData(dimensions=seg.shape)
grid.spacing = (1, 1, 1)
grid.origin = (0, 0, 0)
grid.point_data["values"] = seg.flatten(order="F")
# --- Step 5: Enhanced 3D visualization with lighting ---
plotter = pv.Plotter(off_screen=True, window_size=[512, 512])
plotter.add_volume(
grid,
cmap="bone",
opacity="sigmoid_5",
shade=True,
diffuse=1.0,
specular=0.3,
specular_power=15,
)
plotter.set_background("black")
# Add realistic light
light = pv.Light(
position=(seg.shape[2]*2, seg.shape[1]*2, seg.shape[0]),
color='white',
intensity=1.5,
)
plotter.add_light(light)
frames = []
for angle in range(0, 360, 5):
plotter.camera_position = [
(seg.shape[2]*2, 0, seg.shape[0]/2),
(seg.shape[2]/2, seg.shape[1]/2, seg.shape[0]/2),
(0, 0, 1),
]
plotter.camera.azimuth = angle
img_frame = plotter.screenshot(return_img=True)
frames.append(img_frame)
plotter.close()
# --- Step 6: Save rotating MP4 video ---
os.makedirs("static", exist_ok=True)
out_path = "static/xray_monai_3d_fullbody.mp4"
iio.imwrite(out_path, frames, fps=12, codec="libx264")
return JSONResponse({
"ok": True,
"message": "βœ… MONAI + PyVista 3D full-body visualization generated successfully",
"video_url": f"{HF_SPACE_URL}/{out_path}"
})
except Exception as e:
import traceback; traceback.print_exc()
return JSONResponse({"error": str(e)}, status_code=500)
# -----------------------------
# Config
# -----------------------------
CLIP_CANDIDATES = [
"microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224",
"openai/clip-vit-base-patch32",
]
CLIP_TEMP = float(os.getenv("CLIP_TEMP", "1.0"))
CHEST_TEMP = float(os.getenv("CHEST_TEMP", "1.2"))
CAL_TAU = float(os.getenv("CAL_TAU", "1.0")) # <-- ADDED: site calibration temperature
TB_REPO = os.getenv("TB_REPO", "").strip()
TB_LABEL = os.getenv("TB_LABEL", "Tuberculosis")
HF_TOKEN = os.getenv("HF_TOKEN", None)
DEFAULT_THRESHOLDS: Dict[str, float] = {
# chest
"Pneumonia": 0.70,
"Pulmonary Fibrosis": 0.75,
"COPD (proxy)": 0.75,
"Lung Cancer (proxy)": 0.70,
"Tuberculosis": 0.70,
# non-chest
"Compression fracture": 0.65,
"Scoliosis (proxy)": 0.65,
"Hip fracture": 0.65,
"OA (proxy)": 0.65,
"Fracture (upper_ext)": 0.65,
"Fracture (lower_ext)": 0.65,
"Implant/Hardware": 0.65,
"Skull fracture": 0.65,
}
ABSTAIN_MARGIN = float(os.getenv("ABSTAIN_MARGIN", "0.08"))
REGION_CONF_MIN = float(os.getenv("REGION_CONF_MIN", "0.40"))
REGION_NAMES = ["skull","chest","spine","pelvis","upper_ext","lower_ext"]
REGION_PROMPTS = {
"skull": ["X-ray of skull", "Head/skull radiograph"],
"chest": ["Chest X-ray", "Thorax radiograph"],
"spine": ["X-ray of spine", "Spine radiograph"],
"pelvis": ["Pelvis or hip X-ray", "Pelvic radiograph"],
"upper_ext": ["Upper extremity X-ray (shoulder to hand)", "Arm/wrist/hand X-ray"],
"lower_ext": ["Lower extremity X-ray (hip to foot)", "Leg/ankle/foot X-ray"],
}
REGION_DISPLAY = {
"chest":"chest",
"skull":"skull",
"spine":"spine",
"pelvis":"pelvis/hip",
"upper_ext":"upper extremity",
"lower_ext":"lower extremity",
"unknown":"unknown",
}
# Tasks (non-chest)
TASKS: Dict[str, List[Dict[str, Any]]] = {
"spine": [
{"name": "Compression fracture", "zs": ("X-ray of spine with compression fracture", "X-ray of spine without fracture")},
{"name": "Scoliosis (proxy)", "zs": ("X-ray of spine with scoliosis", "X-ray of a straight spine")},
],
"pelvis": [
{"name": "Hip fracture", "zs": ("Pelvis or hip X-ray with hip fracture", "Pelvis or hip X-ray without fracture")},
{"name": "OA (proxy)", "zs": ("Pelvis or hip X-ray with osteoarthritis", "Pelvis or hip X-ray with normal joint space")},
],
"upper_ext": [
{"name": "Fracture (upper_ext)", "zs": ("Upper extremity X-ray with fracture", "Upper extremity X-ray without fracture")},
{"name": "Implant/Hardware", "zs": ("X-ray with orthopedic implant or hardware", "X-ray without orthopedic implant")},
],
"lower_ext": [
{"name": "Fracture (lower_ext)", "zs": ("Lower extremity X-ray with fracture", "Lower extremity X-ray without fracture")},
{"name": "Implant/Hardware", "zs": ("X-ray with orthopedic implant or hardware", "X-ray without orthopedic implant")},
],
"skull": [
{"name": "Skull fracture", "zs": ("Skull X-ray with fracture", "Skull X-ray without fracture")},
],
}
# Sub-label descriptions for every screening item
CHEST_DESCRIPTIONS = {
"Pneumonia": "Patchy/lobar opacities suggesting infection/inflammation; confirm clinically Β± labs.",
"Pulmonary Fibrosis": "Chronic interstitial changes; consider HRCT for characterization.",
"COPD (proxy)": "Proxy via emphysema/hyperinflation; COPD diagnosis needs spirometry/history.",
"Lung Cancer (proxy)": "Proxy via mass/nodule; suspicious lesions need CT Β± biopsy.",
"Tuberculosis": "CXR overlaps with other diseases; definitive Dx needs microbiology/NAAT.",
"Asthma": "Typically not detectable on CXR; diagnosis is clinical with spirometry.",
}
TASK_DESCRIPTIONS = {
"Compression fracture": "Height loss of vertebral body; triage to CT/MRI if clinical concern.",
"Scoliosis (proxy)": "Curve or rotation; clinical significance depends on Cobb angle (not provided).",
"Hip fracture": "Disruption of trabeculae/cortical lines; confirm with dedicated views/CT.",
"OA (proxy)": "Joint space narrowing, osteophytes, subchondral changes; clinical correlation.",
"Fracture (upper_ext)": "Cortical break or lucent line; immobilize and obtain dedicated views.",
"Fracture (lower_ext)": "Cortical break or lucent line; weight-bearing precautions and follow-up.",
"Implant/Hardware": "Presence of orthopedic material; assess alignment/loosening if applicable.",
"Skull fracture": "Linear or depressed lucency; correlate with neuro status/CT.",
}
CHEST_TARGETS = {
"Pneumonia": ["Pneumonia"],
"Pulmonary Fibrosis": ["Fibrosis"],
"COPD (proxy)": ["Emphysema"],
"Lung Cancer (proxy)": ["Mass", "Nodule"],
"Tuberculosis": ["Tuberculosis", "TB"],
"Asthma": [],
}
# -----------------------------
# Utilities
# -----------------------------
def to_uint8(img):
if isinstance(img, Image.Image):
img = np.array(img)
if img.dtype == np.uint8:
return img
img = img.astype(np.float32)
img -= img.min(); rng = img.max() - img.min()
if rng > 1e-6:
img = img / (img.max() + 1e-6)
return (img * 255.0).clip(0, 255).astype(np.uint8)
def read_any_xray(path_or_buffer) -> np.ndarray:
"""
Reads a DICOM (.dcm) or image (PNG/JPG) and returns an 8-bit grayscale (uint8) array.
DICOM handling:
- Applies RescaleSlope/Intercept if present.
- Applies Window Center/Width (first value if multiple).
- Correctly inverts MONOCHROME1 (after VOI windowing).
"""
name = getattr(path_or_buffer, "name", str(path_or_buffer)).lower()
# Non-DICOM: use PIL and convert to 8-bit grayscale
if not name.endswith(".dcm"):
return np.array(Image.open(path_or_buffer).convert("L"))
if not HAVE_DICOM:
raise RuntimeError("Install pydicom to read DICOM files.")
ds = pydicom.dcmread(path_or_buffer)
# Raw pixel array -> float
arr = ds.pixel_array.astype(np.float32)
# 1) Modality LUT (Rescale Slope/Intercept)
slope = float(getattr(ds, "RescaleSlope", 1.0))
intercept = float(getattr(ds, "RescaleIntercept", 0.0))
if slope != 1.0 or intercept != 0.0:
arr = arr * slope + intercept
# Helper: coerce WindowCenter/Width to float (first value if multi)
def _first_float(x):
if x is None:
return None
try:
# pydicom may give MultiValue/Sequence/DSfloat
return float(np.atleast_1d(x)[0])
except Exception:
try:
return float(x)
except Exception:
return None
# 2) VOI LUT: Windowing (if available)
wc = _first_float(getattr(ds, "WindowCenter", None))
ww = _first_float(getattr(ds, "WindowWidth", None))
if ww is not None and ww > 0:
lo = wc - ww / 2.0
hi = wc + ww / 2.0
arr = np.clip(arr, lo, hi)
arr = (arr - lo) / (hi - lo + 1e-6) # -> 0..1
else:
# Fallback min-max
mn = float(np.min(arr))
mx = float(np.max(arr))
arr = (arr - mn) / (mx - mn + 1e-6)
# 3) Photometric interpretation (invert MONOCHROME1)
phot = str(getattr(ds, "PhotometricInterpretation", "")).upper()
if phot == "MONOCHROME1":
arr = 1.0 - arr
# Return uint8 image
return (arr * 255.0).clip(0, 255).astype(np.uint8)
def edges_overlay(gray_u8: np.ndarray) -> np.ndarray:
try:
edges = cv2.Canny(gray_u8, 50, 150)
overlay = np.stack([gray_u8]*3, axis=-1)
overlay[edges>0] = [255, 80, 80]
return overlay
except Exception:
return np.stack([gray_u8]*3, axis=-1)
def edges_overlay(gray_u8: np.ndarray) -> np.ndarray:
try:
edges = cv2.Canny(gray_u8, 50, 150)
overlay = np.stack([gray_u8]*3, axis=-1)
overlay[edges>0] = [255, 80, 80]
return overlay
except Exception:
return np.stack([gray_u8]*3, axis=-1)
# βœ… Add your pseudo-3D function here
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # noqa: F401
def pseudo3d_surface(gray_u8: np.ndarray) -> np.ndarray:
"""Generate a pseudo-3D surface plot from a 2D X-ray."""
try:
H, W = gray_u8.shape
X, Y = np.meshgrid(np.arange(W), np.arange(H))
Z = gray_u8.astype(np.float32)
fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, Y, Z, cmap="bone", linewidth=0, antialiased=True)
ax.set_axis_off()
fig.tight_layout(pad=0)
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
plt.close(fig)
buf.seek(0)
img = Image.open(buf).convert("RGB")
return np.array(img)
except Exception as e:
print("Pseudo-3D generation failed:", e)
return np.stack([gray_u8]*3, axis=-1)
def xray_to_depth_3d(gray_u8: np.ndarray) -> np.ndarray:
"""Convert 2D X-ray into pseudo-3D surface using Intel/dpt-hybrid-midas."""
if not HAVE_DEPTH:
return np.stack([gray_u8]*3, axis=-1)
try:
pil = Image.fromarray(gray_u8).convert("RGB")
inputs = DEPTH_EXTRACTOR(images=pil, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = DEPTH_MODEL(**inputs)
predicted_depth = outputs.predicted_depth.squeeze().cpu().numpy()
# Normalize depth
depth_min, depth_max = predicted_depth.min(), predicted_depth.max()
depth_norm = (predicted_depth - depth_min) / (depth_max - depth_min + 1e-6)
# Render pseudo-3D
H, W = depth_norm.shape
X, Y = np.meshgrid(np.arange(W), np.arange(H))
Z = depth_norm * 255
fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot(111, projection="3d")
ax.plot_surface(X, Y, Z, cmap="bone", linewidth=0, antialiased=True)
ax.set_axis_off()
fig.tight_layout(pad=0)
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
plt.close(fig)
buf.seek(0)
img = Image.open(buf).convert("RGB")
return np.array(img)
except Exception as e:
print("3D depth rendering failed:", e)
return np.stack([gray_u8]*3, axis=-1)
def risk_level(prob: Optional[float]) -> str:
if prob is None: return "N/A"
if prob >= 0.70: return "High"
if prob >= 0.50: return "Moderate"
return "Low"
def risk_tag(prob: Optional[float]) -> str:
level = risk_level(prob)
color = {"High":"#E03131","Moderate":"#F59F00","Low":"#2F9E44","N/A":"#868E96"}[level]
pct = "N/A" if prob is None else f"{prob:.0%}"
return f"<span style='background:{color};color:#fff;border-radius:8px;padding:2px 8px;font-size:0.85rem'>{level} {pct}</span>"
def percent_bar(prob: Optional[float]) -> str:
if prob is None: return ""
width = int(round(prob*100))
return (
"<div style=\"width:100%;height:8px;background:#2b2b2b;border-radius:6px;overflow:hidden;margin:4px 0 8px\">"
f"<div style=\"width:{width}%;height:8px;background:#4c6ef5\"></div>"
"</div>"
)
# -----------------------------
# Extra detail generator for ALL body regions
# -----------------------------
def generate_extra_details(res: dict) -> str:
region = res.get("region", "unknown")
tasks = res.get("tasks", [])
region_texts = {
"skull": (
"🧠 **Image Details (Skull)**\n\n"
"View type: frontal or lateral skull radiograph.\n\n"
"Bones visible: cranial vault, facial bones, mandible.\n\n"
"πŸ”Ž **AI / Computer Vision Context**\n"
"- Region classifier: skull.\n"
"- Tasks: skull fracture detection, implant/hardware.\n"
"- AI approach: CLIP region classification + fracture CNN/ViT.\n\n"
"🚦 **Limitations**\n"
"- Small hairline fractures may be missed.\n"
"- CT is gold standard for head trauma."
),
"chest": (
"🫁 **Image Details (Chest)**\n\n"
"View type: PA/AP/Lateral chest X-ray.\n\n"
"Visible: lungs, ribs, diaphragm, clavicles, cardiac silhouette.\n\n"
"πŸ”Ž **AI / Computer Vision Context**\n"
"- Region classifier: chest.\n"
"- Tasks: Pneumonia, fibrosis, COPD, TB, lung cancer proxy.\n"
"- Uses TorchXRayVision DenseNet + GradCAM overlays.\n\n"
"🚦 **Limitations**\n"
"- Cannot replace CT for detailed lung/mediastinal lesions.\n"
"- Diagnosis requires clinical + lab correlation."
),
"spine": (
"🦴 **Image Details (Spine)**\n\n"
"View type: cervical, thoracic, or lumbar.\n\n"
"Visible: vertebral bodies, intervertebral spaces, alignment.\n\n"
"πŸ”Ž **AI / Computer Vision Context**\n"
"- Region classifier: spine.\n"
"- Tasks: compression fracture, scoliosis.\n"
"- AI approach: zero-shot CLIP + CNNs for fracture lines.\n\n"
"🚦 **Limitations**\n"
"- MRI needed for discs/cord.\n"
"- Subtle wedge fractures may require CT."
),
"pelvis": (
"🦴 **Image Details (Pelvis/Hip)**\n\n"
"View type: AP pelvis/hip.\n\n"
"Visible: iliac bones, acetabulum, femoral heads.\n\n"
"πŸ”Ž **AI / Computer Vision Context**\n"
"- Region classifier: pelvis/hip.\n"
"- Tasks: hip fracture, osteoarthritis.\n"
"- AI approach: CLIP region classification + fracture CNN/transformer.\n\n"
"🚦 **Limitations**\n"
"- Subtle fractures may be missed β†’ CT/MRI often needed."
),
"upper_ext": (
"βœ‹ **Image Details (Upper Extremity)**\n\n"
"View type: arm, shoulder, elbow, wrist, hand.\n\n"
"Visible: humerus, radius, ulna, carpals, metacarpals.\n\n"
"πŸ”Ž **AI / Computer Vision Context**\n"
"- Region classifier: upper extremity.\n"
"- Tasks: fractures, implants/hardware.\n"
"- AI approach: CLIP + fracture CNNs (MURA dataset fine-tuning).\n\n"
"🚦 **Limitations**\n"
"- Soft tissues (ligaments/tendons) invisible.\n"
"- US/MRI needed for rotator cuff, ligament injuries."
),
"lower_ext": (
"🦢 **Image Details (Lower Extremity)**\n\n"
"View type: hip, femur, knee, tibia, ankle, foot.\n\n"
"Visible: phalanges, metatarsals, tarsals, tibia, fibula.\n\n"
"πŸ”Ž **AI / Computer Vision Context**\n"
"- Region classifier: lower extremity.\n"
"- Tasks: fractures, osteoarthritis, implants, deformity alignment.\n"
"- AI approach: CLIP region classification β†’ fracture CNN/transformer (MURA dataset).\n\n"
"🚦 **Limitations**\n"
"- Multiple views required to rule out fractures.\n"
"- Soft tissue injuries invisible on X-ray."
),
"unknown": (
"❓ **Region Unknown**\n\n"
"The model could not confidently classify this X-ray.\n"
"Please check image quality (exposure, blur, cropping)."
),
}
text = region_texts.get(region, region_texts["unknown"])
# Append AI task probabilities if available
if tasks:
text += "\n\n**Screening Results:**\n"
for t in tasks:
name = t.get("name")
p = t.get("prob")
prob_str = "N/A" if p is None else f"{p*100:.1f}%"
text += f"- {name}: {prob_str}\n"
return text
# =============================================
# 🧠 AI Radiology Report Generator (fixed version)
# =============================================
import os
import google.generativeai as genai
# βœ… Load Gemini API key from Hugging Face secret
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
# βœ… Initialize Gemini model (flash = faster / cheaper)
try:
# βœ… New API style (v1)
gemini_model = genai.GenerativeModel("gemini-2.5-flash")
except Exception as e:
print("⚠️ Falling back to Gemini v1beta format due to:", e)
try:
# βœ… Old API style (v1beta)
gemini_model = genai.GenerativeModel("models/gemini-1.5-flash-latest")
except Exception:
# βœ… Fallback to PRO model if FLASH not available
gemini_model = genai.GenerativeModel("gemini-1.5-pro")
def generate_medical_report(findings: dict) -> str:
"""
Generate a professional radiology report from AI findings using Gemini,
and format the output with Markdown and proper newlines.
"""
prompt = f"""
You are a senior radiologist. Based on the following AI findings,
generate a detailed, structured radiology report in clinical style.
Findings:
{json.dumps(findings, indent=2)}
Follow this format:
PATIENT INFORMATION
TECHNICAL DETAILS
RADIOLOGICAL FINDINGS
IMPRESSION
RECOMMENDATIONS
FOLLOW-UP
"""
try:
response = gemini_model.generate_content(prompt)
# βœ… Extract text from Gemini response safely
if hasattr(response, "text"):
report_text = response.text.strip()
elif hasattr(response, "candidates") and response.candidates:
report_text = response.candidates[0].content.parts[0].text.strip()
else:
return "AI medical report could not be generated."
# βœ… Format report into readable HTML (newlines + Markdown)
formatted_report = format_ai_report(report_text)
return formatted_report
except Exception as e:
import traceback
print("Gemini AI report generation failed:", e)
traceback.print_exc()
return "AI medical report could not be generated."
import re
import markdown
def format_ai_report(text: str) -> str:
"""
Formats Gemini AI report text into readable HTML with proper newlines and Markdown styling.
"""
if not text:
return "<p><i>No report generated.</i></p>"
text = re.sub(r'\n{2,}', '\n\n', text.strip()) # normalize spacing
html_text = markdown.markdown(text) # convert Markdown to HTML
html_text = html_text.replace('\n', '<br>') # keep single line breaks
styled_html = f"""
<div style="font-family:'Segoe UI',sans-serif;line-height:1.6;color:#e0e0e0;">
{html_text}
</div>
"""
return styled_html
# ---------- Saliency β†’ boxes & overlay ----------
def _colorize_heatmap(gray_u8: np.ndarray, sal: np.ndarray) -> np.ndarray:
sal = (sal - sal.min()) / (sal.max() - sal.min() + 1e-6)
base = cv2.cvtColor(gray_u8, cv2.COLOR_GRAY2BGR)
heat = cv2.applyColorMap((sal*255).astype(np.uint8), cv2.COLORMAP_JET)
return cv2.addWeighted(base, 0.5, heat, 0.5, 0)
def _topk_boxes_from_saliency(sal: np.ndarray, k: int = 3,
min_area_ratio: float = 0.001,
q: float = 0.92):
"""
Returns list of boxes [{'x':..,'y':..,'w':..,'h':..,'score':..}] from saliency (0..1).
q is the quantile threshold; increase for tighter boxes.
"""
sal = (sal - sal.min()) / (sal.max() - sal.min() + 1e-6)
H, W = sal.shape
thr = float(np.quantile(sal, q))
mask = (sal >= thr).astype(np.uint8) * 255
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3,3), np.uint8), iterations=1)
cnts, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
areas = [(cv2.contourArea(c), c) for c in cnts]
areas.sort(key=lambda x: x[0], reverse=True)
min_area = H*W*min_area_ratio
boxes = []
for area, c in areas:
if area < min_area:
continue
x, y, w, h = cv2.boundingRect(c)
roi = sal[y:y+h, x:x+w]
score = float(roi.mean())
boxes.append({'x': int(x), 'y': int(y), 'w': int(w), 'h': int(h), 'score': score})
if len(boxes) >= k:
break
return boxes
def _overlay_with_boxes(gray_u8: np.ndarray, sal: np.ndarray, label: str = "Suspect") -> Tuple[np.ndarray, list]:
"""Color heatmap + draw yellow rectangles with labels."""
out = _colorize_heatmap(gray_u8, sal)
boxes = _topk_boxes_from_saliency(sal, k=3, min_area_ratio=0.001, q=0.92)
for i, b in enumerate(boxes, start=1):
x, y, w, h = b['x'], b['y'], b['w'], b['h']
cv2.rectangle(out, (x, y), (x+w, y+h), (0, 255, 255), 2)
text = f"{label} #{i}"
cv2.putText(out, text, (x, max(0, y-6)), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0,255,255), 1, cv2.LINE_AA)
return out, boxes
# -----------------------------
# PDF report (FPDF) β€” Unicode-safe + robust PNG embed on Windows
# -----------------------------
def pdf_report_bytes(gray_u8: np.ndarray, res: Dict[str, Any],
include_overlay: bool = True,
findings_png: Optional[np.ndarray] = None) -> bytes:
if not HAVE_PDF:
raise RuntimeError("FPDF not installed. Run: pip install fpdf")
def _safe(s):
if s is None: return ""
s = str(s)
s = (s.replace("β€”", "-").replace("–", "-").replace("\u2011", "-")
.replace("β€’", "*").replace("Β±", "+/-")
.replace("β€œ", '"').replace("”", '"').replace("’", "'").replace("…", "..."))
s = unicodedata.normalize("NFKD", s).encode("ascii", "ignore").decode("ascii")
return s
pdf = FPDF(format="A4", unit="mm")
pdf.set_auto_page_break(auto=True, margin=12)
title = "AI Radiology Report – Research Use Only"
region = res.get("region", "unknown")
conf = res.get("region_conf", 0.0)
meta = res.get("meta_pred", {})
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M")
# ---------------- Page 1 ----------------
pdf.add_page()
pdf.set_font("Helvetica", "B", 16)
pdf.cell(0, 10, _safe(title), ln=1)
pdf.set_font("Helvetica", "I", 10)
pdf.cell(0, 6, _safe(f"Generated on: {timestamp}"), ln=1)
pdf.cell(0, 6, _safe("Disclaimer: This is an AI-generated report, research use only, not for diagnostic purposes."), ln=1)
pdf.ln(4)
# Patient / Meta Info
pdf.set_font("Helvetica", "B", 12)
pdf.cell(0, 8, "Patient / Meta Information", ln=1)
pdf.set_font("Helvetica", "", 11)
pdf.cell(0, 7, _safe(f"Age (approx): {meta.get('age',{}).get('label','N/A')}"), ln=1)
pdf.cell(0, 7, _safe(f"Sex: {meta.get('sex',{}).get('label','N/A')}"), ln=1)
disp_region = REGION_DISPLAY.get(region, region)
pdf.cell(0, 7, _safe(f"Region (auto-detected): {disp_region} | Confidence: {conf:.2f}"), ln=1)
pdf.ln(3)
# Clinical Indication
pdf.set_font("Helvetica", "B", 12)
pdf.cell(0, 8, "Clinical Indication", ln=1)
pdf.set_font("Helvetica", "", 11)
pdf.multi_cell(0, 6, _safe(f"AI-assisted screening of full-body X-ray. Region auto-detected as {disp_region}."))
pdf.ln(3)
# Technique
pdf.set_font("Helvetica", "B", 12)
pdf.cell(0, 8, "Technique", ln=1)
pdf.set_font("Helvetica", "", 11)
pdf.multi_cell(0, 6, _safe("Single-view X-ray (projection estimated by model)."))
pdf.ln(3)
# Findings
pdf.set_font("Helvetica", "B", 12)
pdf.cell(0, 8, "Findings", ln=1)
pdf.set_font("Helvetica", "", 11)
details = generate_extra_details(res)
if details:
for line in details.split("\n"):
pdf.multi_cell(0, 6, _safe(line))
else:
pdf.multi_cell(0, 6, "No detailed findings available.")
pdf.ln(3)
# Disease Screening
pdf.set_font("Helvetica", "B", 12)
pdf.cell(0, 8, "Disease Screening", ln=1)
pdf.set_font("Helvetica", "", 11)
tasks = res.get("tasks", [])
if not tasks:
pdf.multi_cell(0, 6, "No tasks configured for this region or models unavailable.")
else:
for t in tasks:
name = t.get("name", "?")
p = t.get("prob", None)
desc = t.get("desc", "")
risk = risk_level(p)
p_str = "N/A" if p is None else f"{p*100:.1f}%"
pdf.multi_cell(0, 6, _safe(f"{name}: {p_str} | Risk: {risk}"))
if desc:
pdf.set_text_color(120,120,120)
pdf.multi_cell(0, 6, _safe(f" - {desc}"))
pdf.set_text_color(0,0,0)
pdf.ln(1)
pdf.ln(3)
# Impression / Conclusion
pdf.set_font("Helvetica", "B", 12)
pdf.cell(0, 8, "Impression / Conclusion", ln=1)
pdf.set_font("Helvetica", "", 11)
high_risk = [t for t in tasks if risk_level(t.get("prob")) == "High"]
if high_risk:
pdf.multi_cell(0, 6, "Abnormal findings detected – recommend further evaluation / correlation with clinical context.")
else:
pdf.multi_cell(0, 6, "No high-risk abnormalities detected by AI. Findings consistent with low or moderate risk only.")
pdf.ln(4)
# ---------------- Page 2: Original Image ----------------
fd1, p1 = tempfile.mkstemp(suffix=".png"); os.close(fd1)
try:
Image.fromarray(gray_u8).convert("RGB").save(p1, format="PNG")
pdf.add_page()
pdf.set_font("Helvetica", "B", 12)
pdf.cell(0, 8, "Original X-ray Image", ln=1)
pdf.image(p1, x=10, y=None, w=180)
finally:
try: os.remove(p1)
except Exception: pass
# ---------------- Page 3: Overlay ----------------
if include_overlay:
fd2, p2 = tempfile.mkstemp(suffix=".png"); os.close(fd2)
try:
if findings_png is not None:
Image.fromarray(findings_png).convert("RGB").save(p2, format="PNG")
title2 = "AI Heatmap with Suspect Boxes"
else:
over = edges_overlay(gray_u8)
Image.fromarray(over).convert("RGB").save(p2, format="PNG")
title2 = "Edges Overlay"
pdf.add_page()
pdf.set_font("Helvetica", "B", 12)
pdf.cell(0, 8, _safe(title2), ln=1)
pdf.image(p2, x=10, y=None, w=180)
finally:
try: os.remove(p2)
except Exception: pass
return pdf.output(dest="S").encode("latin-1")
# -----------------------------
# Meta: Age / Sex / View (zero-shot via CLIP/BiomedCLIP)
# -----------------------------
def _region_word(region: str) -> str:
return {
"chest": "Chest X-ray",
"skull": "Skull X-ray",
"spine": "Spine X-ray",
"pelvis": "Pelvis/hip X-ray",
"upper_ext": "Upper extremity X-ray",
"lower_ext": "Lower extremity X-ray",
}.get(region, "X-ray image")
def predict_view(pil_rgb: Image.Image, region: str) -> dict:
if region != "chest":
return {"label": "N/A", "probs": {}}
texts = ["PA chest X-ray", "AP chest X-ray", "Lateral chest X-ray"]
p = clip_probs(pil_rgb, texts, temp=CLIP_TEMP)
idx = int(np.argmax(p))
return {"label": ["PA","AP","Lateral"][idx],
"probs": {"PA": float(p[0]), "AP": float(p[1]), "Lateral": float(p[2])}}
def predict_sex(pil_rgb: Image.Image, region: str) -> dict:
base = _region_word(region)
texts = [f"{base} of a male patient", f"{base} of a female patient"]
p = clip_probs(pil_rgb, texts, temp=CLIP_TEMP)
idx = int(np.argmax(p))
return {"label": ["Male","Female"][idx],
"probs": {"Male": float(p[0]), "Female": float(p[1])}}
def predict_age(pil_rgb: Image.Image, region: str) -> dict:
base = _region_word(region)
labels = ["Child (<=14)", "Adolescent (15-19)", "Adult (20-64)", "Elderly (65+)"]
prompts = [f"{base} of a child", f"{base} of an adolescent", f"{base} of an adult", f"{base} of an elderly person"]
anchors = np.array([8.0, 17.0, 40.0, 75.0], dtype=np.float32)
p = clip_probs(pil_rgb, prompts, temp=CLIP_TEMP)
est = float(np.dot(p, anchors))
return {"label": f"{est:.1f} yrs", "years": est, "probs": {labels[i]: float(p[i]) for i in range(len(labels))}}
def predict_meta(pil_rgb: Image.Image, region: str) -> dict:
if CLIP_MDL is None or CLIP_PROC is None:
return {"age": {"label": "N/A", "years": None, "probs": {}},
"sex": {"label": "N/A", "probs": {}},
"view": {"label": "N/A", "probs": {}}}
return {"age": predict_age(pil_rgb, region), "sex": predict_sex(pil_rgb, region), "view": predict_view(pil_rgb, region)}
# -----------------------------
# CLIP/BiomedCLIP zero-shot
# -----------------------------
def load_clip_like():
if not HAVE_TRF: return None, None
for rid in CLIP_CANDIDATES:
try:
proc = AutoProcessor.from_pretrained(rid, token=HF_TOKEN)
mdl = CLIPModel.from_pretrained(rid, token=HF_TOKEN).eval().to(DEVICE)
return mdl, proc
except Exception:
continue
return None, None
CLIP_MDL, CLIP_PROC = load_clip_like()
def clip_probs(pil_rgb: Image.Image, texts: List[str], temp: float = 1.0) -> np.ndarray:
if CLIP_MDL is None or CLIP_PROC is None:
return np.full((len(texts),), 1.0/len(texts), dtype=np.float32)
inputs = CLIP_PROC(text=texts, images=pil_rgb, return_tensors="pt", padding=True).to(DEVICE)
with torch.no_grad():
out = CLIP_MDL(**inputs)
logits = out.logits_per_image.squeeze(0) / max(1e-6, temp)
probs = torch.softmax(logits, dim=0).detach().cpu().numpy()
return probs.astype(np.float32)
def clip_binary_prob(pil_rgb: Image.Image, pos_text: str, neg_text: str, temp: float = 1.0) -> float:
probs = clip_probs(pil_rgb, [pos_text, neg_text], temp=temp)
return float(probs[0])
# -----------------------------
# Chest ensemble (TorchXRayVision)
# -----------------------------
def load_xrv_models():
if not HAVE_XRV:
return None, None
import torch
import torchxrayvision as xrv
# βœ… Fix for PyTorch 2.6+ safe-loading policy
# Allow DenseNet class from torchxrayvision to be unpickled safely
try:
torch.serialization.add_safe_globals([xrv.models.DenseNet])
except Exception as e:
print("Safe global registration failed:", e)
def _mk(weights_key):
try:
m = xrv.models.DenseNet(weights=weights_key)
return m.eval().to(DEVICE)
except Exception as e:
print(f"⚠️ Error loading model {weights_key}:", e)
return None
# Load the CheX model (you can add more if needed)
models = [_mk("densenet121-res224-chex")]
# Combine all pathology labels
union = {}
for m in models:
if m is None:
continue
for lbl in m.pathologies:
k = lbl.lower()
if k not in union:
union[k] = lbl
return models, union
# Initialize global variables
XRV_MODELS, LOWER_TO_PRETTY = load_xrv_models()
def xrv_logits_dict(gray_u8: np.ndarray, mdl) -> dict:
img01 = gray_u8.astype(np.float32) / 255.0
img_hu = img01 * 2048.0 - 1024.0
t = torch.from_numpy(img_hu).unsqueeze(0).unsqueeze(0).to(DEVICE)
t = torch.nn.functional.interpolate(t, size=224, mode="bilinear", align_corners=False)
with torch.no_grad():
logits = mdl(t).squeeze(0).detach().cpu().numpy()
return {lbl.lower(): float(logits[i]) for i, lbl in enumerate(mdl.pathologies)}
def chest_disease_probs(gray_u8: np.ndarray):
if not HAVE_XRV or XRV_MODELS is None:
return {}, []
imgs = [gray_u8, np.ascontiguousarray(np.flip(gray_u8, axis=1))]
sum_logits, count = {}, {}
for im in imgs:
for mdl in XRV_MODELS:
d = xrv_logits_dict(im, mdl)
for lbl, v in d.items():
sum_logits[lbl] = sum_logits.get(lbl, 0.0) + v
count[lbl] = count.get(lbl, 0) + 1
mean_logits = {k: (sum_logits[k]/max(1,count[k]))/max(1e-6, CHEST_TEMP) for k in sum_logits}
label_probs = {k: 1.0/(1.0+np.exp(-v)) for k,v in mean_logits.items()}
requested = {}
for disp, aliases in CHEST_TARGETS.items():
if not aliases:
requested[disp] = None
continue
p=None
for a in aliases:
p = label_probs.get(a.lower())
if p is not None:
break
requested[disp] = p
ranked = sorted(label_probs.items(), key=lambda x: x[1], reverse=True)[:10]
ranked_pretty = [(LOWER_TO_PRETTY.get(lbl, lbl.title()), p) for lbl,p in ranked]
return requested, ranked_pretty
# Optional HF binary plugin (e.g., TB)
def load_hf_binary(repo_id: str, token: Optional[str] = None):
if not (HAVE_TRF and repo_id): return None, None, None
proc = AutoImageProcessor.from_pretrained(repo_id, token=token)
mdl = AutoModelForImageClassification.from_pretrained(repo_id, token=token).eval().to(DEVICE)
id2 = getattr(mdl.config, "id2label", None)
return mdl, proc, id2
def run_hf_binary(mdl, proc, pil_rgb: Image.Image, target_label_name: str) -> Optional[float]:
if mdl is None or proc is None: return None
inputs = proc(images=pil_rgb, return_tensors="pt").to(DEVICE)
with torch.no_grad():
out = mdl(**inputs)
logits = out.logits.squeeze(0)
if logits.numel() == 1:
return float(torch.sigmoid(logits).item())
id2 = getattr(mdl.config, "id2label", None) or {}
name2idx = {v.lower(): int(k) for k,v in id2.items()}
idx = name2idx.get(target_label_name.lower())
if idx is None or idx >= logits.numel(): return None
return float(torch.sigmoid(logits[idx]).item())
# -----------------------------
# Saliency for chest (Grad-CAM) and non-chest (CLIP occlusion)
# -----------------------------
def gradcam_single_xrv(gray_u8: np.ndarray, mdl, class_idx: int, out_size: Tuple[int,int]) -> Optional[np.ndarray]:
"""
Grad-CAM for a single TorchXRayVision DenseNet model & class index.
Returns saliency map in 0..1 resized to out_size.
"""
try:
mdl.eval()
feats = {}
grads = {}
def fwd_hook(module, inp, out):
feats["x"] = out.detach()
def bwd_hook(module, grad_in, grad_out):
grads["x"] = grad_out[0].detach()
h1 = mdl.features.register_forward_hook(fwd_hook)
h2 = mdl.features.register_full_backward_hook(bwd_hook)
img01 = gray_u8.astype(np.float32) / 255.0
img_hu = img01 * 2048.0 - 1024.0
t = torch.from_numpy(img_hu).unsqueeze(0).unsqueeze(0).to(DEVICE)
t = torch.nn.functional.interpolate(t, size=224, mode="bilinear", align_corners=False)
t.requires_grad_(True)
mdl.zero_grad(set_to_none=True)
out = mdl(t) # logits
if class_idx >= out.shape[1]:
h1.remove(); h2.remove(); return None
target = out[0, class_idx]
target.backward(retain_graph=True)
fmap = feats["x"] # (1, C, H, W)
grad = grads["x"] # (1, C, H, W)
weights = grad.mean(dim=(2,3), keepdim=True) # (1, C, 1, 1)
cam = torch.relu((weights * fmap).sum(dim=1)).squeeze(0) # (H, W)
cam = cam.detach().cpu().numpy()
cam = cv2.resize(cam, out_size[::-1], interpolation=cv2.INTER_LINEAR)
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-6)
h1.remove(); h2.remove()
return cam
except Exception:
try:
h1.remove(); h2.remove()
except Exception:
pass
return None
def chest_findings_heatmap(gray_u8: np.ndarray, res: Dict[str,Any]) -> Optional[Tuple[np.ndarray, str]]:
"""Return (saliency_map_0..1, target_label_name) for chest via ensemble Grad-CAM."""
if not HAVE_XRV or XRV_MODELS is None:
return None
chest = res.get("chest", {})
candidates = [(k, v) for k, v in chest.items() if v is not None]
if not candidates:
return None
target_name, _ = max(candidates, key=lambda kv: kv[1])
lbl_lowers = [a.lower() for a in CHEST_TARGETS.get(target_name, [])]
if not lbl_lowers:
return None
h, w = gray_u8.shape
cams = []
for mdl in XRV_MODELS:
idx = None
for a in lbl_lowers:
try:
idx = [s.lower() for s in mdl.pathologies].index(a); break
except ValueError:
continue
if idx is None:
continue
cam = gradcam_single_xrv(gray_u8, mdl, idx, (h, w))
if cam is not None:
cams.append(cam)
if not cams:
return None
sal = np.mean(cams, axis=0)
sal = (sal - sal.min()) / (sal.max() - sal.min() + 1e-6)
return sal, target_name
def best_nonchest_task_and_prompts(res: Dict[str,Any]) -> Optional[Tuple[str, Tuple[str,str]]]:
region = res.get("region", "unknown")
if region not in TASKS:
return None
# map name -> (pos, neg)
name2prompts = {}
for t in TASKS[region]:
if "zs" in t:
name2prompts[t["name"]] = t["zs"]
# pick highest-prob configured task
best_name = None
best_p = -1.0
for t in res.get("tasks", []):
if t.get("region") != region:
continue
nm = t.get("name")
if nm in name2prompts and t.get("prob") is not None:
if float(t["prob"]) > best_p:
best_p = float(t["prob"]); best_name = nm
if best_name is None:
# fallback to first available task config
if TASKS[region]:
cand = TASKS[region][0]
if "zs" in cand:
return cand["name"], cand["zs"]
return None
return best_name, name2prompts[best_name]
def clip_occlusion_saliency(pil_rgb: Image.Image, pos_text: str, neg_text: str, n_segments: int = 140) -> Optional[np.ndarray]:
"""Return saliency map (0..1) via SLIC-occlusion for CLIP zero-shot."""
if not (HAVE_SKIMG and CLIP_MDL is not None):
return None
rgb = np.array(pil_rgb.convert("RGB"))
base_p = clip_binary_prob(pil_rgb, pos_text, neg_text, temp=CLIP_TEMP)
h, w, _ = rgb.shape
seg = slic(img_as_float(rgb), n_segments=n_segments, compactness=10, start_label=0)
blurred = cv2.GaussianBlur(rgb, (21,21), 0)
contrib = np.zeros((h, w), dtype=np.float32)
for lab in np.unique(seg):
mask = (seg == lab)
occluded = rgb.copy()
occluded[mask] = blurred[mask]
p = clip_binary_prob(Image.fromarray(occluded), pos_text, neg_text, temp=CLIP_TEMP)
delta = max(0.0, base_p - p) # drop in prob -> importance
contrib[mask] = delta
sal = (contrib - contrib.min()) / (contrib.max() - contrib.min() + 1e-6)
return sal
def findings_overlay(gray_u8: np.ndarray, res: Dict[str,Any]):
"""
Returns (overlay_bgr, meta_dict)
meta_dict = {"label": <task name>, "boxes": [ {x,y,w,h,score}, ... ]}
"""
region = res.get("region", "unknown")
pil = Image.fromarray(gray_u8).convert("RGB")
if region == "chest":
out = chest_findings_heatmap(gray_u8, res)
if out is None:
return None, None
sal, label = out
overlay, boxes = _overlay_with_boxes(gray_u8, sal, label=label)
return overlay, {"label": label, "boxes": boxes}
else:
pair = best_nonchest_task_and_prompts(res)
if pair is None:
return None, None
name, (pos, neg) = pair
sal = clip_occlusion_saliency(pil, pos, neg)
if sal is None:
return None, None
overlay, boxes = _overlay_with_boxes(gray_u8, sal, label=name)
return overlay, {"label": name, "boxes": boxes}
# -----------------------------
# Gating helpers
# -----------------------------
def abstain_flag(p: Optional[float], margin: float = ABSTAIN_MARGIN) -> bool:
if p is None: return True
return abs(p - 0.5) < margin
def triage_flag(p: Optional[float], threshold: float) -> bool:
if p is None: return False
return p >= threshold
# ---- ADDED: Quality metrics & gates, calibration, audit log ----
def quality_metrics(gray_u8):
H, W = gray_u8.shape[:2]
try:
blur_var = float(cv2.Laplacian(gray_u8, cv2.CV_64F).var())
except Exception:
blur_var = float('nan')
under = float((gray_u8 < 10).mean())
over = float((gray_u8 > 245).mean())
return {"min_side": min(H, W), "blur_var": blur_var, "under_pct": under, "over_pct": over}
def gate_decision(metrics, region_conf,
min_side=256, blur_var_min=80.0,
under_max=0.25, over_max=0.25,
region_conf_min=REGION_CONF_MIN):
ok, reasons = True, []
if metrics["min_side"] < min_side: ok=False; reasons.append("image too small")
if not np.isnan(metrics["blur_var"]) and metrics["blur_var"] < blur_var_min: ok=False; reasons.append("too blurry")
if metrics["under_pct"] > under_max: ok=False; reasons.append("under-exposed")
if metrics["over_pct"] > over_max: ok=False; reasons.append("over-exposed")
if region_conf < region_conf_min: ok=False; reasons.append("low region confidence (OOD)")
return ok, reasons
def calibrate_prob(p, tau=CAL_TAU):
if p is None: return None
p = float(np.clip(p, 1e-6, 1-1e-6))
logit = np.log(p/(1-p))
return 1.0/(1.0 + np.exp(-logit/max(1e-6, tau)))
LOG_PATH = pathlib.Path("logs"); LOG_PATH.mkdir(exist_ok=True)
def audit_log(event, payload, raw_bytes=None):
rec = {
"ts": datetime.datetime.utcnow().isoformat()+"Z",
"event": event,
"payload": payload,
"input_sha256": hashlib.sha256(raw_bytes).hexdigest() if raw_bytes else None
}
with open(LOG_PATH / f"audit-{datetime.date.today().isoformat()}.jsonl", "a", encoding="utf-8") as f:
f.write(json.dumps(rec) + "\n")
# -----------------------------
# Shared inference
# -----------------------------
def analyze_fullbody(gray_u8: np.ndarray) -> Dict[str, Any]:
result: Dict[str, Any] = {
"meta": {"version": 1, "device": DEVICE},
"region": None, "region_conf": None, "region_source": None,
"tasks": [], "chest": {}
}
pil_rgb = Image.fromarray(gray_u8).convert("RGB")
# Region router (CLIP)
texts = [REGION_PROMPTS[r][0] for r in REGION_NAMES]
probs = clip_probs(pil_rgb, texts, temp=CLIP_TEMP)
idx = int(np.argmax(probs))
region = REGION_NAMES[idx]; conf = float(probs[idx])
result.update({"region": region, "region_conf": conf, "region_source": "zero-shot"})
# Meta (Age/Sex/View) using region-aware prompts
result["meta_pred"] = predict_meta(pil_rgb, region)
# Low-confidence region β†’ unknown
if conf < REGION_CONF_MIN:
result["region"] = "unknown"
# Region-specific screening
if region == "chest" and HAVE_XRV and XRV_MODELS is not None:
chest_probs, _ranked = chest_disease_probs(gray_u8)
if TB_REPO and HAVE_TRF:
tb_m, tb_p, _ = load_hf_binary(TB_REPO, HF_TOKEN)
tb_prob = run_hf_binary(tb_m, tb_p, pil_rgb, TB_LABEL)
if tb_prob is not None:
chest_probs["Tuberculosis"] = tb_prob
result["chest"] = chest_probs
for name, p in chest_probs.items():
thr = DEFAULT_THRESHOLDS.get(name, 0.5)
result["tasks"].append({
"region": "chest", "name": name, "prob": p,
"threshold": thr, "abstain": abstain_flag(p),
"triage": triage_flag(p, thr),
"method": "xrv-ensemble" if name != TB_LABEL else "hf-binary",
"desc": CHEST_DESCRIPTIONS.get(name, ""),
})
else:
for t in TASKS.get(region, []):
name = t["name"]; thr = DEFAULT_THRESHOLDS.get(name, 0.6)
prob = None; method = None
if "hf_repo" in t and t.get("hf_repo") and HAVE_TRF:
mdl, proc, _ = load_hf_binary(t["hf_repo"], HF_TOKEN)
label = t.get("label", "positive")
prob = run_hf_binary(mdl, proc, pil_rgb, label)
method = f"hf:{t['hf_repo']}"
elif "zs" in t:
pos, neg = t["zs"]
prob = clip_binary_prob(pil_rgb, pos, neg, temp=CLIP_TEMP)
method = "zero-shot-clip"
result["tasks"].append({
"region": region, "name": name, "prob": prob,
"threshold": thr, "abstain": abstain_flag(prob),
"triage": triage_flag(prob, thr), "method": method,
"desc": TASK_DESCRIPTIONS.get(name, ""),
})
return result