iljung1106 commited on
Commit
5570c3c
Β·
1 Parent(s): 39e77fe

combined classify and analyze

Browse files
Files changed (2) hide show
  1. app/visualization.py +28 -34
  2. webui_gradio.py +57 -102
app/visualization.py CHANGED
@@ -261,41 +261,35 @@ def analyze_views(
261
  )
262
 
263
 
264
- def format_analysis_text(analysis: ViewAnalysis) -> str:
265
- """Format analysis results as markdown text."""
266
- lines = ["## πŸ“Š View & Branch Analysis\n"]
 
 
 
 
 
 
 
 
267
 
268
- # View weights
269
- lines.append("### View Attention Weights")
270
- lines.append("How much each view contributed to the final embedding:\n")
271
  for k in ("whole", "face", "eyes"):
272
  w = analysis.view_weights.get(k, 0.0)
273
- bar_len = int(w * 20)
274
- bar = "β–ˆ" * bar_len + "β–‘" * (20 - bar_len)
275
- lines.append(f"- **{k.capitalize()}**: `{bar}` {w:.1%}")
276
-
277
- lines.append("")
278
-
279
- # Branch weights per view
280
- lines.append("### Branch Attention Weights (per view)")
281
- lines.append("Which style features were most important:\n")
282
- branch_names = ["Gram", "Cov", "Spectrum", "Stats"]
283
- branch_desc = {
284
- "Gram": "texture patterns",
285
- "Cov": "color correlations",
286
- "Spectrum": "frequency content",
287
- "Stats": "mean/variance",
288
- }
289
-
290
- for view_name in ("whole", "face", "eyes"):
291
- bw = analysis.branch_weights.get(view_name, {})
292
- if bw:
293
- lines.append(f"\n**{view_name.capitalize()}**:")
294
- for b in branch_names:
295
- w = bw.get(b, 0.0)
296
- bar_len = int(w * 15)
297
- bar = "β–“" * bar_len + "β–‘" * (15 - bar_len)
298
- lines.append(f" - {b} ({branch_desc[b]}): `{bar}` {w:.1%}")
299
-
300
- return "\n".join(lines)
301
 
 
261
  )
262
 
263
 
264
+ def format_view_weights_html(analysis: ViewAnalysis) -> str:
265
+ """Format view weights as clean HTML with styled progress bars."""
266
+ # View labels with descriptions (eye = singular)
267
+ view_info = {
268
+ "whole": ("Whole Image", "#4CAF50"), # green
269
+ "face": ("Face", "#2196F3"), # blue
270
+ "eyes": ("Eye Region", "#FF9800"), # orange
271
+ }
272
+
273
+ html_parts = ['<div style="font-family: sans-serif; padding: 10px;">']
274
+ html_parts.append('<h3 style="margin-bottom: 15px;">πŸ“Š View Contribution</h3>')
275
 
 
 
 
276
  for k in ("whole", "face", "eyes"):
277
  w = analysis.view_weights.get(k, 0.0)
278
+ label, color = view_info[k]
279
+ pct = int(w * 100)
280
+
281
+ html_parts.append(f'''
282
+ <div style="margin-bottom: 12px;">
283
+ <div style="display: flex; justify-content: space-between; margin-bottom: 4px;">
284
+ <span style="font-weight: 500;">{label}</span>
285
+ <span style="font-weight: 600; color: {color};">{pct}%</span>
286
+ </div>
287
+ <div style="background: #e0e0e0; border-radius: 4px; height: 20px; overflow: hidden;">
288
+ <div style="background: {color}; width: {pct}%; height: 100%; border-radius: 4px; transition: width 0.3s;"></div>
289
+ </div>
290
+ </div>
291
+ ''')
292
+
293
+ html_parts.append('</div>')
294
+ return "".join(html_parts)
 
 
 
 
 
 
 
 
 
 
 
295
 
webui_gradio.py CHANGED
@@ -166,7 +166,7 @@ _patch_gradio_client_bool_jsonschema()
166
  from app.model_io import LoadedModel, embed_triview, load_style_model
167
  from app.proto_db import PrototypeDB, load_prototype_db, topk_predictions_unique_labels
168
  from app.view_extractor import AnimeFaceEyeExtractor, ExtractorCfg
169
- from app.visualization import ViewAnalysis, analyze_views, format_analysis_text
170
 
171
 
172
  ROOT = Path(__file__).resolve().parent
@@ -268,16 +268,19 @@ def load_all(ckpt_path: str, proto_path: str, device: str) -> str:
268
  return f"βœ… Loaded checkpoint `{Path(ckpt_path).name}` (stage={lm.stage_i}) and proto DB `{Path(proto_path).name}` (N={db.centers.shape[0]})"
269
 
270
 
271
- def classify(
272
  whole_img,
273
  topk: int,
274
  ):
275
  """
276
- Classify using auto-extracted face/eyes from whole image.
277
- Returns: status, table_rows, face_preview, eyes_preview
 
278
  """
 
 
279
  if APP_STATE.lm is None or APP_STATE.db is None:
280
- return "❌ Click **Load** first.", [], None, None
281
 
282
  lm = APP_STATE.lm
283
  db = APP_STATE.db
@@ -292,88 +295,48 @@ def classify(
292
 
293
  w = _to_pil(whole_img)
294
  if w is None:
295
- return "❌ Provide a whole image.", [], None, None
296
 
297
  try:
 
298
  face_pil = None
299
- eyes_pil = None
300
  if ex is not None:
301
  rgb = np.array(w.convert("RGB"))
302
- face_rgb, eyes_rgb = ex.extract(rgb)
303
  if face_rgb is not None:
304
  face_pil = Image.fromarray(face_rgb)
305
- if eyes_rgb is not None:
306
- eyes_pil = Image.fromarray(eyes_rgb)
307
 
 
308
  wt = _pil_to_tensor(w, lm.T_w)
309
  ft = _pil_to_tensor(face_pil, lm.T_f) if face_pil is not None else None
310
- et = _pil_to_tensor(eyes_pil, lm.T_e) if eyes_pil is not None else None
 
 
311
  z = embed_triview(lm, whole=wt, face=ft, eyes=et)
312
  preds = topk_predictions_unique_labels(db, z, topk=int(topk))
313
- except Exception as ex:
314
- return f"❌ Inference failed: {ex}", [], None, None
315
-
316
- rows = [[name, float(score)] for (name, score) in preds]
317
- return "βœ… OK", rows, (face_pil if "face_pil" in locals() else None), (eyes_pil if "eyes_pil" in locals() else None)
318
-
319
-
320
- def analyze_image(whole_img):
321
- """
322
- Analyze an image showing view weights, branch weights, and Grad-CAM.
323
- Returns: status, analysis_text, whole_gradcam, face_gradcam, eyes_gradcam, face_preview, eyes_preview
324
- """
325
- if APP_STATE.lm is None:
326
- return "❌ Click **Load** first.", "", None, None, None, None, None
327
-
328
- lm = APP_STATE.lm
329
- ex = APP_STATE.extractor
330
-
331
- def _to_pil(x):
332
- if x is None:
333
- return None
334
- if isinstance(x, Image.Image):
335
- return x
336
- return Image.fromarray(x)
337
-
338
- w = _to_pil(whole_img)
339
- if w is None:
340
- return "❌ Provide a whole image.", "", None, None, None, None, None
341
-
342
- try:
343
- # Extract face and eyes
344
- face_pil = None
345
- eyes_pil = None
346
- if ex is not None:
347
- rgb = np.array(w.convert("RGB"))
348
- face_rgb, eyes_rgb = ex.extract(rgb)
349
- if face_rgb is not None:
350
- face_pil = Image.fromarray(face_rgb)
351
- if eyes_rgb is not None:
352
- eyes_pil = Image.fromarray(eyes_rgb)
353
-
354
- # Prepare tensors
355
- wt = _pil_to_tensor(w, lm.T_w)
356
- ft = _pil_to_tensor(face_pil, lm.T_f) if face_pil is not None else None
357
- et = _pil_to_tensor(eyes_pil, lm.T_e) if eyes_pil is not None else None
358
 
 
359
  views = {"whole": wt, "face": ft, "eyes": et}
360
- original_images = {"whole": w, "face": face_pil, "eyes": eyes_pil}
361
-
362
- # Run analysis
363
  analysis = analyze_views(lm.model, views, original_images, lm.device)
364
- analysis_text = format_analysis_text(analysis)
365
 
366
  return (
367
- "βœ… Analysis complete",
368
- analysis_text,
 
369
  analysis.gradcam_heatmaps.get("whole"),
370
  analysis.gradcam_heatmaps.get("face"),
371
  analysis.gradcam_heatmaps.get("eyes"),
372
  face_pil,
373
- eyes_pil,
374
  )
375
  except Exception as e:
376
- return f"❌ Analysis failed: {e}", "", None, None, None, None, None
377
 
378
 
379
  def _gallery_item_to_pil(item) -> Optional[Image.Image]:
@@ -569,46 +532,38 @@ def build_ui() -> gr.Blocks:
569
 
570
  with gr.Tab("Classify"):
571
  with gr.Row():
572
- whole = gr.Image(label="Whole image (required)", type="pil")
573
- face_prev = gr.Image(label="Extracted face (auto)", type="pil")
574
- eyes_prev = gr.Image(label="Extracted eyes (auto)", type="pil")
575
- with gr.Row():
576
- topk = gr.Slider(1, 20, value=5, step=1, label="Top-K")
577
- run_btn = gr.Button("Run", variant="primary")
578
-
579
- out_status = gr.Markdown("")
580
- table = gr.Dataframe(headers=["label", "cosine_sim"], datatype=["str", "number"], interactive=False)
581
- run_btn.click(classify, inputs=[whole, topk], outputs=[out_status, table, face_prev, eyes_prev])
582
-
583
- with gr.Tab("Analyze (Grad-CAM)"):
584
- gr.Markdown(
585
- "### πŸ” View & Branch Analysis with Grad-CAM\n"
586
- "Visualize which parts of the image and which style features the model focuses on.\n"
587
- "- **View weights**: How much each view (whole/face/eyes) contributed\n"
588
- "- **Branch weights**: Which style features (Gram/Cov/Spectrum/Stats) were important\n"
589
- "- **Grad-CAM**: Spatial attention heatmaps showing where the model looked"
590
- )
591
- with gr.Row():
592
- analyze_input = gr.Image(label="Whole image", type="pil")
593
- analyze_btn = gr.Button("Analyze", variant="primary")
594
- analyze_status = gr.Markdown("")
595
- analyze_text = gr.Markdown("")
596
-
597
- gr.Markdown("### Grad-CAM Heatmaps")
598
  with gr.Row():
599
- gcam_whole = gr.Image(label="Whole (Grad-CAM)", type="pil")
600
- gcam_face = gr.Image(label="Face (Grad-CAM)", type="pil")
601
- gcam_eyes = gr.Image(label="Eyes (Grad-CAM)", type="pil")
602
 
603
- gr.Markdown("### Extracted Views")
 
604
  with gr.Row():
605
- analyze_face = gr.Image(label="Extracted Face", type="pil")
606
- analyze_eyes = gr.Image(label="Extracted Eyes", type="pil")
607
 
608
- analyze_btn.click(
609
- analyze_image,
610
- inputs=[analyze_input],
611
- outputs=[analyze_status, analyze_text, gcam_whole, gcam_face, gcam_eyes, analyze_face, analyze_eyes],
612
  )
613
 
614
  with gr.Tab("Add prototype (temporary)"):
@@ -616,8 +571,8 @@ def build_ui() -> gr.Blocks:
616
  "### ⚠️ Temporary Prototypes Only\n"
617
  "Add prototypes using random triplet combinations and K-means clustering (same as eval process).\n"
618
  "1. Upload multiple whole images\n"
619
- "2. Face/eyes are auto-extracted from each\n"
620
- "3. Random triplets (whole + face + eyes) are created\n"
621
  "4. K-means clustering creates K prototype centers\n\n"
622
  "**These prototypes are session-only** β€” lost when the Space restarts."
623
  )
 
166
  from app.model_io import LoadedModel, embed_triview, load_style_model
167
  from app.proto_db import PrototypeDB, load_prototype_db, topk_predictions_unique_labels
168
  from app.view_extractor import AnimeFaceEyeExtractor, ExtractorCfg
169
+ from app.visualization import ViewAnalysis, analyze_views, format_view_weights_html
170
 
171
 
172
  ROOT = Path(__file__).resolve().parent
 
268
  return f"βœ… Loaded checkpoint `{Path(ckpt_path).name}` (stage={lm.stage_i}) and proto DB `{Path(proto_path).name}` (N={db.centers.shape[0]})"
269
 
270
 
271
+ def classify_and_analyze(
272
  whole_img,
273
  topk: int,
274
  ):
275
  """
276
+ Classify and analyze an image in one pass.
277
+ Returns: status, table_rows, view_weights_html,
278
+ gcam_whole, gcam_face, gcam_eye, face_preview, eye_preview
279
  """
280
+ empty_result = ("", [], "", None, None, None, None, None)
281
+
282
  if APP_STATE.lm is None or APP_STATE.db is None:
283
+ return ("❌ Click **Load** first.",) + empty_result[1:]
284
 
285
  lm = APP_STATE.lm
286
  db = APP_STATE.db
 
295
 
296
  w = _to_pil(whole_img)
297
  if w is None:
298
+ return ("❌ Provide a whole image.",) + empty_result[1:]
299
 
300
  try:
301
+ # Extract face and eye region
302
  face_pil = None
303
+ eye_pil = None
304
  if ex is not None:
305
  rgb = np.array(w.convert("RGB"))
306
+ face_rgb, eye_rgb = ex.extract(rgb)
307
  if face_rgb is not None:
308
  face_pil = Image.fromarray(face_rgb)
309
+ if eye_rgb is not None:
310
+ eye_pil = Image.fromarray(eye_rgb)
311
 
312
+ # Prepare tensors
313
  wt = _pil_to_tensor(w, lm.T_w)
314
  ft = _pil_to_tensor(face_pil, lm.T_f) if face_pil is not None else None
315
+ et = _pil_to_tensor(eye_pil, lm.T_e) if eye_pil is not None else None
316
+
317
+ # Classification
318
  z = embed_triview(lm, whole=wt, face=ft, eyes=et)
319
  preds = topk_predictions_unique_labels(db, z, topk=int(topk))
320
+ rows = [[name, float(score)] for (name, score) in preds]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
+ # Analysis (Grad-CAM + view weights)
323
  views = {"whole": wt, "face": ft, "eyes": et}
324
+ original_images = {"whole": w, "face": face_pil, "eyes": eye_pil}
 
 
325
  analysis = analyze_views(lm.model, views, original_images, lm.device)
326
+ view_weights_html = format_view_weights_html(analysis)
327
 
328
  return (
329
+ "βœ… Done",
330
+ rows,
331
+ view_weights_html,
332
  analysis.gradcam_heatmaps.get("whole"),
333
  analysis.gradcam_heatmaps.get("face"),
334
  analysis.gradcam_heatmaps.get("eyes"),
335
  face_pil,
336
+ eye_pil,
337
  )
338
  except Exception as e:
339
+ return (f"❌ Failed: {e}",) + empty_result[1:]
340
 
341
 
342
  def _gallery_item_to_pil(item) -> Optional[Image.Image]:
 
532
 
533
  with gr.Tab("Classify"):
534
  with gr.Row():
535
+ with gr.Column(scale=1):
536
+ whole = gr.Image(label="Upload image", type="pil")
537
+ with gr.Row():
538
+ topk = gr.Slider(1, 20, value=5, step=1, label="Top-K")
539
+ run_btn = gr.Button("Run", variant="primary")
540
+ out_status = gr.Markdown("")
541
+
542
+ with gr.Column(scale=1):
543
+ view_weights_display = gr.HTML(label="View Contribution")
544
+
545
+ # Classification results
546
+ gr.Markdown("### 🎯 Classification Results")
547
+ table = gr.Dataframe(headers=["Artist", "Similarity"], datatype=["str", "number"], interactive=False)
548
+
549
+ # Grad-CAM heatmaps
550
+ gr.Markdown("### πŸ”₯ Grad-CAM Attention Maps")
551
+ gr.Markdown("*Where the model focused in each view:*")
 
 
 
 
 
 
 
 
 
552
  with gr.Row():
553
+ gcam_whole = gr.Image(label="Whole Image", type="pil")
554
+ gcam_face = gr.Image(label="Face", type="pil")
555
+ gcam_eye = gr.Image(label="Eye Region", type="pil")
556
 
557
+ # Extracted views
558
+ gr.Markdown("### πŸ‘οΈ Auto-Extracted Views")
559
  with gr.Row():
560
+ face_prev = gr.Image(label="Detected Face", type="pil")
561
+ eye_prev = gr.Image(label="Detected Eye", type="pil")
562
 
563
+ run_btn.click(
564
+ classify_and_analyze,
565
+ inputs=[whole, topk],
566
+ outputs=[out_status, table, view_weights_display, gcam_whole, gcam_face, gcam_eye, face_prev, eye_prev],
567
  )
568
 
569
  with gr.Tab("Add prototype (temporary)"):
 
571
  "### ⚠️ Temporary Prototypes Only\n"
572
  "Add prototypes using random triplet combinations and K-means clustering (same as eval process).\n"
573
  "1. Upload multiple whole images\n"
574
+ "2. Face and eye region are auto-extracted from each\n"
575
+ "3. Random triplets (whole + face + eye) are created\n"
576
  "4. K-means clustering creates K prototype centers\n\n"
577
  "**These prototypes are session-only** β€” lost when the Space restarts."
578
  )