Spaces:
Sleeping
Sleeping
| import logging | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| from scripts.trainer import XrayReg | |
| # Model selection dropdown options (extendable) | |
| MODEL_OPTIONS = { | |
| "XrayReg (912yp4l6) [vit_large_patch16_224_in21k]": { | |
| "ckpt": "xray_regression_noaug/912yp4l6/checkpoints/epoch=99-step=5900.ckpt", | |
| "model_name": "vit_large_patch16_224_in21k" | |
| }, | |
| "XrayReg (ie399gjr) [vit_small_patch16_224_in21k]": { | |
| "ckpt": "xray_regression_noaug/ie399gjr/checkpoints/epoch=99-step=5900.ckpt", | |
| "model_name": "vit_small_patch16_224_in21k" | |
| }, | |
| "XrayReg (kcku20nx) [vit_large_patch16_224_in21k]": { | |
| "ckpt": "xray_regression_noaug/kcku20nx/checkpoints/epoch=99-step=5900.ckpt", | |
| "model_name": "vit_large_patch16_224_in21k" | |
| }, | |
| "XrayReg (ohtmkj0i) [vit_base_patch16_224_in21k]": { | |
| "ckpt": "xray_regression_noaug/ohtmkj0i/checkpoints/epoch=99-step=5900.ckpt", | |
| "model_name": "vit_base_patch16_224_in21k" | |
| }, | |
| "XrayReg (vlk8qrkx) [vit_large_patch16_224_in21k]": { | |
| "ckpt": "xray_regression_noaug/vlk8qrkx/checkpoints/epoch=99-step=5900.ckpt", | |
| "model_name": "vit_large_patch16_224_in21k" | |
| }, | |
| } | |
| def preprocess_image(inp): | |
| """ | |
| Preprocess the input image. | |
| Returns: | |
| input_tensor: Tensor to be fed into the model. | |
| rgb_img: NumPy array normalized to [0, 1] for GradCAM visualization. | |
| """ | |
| try: | |
| preprocess = transforms.Compose( | |
| [ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| ] | |
| ) | |
| input_tensor = preprocess(inp).unsqueeze(0) | |
| rgb_img = np.array(inp.resize((224, 224))).astype(np.float32) / 255.0 | |
| return input_tensor, rgb_img | |
| except Exception as e: | |
| logging.error("Error in image preprocessing: %s", e) | |
| raise | |
| def preprocess_image_custom(inp): | |
| preprocess = transforms.Compose( | |
| [ | |
| transforms.Resize((224, 224)), | |
| transforms.Grayscale(num_output_channels=1), | |
| transforms.ToTensor(), | |
| ] | |
| ) | |
| input_tensor = preprocess(inp).unsqueeze(0) | |
| rgb_img = np.array(inp.resize((224, 224)).convert("RGB")).astype(np.float32) / 255.0 | |
| return input_tensor, rgb_img | |
| def predict_custom(model, input_tensor): | |
| with torch.no_grad(): | |
| input_tensor = ( | |
| input_tensor.cuda() if torch.cuda.is_available() else input_tensor | |
| ) | |
| pred = model(input_tensor) | |
| pred = pred.cpu().numpy().flatten()[0] | |
| return float(pred) | |
| def load_custom_model(model_key): | |
| model_info = MODEL_OPTIONS[model_key] | |
| # Pass model_name to config for correct model instantiation | |
| config = {"model": {"name": model_info["model_name"]}} | |
| model = XrayReg.load_from_checkpoint(model_info["ckpt"], map_location="cpu") | |
| model = model.model | |
| model.eval() | |
| for param in model.parameters(): | |
| param.requires_grad = True | |
| return model | |
| def predict_and_cam_custom(inp, model): | |
| input_tensor, rgb_img = preprocess_image_custom(inp) | |
| model = model.cuda() if torch.cuda.is_available() else model.model | |
| model.eval() | |
| for param in model.parameters(): | |
| param.requires_grad = True | |
| value = predict_custom(model, input_tensor) | |
| # GradCAM for regression: use last conv layer, target output | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| # model = model.cuda() | |
| target_layers = [ | |
| layer | |
| for name, layer in model.named_modules() | |
| if isinstance(layer, torch.nn.Conv2d) | |
| ][-1:] | |
| gradcam = GradCAM(model=model, target_layers=target_layers) | |
| targets = [ClassifierOutputTarget(0)] # For regression, just use output | |
| grayscale_cam = gradcam(input_tensor=input_tensor, targets=targets)[0] | |
| cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) | |
| cam_pil = Image.fromarray(cam_image) | |
| # Return as tuple (number, image) for Gradio | |
| return value, cam_pil | |
| def create_interface_custom(): | |
| # Use stateful model cache to avoid reloading on every prediction | |
| from functools import lru_cache | |
| def cached_load_model(model_key): | |
| return load_custom_model(model_key) | |
| def predict_wrapper(inp, model_key): | |
| model = cached_load_model(model_key) | |
| return predict_and_cam_custom(inp, model) | |
| interface = gr.Interface( | |
| fn=predict_wrapper, | |
| inputs=[ | |
| gr.Image(type="pil"), | |
| gr.Dropdown(list(MODEL_OPTIONS.keys()), label="Model"), | |
| ], | |
| outputs=[ | |
| gr.Number(label="Regression Output"), | |
| gr.Image(type="pil", label="GradCAM Visualization"), | |
| ], | |
| examples=None, | |
| title="Xray Regression Gradio App", | |
| description="Upload an X-ray image and select a model to get regression output and GradCAM visualization.", | |
| allow_flagging="never", | |
| live=True, # Ensures model reloads on dropdown change | |
| ) | |
| return interface | |
| def download_models(): | |
| import huggingface_hub | |
| repo_name = "SuperSecureHuman/xray-reg-models" | |
| local_dir = "./" | |
| huggingface_hub.snapshot_download( | |
| repo_id=repo_name, | |
| local_dir=local_dir, | |
| ) | |
| def main(): | |
| # Download models if not already present | |
| try: | |
| download_models() | |
| except Exception as e: | |
| logging.error("Error downloading models: %s", e) | |
| exit(1) | |
| logging.basicConfig(level=logging.INFO) | |
| interface = create_interface_custom() | |
| interface.launch() | |
| if __name__ == "__main__": | |
| main() |