Spaces:
Running
on
Zero
Running
on
Zero
iljung1106
commited on
Commit
Β·
93d1be8
1
Parent(s):
5570c3c
Grad CAM to XGrad CAM
Browse files- app/visualization.py +17 -10
- webui_gradio.py +8 -8
app/visualization.py
CHANGED
|
@@ -71,14 +71,17 @@ def _get_branch_weights(encoder, x: torch.Tensor) -> Dict[str, float]:
|
|
| 71 |
}
|
| 72 |
|
| 73 |
|
| 74 |
-
def
|
| 75 |
encoder,
|
| 76 |
x: torch.Tensor,
|
| 77 |
target_layer_name: str = "b3",
|
| 78 |
) -> np.ndarray:
|
| 79 |
"""
|
| 80 |
-
Compute
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
| 82 |
Returns a heatmap as numpy array [H, W] normalized to [0, 1].
|
| 83 |
"""
|
| 84 |
# Storage for activations and gradients
|
|
@@ -86,10 +89,10 @@ def _compute_gradcam(
|
|
| 86 |
gradients = {}
|
| 87 |
|
| 88 |
def forward_hook(module, input, output):
|
| 89 |
-
activations["value"] = output.detach()
|
| 90 |
|
| 91 |
def backward_hook(module, grad_input, grad_output):
|
| 92 |
-
gradients["value"] = grad_output[0].detach()
|
| 93 |
|
| 94 |
# Get the target layer
|
| 95 |
target_layer = getattr(encoder, target_layer_name, None)
|
|
@@ -124,8 +127,12 @@ def _compute_gradcam(
|
|
| 124 |
if acts is None or grads is None:
|
| 125 |
return np.zeros((x.shape[2], x.shape[3]), dtype=np.float32)
|
| 126 |
|
| 127 |
-
#
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
# Weighted combination of activations
|
| 131 |
cam = (weights * acts).sum(dim=1, keepdim=True) # [B, 1, H, W]
|
|
@@ -242,7 +249,7 @@ def analyze_views(
|
|
| 242 |
|
| 243 |
# Grad-CAM
|
| 244 |
try:
|
| 245 |
-
heatmap =
|
| 246 |
if original_images.get(k) is not None:
|
| 247 |
gradcam_heatmaps[k] = _overlay_heatmap(original_images[k], heatmap, alpha=0.5)
|
| 248 |
else:
|
|
@@ -263,11 +270,11 @@ def analyze_views(
|
|
| 263 |
|
| 264 |
def format_view_weights_html(analysis: ViewAnalysis) -> str:
|
| 265 |
"""Format view weights as clean HTML with styled progress bars."""
|
| 266 |
-
# View labels with descriptions
|
| 267 |
view_info = {
|
| 268 |
"whole": ("Whole Image", "#4CAF50"), # green
|
| 269 |
"face": ("Face", "#2196F3"), # blue
|
| 270 |
-
"eyes": ("
|
| 271 |
}
|
| 272 |
|
| 273 |
html_parts = ['<div style="font-family: sans-serif; padding: 10px;">']
|
|
|
|
| 71 |
}
|
| 72 |
|
| 73 |
|
| 74 |
+
def _compute_xgradcam(
|
| 75 |
encoder,
|
| 76 |
x: torch.Tensor,
|
| 77 |
target_layer_name: str = "b3",
|
| 78 |
) -> np.ndarray:
|
| 79 |
"""
|
| 80 |
+
Compute XGrad-CAM heatmap for a ViewEncoder.
|
| 81 |
+
XGrad-CAM is an improved variant that uses element-wise gradient-activation
|
| 82 |
+
products normalized by activation sums, providing better localization.
|
| 83 |
+
|
| 84 |
+
Reference: Axiom-based Grad-CAM (Fu et al., BMVC 2020)
|
| 85 |
Returns a heatmap as numpy array [H, W] normalized to [0, 1].
|
| 86 |
"""
|
| 87 |
# Storage for activations and gradients
|
|
|
|
| 89 |
gradients = {}
|
| 90 |
|
| 91 |
def forward_hook(module, input, output):
|
| 92 |
+
activations["value"] = output.detach().clone()
|
| 93 |
|
| 94 |
def backward_hook(module, grad_input, grad_output):
|
| 95 |
+
gradients["value"] = grad_output[0].detach().clone()
|
| 96 |
|
| 97 |
# Get the target layer
|
| 98 |
target_layer = getattr(encoder, target_layer_name, None)
|
|
|
|
| 127 |
if acts is None or grads is None:
|
| 128 |
return np.zeros((x.shape[2], x.shape[3]), dtype=np.float32)
|
| 129 |
|
| 130 |
+
# XGrad-CAM: weights = sum(grads * acts, spatial) / (sum(acts, spatial) + eps)
|
| 131 |
+
# This normalizes by the activation magnitude, improving localization
|
| 132 |
+
grad_act_product = grads * acts # [B, C, H, W]
|
| 133 |
+
sum_grad_act = grad_act_product.sum(dim=(2, 3), keepdim=True) # [B, C, 1, 1]
|
| 134 |
+
sum_acts = acts.sum(dim=(2, 3), keepdim=True) + 1e-7 # [B, C, 1, 1]
|
| 135 |
+
weights = sum_grad_act / sum_acts # [B, C, 1, 1]
|
| 136 |
|
| 137 |
# Weighted combination of activations
|
| 138 |
cam = (weights * acts).sum(dim=1, keepdim=True) # [B, 1, H, W]
|
|
|
|
| 249 |
|
| 250 |
# Grad-CAM
|
| 251 |
try:
|
| 252 |
+
heatmap = _compute_xgradcam(enc, x.clone(), target_layer_name="b3")
|
| 253 |
if original_images.get(k) is not None:
|
| 254 |
gradcam_heatmaps[k] = _overlay_heatmap(original_images[k], heatmap, alpha=0.5)
|
| 255 |
else:
|
|
|
|
| 270 |
|
| 271 |
def format_view_weights_html(analysis: ViewAnalysis) -> str:
|
| 272 |
"""Format view weights as clean HTML with styled progress bars."""
|
| 273 |
+
# View labels with descriptions
|
| 274 |
view_info = {
|
| 275 |
"whole": ("Whole Image", "#4CAF50"), # green
|
| 276 |
"face": ("Face", "#2196F3"), # blue
|
| 277 |
+
"eyes": ("Eyes", "#FF9800"), # orange
|
| 278 |
}
|
| 279 |
|
| 280 |
html_parts = ['<div style="font-family: sans-serif; padding: 10px;">']
|
webui_gradio.py
CHANGED
|
@@ -298,7 +298,7 @@ def classify_and_analyze(
|
|
| 298 |
return ("β Provide a whole image.",) + empty_result[1:]
|
| 299 |
|
| 300 |
try:
|
| 301 |
-
# Extract face and
|
| 302 |
face_pil = None
|
| 303 |
eye_pil = None
|
| 304 |
if ex is not None:
|
|
@@ -319,7 +319,7 @@ def classify_and_analyze(
|
|
| 319 |
preds = topk_predictions_unique_labels(db, z, topk=int(topk))
|
| 320 |
rows = [[name, float(score)] for (name, score) in preds]
|
| 321 |
|
| 322 |
-
# Analysis (
|
| 323 |
views = {"whole": wt, "face": ft, "eyes": et}
|
| 324 |
original_images = {"whole": w, "face": face_pil, "eyes": eye_pil}
|
| 325 |
analysis = analyze_views(lm.model, views, original_images, lm.device)
|
|
@@ -546,19 +546,19 @@ def build_ui() -> gr.Blocks:
|
|
| 546 |
gr.Markdown("### π― Classification Results")
|
| 547 |
table = gr.Dataframe(headers=["Artist", "Similarity"], datatype=["str", "number"], interactive=False)
|
| 548 |
|
| 549 |
-
#
|
| 550 |
-
gr.Markdown("### π₯
|
| 551 |
gr.Markdown("*Where the model focused in each view:*")
|
| 552 |
with gr.Row():
|
| 553 |
gcam_whole = gr.Image(label="Whole Image", type="pil")
|
| 554 |
gcam_face = gr.Image(label="Face", type="pil")
|
| 555 |
-
gcam_eye = gr.Image(label="
|
| 556 |
|
| 557 |
# Extracted views
|
| 558 |
gr.Markdown("### ποΈ Auto-Extracted Views")
|
| 559 |
with gr.Row():
|
| 560 |
face_prev = gr.Image(label="Detected Face", type="pil")
|
| 561 |
-
eye_prev = gr.Image(label="Detected
|
| 562 |
|
| 563 |
run_btn.click(
|
| 564 |
classify_and_analyze,
|
|
@@ -571,8 +571,8 @@ def build_ui() -> gr.Blocks:
|
|
| 571 |
"### β οΈ Temporary Prototypes Only\n"
|
| 572 |
"Add prototypes using random triplet combinations and K-means clustering (same as eval process).\n"
|
| 573 |
"1. Upload multiple whole images\n"
|
| 574 |
-
"2. Face and
|
| 575 |
-
"3. Random triplets (whole + face +
|
| 576 |
"4. K-means clustering creates K prototype centers\n\n"
|
| 577 |
"**These prototypes are session-only** β lost when the Space restarts."
|
| 578 |
)
|
|
|
|
| 298 |
return ("β Provide a whole image.",) + empty_result[1:]
|
| 299 |
|
| 300 |
try:
|
| 301 |
+
# Extract face and eyes
|
| 302 |
face_pil = None
|
| 303 |
eye_pil = None
|
| 304 |
if ex is not None:
|
|
|
|
| 319 |
preds = topk_predictions_unique_labels(db, z, topk=int(topk))
|
| 320 |
rows = [[name, float(score)] for (name, score) in preds]
|
| 321 |
|
| 322 |
+
# Analysis (XGrad-CAM + view weights)
|
| 323 |
views = {"whole": wt, "face": ft, "eyes": et}
|
| 324 |
original_images = {"whole": w, "face": face_pil, "eyes": eye_pil}
|
| 325 |
analysis = analyze_views(lm.model, views, original_images, lm.device)
|
|
|
|
| 546 |
gr.Markdown("### π― Classification Results")
|
| 547 |
table = gr.Dataframe(headers=["Artist", "Similarity"], datatype=["str", "number"], interactive=False)
|
| 548 |
|
| 549 |
+
# XGrad-CAM heatmaps
|
| 550 |
+
gr.Markdown("### π₯ XGrad-CAM Attention Maps")
|
| 551 |
gr.Markdown("*Where the model focused in each view:*")
|
| 552 |
with gr.Row():
|
| 553 |
gcam_whole = gr.Image(label="Whole Image", type="pil")
|
| 554 |
gcam_face = gr.Image(label="Face", type="pil")
|
| 555 |
+
gcam_eye = gr.Image(label="Eyes", type="pil")
|
| 556 |
|
| 557 |
# Extracted views
|
| 558 |
gr.Markdown("### ποΈ Auto-Extracted Views")
|
| 559 |
with gr.Row():
|
| 560 |
face_prev = gr.Image(label="Detected Face", type="pil")
|
| 561 |
+
eye_prev = gr.Image(label="Detected Eyes", type="pil")
|
| 562 |
|
| 563 |
run_btn.click(
|
| 564 |
classify_and_analyze,
|
|
|
|
| 571 |
"### β οΈ Temporary Prototypes Only\n"
|
| 572 |
"Add prototypes using random triplet combinations and K-means clustering (same as eval process).\n"
|
| 573 |
"1. Upload multiple whole images\n"
|
| 574 |
+
"2. Face and eyes are auto-extracted from each\n"
|
| 575 |
+
"3. Random triplets (whole + face + eyes) are created\n"
|
| 576 |
"4. K-means clustering creates K prototype centers\n\n"
|
| 577 |
"**These prototypes are session-only** β lost when the Space restarts."
|
| 578 |
)
|