Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Plotting Accuracy of different methods for different generative models. | |
| """ | |
| # Importing Libraries | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import seaborn as sns | |
| import os, sys, warnings | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| warnings.filterwarnings("ignore") | |
| import defaults | |
| # Function to get performance of perceptual models | |
| def get_performance( | |
| results_dataframe:pd.DataFrame, | |
| generative_models:list, | |
| metric:str | |
| ): | |
| """ | |
| Args: | |
| results_dataframe (pd.DataFrame): Pandas Dataframe of loaded from .csv file. | |
| generative_models (list): List of generative models to consider from the dataframe. | |
| metric (str): Evaluation Metric. | |
| """ | |
| # Assertions | |
| assert all(gen_model in results_dataframe.columns for gen_model in generative_models), "Invalid list of generative models" | |
| # Metric-Index | |
| if metric == "mAP": | |
| metric_index = 0 | |
| elif metric == "mAcc": | |
| metric_index = 1 | |
| elif metric == "mAcc_Real": | |
| metric_index = 2 | |
| else: | |
| metric_index = 3 | |
| # Performance | |
| performance = {} | |
| for _, row in results_dataframe.iterrows(): | |
| model = row["model"] | |
| model_performance = [] | |
| for gen_model in generative_models: | |
| model_performance.append( | |
| np.round(float(eval(row[gen_model])[metric_index]), decimals=2) | |
| ) | |
| model_performance.append( | |
| np.round(float(eval(row["(mAP, mAcc, mAcc_Real, mAcc_fake)"])[metric_index]), decimals=2) | |
| ) | |
| performance[model] = model_performance | |
| return performance | |
| def realign_polar_xticks(ax): | |
| for theta, label in zip(ax.get_xticks(), ax.get_xticklabels()): | |
| theta = theta * ax.get_theta_direction() + ax.get_theta_offset() | |
| theta = np.pi/2 - theta | |
| y, x = np.cos(theta), np.sin(theta) | |
| if x >= 0.1: | |
| label.set_horizontalalignment('left') | |
| if x <= -0.1: | |
| label.set_horizontalalignment('right') | |
| if y >= 0.5: | |
| label.set_verticalalignment('bottom') | |
| if y <= -0.5: | |
| label.set_verticalalignment('top') | |
| # Paths | |
| Our_Results = os.path.join(os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "results") | |
| Prior_Results = os.path.join(os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "prior_methods/prior_results") | |
| # GenImage Table Results for variants of Data Augmentation | |
| Without_Distortion = [] | |
| performance = {} | |
| # Dataframe Paths | |
| Our_Results_Dataframe_path = os.path.join(Our_Results, "extensive_GenImage_GenImage", "default.csv") | |
| Prior_Results_Dataframe_path1 = os.path.join(Prior_Results, "extensive_GenImage_GenImage", "default.csv") | |
| Prior_Results_Dataframe_path2 = os.path.join(Prior_Results, "p=0.5_standard_GenImage_GenImage", "default.csv") | |
| # Dataframe | |
| Our_df = pd.read_csv(Our_Results_Dataframe_path) | |
| Prior_df1 = pd.read_csv(Prior_Results_Dataframe_path1) | |
| Prior_df2 = pd.read_csv(Prior_Results_Dataframe_path2) | |
| df = pd.concat([Our_df, Prior_df1, Prior_df2]) | |
| # Getting Performance | |
| performance = get_performance( | |
| results_dataframe=df, | |
| generative_models=["midjourney", "sdv4", "sdv5", "adm", "glide", "wukong", "vqdm", "biggan"], | |
| metric="mAcc" | |
| ) | |
| Without_Distortion.append(performance) | |
| # Plotting Polar Plot for different models | |
| categories = ["Midjourney", "SDv1.4", "SDv1.5", "Guided", "Glide", "Wukong", "VQDM", "BigGAN"] | |
| # Plotting | |
| fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True)) | |
| plt.rcParams.update({'font.size': 14}) | |
| angle = list(np.linspace(0, 2 * np.pi, len(categories), endpoint=False)) | |
| angle += angle[:1] | |
| ax.set_yticklabels([]) | |
| ax.set_xticks(angle[:-1]) | |
| ax.set_xticklabels(categories) | |
| ax.xaxis.set_tick_params(labelsize=14) | |
| realign_polar_xticks(ax) | |
| for model_name, color in [("clip-vit-l-14", "C0"), ("drct-clip-vit-l-14", "C1"), ("reiqa", "C3"), ("contrique", "C2")]: | |
| accuracy = list(Without_Distortion[0][model_name][:-1]) | |
| accuracy += accuracy[:1] | |
| ax.fill(angle, accuracy, color=color, alpha=0.25) | |
| if model_name == "clip-vit-l-14": | |
| label = "UnivFD" | |
| elif model_name == "drct-clip-vit-l-14": | |
| label = "DRCT/UnivFD" | |
| elif model_name == "contrique": | |
| label = "CONTRIQUE" | |
| elif model_name == "reiqa": | |
| label = "ReIQA" | |
| else: | |
| label = model_name | |
| ax.plot(angle, accuracy, color=color, linewidth=2, label=label) | |
| plt.legend(loc="center") | |
| plt.savefig("plots/GenImage_polar_plot.png", bbox_inches='tight', pad_inches=0.1, dpi=750) |