File size: 5,420 Bytes
552875d
 
 
 
 
 
 
 
20ca3ee
552875d
 
 
 
 
 
 
 
 
20ca3ee
 
 
 
552875d
 
 
 
 
20ca3ee
552875d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20ca3ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552875d
 
20ca3ee
 
552875d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import gradio as gr
import torch
import torchvision
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from huggingface_hub import hf_hub_download
import os
import io

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

# Class names and colors
CLASS_NAMES = {1: 'Nipple', 2: 'Lump'}
CLASS_COLORS = {1: 'white', 2: 'white'}

# Global variables for model (load once at startup)
MODEL = None
DEVICE = None

def preprocess_image(image):
    """Load and preprocess image for Faster R-CNN."""
    # Convert PIL Image to numpy array
    image = np.array(image)
    
    # Already in RGB format from Gradio
    image = image.astype(np.float32) / 255.0  # Normalize to [0,1]
    
    # Normalize using ImageNet mean and std
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    image = (image - mean) / std
    
    return torch.tensor(image.transpose(2, 0, 1), dtype=torch.float32)

def load_model(checkpoint_path, device):
    """Load Faster R-CNN model with fine-tuned weights."""
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, len(CLASS_NAMES) + 1)
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model.to(device).eval()
    return model

def initialize_model():
    """Initialize model at startup by downloading from HF Hub."""
    global MODEL, DEVICE
    
    if MODEL is None:
        DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        print(f"Downloading model from IFMedTech/Lumps repository...")
        # Download model from private HuggingFace repository
        # Token is automatically read from HF_TOKEN environment variable (Spaces secrets)
        try:
            checkpoint_path = hf_hub_download(
                repo_id="IFMedTech/Lumps",
                filename="lumps.pth",
                repo_type="model",
                token=os.environ.get("HF_TOKEN")  # Use token from Spaces secrets
            )
            print(f"Model downloaded to: {checkpoint_path}")
            print(f"Loading model on {DEVICE}...")
            MODEL = load_model(checkpoint_path, DEVICE)
            print(f"Model loaded successfully!")
        except Exception as e:
            print(f"Error loading model: {e}")
            raise RuntimeError(
                f"Failed to load model from IFMedTech/Lumps. "
                f"Please ensure HF_TOKEN is set in Spaces secrets with read access to the private repository. "
                f"Error: {e}"
            )
    
    return MODEL, DEVICE

def predict(image, score_thresh=0.5):
    """Run inference and return image with bounding boxes."""
    # Ensure model is loaded
    model, device = initialize_model()
    
    # Preprocess image
    image_tensor = preprocess_image(image)
    
    # Run inference
    model.eval()
    with torch.no_grad():
        preds = model([image_tensor.to(device)])[0]
    
    boxes, labels, scores = preds['boxes'].cpu().numpy(), preds['labels'].cpu().numpy(), preds['scores'].cpu().numpy()
    
    # Filter based on confidence threshold
    keep = scores >= score_thresh
    boxes, labels, scores = boxes[keep], labels[keep], scores[keep]
    
    # Convert tensor back to image
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    image_np = image_tensor.cpu().permute(1, 2, 0).numpy() * std + mean
    image_np = np.clip(image_np, 0, 1)
    
    # Create figure
    fig, ax = plt.subplots(1, figsize=(12, 12))
    ax.imshow(image_np)
    
    # Draw bounding boxes
    for box, label, score in zip(boxes, labels, scores):
        xmin, ymin, xmax, ymax = box
        rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, 
                                linewidth=3, edgecolor=CLASS_COLORS.get(label, 'blue'), 
                                facecolor='none')
        ax.add_patch(rect)
        ax.text(xmin, ymin - 10, f"{CLASS_NAMES.get(label, f'class_{label}')} ({score:.2f})", 
                fontsize=14, color='white', backgroundcolor='black', weight='bold')
    
    plt.axis('off')
    
    # Convert plot to image
    buf = io.BytesIO()
    plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
    buf.seek(0)
    result_image = Image.open(buf)
    plt.close()
    
    return result_image

# Create Gradio interface
demo = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="pil", label="Upload Breast Image"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.05, label="Confidence Threshold")
    ],
    outputs=gr.Image(type="pil", label="Detection Results"),
    title="Breast Lumps Detection",
    description="""Upload a breast image to detect lumps and nipples using a Faster R-CNN model.\n\n
    ⚠️ **Important Medical Disclaimer**: This is a screening tool for research and assistive purposes only. 
    It should NOT be used as the sole basis for medical diagnosis. All detections must be reviewed and confirmed 
    by qualified medical professionals. This model is not FDA approved or certified for clinical diagnosis.""",
    examples=None,
    allow_flagging="never"
)

if __name__ == "__main__":
    demo.launch()