Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torchvision | |
| import numpy as np | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as patches | |
| from torchvision.models.detection.faster_rcnn import FastRCNNPredictor | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| import io | |
| os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' | |
| # Class names and colors | |
| CLASS_NAMES = {1: 'Nipple', 2: 'Lump'} | |
| CLASS_COLORS = {1: 'white', 2: 'white'} | |
| # Global variables for model (load once at startup) | |
| MODEL = None | |
| DEVICE = None | |
| def preprocess_image(image): | |
| """Load and preprocess image for Faster R-CNN.""" | |
| # Convert PIL Image to numpy array | |
| image = np.array(image) | |
| # Already in RGB format from Gradio | |
| image = image.astype(np.float32) / 255.0 # Normalize to [0,1] | |
| # Normalize using ImageNet mean and std | |
| mean = np.array([0.485, 0.456, 0.406]) | |
| std = np.array([0.229, 0.224, 0.225]) | |
| image = (image - mean) / std | |
| return torch.tensor(image.transpose(2, 0, 1), dtype=torch.float32) | |
| def load_model(checkpoint_path, device): | |
| """Load Faster R-CNN model with fine-tuned weights.""" | |
| model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None) | |
| in_features = model.roi_heads.box_predictor.cls_score.in_features | |
| model.roi_heads.box_predictor = FastRCNNPredictor(in_features, len(CLASS_NAMES) + 1) | |
| model.load_state_dict(torch.load(checkpoint_path, map_location=device)) | |
| model.to(device).eval() | |
| return model | |
| def initialize_model(): | |
| """Initialize model at startup by downloading from HF Hub.""" | |
| global MODEL, DEVICE | |
| if MODEL is None: | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Downloading model from IFMedTech/Lumps repository...") | |
| # Download model from private HuggingFace repository | |
| # Token is automatically read from HF_TOKEN environment variable (Spaces secrets) | |
| try: | |
| checkpoint_path = hf_hub_download( | |
| repo_id="IFMedTech/Lumps", | |
| filename="lumps.pth", | |
| repo_type="model", | |
| token=os.environ.get("HF_TOKEN") # Use token from Spaces secrets | |
| ) | |
| print(f"Model downloaded to: {checkpoint_path}") | |
| print(f"Loading model on {DEVICE}...") | |
| MODEL = load_model(checkpoint_path, DEVICE) | |
| print(f"Model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| raise RuntimeError( | |
| f"Failed to load model from IFMedTech/Lumps. " | |
| f"Please ensure HF_TOKEN is set in Spaces secrets with read access to the private repository. " | |
| f"Error: {e}" | |
| ) | |
| return MODEL, DEVICE | |
| def predict(image, score_thresh=0.5): | |
| """Run inference and return image with bounding boxes.""" | |
| # Ensure model is loaded | |
| model, device = initialize_model() | |
| # Preprocess image | |
| image_tensor = preprocess_image(image) | |
| # Run inference | |
| model.eval() | |
| with torch.no_grad(): | |
| preds = model([image_tensor.to(device)])[0] | |
| boxes, labels, scores = preds['boxes'].cpu().numpy(), preds['labels'].cpu().numpy(), preds['scores'].cpu().numpy() | |
| # Filter based on confidence threshold | |
| keep = scores >= score_thresh | |
| boxes, labels, scores = boxes[keep], labels[keep], scores[keep] | |
| # Convert tensor back to image | |
| mean = np.array([0.485, 0.456, 0.406]) | |
| std = np.array([0.229, 0.224, 0.225]) | |
| image_np = image_tensor.cpu().permute(1, 2, 0).numpy() * std + mean | |
| image_np = np.clip(image_np, 0, 1) | |
| # Create figure | |
| fig, ax = plt.subplots(1, figsize=(12, 12)) | |
| ax.imshow(image_np) | |
| # Draw bounding boxes | |
| for box, label, score in zip(boxes, labels, scores): | |
| xmin, ymin, xmax, ymax = box | |
| rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, | |
| linewidth=3, edgecolor=CLASS_COLORS.get(label, 'blue'), | |
| facecolor='none') | |
| ax.add_patch(rect) | |
| ax.text(xmin, ymin - 10, f"{CLASS_NAMES.get(label, f'class_{label}')} ({score:.2f})", | |
| fontsize=14, color='white', backgroundcolor='black', weight='bold') | |
| plt.axis('off') | |
| # Convert plot to image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) | |
| buf.seek(0) | |
| result_image = Image.open(buf) | |
| plt.close() | |
| return result_image | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Breast Image"), | |
| gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.05, label="Confidence Threshold") | |
| ], | |
| outputs=gr.Image(type="pil", label="Detection Results"), | |
| title="Breast Lumps Detection", | |
| description="""Upload a breast image to detect lumps and nipples using a Faster R-CNN model.\n\n | |
| ⚠️ **Important Medical Disclaimer**: This is a screening tool for research and assistive purposes only. | |
| It should NOT be used as the sole basis for medical diagnosis. All detections must be reviewed and confirmed | |
| by qualified medical professionals. This model is not FDA approved or certified for clinical diagnosis.""", | |
| examples=None, | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |