Spaces:
Running
Running
| from typing import List, Tuple, Dict | |
| from collections import OrderedDict | |
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| import timm | |
| from timm.data import create_transform | |
| from timm.models import create_model | |
| from timm.utils import AttentionExtract | |
| from PIL import Image | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| def get_attention_models() -> List[str]: | |
| """Get a list of timm models that have attention blocks.""" | |
| all_models = timm.list_pretrained() | |
| # FIXME Focusing on ViT models for initial impl | |
| attention_models = [model for model in all_models if any([model.lower().startswith(p) for p in ('vit', 'deit', 'beit', 'eva')])] | |
| return attention_models | |
| def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtract]: | |
| """Load a model from timm and prepare it for attention extraction.""" | |
| timm.layers.set_fused_attn(False) | |
| model = create_model(model_name, pretrained=True) | |
| model.eval() | |
| extractor = AttentionExtract(model, method='fx') # can use 'hooks', can also allow specifying matching names for attention nodes or modules... | |
| return model, extractor | |
| def process_image( | |
| image: Image.Image, | |
| model: torch.nn.Module, | |
| extractor: AttentionExtract | |
| ) -> Dict[str, torch.Tensor]: | |
| """Process the input image and get the attention maps.""" | |
| # Get the correct transform for the model | |
| config = model.pretrained_cfg | |
| transform = create_transform( | |
| input_size=config['input_size'], | |
| crop_pct=config['crop_pct'], | |
| mean=config['mean'], | |
| std=config['std'], | |
| interpolation=config['interpolation'], | |
| is_training=False | |
| ) | |
| # Preprocess the image | |
| tensor = transform(image).unsqueeze(0) | |
| # Extract attention maps | |
| attention_maps = extractor(tensor) | |
| return attention_maps | |
| def apply_mask(image: np.ndarray, mask: np.ndarray, color: Tuple[float, float, float], alpha: float = 0.5) -> np.ndarray: | |
| # Ensure mask and image have the same shape | |
| mask = mask[:, :, np.newaxis] | |
| mask = np.repeat(mask, 3, axis=2) | |
| # Convert color to numpy array | |
| color = np.array(color) | |
| # Apply mask | |
| masked_image = image * (1 - alpha * mask) + alpha * mask * color[np.newaxis, np.newaxis, :] * 255 | |
| return masked_image.astype(np.uint8) | |
| def rollout(attentions, discard_ratio, head_fusion, num_prefix_tokens=1): | |
| # based on https://github.com/jacobgil/vit-explain/blob/main/vit_rollout.py | |
| result = torch.eye(attentions[0].size(-1)) | |
| with torch.no_grad(): | |
| for attention in attentions: | |
| if head_fusion.startswith('mean'): | |
| # mean_std fusion doesn't appear to make sense with rollout | |
| attention_heads_fused = attention.mean(dim=0) | |
| elif head_fusion == "max": | |
| attention_heads_fused = attention.amax(dim=0) | |
| elif head_fusion == "min": | |
| attention_heads_fused = attention.amin(dim=0) | |
| else: | |
| raise ValueError("Attention head fusion type Not supported") | |
| # Discard the lowest attentions, but don't discard the prefix tokens | |
| flat = attention_heads_fused.view(-1) | |
| _, indices = flat.topk(int(flat.size(-1 )* discard_ratio), -1, False) | |
| indices = indices[indices >= num_prefix_tokens] | |
| flat[indices] = 0 | |
| I = torch.eye(attention_heads_fused.size(-1)) | |
| a = (attention_heads_fused + 1.0 * I) / 2 | |
| a = a / a.sum(dim=-1) | |
| result = torch.matmul(a, result) | |
| # Look at the total attention between the prefix tokens (usually class tokens) | |
| # and the image patches | |
| # FIXME this is token 0 vs non-prefix right now, need to cover other cases (> 1 prefix, no prefix, etc) | |
| mask = result[0, num_prefix_tokens:] | |
| width = int(mask.size(-1) ** 0.5) | |
| mask = mask.reshape(width, width).numpy() | |
| mask = mask / np.max(mask) | |
| return mask | |
| def visualize_attention( | |
| image: Image.Image, | |
| model_name: str, | |
| head_fusion: str, | |
| discard_ratio: float, | |
| ) -> Tuple[List[Image.Image], Image.Image]: | |
| """Visualize attention maps and rollout for the given image and model.""" | |
| model, extractor = load_model(model_name) | |
| attention_maps = process_image(image, model, extractor) | |
| # FIXME handle wider range of models that may not have num_prefix_tokens attr | |
| num_prefix_tokens = getattr(model, 'num_prefix_tokens', 1) # Default to 1 class token if not specified | |
| # Convert PIL Image to numpy array | |
| image_np = np.array(image) | |
| # Create visualizations | |
| visualizations = [] | |
| attentions_for_rollout = [] | |
| for layer_name, attn_map in attention_maps.items(): | |
| print(f"Attention map shape for {layer_name}: {attn_map.shape}") | |
| attn_map = attn_map[0] # Remove batch dimension | |
| attentions_for_rollout.append(attn_map) | |
| attn_map = attn_map[:, :, num_prefix_tokens:] # Remove prefix tokens for visualization | |
| if head_fusion == 'mean_std': | |
| attn_map = attn_map.mean(0) / attn_map.std(0) | |
| elif head_fusion == 'mean': | |
| attn_map = attn_map.mean(0) | |
| elif head_fusion == 'max': | |
| attn_map = attn_map.amax(0) | |
| elif head_fusion == 'min': | |
| attn_map = attn_map.amin(0) | |
| else: | |
| raise ValueError(f"Invalid head fusion method: {head_fusion}") | |
| # Use the first token's attention (usually the class token) | |
| # FIXME handle different prefix token scenarios | |
| attn_map = attn_map[0] | |
| # Reshape the attention map to 2D | |
| num_patches = int(attn_map.shape[0] ** 0.5) | |
| attn_map = attn_map.reshape(num_patches, num_patches) | |
| # Interpolate to match image size | |
| attn_map = torch.tensor(attn_map).unsqueeze(0).unsqueeze(0) | |
| attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False) | |
| attn_map = attn_map.squeeze().cpu().numpy() | |
| # Normalize attention map | |
| attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min()) | |
| # Create visualization | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10)) | |
| # Original image | |
| ax1.imshow(image_np) | |
| ax1.set_title("Original Image") | |
| ax1.axis('off') | |
| # Attention map overlay | |
| masked_image = apply_mask(image_np, attn_map, color=(1, 0, 0)) # Red mask | |
| ax2.imshow(masked_image) | |
| ax2.set_title(f'Attention Map for {layer_name}') | |
| ax2.axis('off') | |
| plt.tight_layout() | |
| # Convert plot to image | |
| fig.canvas.draw() | |
| vis_image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
| visualizations.append(vis_image) | |
| plt.close(fig) | |
| # Calculate rollout | |
| rollout_mask = rollout(attentions_for_rollout, discard_ratio, head_fusion, num_prefix_tokens) | |
| # Create rollout visualization | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10)) | |
| # Original image | |
| ax1.imshow(image_np) | |
| ax1.set_title("Original Image") | |
| ax1.axis('off') | |
| # Rollout overlay | |
| rollout_mask_pil = Image.fromarray((rollout_mask * 255).astype(np.uint8)) | |
| rollout_mask_resized = np.array(rollout_mask_pil.resize((image_np.shape[1], image_np.shape[0]), Image.BICUBIC)) / 255.0 | |
| masked_image = apply_mask(image_np, rollout_mask_resized, color=(1, 0, 0)) # Red mask | |
| ax2.imshow(masked_image) | |
| ax2.set_title('Attention Rollout') | |
| ax2.axis('off') | |
| plt.tight_layout() | |
| # Convert plot to image | |
| fig.canvas.draw() | |
| rollout_image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
| plt.close(fig) | |
| return visualizations, rollout_image | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=visualize_attention, | |
| inputs=[ | |
| gr.Image(type="pil", label="Input Image"), | |
| gr.Dropdown(choices=get_attention_models(), label="Select Model"), | |
| gr.Dropdown( | |
| choices=['mean_std', 'mean', 'max', 'min'], | |
| label="Head Fusion Method", | |
| value='mean' # Default value | |
| ), | |
| gr.Slider(0, 1, 0.9, label="Discard Ratio", info="Ratio of lowest attentions to discard") | |
| ], | |
| outputs=[ | |
| gr.Gallery(label="Attention Maps"), | |
| gr.Image(label="Attention Rollout") | |
| ], | |
| title="Attention Map Visualizer for timm Models", | |
| description="Upload an image and select a timm model to visualize its attention maps." | |
| ) | |
| iface.launch() |