krishnasrikard
Codes
2cda712
"""
Plotting Accuracy of different methods for different generative models.
"""
# Importing Libraries
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import os, sys, warnings
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
warnings.filterwarnings("ignore")
import defaults
# Function to get performance of perceptual models
def get_performance(
results_dataframe:pd.DataFrame,
generative_models:list,
metric:str
):
"""
Args:
results_dataframe (pd.DataFrame): Pandas Dataframe of loaded from .csv file.
generative_models (list): List of generative models to consider from the dataframe.
metric (str): Evaluation Metric.
"""
# Assertions
assert all(gen_model in results_dataframe.columns for gen_model in generative_models), "Invalid list of generative models"
# Metric-Index
if metric == "mAP":
metric_index = 0
elif metric == "mAcc":
metric_index = 1
elif metric == "mAcc_Real":
metric_index = 2
else:
metric_index = 3
# Performance
performance = {}
for _, row in results_dataframe.iterrows():
model = row["model"]
model_performance = []
for gen_model in generative_models:
model_performance.append(
np.round(float(eval(row[gen_model])[metric_index]), decimals=2)
)
model_performance.append(
np.round(float(eval(row["(mAP, mAcc, mAcc_Real, mAcc_fake)"])[metric_index]), decimals=2)
)
performance[model] = model_performance
return performance
def realign_polar_xticks(ax):
for theta, label in zip(ax.get_xticks(), ax.get_xticklabels()):
theta = theta * ax.get_theta_direction() + ax.get_theta_offset()
theta = np.pi/2 - theta
y, x = np.cos(theta), np.sin(theta)
if x >= 0.1:
label.set_horizontalalignment('left')
if x <= -0.1:
label.set_horizontalalignment('right')
if y >= 0.5:
label.set_verticalalignment('bottom')
if y <= -0.5:
label.set_verticalalignment('top')
# Paths
Our_Results = os.path.join(os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "results")
Prior_Results = os.path.join(os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "prior_methods/prior_results")
# GenImage Table Results for variants of Data Augmentation
Without_Distortion = []
performance = {}
# Dataframe Paths
Our_Results_Dataframe_path = os.path.join(Our_Results, "extensive_GenImage_GenImage", "default.csv")
Prior_Results_Dataframe_path1 = os.path.join(Prior_Results, "extensive_GenImage_GenImage", "default.csv")
Prior_Results_Dataframe_path2 = os.path.join(Prior_Results, "p=0.5_standard_GenImage_GenImage", "default.csv")
# Dataframe
Our_df = pd.read_csv(Our_Results_Dataframe_path)
Prior_df1 = pd.read_csv(Prior_Results_Dataframe_path1)
Prior_df2 = pd.read_csv(Prior_Results_Dataframe_path2)
df = pd.concat([Our_df, Prior_df1, Prior_df2])
# Getting Performance
performance = get_performance(
results_dataframe=df,
generative_models=["midjourney", "sdv4", "sdv5", "adm", "glide", "wukong", "vqdm", "biggan"],
metric="mAcc"
)
Without_Distortion.append(performance)
# Plotting Polar Plot for different models
categories = ["Midjourney", "SDv1.4", "SDv1.5", "Guided", "Glide", "Wukong", "VQDM", "BigGAN"]
# Plotting
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
plt.rcParams.update({'font.size': 14})
angle = list(np.linspace(0, 2 * np.pi, len(categories), endpoint=False))
angle += angle[:1]
ax.set_yticklabels([])
ax.set_xticks(angle[:-1])
ax.set_xticklabels(categories)
ax.xaxis.set_tick_params(labelsize=14)
realign_polar_xticks(ax)
for model_name, color in [("clip-vit-l-14", "C0"), ("drct-clip-vit-l-14", "C1"), ("reiqa", "C3"), ("contrique", "C2")]:
accuracy = list(Without_Distortion[0][model_name][:-1])
accuracy += accuracy[:1]
ax.fill(angle, accuracy, color=color, alpha=0.25)
if model_name == "clip-vit-l-14":
label = "UnivFD"
elif model_name == "drct-clip-vit-l-14":
label = "DRCT/UnivFD"
elif model_name == "contrique":
label = "CONTRIQUE"
elif model_name == "reiqa":
label = "ReIQA"
else:
label = model_name
ax.plot(angle, accuracy, color=color, linewidth=2, label=label)
plt.legend(loc="center")
plt.savefig("plots/GenImage_polar_plot.png", bbox_inches='tight', pad_inches=0.1, dpi=750)