""" Visualization utilities for artist embedding model: - Grad-CAM heatmaps - View attention weights (whole/face/eyes) - Branch attention weights (Gram/Cov/Spectrum/Stats) """ from __future__ import annotations from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import numpy as np import torch import torch.nn.functional as F from PIL import Image @dataclass class ViewAnalysis: """Analysis results for a single inference.""" # View attention weights [3] for whole/face/eyes view_weights: Dict[str, float] # Branch attention weights per view {view_name: {branch_name: weight}} branch_weights: Dict[str, Dict[str, float]] # Grad-CAM heatmaps per view (PIL Images) gradcam_heatmaps: Dict[str, Optional[Image.Image]] # Original images for overlay original_images: Dict[str, Optional[Image.Image]] def _get_branch_weights(encoder, x: torch.Tensor) -> Dict[str, float]: """ Extract branch attention weights from a ViewEncoder forward pass. Returns dict with keys: gram, cov, spectrum, stats """ # We need to do a partial forward to get the branch gate weights with torch.no_grad(): x_lab = encoder._rgb_to_lab(x) f0 = encoder.stem(x_lab) f1 = encoder.b1(f0) f2 = encoder.b2(f1) f3 = encoder.b3(f2) f4 = encoder.b4(f3) g3 = encoder.h_gram3(f3) c3 = encoder.h_cov3(f3) sp3 = encoder.h_sp3(f3) st3 = encoder.h_st3(f3) g4 = encoder.h_gram4(f4) c4 = encoder.h_cov4(f4) sp4 = encoder.h_sp4(f4) st4 = encoder.h_st4(f4) b_gram = torch.cat([g3, g4], dim=1) b_cov = torch.cat([c3, c4], dim=1) b_sp = torch.cat([sp3, sp4], dim=1) b_st = torch.cat([st3, st4], dim=1) flat = torch.cat([b_gram, b_cov, b_sp, b_st], dim=1) gate_logits = encoder.branch_gate(flat) w = torch.softmax(gate_logits, dim=-1) # w is [1, 4] for single image w_np = w[0].cpu().numpy() return { "Gram": float(w_np[0]), "Cov": float(w_np[1]), "Spectrum": float(w_np[2]), "Stats": float(w_np[3]), } def _compute_xgradcam( encoder, x: torch.Tensor, target_layer_name: str = "b3", ) -> np.ndarray: """ Compute XGrad-CAM heatmap for a ViewEncoder. XGrad-CAM is an improved variant that uses element-wise gradient-activation products normalized by activation sums, providing better localization. Reference: Axiom-based Grad-CAM (Fu et al., BMVC 2020) Returns a heatmap as numpy array [H, W] normalized to [0, 1]. """ # Storage for activations and gradients activations = {} gradients = {} def forward_hook(module, input, output): activations["value"] = output.detach().clone() def backward_hook(module, grad_input, grad_output): gradients["value"] = grad_output[0].detach().clone() # Get the target layer target_layer = getattr(encoder, target_layer_name, None) if target_layer is None: # Fallback to b2 or b1 for fallback in ["b2", "b1", "stem"]: target_layer = getattr(encoder, fallback, None) if target_layer is not None: break if target_layer is None: return np.zeros((x.shape[2], x.shape[3]), dtype=np.float32) # Register hooks fwd_handle = target_layer.register_forward_hook(forward_hook) bwd_handle = target_layer.register_full_backward_hook(backward_hook) try: # Forward pass x.requires_grad_(True) output = encoder(x) # Backward pass - use the L2 norm of output as target target = output.norm(dim=1).sum() encoder.zero_grad() target.backward(retain_graph=True) # Get activations and gradients acts = activations.get("value") grads = gradients.get("value") if acts is None or grads is None: return np.zeros((x.shape[2], x.shape[3]), dtype=np.float32) # XGrad-CAM: weights = sum(grads * acts, spatial) / (sum(acts, spatial) + eps) # This normalizes by the activation magnitude, improving localization grad_act_product = grads * acts # [B, C, H, W] sum_grad_act = grad_act_product.sum(dim=(2, 3), keepdim=True) # [B, C, 1, 1] sum_acts = acts.sum(dim=(2, 3), keepdim=True) + 1e-7 # [B, C, 1, 1] weights = sum_grad_act / sum_acts # [B, C, 1, 1] # Weighted combination of activations cam = (weights * acts).sum(dim=1, keepdim=True) # [B, 1, H, W] cam = F.relu(cam) # Only positive contributions # Normalize cam = cam[0, 0].cpu().numpy() if cam.max() > 0: cam = cam / cam.max() # Resize to input size cam_pil = Image.fromarray((cam * 255).astype(np.uint8)) cam_pil = cam_pil.resize((x.shape[3], x.shape[2]), Image.BILINEAR) cam = np.array(cam_pil).astype(np.float32) / 255.0 return cam finally: fwd_handle.remove() bwd_handle.remove() x.requires_grad_(False) def _overlay_heatmap( image: Image.Image, heatmap: np.ndarray, alpha: float = 0.5, colormap: str = "jet", ) -> Image.Image: """Overlay a heatmap on an image.""" import matplotlib.pyplot as plt # Ensure heatmap is 2D and normalized heatmap = np.clip(heatmap, 0, 1) # Get colormap cmap = plt.get_cmap(colormap) heatmap_colored = cmap(heatmap)[:, :, :3] # RGB only, no alpha heatmap_colored = (heatmap_colored * 255).astype(np.uint8) # Resize heatmap to match image heatmap_pil = Image.fromarray(heatmap_colored) heatmap_pil = heatmap_pil.resize(image.size, Image.BILINEAR) # Blend image_rgb = image.convert("RGB") blended = Image.blend(image_rgb, heatmap_pil, alpha) return blended def analyze_views( model: torch.nn.Module, views: Dict[str, Optional[torch.Tensor]], original_images: Dict[str, Optional[Image.Image]], device: torch.device, ) -> ViewAnalysis: """ Perform full analysis on a set of views. Returns view weights, branch weights per view, and Grad-CAM heatmaps. """ model.eval() # Prepare masks masks = {} view_tensors = {} for k in ("whole", "face", "eyes"): if views.get(k) is not None: view_tensors[k] = views[k].unsqueeze(0).to(device) masks[k] = torch.ones(1, dtype=torch.bool, device=device) else: view_tensors[k] = None masks[k] = torch.zeros(1, dtype=torch.bool, device=device) # Get view attention weights from forward pass with torch.no_grad(): _, _, W = model(view_tensors, masks) # W is [1, num_present_views] W_np = W[0].cpu().numpy() # Map W back to view names (only present views have weights) view_order = ["whole", "face", "eyes"] present_views = [k for k in view_order if view_tensors[k] is not None] view_weights = {} for i, k in enumerate(present_views): view_weights[k] = float(W_np[i]) for k in view_order: if k not in view_weights: view_weights[k] = 0.0 # Get branch weights and Grad-CAM for each view branch_weights = {} gradcam_heatmaps = {} # Get encoder (shared or separate) enc_whole = model.enc_whole enc_face = model.enc_face enc_eyes = model.enc_eyes encoders = {"whole": enc_whole, "face": enc_face, "eyes": enc_eyes} for k in view_order: if view_tensors[k] is not None: enc = encoders[k] x = view_tensors[k] # Branch weights try: branch_weights[k] = _get_branch_weights(enc, x) except Exception: branch_weights[k] = {"Gram": 0.25, "Cov": 0.25, "Spectrum": 0.25, "Stats": 0.25} # Grad-CAM try: heatmap = _compute_xgradcam(enc, x.clone(), target_layer_name="b3") if original_images.get(k) is not None: gradcam_heatmaps[k] = _overlay_heatmap(original_images[k], heatmap, alpha=0.5) else: gradcam_heatmaps[k] = None except Exception: gradcam_heatmaps[k] = None else: branch_weights[k] = {} gradcam_heatmaps[k] = None return ViewAnalysis( view_weights=view_weights, branch_weights=branch_weights, gradcam_heatmaps=gradcam_heatmaps, original_images={k: original_images.get(k) for k in view_order}, ) def format_view_weights_html(analysis: ViewAnalysis) -> str: """Format view weights as clean HTML with styled progress bars.""" # View labels with descriptions view_info = { "whole": ("Whole Image", "#4CAF50"), # green "face": ("Face", "#2196F3"), # blue "eyes": ("Eye", "#FF9800"), # orange (single-eye crop) } html_parts = ['