TuRTLe-Leaderboard / data_processing.py
ggcristian's picture
Add NST support for plot view
b88c702
import pandas as pd
import plotly.express as px
from config.constants import (
CC_BENCHMARKS,
LC_BENCHMARKS,
MC_BENCHMARKS,
NON_RTL_METRICS,
RTL_METRICS,
S2R_BENCHMARKS,
SCATTER_PLOT_X_TICKS,
TYPE_COLORS,
Y_AXIS_LIMITS,
DISCARDED_MODELS,
)
from utils import filter_bench, filter_bench_all, filter_NotSoTiny, filter_RTLRepo, handle_special_cases
# this is just a simple class to load the correct data depending on which sim we are at
class Simulator:
def __init__(self, icarus_df, icarus_agg, verilator_df, verilator_agg, yosys_df=None, yosys_agg=None):
self.icarus_df = icarus_df
self.icarus_agg = icarus_agg
self.verilator_df = verilator_df
self.verilator_agg = verilator_agg
self.yosys_df = yosys_df if yosys_df is not None else pd.DataFrame()
self.yosys_agg = yosys_agg if yosys_agg is not None else pd.DataFrame()
self.current_simulator = "Icarus" if yosys_df is None else "YoSys EQV"
def get_current_df(self):
if self.current_simulator == "Icarus":
return self.icarus_df
elif self.current_simulator == "Verilator":
return self.verilator_df
elif self.current_simulator == "YoSys EQV":
return self.yosys_df
else:
return self.icarus_df
def get_current_agg(self):
if self.current_simulator == "Icarus":
return self.icarus_agg
elif self.current_simulator == "Verilator":
return self.verilator_agg
elif self.current_simulator == "YoSys EQV":
return self.yosys_agg
else:
return self.icarus_agg
def set_simulator(self, simulator):
self.current_simulator = simulator
# filtering main function for the leaderboard body
def filter_leaderboard(task, benchmark, model_type, search_query, max_params, state, name):
"""Filter leaderboard data based on user selections."""
# Auto-correct simulator if there's a mismatch between task and simulator
if task == "Module Completion" and state.current_simulator != "YoSys EQV":
state.set_simulator("YoSys EQV")
elif task != "Module Completion" and state.current_simulator == "YoSys EQV":
state.set_simulator("Icarus")
subset = state.get_current_df().copy()
# Filter by task specific benchmarks when 'All' benchmarks is selected
if task == "Spec-to-RTL":
valid_benchmarks = S2R_BENCHMARKS
if benchmark == "All":
subset = subset[subset["Benchmark"].isin(valid_benchmarks)]
elif task == "Code Completion":
valid_benchmarks = CC_BENCHMARKS
if benchmark == "All":
subset = subset[subset["Benchmark"].isin(valid_benchmarks)]
elif task == "Line Completion":
valid_benchmarks = LC_BENCHMARKS
if benchmark == "All":
subset = subset[subset["Benchmark"].isin(valid_benchmarks)]
elif task == "Module Completion":
valid_benchmarks = MC_BENCHMARKS
if benchmark == "All":
subset = subset[subset["Benchmark"].isin(valid_benchmarks)]
if benchmark != "All":
subset = state.get_current_df()[state.get_current_df()["Benchmark"] == benchmark]
if model_type != "All":
# without emojis
subset = subset[subset["Model Type"] == model_type.split(" ")[0]]
if search_query:
subset = subset[subset["Model"].str.contains(search_query, case=False, na=False)]
max_params = float(max_params)
if max_params < 995: # when re-setting the max param slider we never reach 1000 again xd
subset = subset[subset["Params"] <= max_params]
else:
subset["Params"] = subset["Params"].fillna("Unknown")
if name == "Other Models":
subset = subset[subset["Model"].isin(DISCARDED_MODELS)]
else:
subset = subset[~subset["Model"].isin(DISCARDED_MODELS)]
if benchmark == "All":
if task == "Spec-to-RTL":
return filter_bench_all(subset, state.get_current_agg(), agg_column="Agg S2R", name=name)
elif task == "Code Completion":
return filter_bench_all(subset, state.get_current_agg(), agg_column="Agg MC", name=name)
elif task == "Line Completion":
return filter_RTLRepo(subset, name=name)
elif benchmark == "RTL-Repo":
return filter_RTLRepo(subset, name=name)
elif benchmark == "NotSoTiny-25-12":
return filter_NotSoTiny(subset, name=name)
else:
agg_column = None
if benchmark == "VerilogEval S2R":
agg_column = "Agg VerilogEval S2R"
elif benchmark == "VerilogEval MC":
agg_column = "Agg VerilogEval MC"
elif benchmark == "RTLLM":
agg_column = "Agg RTLLM"
elif benchmark == "VeriGen":
agg_column = "Agg VeriGen"
return filter_bench(subset, state.get_current_agg(), agg_column, name=name)
def generate_scatter_plot(benchmark, metric, state, simulator=None):
"""Generate a scatter plot for the given benchmark and metric."""
benchmark, metric = handle_special_cases(benchmark, metric)
# ugly
# NotSoTiny-25-12 uses YoSys simulator
if benchmark == "NotSoTiny-25-12":
state.set_simulator("YoSys EQV")
else:
state.set_simulator("Icarus")
subset = state.get_current_df()[state.get_current_df()["Benchmark"] == benchmark]
subset = subset[~subset["Model"].isin(DISCARDED_MODELS)]
# Check if the benchmark exists in the current simulator's data
if subset.empty:
fig = px.scatter(title=f"No data available for {benchmark} in {state.current_simulator} simulator")
fig.update_layout(
xaxis=dict(visible=False),
yaxis=dict(visible=False),
annotations=[
dict(
text=f"No data available for {benchmark} in current simulator",
xref="paper",
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
font=dict(size=14),
)
],
)
return fig
if benchmark == "RTL-Repo":
subset = subset[subset["Metric"].str.contains("EM", case=False, na=False)]
detailed_scores = subset.groupby("Model", as_index=False)["Score"].mean()
detailed_scores.rename(columns={"Score": "Exact Matching (EM)"}, inplace=True)
elif benchmark == "NotSoTiny-25-12":
# NotSoTiny uses mean aggregation across all metrics
detailed_scores = subset.pivot_table(index="Model", columns="Metric", values="Score", aggfunc='mean').reset_index()
# Rename metrics to match the display names
detailed_scores.rename(columns={
"STX": "Syntax (STX)",
"EQV": "Functionality (EQV)",
"Cov Mean": "Cell Coverage",
}, inplace=True)
else:
detailed_scores = subset.pivot_table(index="Model", columns="Metric", values="Score").reset_index()
details = state.get_current_df()[["Model", "Params", "Model Type"]].drop_duplicates("Model")
scatter_data = pd.merge(detailed_scores, details, on="Model", how="left").dropna(
subset=["Params", metric]
)
scatter_data["x"] = scatter_data["Params"]
scatter_data["y"] = scatter_data[metric]
scatter_data["size"] = (scatter_data["x"] ** 0.3) * 40
scatter_data["color"] = scatter_data["Model Type"].map(TYPE_COLORS).fillna("gray")
y_range = Y_AXIS_LIMITS.get(metric, [0, 80])
fig = px.scatter(
scatter_data,
x="x",
y="y",
log_x=True,
size="size",
color="Model Type",
text="Model",
hover_data={metric: ":.2f"},
title=f"Params vs. {metric} for {benchmark}",
labels={"x": "# Params (Log Scale)", "y": metric},
template="plotly_white",
height=600,
width=1200,
)
fig.update_traces(
textposition="top center",
textfont_size=10,
marker=dict(opacity=0.8, line=dict(width=0.5, color="black")),
)
fig.update_layout(
xaxis=dict(
showgrid=True,
type="log",
tickmode="array",
tickvals=SCATTER_PLOT_X_TICKS["tickvals"],
ticktext=SCATTER_PLOT_X_TICKS["ticktext"],
),
showlegend=False,
yaxis=dict(range=y_range),
margin=dict(l=50, r=50, t=50, b=50),
plot_bgcolor="white",
)
return fig