Spaces:
Sleeping
Sleeping
| # ui/csv_tab.py | |
| """ | |
| Builds the CSV-upload tab (batch metrics). | |
| - Summary table: **only global scores** (no S/O/A/P). Labels are short (e.g., "BLEU", not "BLEU GLOBAL"). | |
| - Detailed table: shows only global F1 columns (colored) and, when available, dark badges for P/R. | |
| - CSV export includes whatever columns the backend produced; UI renders only the globals. | |
| - Upload "Status" is collapsed into the file input's label. | |
| - Errors (missing CSV, columns not chosen, etc.) are displayed in the status textbox under "Run Evaluation". | |
| """ | |
| import os | |
| import time | |
| import tempfile | |
| import gradio as gr | |
| import pandas as pd | |
| from metrics import compute_all_metrics_batch, BERT_FRIENDLY_TO_MODEL | |
| from ui.widgets import MetricCheckboxGroup, BertCheckboxGroup | |
| from utils.file_utils import smart_read_csv | |
| from utils.colors_utils import get_metric_color | |
| from utils.tokenizer_refgen import generate_diff_html | |
| # ------------------- Summary HTML builder (GLOBAL ONLY) ------------------- | |
| def build_summary_html(result_df: pd.DataFrame, selected_metrics: list, bert_models: list | None = None) -> str: | |
| def safe_stats(col): | |
| if col not in result_df.columns: | |
| return None | |
| s = result_df[col].dropna() | |
| if s.empty: | |
| return None | |
| s = s.astype(float) | |
| avg, mn, mx = s.mean(), s.min(), s.max() | |
| def audio_id_for(v): | |
| subset = result_df[result_df[col].astype(float) == v] | |
| if not subset.empty and "code_audio_transcription" in subset.columns: | |
| aid = subset.iloc[0]["code_audio_transcription"] | |
| try: | |
| return int(aid) | |
| except Exception: | |
| return aid | |
| return "" | |
| return {"avg": avg, "min": mn, "min_id": audio_id_for(mn), "max": mx, "max_id": audio_id_for(mx)} | |
| rows = [] | |
| # NOTE: We used to show per-section rows (S/O/A/P) when a single metric was selected. | |
| # That logic has been **removed**; we now present **only global** rows for all metrics. | |
| if "BLEU" in selected_metrics: | |
| s = safe_stats("bleu_global") | |
| if s: | |
| rows.append(("bleu_global", s)) | |
| if "BLEURT" in selected_metrics: | |
| s = safe_stats("bleurt_global") | |
| if s: | |
| rows.append(("bleurt_global", s)) | |
| if "ROUGE" in selected_metrics: | |
| s = safe_stats("rougeL_global_f1") | |
| if s: | |
| rows.append(("rougeL_global_f1", s)) | |
| # BERTScore (global only) | |
| if "BERTSCORE" in selected_metrics and bert_models: | |
| # NOTE: Previously, if only BERTScore with one model was selected, we added per-section rows. | |
| # That behavior is **disabled**. We only show global columns: | |
| # - bertscore_<short>_f1 (multi-model) | |
| # - or bertscore_global_f1 (if that's what backend produced) | |
| for friendly in bert_models: | |
| mid = BERT_FRIENDLY_TO_MODEL.get(friendly) | |
| if not mid: | |
| continue | |
| short = mid.split("/")[-1].replace("-", "_") | |
| col = f"bertscore_{short}_f1" if f"bertscore_{short}_f1" in result_df.columns else "bertscore_global_f1" | |
| s = safe_stats(col) | |
| if s: | |
| rows.append((col, s)) | |
| if not rows: | |
| return "<div style='padding:8px;background:#1f1f1f;color:#eee;border-radius:6px;'>No summary available.</div>" | |
| # Build HTML table | |
| html = """ | |
| <div style="margin-bottom:12px;overflow-x:auto;"> | |
| <div style="font-weight:600;margin-bottom:4px;color:#f5f5f5;font-size:16px;">Summary Statistics</div> | |
| <table style="border-collapse:collapse;width:100%;font-family:system-ui,-apple-system,BlinkMacSystemFont,Segoe UI,Roboto,sans-serif;border-radius:8px;overflow:hidden;min-width:500px;"> | |
| <thead><tr> | |
| <th style="padding:8px 12px;background:#2d3748;color:#fff;text-align:left;font-weight:600;">Metric</th> | |
| <th style="padding:8px 12px;background:#2d3748;color:#fff;text-align:center;font-weight:600;">Avg</th> | |
| <th style="padding:8px 12px;background:#2d3748;color:#fff;text-align:center;font-weight:600;">Min (ID)</th> | |
| <th style="padding:8px 12px;background:#2d3748;color:#fff;text-align:center;font-weight:600;">Max (ID)</th> | |
| </tr></thead><tbody> | |
| """ | |
| for col, stat in rows: | |
| # Pretty names (drop "GLOBAL") | |
| if col == "bleu_global": | |
| name = "BLEU" | |
| elif col == "bleurt_global": | |
| name = "BLEURT" | |
| elif col == "rougeL_global_f1": | |
| name = "ROUGE-L" | |
| elif col.startswith("bertscore_"): | |
| if col == "bertscore_global_f1": | |
| name = "BERTSCORE" | |
| else: | |
| label = " ".join(col.split("_")[1:-1]).upper() | |
| name = f"BERTSCORE {label}" if label else "BERTSCORE" | |
| else: | |
| name = col.replace("_", " ").upper() | |
| avg = f"{stat['avg']:.4f}" | |
| mn = f"{stat['min']:.4f} ({stat['min_id']})" if stat['min_id'] != "" else f"{stat['min']:.4f}" | |
| mx = f"{stat['max']:.4f} ({stat['max_id']})" if stat['max_id'] != "" else f"{stat['max']:.4f}" | |
| # Color scale by metric family (F1) | |
| if col.startswith("bleu_"): | |
| ca, cm, cx = get_metric_color(stat['avg'], "BLEU"), get_metric_color(stat['min'], "BLEU"), get_metric_color(stat['max'], "BLEU") | |
| elif col.startswith("bleurt_"): | |
| ca, cm, cx = get_metric_color(stat['avg'], "BLEURT"), get_metric_color(stat['min'], "BLEURT"), get_metric_color(stat['max'], "BLEURT") | |
| elif col.startswith("rougeL_"): | |
| ca, cm, cx = get_metric_color(stat['avg'], "ROUGE"), get_metric_color(stat['min'], "ROUGE"), get_metric_color(stat['max'], "ROUGE") | |
| else: | |
| ca, cm, cx = get_metric_color(stat['avg'], "BERTSCORE"), get_metric_color(stat['min'], "BERTSCORE"), get_metric_color(stat['max'], "BERTSCORE") | |
| html += f""" | |
| <tr style="background:#0f1218;"> | |
| <td style="padding:8px 12px;border:1px solid #2f3240;color:#fff;white-space:nowrap;">{name}</td> | |
| <td style="padding:8px 12px;border:1px solid #2f3240;background:{ca};color:#fff;text-align:center;white-space:nowrap;">{avg}</td> | |
| <td style="padding:8px 12px;border:1px solid #2f3240;background:{cm};color:#fff;text-align:center;white-space:nowrap;">{mn}</td> | |
| <td style="padding:8px 12px;border:1px solid #2f3240;background:{cx};color:#fff;text-align:center;white-space:nowrap;">{mx}</td> | |
| </tr> | |
| """ | |
| html += "</tbody></table></div>" | |
| return html | |
| # ------------------- Detailed table (GLOBAL ONLY, F1 colored + dark P/R badges) ------------------- | |
| def render_results_table_html(result_df: pd.DataFrame) -> str: | |
| if result_df is None or result_df.empty: | |
| return "<div style='padding:8px;background:#1f1f1f;color:#eee;border-radius:6px;'>No results.</div>" | |
| # Keep only *global* F1 columns (skip *_p/_r and any S/O/A/P) | |
| def is_global_f1(col: str) -> bool: | |
| if col == "code_audio_transcription": | |
| return False | |
| if col.endswith("_p") or col.endswith("_r"): | |
| return False | |
| if col.startswith("bleu_"): | |
| return col == "bleu_global" | |
| if col.startswith("bleurt_"): | |
| return col == "bleurt_global" | |
| if col.startswith("rougeL_"): | |
| return col == "rougeL_global_f1" | |
| if col.startswith("bertscore_"): | |
| parts = col.split("_") | |
| # Exclude per-section: bertscore_S_f1, etc. | |
| if len(parts) >= 2 and parts[1] in {"S", "O", "A", "P"}: | |
| return False | |
| # Allow model-specific or "bertscore_global_f1" | |
| return parts[-1] == "f1" or col == "bertscore_global_f1" | |
| return False | |
| f1_cols = [c for c in result_df.columns if is_global_f1(c)] | |
| # Sort for readability: BLEU, BLEURT, ROUGE-L, BERTSCORE (...) | |
| def _grp_key(col): | |
| if col.startswith("bleu_"): | |
| g = 0 | |
| elif col.startswith("bleurt_"): | |
| g = 1 | |
| elif col.startswith("rougeL_"): | |
| g = 2 | |
| elif col.startswith("bertscore_"): | |
| g = 3 | |
| else: | |
| g = 9 | |
| return (g, col) | |
| f1_cols = sorted(f1_cols, key=_grp_key) | |
| # HTML table | |
| html = [ | |
| "<div style='overflow-x:auto;'>", | |
| "<div style='font-weight:600;margin:8px 0;color:#f5f5f5;font-size:16px;'>Individual Results</div>", | |
| "<table style='border-collapse:collapse;width:100%;font-family:system-ui,-apple-system,BlinkMacSystemFont,Segoe UI,Roboto,sans-serif;border-radius:8px;overflow:hidden;'>", | |
| "<thead><tr>", | |
| "<th style='padding:8px 12px;background:#2d3748;color:#fff;text-align:left;font-weight:600;white-space:nowrap;'>ID</th>", | |
| ] | |
| def pretty_header(col: str) -> str: | |
| if col == "bleu_global": | |
| return "BLEU" | |
| if col == "bleurt_global": | |
| return "BLEURT" | |
| if col == "rougeL_global_f1": | |
| return "ROUGE-L" | |
| if col.startswith("bertscore_"): | |
| if col == "bertscore_global_f1": | |
| return "BERTSCORE" | |
| label = " ".join(col.split("_")[1:-1]).upper() | |
| return f"BERTSCORE {label}" if label else "BERTSCORE" | |
| return col.replace("_", " ").upper() | |
| for col in f1_cols: | |
| html.append( | |
| f"<th style='padding:8px 12px;background:#2d3748;color:#fff;text-align:center;font-weight:600;white-space:nowrap;'>{pretty_header(col)}</th>" | |
| ) | |
| html.append("</tr></thead><tbody>") | |
| for _, row in result_df.iterrows(): | |
| rid = row.get("code_audio_transcription", "") | |
| try: | |
| rid = int(rid) | |
| except Exception: | |
| pass | |
| html.append("<tr style='background:#0f1218;'>") | |
| html.append(f"<td style='padding:8px 12px;border:1px solid #2f3240;color:#fff;white-space:nowrap;'>{rid}</td>") | |
| for col in f1_cols: | |
| val = row.get(col, None) | |
| # figure metric family & pick P/R columns accordingly | |
| metric_kind = "BERTSCORE" | |
| p_text = r_text = "" | |
| if col.startswith("bleu_"): | |
| metric_kind = "BLEU" | |
| # BLEU: no P/R | |
| elif col.startswith("bleurt_"): | |
| metric_kind = "BLEURT" | |
| elif col.startswith("rougeL_"): | |
| metric_kind = "ROUGE" | |
| base = "rougeL_global" # global root | |
| pcol, rcol = f"{base}_p", f"{base}_r" | |
| p = row.get(pcol, None) | |
| r = row.get(rcol, None) | |
| p_text = f"P: {p:.4f}" if isinstance(p, (int, float)) else "" | |
| r_text = f"R: {r:.4f}" if isinstance(r, (int, float)) else "" | |
| elif col.startswith("bertscore_"): | |
| metric_kind = "BERTSCORE" | |
| # try model-specific first | |
| base = col[:-3] if col.endswith("_f1") else col # strip trailing _f1 | |
| pcol, rcol = f"{base}_p", f"{base}_r" | |
| if pcol not in result_df.columns and rcol not in result_df.columns: | |
| # fallback to "bertscore_global" naming | |
| pcol, rcol = "bertscore_global_p", "bertscore_global_r" | |
| p = row.get(pcol, None) | |
| r = row.get(rcol, None) | |
| p_text = f"P: {p:.4f}" if isinstance(p, (int, float)) else "" | |
| r_text = f"R: {r:.4f}" if isinstance(r, (int, float)) else "" | |
| if isinstance(val, (int, float)): | |
| bg = get_metric_color(float(val), metric_kind) | |
| val_text = f"{float(val):.4f}" | |
| else: | |
| bg = "transparent" | |
| val_text = "—" | |
| # Dark badges for P/R | |
| pills = [] | |
| if p_text: | |
| pills.append("<span style='padding:1px 6px;border-radius:999px;background:rgba(0,0,0,.48);color:#fff;display:inline-block;'>" | |
| f"{p_text}</span>") | |
| if r_text: | |
| pills.append("<span style='padding:1px 6px;border-radius:999px;background:rgba(0,0,0,.48);color:#fff;display:inline-block;margin-left:6px;'>" | |
| f"{r_text}</span>") | |
| badges = "" | |
| if pills: | |
| badges = "<div style='font-size:12px;margin-top:4px;line-height:1.2;'>" + "".join(pills) + "</div>" | |
| html.append( | |
| f"<td style='padding:8px 12px;border:1px solid #2f3240;background:{bg};color:#fff;text-align:center;white-space:nowrap;'>" | |
| f"{val_text}{badges}</td>" | |
| ) | |
| html.append("</tr>") | |
| html.append("</tbody></table></div>") | |
| return "".join(html) | |
| # ------------------- Tab builder ------------------- | |
| def build_csv_tab(): | |
| with gr.Blocks() as tab: | |
| state_df = gr.State() # original uploaded DataFrame | |
| state_pairs = gr.State() # standardized pairs: id + reference + generated | |
| state_result = gr.State() # metrics result DataFrame for export | |
| gr.Markdown("# RUN AN EXPERIMENT VIA CSV UPLOAD") | |
| gr.Markdown( | |
| "Upload a CSV of reference/generated text pairs, map the columns, pick metrics, and run a batch evaluation. \n " | |
| "F1 is highlighted in color; Precision/Recall appear as small dark badges." | |
| ) | |
| gr.Markdown("## Experiment Configuration") | |
| # 1) Upload CSV (status collapsed into the label) | |
| gr.Markdown("### Upload CSV") | |
| gr.Markdown("Provide a CSV file containing your data. It should include columns for the reference text, the generated text, and an identifier (e.g., audio ID).") | |
| with gr.Row(): | |
| file_input = gr.File(label="Upload CSV", file_types=[".csv"]) | |
| # 2) Map Columns | |
| gr.Markdown("### Map Columns") | |
| gr.Markdown("Select which columns in your CSV correspond to the reference text, generated text, and audio/example ID.") | |
| with gr.Row(visible=False) as mapping: | |
| ref_col = gr.Dropdown(label="Reference Column", choices=[]) | |
| gen_col = gr.Dropdown(label="Generated Column", choices=[]) | |
| id_col = gr.Dropdown(label="Audio ID Column", choices=[]) | |
| # 3) Select Metrics | |
| gr.Markdown("### Select Metrics") | |
| metric_selector = MetricCheckboxGroup() | |
| bert_model_selector = BertCheckboxGroup() | |
| # ---------- Divider before RESULTS ---------- | |
| gr.HTML("""<div style="height:1px;margin:22px 0;background: | |
| linear-gradient(90deg, rgba(0,0,0,0) 0%, #4a5568 35%, #4a5568 65%, rgba(0,0,0,0) 100%);"></div>""") | |
| gr.Markdown("# RESULTS") | |
| # Emphasize the run button | |
| gr.HTML(""" | |
| <style> | |
| #run-eval-btn button { | |
| background: linear-gradient(135deg, #3b82f6 0%, #2563eb 100%) !important; | |
| color: #fff !important; | |
| border: none !important; | |
| box-shadow: 0 6px 16px rgba(0,0,0,.25); | |
| } | |
| #run-eval-btn button:hover { filter: brightness(1.08); transform: translateY(-1px); } | |
| </style> | |
| """) | |
| # 4) Run Evaluation (+ Export control) | |
| with gr.Row(): | |
| run_btn = gr.Button("🚀 Run Evaluation", variant="primary", elem_id="run-eval-btn") | |
| download_btn = gr.DownloadButton(label="⬇️ Export full results (CSV)", visible=False) | |
| # This Text box will display both success and error messages | |
| output_status = gr.Text() | |
| summary_output = gr.HTML() | |
| table_output = gr.HTML() | |
| # 5) Inspect example | |
| gr.Markdown("### Inspect an Example") | |
| gr.Markdown("Pick an example by its ID to view the reference vs generated text with token-level differences highlighted.") | |
| with gr.Accordion("🔍 Show reference & generated text", open=False): | |
| pick_id = gr.Dropdown(label="Pick an Audio ID", choices=[]) | |
| ref_disp = gr.Textbox(label="Reference Text", lines=6, interactive=False) | |
| gen_disp = gr.Textbox(label="Generated Text", lines=6, interactive=False) | |
| diff_disp= gr.HTML() | |
| # ---- Handlers ---- | |
| def handle_upload(f): | |
| if not f: | |
| # reset label & hide mapping | |
| return ( | |
| None, | |
| gr.update(choices=[]), gr.update(choices=[]), gr.update(choices=[]), | |
| gr.update(visible=False), | |
| gr.update(label="Upload CSV") | |
| ) | |
| df = smart_read_csv(f.name) | |
| cols = list(df.columns) | |
| return ( | |
| df, | |
| gr.update(choices=cols, value=None), | |
| gr.update(choices=cols, value=None), | |
| gr.update(choices=cols, value=None), | |
| gr.update(visible=True), | |
| gr.update(label="Upload CSV — OK: selecione as colunas.") | |
| ) | |
| def run_batch(df, r, g, i, mets, berts): | |
| # Pre-flight validation: CSV uploaded? | |
| if df is None: | |
| return ( | |
| "Erro: por favor faça upload de um CSV e selecione as colunas.", | |
| "", "", gr.update(choices=[]), None, None, gr.update(visible=False) | |
| ) | |
| # Columns chosen? | |
| if not r or not g or not i: | |
| return ( | |
| "Erro: selecione as colunas de Reference, Generated e Audio ID.", | |
| "", "", gr.update(choices=[]), None, None, gr.update(visible=False) | |
| ) | |
| # Columns exist? | |
| missing = [c for c in [i, r, g] if c not in df.columns] | |
| if missing: | |
| return ( | |
| f"Erro: as colunas não existem no CSV: {missing}", | |
| "", "", gr.update(choices=[]), None, None, gr.update(visible=False) | |
| ) | |
| # Metrics chosen? | |
| if not mets: | |
| return ( | |
| "Erro: selecione pelo menos uma métrica.", | |
| "", "", gr.update(choices=[]), None, None, gr.update(visible=False) | |
| ) | |
| # Rename into standard schema (this is what we'll use for "Inspect an Example") | |
| try: | |
| sub = df[[i, r, g]].rename( | |
| columns={i: "code_audio_transcription", r: "dsc_reference_free_text", g: "dsc_generated_clinical_report"} | |
| ) | |
| except Exception as e: | |
| return ( | |
| f"Erro ao preparar dados: {e}", | |
| "", "", gr.update(choices=[]), None, None, gr.update(visible=False) | |
| ) | |
| # Compute metrics | |
| try: | |
| result = compute_all_metrics_batch( | |
| sub, | |
| mets, | |
| berts if "BERTSCORE" in (mets or []) else None | |
| ) | |
| except Exception as e: | |
| return ( | |
| f"Erro ao calcular métricas: {e}", | |
| "", "", gr.update(choices=[]), None, None, gr.update(visible=False) | |
| ) | |
| # Normalize IDs for dropdown | |
| try: | |
| raw_ids = result["code_audio_transcription"].dropna().unique().tolist() | |
| ids = [] | |
| for x in raw_ids: | |
| try: | |
| ids.append(int(x)) | |
| except Exception: | |
| ids.append(x) | |
| ids = sorted(ids, key=lambda z: (not isinstance(z, int), z)) | |
| except Exception: | |
| ids = [] | |
| # Build HTML views | |
| try: | |
| summary = build_summary_html(result, mets, berts if "BERTSCORE" in (mets or []) else None) | |
| table = render_results_table_html(result) | |
| except Exception as e: | |
| return ( | |
| f"Erro ao renderizar resultados: {e}", | |
| "", "", gr.update(choices=ids, value=None), None, None, gr.update(visible=False) | |
| ) | |
| # Keep results for export & show download button | |
| # Also keep standardized pairs (sub) for the "Inspect an Example" view | |
| return ( | |
| "Métricas calculadas com sucesso.", | |
| summary, | |
| table, | |
| gr.update(choices=ids, value=None), | |
| result, | |
| sub, | |
| gr.update(visible=True), | |
| ) | |
| def show_example(pairs_df, audio_id): | |
| # Use the standardized pairs dataframe (id + reference + generated) | |
| if pairs_df is None or audio_id is None: | |
| return "", "", "" | |
| try: | |
| row = pairs_df[pairs_df["code_audio_transcription"] == audio_id] | |
| if row.empty: | |
| # Try float cast fallback for IDs that come as strings | |
| try: | |
| audio_id2 = float(audio_id) | |
| row = pairs_df[pairs_df["code_audio_transcription"] == audio_id2] | |
| except Exception: | |
| return "", "", "" | |
| if row.empty: | |
| return "", "", "" | |
| row = row.iloc[0] | |
| ref_txt = row["dsc_reference_free_text"] | |
| gen_txt = row["dsc_generated_clinical_report"] | |
| return ref_txt, gen_txt, generate_diff_html(ref_txt, gen_txt) | |
| except Exception: | |
| return "", "", "" | |
| def _export_results_csv(df: pd.DataFrame | None) -> str: | |
| # Always export with comma separator; include ALL columns that were computed | |
| if df is None or df.empty: | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") | |
| with open(tmp.name, "w", encoding="utf-8") as f: | |
| f.write("no_data\n") | |
| return tmp.name | |
| ts = time.strftime("%Y%m%d_%H%M%S") | |
| tmp_path = os.path.join(tempfile.gettempdir(), f"automatic_metrics_{ts}.csv") | |
| df.to_csv(tmp_path, sep=",", index=False) | |
| return tmp_path | |
| # ---- Wiring ---- | |
| file_input.change( | |
| fn=handle_upload, | |
| inputs=[file_input], | |
| outputs=[state_df, ref_col, gen_col, id_col, mapping, file_input], # update label in place | |
| ) | |
| metric_selector.change( | |
| fn=lambda ms: gr.update(visible="BERTSCORE" in ms), | |
| inputs=[metric_selector], | |
| outputs=[bert_model_selector], | |
| ) | |
| run_btn.click( | |
| fn=run_batch, | |
| inputs=[state_df, ref_col, gen_col, id_col, metric_selector, bert_model_selector], | |
| outputs=[output_status, summary_output, table_output, pick_id, state_result, state_pairs, download_btn], | |
| ) | |
| # Use standardized pairs DF for example view (fixes KeyError on original DF) | |
| pick_id.change( | |
| fn=show_example, | |
| inputs=[state_pairs, pick_id], | |
| outputs=[ref_disp, gen_disp, diff_disp], | |
| ) | |
| download_btn.click( | |
| fn=_export_results_csv, | |
| inputs=[state_result], | |
| outputs=download_btn, # path returned; Gradio serves it | |
| ) | |
| return tab | |