iljung1106 commited on
Commit
93d1be8
Β·
1 Parent(s): 5570c3c

Grad CAM to XGrad CAM

Browse files
Files changed (2) hide show
  1. app/visualization.py +17 -10
  2. webui_gradio.py +8 -8
app/visualization.py CHANGED
@@ -71,14 +71,17 @@ def _get_branch_weights(encoder, x: torch.Tensor) -> Dict[str, float]:
71
  }
72
 
73
 
74
- def _compute_gradcam(
75
  encoder,
76
  x: torch.Tensor,
77
  target_layer_name: str = "b3",
78
  ) -> np.ndarray:
79
  """
80
- Compute Grad-CAM heatmap for a ViewEncoder.
81
- Uses gradients of the output w.r.t. an intermediate feature map.
 
 
 
82
  Returns a heatmap as numpy array [H, W] normalized to [0, 1].
83
  """
84
  # Storage for activations and gradients
@@ -86,10 +89,10 @@ def _compute_gradcam(
86
  gradients = {}
87
 
88
  def forward_hook(module, input, output):
89
- activations["value"] = output.detach()
90
 
91
  def backward_hook(module, grad_input, grad_output):
92
- gradients["value"] = grad_output[0].detach()
93
 
94
  # Get the target layer
95
  target_layer = getattr(encoder, target_layer_name, None)
@@ -124,8 +127,12 @@ def _compute_gradcam(
124
  if acts is None or grads is None:
125
  return np.zeros((x.shape[2], x.shape[3]), dtype=np.float32)
126
 
127
- # Compute Grad-CAM weights (global average pooling of gradients)
128
- weights = grads.mean(dim=(2, 3), keepdim=True) # [B, C, 1, 1]
 
 
 
 
129
 
130
  # Weighted combination of activations
131
  cam = (weights * acts).sum(dim=1, keepdim=True) # [B, 1, H, W]
@@ -242,7 +249,7 @@ def analyze_views(
242
 
243
  # Grad-CAM
244
  try:
245
- heatmap = _compute_gradcam(enc, x.clone(), target_layer_name="b3")
246
  if original_images.get(k) is not None:
247
  gradcam_heatmaps[k] = _overlay_heatmap(original_images[k], heatmap, alpha=0.5)
248
  else:
@@ -263,11 +270,11 @@ def analyze_views(
263
 
264
  def format_view_weights_html(analysis: ViewAnalysis) -> str:
265
  """Format view weights as clean HTML with styled progress bars."""
266
- # View labels with descriptions (eye = singular)
267
  view_info = {
268
  "whole": ("Whole Image", "#4CAF50"), # green
269
  "face": ("Face", "#2196F3"), # blue
270
- "eyes": ("Eye Region", "#FF9800"), # orange
271
  }
272
 
273
  html_parts = ['<div style="font-family: sans-serif; padding: 10px;">']
 
71
  }
72
 
73
 
74
+ def _compute_xgradcam(
75
  encoder,
76
  x: torch.Tensor,
77
  target_layer_name: str = "b3",
78
  ) -> np.ndarray:
79
  """
80
+ Compute XGrad-CAM heatmap for a ViewEncoder.
81
+ XGrad-CAM is an improved variant that uses element-wise gradient-activation
82
+ products normalized by activation sums, providing better localization.
83
+
84
+ Reference: Axiom-based Grad-CAM (Fu et al., BMVC 2020)
85
  Returns a heatmap as numpy array [H, W] normalized to [0, 1].
86
  """
87
  # Storage for activations and gradients
 
89
  gradients = {}
90
 
91
  def forward_hook(module, input, output):
92
+ activations["value"] = output.detach().clone()
93
 
94
  def backward_hook(module, grad_input, grad_output):
95
+ gradients["value"] = grad_output[0].detach().clone()
96
 
97
  # Get the target layer
98
  target_layer = getattr(encoder, target_layer_name, None)
 
127
  if acts is None or grads is None:
128
  return np.zeros((x.shape[2], x.shape[3]), dtype=np.float32)
129
 
130
+ # XGrad-CAM: weights = sum(grads * acts, spatial) / (sum(acts, spatial) + eps)
131
+ # This normalizes by the activation magnitude, improving localization
132
+ grad_act_product = grads * acts # [B, C, H, W]
133
+ sum_grad_act = grad_act_product.sum(dim=(2, 3), keepdim=True) # [B, C, 1, 1]
134
+ sum_acts = acts.sum(dim=(2, 3), keepdim=True) + 1e-7 # [B, C, 1, 1]
135
+ weights = sum_grad_act / sum_acts # [B, C, 1, 1]
136
 
137
  # Weighted combination of activations
138
  cam = (weights * acts).sum(dim=1, keepdim=True) # [B, 1, H, W]
 
249
 
250
  # Grad-CAM
251
  try:
252
+ heatmap = _compute_xgradcam(enc, x.clone(), target_layer_name="b3")
253
  if original_images.get(k) is not None:
254
  gradcam_heatmaps[k] = _overlay_heatmap(original_images[k], heatmap, alpha=0.5)
255
  else:
 
270
 
271
  def format_view_weights_html(analysis: ViewAnalysis) -> str:
272
  """Format view weights as clean HTML with styled progress bars."""
273
+ # View labels with descriptions
274
  view_info = {
275
  "whole": ("Whole Image", "#4CAF50"), # green
276
  "face": ("Face", "#2196F3"), # blue
277
+ "eyes": ("Eyes", "#FF9800"), # orange
278
  }
279
 
280
  html_parts = ['<div style="font-family: sans-serif; padding: 10px;">']
webui_gradio.py CHANGED
@@ -298,7 +298,7 @@ def classify_and_analyze(
298
  return ("❌ Provide a whole image.",) + empty_result[1:]
299
 
300
  try:
301
- # Extract face and eye region
302
  face_pil = None
303
  eye_pil = None
304
  if ex is not None:
@@ -319,7 +319,7 @@ def classify_and_analyze(
319
  preds = topk_predictions_unique_labels(db, z, topk=int(topk))
320
  rows = [[name, float(score)] for (name, score) in preds]
321
 
322
- # Analysis (Grad-CAM + view weights)
323
  views = {"whole": wt, "face": ft, "eyes": et}
324
  original_images = {"whole": w, "face": face_pil, "eyes": eye_pil}
325
  analysis = analyze_views(lm.model, views, original_images, lm.device)
@@ -546,19 +546,19 @@ def build_ui() -> gr.Blocks:
546
  gr.Markdown("### 🎯 Classification Results")
547
  table = gr.Dataframe(headers=["Artist", "Similarity"], datatype=["str", "number"], interactive=False)
548
 
549
- # Grad-CAM heatmaps
550
- gr.Markdown("### πŸ”₯ Grad-CAM Attention Maps")
551
  gr.Markdown("*Where the model focused in each view:*")
552
  with gr.Row():
553
  gcam_whole = gr.Image(label="Whole Image", type="pil")
554
  gcam_face = gr.Image(label="Face", type="pil")
555
- gcam_eye = gr.Image(label="Eye Region", type="pil")
556
 
557
  # Extracted views
558
  gr.Markdown("### πŸ‘οΈ Auto-Extracted Views")
559
  with gr.Row():
560
  face_prev = gr.Image(label="Detected Face", type="pil")
561
- eye_prev = gr.Image(label="Detected Eye", type="pil")
562
 
563
  run_btn.click(
564
  classify_and_analyze,
@@ -571,8 +571,8 @@ def build_ui() -> gr.Blocks:
571
  "### ⚠️ Temporary Prototypes Only\n"
572
  "Add prototypes using random triplet combinations and K-means clustering (same as eval process).\n"
573
  "1. Upload multiple whole images\n"
574
- "2. Face and eye region are auto-extracted from each\n"
575
- "3. Random triplets (whole + face + eye) are created\n"
576
  "4. K-means clustering creates K prototype centers\n\n"
577
  "**These prototypes are session-only** β€” lost when the Space restarts."
578
  )
 
298
  return ("❌ Provide a whole image.",) + empty_result[1:]
299
 
300
  try:
301
+ # Extract face and eyes
302
  face_pil = None
303
  eye_pil = None
304
  if ex is not None:
 
319
  preds = topk_predictions_unique_labels(db, z, topk=int(topk))
320
  rows = [[name, float(score)] for (name, score) in preds]
321
 
322
+ # Analysis (XGrad-CAM + view weights)
323
  views = {"whole": wt, "face": ft, "eyes": et}
324
  original_images = {"whole": w, "face": face_pil, "eyes": eye_pil}
325
  analysis = analyze_views(lm.model, views, original_images, lm.device)
 
546
  gr.Markdown("### 🎯 Classification Results")
547
  table = gr.Dataframe(headers=["Artist", "Similarity"], datatype=["str", "number"], interactive=False)
548
 
549
+ # XGrad-CAM heatmaps
550
+ gr.Markdown("### πŸ”₯ XGrad-CAM Attention Maps")
551
  gr.Markdown("*Where the model focused in each view:*")
552
  with gr.Row():
553
  gcam_whole = gr.Image(label="Whole Image", type="pil")
554
  gcam_face = gr.Image(label="Face", type="pil")
555
+ gcam_eye = gr.Image(label="Eyes", type="pil")
556
 
557
  # Extracted views
558
  gr.Markdown("### πŸ‘οΈ Auto-Extracted Views")
559
  with gr.Row():
560
  face_prev = gr.Image(label="Detected Face", type="pil")
561
+ eye_prev = gr.Image(label="Detected Eyes", type="pil")
562
 
563
  run_btn.click(
564
  classify_and_analyze,
 
571
  "### ⚠️ Temporary Prototypes Only\n"
572
  "Add prototypes using random triplet combinations and K-means clustering (same as eval process).\n"
573
  "1. Upload multiple whole images\n"
574
+ "2. Face and eyes are auto-extracted from each\n"
575
+ "3. Random triplets (whole + face + eyes) are created\n"
576
  "4. K-means clustering creates K prototype centers\n\n"
577
  "**These prototypes are session-only** β€” lost when the Space restarts."
578
  )