Lumps-Detection / app.py
IFMedTechdemo's picture
Update app.py
20ca3ee verified
import gradio as gr
import torch
import torchvision
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from huggingface_hub import hf_hub_download
import os
import io
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
# Class names and colors
CLASS_NAMES = {1: 'Nipple', 2: 'Lump'}
CLASS_COLORS = {1: 'white', 2: 'white'}
# Global variables for model (load once at startup)
MODEL = None
DEVICE = None
def preprocess_image(image):
"""Load and preprocess image for Faster R-CNN."""
# Convert PIL Image to numpy array
image = np.array(image)
# Already in RGB format from Gradio
image = image.astype(np.float32) / 255.0 # Normalize to [0,1]
# Normalize using ImageNet mean and std
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image = (image - mean) / std
return torch.tensor(image.transpose(2, 0, 1), dtype=torch.float32)
def load_model(checkpoint_path, device):
"""Load Faster R-CNN model with fine-tuned weights."""
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, len(CLASS_NAMES) + 1)
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.to(device).eval()
return model
def initialize_model():
"""Initialize model at startup by downloading from HF Hub."""
global MODEL, DEVICE
if MODEL is None:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Downloading model from IFMedTech/Lumps repository...")
# Download model from private HuggingFace repository
# Token is automatically read from HF_TOKEN environment variable (Spaces secrets)
try:
checkpoint_path = hf_hub_download(
repo_id="IFMedTech/Lumps",
filename="lumps.pth",
repo_type="model",
token=os.environ.get("HF_TOKEN") # Use token from Spaces secrets
)
print(f"Model downloaded to: {checkpoint_path}")
print(f"Loading model on {DEVICE}...")
MODEL = load_model(checkpoint_path, DEVICE)
print(f"Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
raise RuntimeError(
f"Failed to load model from IFMedTech/Lumps. "
f"Please ensure HF_TOKEN is set in Spaces secrets with read access to the private repository. "
f"Error: {e}"
)
return MODEL, DEVICE
def predict(image, score_thresh=0.5):
"""Run inference and return image with bounding boxes."""
# Ensure model is loaded
model, device = initialize_model()
# Preprocess image
image_tensor = preprocess_image(image)
# Run inference
model.eval()
with torch.no_grad():
preds = model([image_tensor.to(device)])[0]
boxes, labels, scores = preds['boxes'].cpu().numpy(), preds['labels'].cpu().numpy(), preds['scores'].cpu().numpy()
# Filter based on confidence threshold
keep = scores >= score_thresh
boxes, labels, scores = boxes[keep], labels[keep], scores[keep]
# Convert tensor back to image
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image_np = image_tensor.cpu().permute(1, 2, 0).numpy() * std + mean
image_np = np.clip(image_np, 0, 1)
# Create figure
fig, ax = plt.subplots(1, figsize=(12, 12))
ax.imshow(image_np)
# Draw bounding boxes
for box, label, score in zip(boxes, labels, scores):
xmin, ymin, xmax, ymax = box
rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
linewidth=3, edgecolor=CLASS_COLORS.get(label, 'blue'),
facecolor='none')
ax.add_patch(rect)
ax.text(xmin, ymin - 10, f"{CLASS_NAMES.get(label, f'class_{label}')} ({score:.2f})",
fontsize=14, color='white', backgroundcolor='black', weight='bold')
plt.axis('off')
# Convert plot to image
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
buf.seek(0)
result_image = Image.open(buf)
plt.close()
return result_image
# Create Gradio interface
demo = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Upload Breast Image"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.05, label="Confidence Threshold")
],
outputs=gr.Image(type="pil", label="Detection Results"),
title="Breast Lumps Detection",
description="""Upload a breast image to detect lumps and nipples using a Faster R-CNN model.\n\n
⚠️ **Important Medical Disclaimer**: This is a screening tool for research and assistive purposes only.
It should NOT be used as the sole basis for medical diagnosis. All detections must be reviewed and confirmed
by qualified medical professionals. This model is not FDA approved or certified for clinical diagnosis.""",
examples=None,
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch()