Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,318 Bytes
2cda712 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
"""
Plotting Accuracy of different methods for different generative models.
"""
# Importing Libraries
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import os, sys, warnings
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
warnings.filterwarnings("ignore")
import defaults
# Function to get performance of perceptual models
def get_performance(
results_dataframe:pd.DataFrame,
generative_models:list,
metric:str
):
"""
Args:
results_dataframe (pd.DataFrame): Pandas Dataframe of loaded from .csv file.
generative_models (list): List of generative models to consider from the dataframe.
metric (str): Evaluation Metric.
"""
# Assertions
assert all(gen_model in results_dataframe.columns for gen_model in generative_models), "Invalid list of generative models"
# Metric-Index
if metric == "mAP":
metric_index = 0
elif metric == "mAcc":
metric_index = 1
elif metric == "mAcc_Real":
metric_index = 2
else:
metric_index = 3
# Performance
performance = {}
for _, row in results_dataframe.iterrows():
model = row["model"]
model_performance = []
for gen_model in generative_models:
model_performance.append(
np.round(float(eval(row[gen_model])[metric_index]), decimals=2)
)
model_performance.append(
np.round(float(eval(row["(mAP, mAcc, mAcc_Real, mAcc_fake)"])[metric_index]), decimals=2)
)
performance[model] = model_performance
return performance
def realign_polar_xticks(ax):
for theta, label in zip(ax.get_xticks(), ax.get_xticklabels()):
theta = theta * ax.get_theta_direction() + ax.get_theta_offset()
theta = np.pi/2 - theta
y, x = np.cos(theta), np.sin(theta)
if x >= 0.1:
label.set_horizontalalignment('left')
if x <= -0.1:
label.set_horizontalalignment('right')
if y >= 0.5:
label.set_verticalalignment('bottom')
if y <= -0.5:
label.set_verticalalignment('top')
# Paths
Our_Results = os.path.join(os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "results")
Prior_Results = os.path.join(os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "prior_methods/prior_results")
# GenImage Table Results for variants of Data Augmentation
Without_Distortion = []
performance = {}
# Dataframe Paths
Our_Results_Dataframe_path = os.path.join(Our_Results, "extensive_GenImage_GenImage", "default.csv")
Prior_Results_Dataframe_path1 = os.path.join(Prior_Results, "extensive_GenImage_GenImage", "default.csv")
Prior_Results_Dataframe_path2 = os.path.join(Prior_Results, "p=0.5_standard_GenImage_GenImage", "default.csv")
# Dataframe
Our_df = pd.read_csv(Our_Results_Dataframe_path)
Prior_df1 = pd.read_csv(Prior_Results_Dataframe_path1)
Prior_df2 = pd.read_csv(Prior_Results_Dataframe_path2)
df = pd.concat([Our_df, Prior_df1, Prior_df2])
# Getting Performance
performance = get_performance(
results_dataframe=df,
generative_models=["midjourney", "sdv4", "sdv5", "adm", "glide", "wukong", "vqdm", "biggan"],
metric="mAcc"
)
Without_Distortion.append(performance)
# Plotting Polar Plot for different models
categories = ["Midjourney", "SDv1.4", "SDv1.5", "Guided", "Glide", "Wukong", "VQDM", "BigGAN"]
# Plotting
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
plt.rcParams.update({'font.size': 14})
angle = list(np.linspace(0, 2 * np.pi, len(categories), endpoint=False))
angle += angle[:1]
ax.set_yticklabels([])
ax.set_xticks(angle[:-1])
ax.set_xticklabels(categories)
ax.xaxis.set_tick_params(labelsize=14)
realign_polar_xticks(ax)
for model_name, color in [("clip-vit-l-14", "C0"), ("drct-clip-vit-l-14", "C1"), ("reiqa", "C3"), ("contrique", "C2")]:
accuracy = list(Without_Distortion[0][model_name][:-1])
accuracy += accuracy[:1]
ax.fill(angle, accuracy, color=color, alpha=0.25)
if model_name == "clip-vit-l-14":
label = "UnivFD"
elif model_name == "drct-clip-vit-l-14":
label = "DRCT/UnivFD"
elif model_name == "contrique":
label = "CONTRIQUE"
elif model_name == "reiqa":
label = "ReIQA"
else:
label = model_name
ax.plot(angle, accuracy, color=color, linewidth=2, label=label)
plt.legend(loc="center")
plt.savefig("plots/GenImage_polar_plot.png", bbox_inches='tight', pad_inches=0.1, dpi=750) |