Spaces:
Sleeping
Sleeping
| """ | |
| Gradio demo – visualise benchmark accuracy curves. | |
| Required CSV files (place in the *same* folder as app.py): | |
| ├── aggregated_accuracy.csv | |
| ├── qa_accuracy.csv | |
| ├── ocr_accuracy.csv | |
| └── temporal_accuracy.csv | |
| Each file has the columns | |
| Model,<context‑length‑1>,<context‑length‑2>,… | |
| where the context‑length headers are strings such as `30min`, `60min`, `120min`, … | |
| No further cleaning / renaming is done apart from two cosmetic replacements | |
| (“gpt4.1” → “ChatGPT 4.1”, “gemini2.5pro” → “Gemini 2.5 Pro”). | |
| """ | |
| from pathlib import Path | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| import gradio as gr | |
| import math | |
| # --------------------------------------------------------------------- # | |
| # Config # | |
| # --------------------------------------------------------------------- # | |
| FILES = { | |
| "aggregated": "aggregated_accuracy.csv", | |
| "qa": "qa_accuracy.csv", | |
| "ocr": "ocr_accuracy.csv", | |
| "temporal": "temporal_accuracy.csv", | |
| } | |
| # Mapping of internal benchmark keys to nicely formatted display labels | |
| DISPLAY_LABELS = { | |
| "aggregated": "Aggregated", | |
| "qa": "QA", | |
| "ocr": "OCR", | |
| "temporal": "Temporal", | |
| } | |
| # Optional: choose which models are selected by default for each benchmark. | |
| # Use the *display names* exactly as they appear in the Models list. | |
| # If a benchmark is missing, it falls back to the first six models. | |
| DEFAULT_MODELS: dict[str, list[str]] = { | |
| "aggregated": [ | |
| "Gemini 2.5 Pro", | |
| "ChatGPT 4.1", | |
| "Qwen2.5-VL-7B", | |
| "InternVL2.5-8B", | |
| "LLaMA-3.2-11B-Vision", | |
| ], | |
| } | |
| RENAME = { | |
| r"gpt4\.1": "ChatGPT 4.1", | |
| r"Gemini\s2\.5\spro": "Gemini 2.5 Pro", | |
| r"LLaMA-3\.2B-11B": "LLaMA-3.2-11B-Vision", | |
| } | |
| # --------------------------------------------------------------------- # | |
| # Data loading # | |
| # --------------------------------------------------------------------- # | |
| def _read_csv(path: str | Path) -> pd.DataFrame: | |
| df = pd.read_csv(path) | |
| df["Model"] = df["Model"].replace(RENAME, regex=True).astype(str) | |
| return df | |
| dfs: dict[str, pd.DataFrame] = {name: _read_csv(path) for name, path in FILES.items()} | |
| # --------------------------------------------------------------------- # | |
| # Colour palette and model metadata # | |
| # --------------------------------------------------------------------- # | |
| import plotly.express as px | |
| SAFE_PALETTE = px.colors.qualitative.Safe # colour-blind-safe qualitative palette (10 colours) | |
| # Deterministic list of all unique model names to ensure consistent colour mapping | |
| ALL_MODELS: list[str] = sorted({m for df in dfs.values() for m in df["Model"].unique()}) | |
| MARKER_SYMBOLS = [ | |
| "circle", | |
| "square", | |
| "triangle-up", | |
| "diamond", | |
| "cross", | |
| "triangle-down", | |
| "x", | |
| "triangle-right", | |
| "triangle-left", | |
| "pentagon", | |
| ] | |
| TIME_COLS = [c for c in dfs["aggregated"].columns if c.lower() != "model"] | |
| def _pretty_time(label: str) -> str: | |
| """‘30min’ → ‘30min’; ‘120min’ → ‘2hr’; keeps original if no match.""" | |
| if label.endswith("min"): | |
| minutes = int(label[:-3]) | |
| if minutes >= 60: | |
| hours = minutes / 60 | |
| return f"{hours:.0f}hr" if hours.is_integer() else f"{hours:.1f}hr" | |
| return label | |
| TIME_LABELS = {c: _pretty_time(c) for c in TIME_COLS} | |
| # --------------------------------------------------------------------- # | |
| # Plotting # | |
| # --------------------------------------------------------------------- # | |
| def render_chart( | |
| benchmark: str, | |
| models: list[str], | |
| log_scale: bool, | |
| ) -> go.Figure: | |
| bench_key = benchmark.lower() | |
| df = dfs[bench_key] | |
| fig = go.Figure() | |
| # Define colour and marker based on deterministic mapping | |
| palette = SAFE_PALETTE | |
| # Determine minimum non-zero Y value across selected models for log scaling | |
| min_y_val = None | |
| for idx, m in enumerate(models): | |
| row = df.loc[df["Model"] == m] | |
| if row.empty: | |
| continue | |
| y = row[TIME_COLS].values.flatten() | |
| y = [val if val != 0 else None for val in y] # show gaps for 0 / missing | |
| # Track minimum non-zero accuracy | |
| y_non_none = [val for val in y if val is not None] | |
| if y_non_none: | |
| cur_min = min(y_non_none) | |
| if min_y_val is None or cur_min < min_y_val: | |
| min_y_val = cur_min | |
| model_idx = ALL_MODELS.index(m) if m in ALL_MODELS else idx | |
| color = palette[model_idx % len(palette)] | |
| symbol = MARKER_SYMBOLS[model_idx % len(MARKER_SYMBOLS)] | |
| fig.add_trace( | |
| go.Scatter( | |
| x=[TIME_LABELS[c] for c in TIME_COLS], | |
| y=y, | |
| mode="lines+markers", | |
| name=m, | |
| line=dict(width=3, color=color), | |
| marker=dict(size=6, color=color, symbol=symbol), | |
| connectgaps=False, | |
| ) | |
| ) | |
| # Set Y-axis properties | |
| if log_scale: | |
| # Fallback to 0.1 if there are no valid points | |
| if min_y_val is None or min_y_val <= 0: | |
| min_y_val = 0.1 | |
| # Plotly expects log10 values for range when axis type is "log" | |
| yaxis_range = [math.floor(math.log10(min_y_val)), 2] # max at 10^2 = 100 | |
| yaxis_type = "log" | |
| else: | |
| yaxis_range = [0, 100] | |
| yaxis_type = "linear" | |
| fig.update_layout( | |
| title=f"{DISPLAY_LABELS.get(bench_key, bench_key.capitalize())} Accuracy Over Time", | |
| xaxis_title="Video Duration", | |
| yaxis_title="Accuracy (%)", | |
| yaxis_type=yaxis_type, | |
| yaxis_range=yaxis_range, | |
| legend_title="Model", | |
| legend=dict( | |
| orientation="h", | |
| y=-0.25, | |
| x=0.5, | |
| xanchor="center", | |
| tracegroupgap=8, | |
| itemwidth=60, | |
| ), | |
| margin=dict(t=40, r=20, b=80, l=60), | |
| template="plotly_dark", | |
| font=dict(family="Inter,Helvetica,Arial,sans-serif", size=14), | |
| title_font=dict(size=20, family="Inter,Helvetica,Arial,sans-serif", color="white"), | |
| xaxis=dict(gridcolor="rgba(255,255,255,0.15)"), | |
| yaxis=dict(gridcolor="rgba(255,255,255,0.15)"), | |
| hoverlabel=dict(bgcolor="#1e1e1e", font_color="#eeeeee", bordercolor="#888"), | |
| ) | |
| return fig | |
| # --------------------------------------------------------------------- # | |
| # UI # | |
| # --------------------------------------------------------------------- # | |
| CSS = """ | |
| #controls { | |
| padding: 8px 12px; | |
| } | |
| .scrollbox { | |
| max-height: 300px; | |
| overflow-y: auto; | |
| } | |
| body, .gradio-container { | |
| font-family: 'Inter', 'Helvetica', sans-serif; | |
| } | |
| .gradio-container h1, .gradio-container h2 { | |
| font-weight: 600; | |
| } | |
| #controls, .scrollbox { | |
| background: rgba(255,255,255,0.02); | |
| border-radius: 6px; | |
| } | |
| input[type="checkbox"]:checked { | |
| accent-color: #FF715E; | |
| } | |
| """ | |
| def available_models(bench: str) -> list[str]: | |
| return sorted(dfs[bench]["Model"].unique()) | |
| def default_models(bench: str) -> list[str]: | |
| """Return list of default-selected models for a benchmark.""" | |
| opts = available_models(bench) | |
| configured = DEFAULT_MODELS.get(bench, []) | |
| # Keep only those present in opts | |
| valid = [m for m in configured if m in opts] | |
| if not valid: | |
| # Fall back to first six | |
| valid = opts[:6] | |
| return valid | |
| with gr.Blocks(theme=gr.themes.Base(), css=CSS) as demo: | |
| gr.Markdown( | |
| """ | |
| # 📈 TimeScope | |
| How long can your video model keep up? | |
| """ | |
| ) | |
| # ---- top controls row ---- # | |
| with gr.Row(): | |
| benchmark_dd = gr.Dropdown( | |
| label="Type", | |
| choices=list(DISPLAY_LABELS.values()), | |
| value=DISPLAY_LABELS["aggregated"], | |
| scale=1, | |
| ) | |
| log_cb = gr.Checkbox( | |
| label="Log-scale Y-axis", | |
| value=False, | |
| scale=1, | |
| ) | |
| # ---- models list and plot ---- # | |
| plot_out = gr.Plot( | |
| render_chart("Aggregated", default_models("aggregated"), False) | |
| ) | |
| models_cb = gr.CheckboxGroup( | |
| label="Models", | |
| choices=available_models("aggregated"), | |
| value=default_models("aggregated"), | |
| interactive=True, | |
| elem_classes=["scrollbox"], | |
| ) | |
| # ‑-- dynamic callbacks ‑-- # | |
| def _update_models(bench: str): | |
| bench_key = bench.lower() | |
| opts = available_models(bench_key) | |
| defaults = default_models(bench_key) | |
| # Use generic gr.update for compatibility across Gradio versions | |
| return gr.update(choices=opts, value=defaults) | |
| benchmark_dd.change( | |
| fn=_update_models, | |
| inputs=benchmark_dd, | |
| outputs=models_cb, | |
| queue=False, | |
| ) | |
| for ctrl in (benchmark_dd, models_cb, log_cb): | |
| ctrl.change( | |
| fn=render_chart, | |
| inputs=[benchmark_dd, models_cb, log_cb], | |
| outputs=plot_out, | |
| queue=False, | |
| ) | |
| # Make legend interaction clearer: click to toggle traces | |
| demo.launch(share=True) |