import gradio as gr import torch import torchvision import numpy as np from PIL import Image import matplotlib.pyplot as plt import matplotlib.patches as patches from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from huggingface_hub import hf_hub_download import os import io os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' # Class names and colors CLASS_NAMES = {1: 'Nipple', 2: 'Lump'} CLASS_COLORS = {1: 'white', 2: 'white'} # Global variables for model (load once at startup) MODEL = None DEVICE = None def preprocess_image(image): """Load and preprocess image for Faster R-CNN.""" # Convert PIL Image to numpy array image = np.array(image) # Already in RGB format from Gradio image = image.astype(np.float32) / 255.0 # Normalize to [0,1] # Normalize using ImageNet mean and std mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) image = (image - mean) / std return torch.tensor(image.transpose(2, 0, 1), dtype=torch.float32) def load_model(checkpoint_path, device): """Load Faster R-CNN model with fine-tuned weights.""" model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None) in_features = model.roi_heads.box_predictor.cls_score.in_features model.roi_heads.box_predictor = FastRCNNPredictor(in_features, len(CLASS_NAMES) + 1) model.load_state_dict(torch.load(checkpoint_path, map_location=device)) model.to(device).eval() return model def initialize_model(): """Initialize model at startup by downloading from HF Hub.""" global MODEL, DEVICE if MODEL is None: DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Downloading model from IFMedTech/Lumps repository...") # Download model from private HuggingFace repository # Token is automatically read from HF_TOKEN environment variable (Spaces secrets) try: checkpoint_path = hf_hub_download( repo_id="IFMedTech/Lumps", filename="lumps.pth", repo_type="model", token=os.environ.get("HF_TOKEN") # Use token from Spaces secrets ) print(f"Model downloaded to: {checkpoint_path}") print(f"Loading model on {DEVICE}...") MODEL = load_model(checkpoint_path, DEVICE) print(f"Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") raise RuntimeError( f"Failed to load model from IFMedTech/Lumps. " f"Please ensure HF_TOKEN is set in Spaces secrets with read access to the private repository. " f"Error: {e}" ) return MODEL, DEVICE def predict(image, score_thresh=0.5): """Run inference and return image with bounding boxes.""" # Ensure model is loaded model, device = initialize_model() # Preprocess image image_tensor = preprocess_image(image) # Run inference model.eval() with torch.no_grad(): preds = model([image_tensor.to(device)])[0] boxes, labels, scores = preds['boxes'].cpu().numpy(), preds['labels'].cpu().numpy(), preds['scores'].cpu().numpy() # Filter based on confidence threshold keep = scores >= score_thresh boxes, labels, scores = boxes[keep], labels[keep], scores[keep] # Convert tensor back to image mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) image_np = image_tensor.cpu().permute(1, 2, 0).numpy() * std + mean image_np = np.clip(image_np, 0, 1) # Create figure fig, ax = plt.subplots(1, figsize=(12, 12)) ax.imshow(image_np) # Draw bounding boxes for box, label, score in zip(boxes, labels, scores): xmin, ymin, xmax, ymax = box rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=3, edgecolor=CLASS_COLORS.get(label, 'blue'), facecolor='none') ax.add_patch(rect) ax.text(xmin, ymin - 10, f"{CLASS_NAMES.get(label, f'class_{label}')} ({score:.2f})", fontsize=14, color='white', backgroundcolor='black', weight='bold') plt.axis('off') # Convert plot to image buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) buf.seek(0) result_image = Image.open(buf) plt.close() return result_image # Create Gradio interface demo = gr.Interface( fn=predict, inputs=[ gr.Image(type="pil", label="Upload Breast Image"), gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.05, label="Confidence Threshold") ], outputs=gr.Image(type="pil", label="Detection Results"), title="Breast Lumps Detection", description="""Upload a breast image to detect lumps and nipples using a Faster R-CNN model.\n\n ⚠️ **Important Medical Disclaimer**: This is a screening tool for research and assistive purposes only. It should NOT be used as the sole basis for medical diagnosis. All detections must be reviewed and confirmed by qualified medical professionals. This model is not FDA approved or certified for clinical diagnosis.""", examples=None, allow_flagging="never" ) if __name__ == "__main__": demo.launch()