File size: 4,494 Bytes
2cda712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Importing Libraries
import numpy as np

from pytorch_lightning.callbacks import TQDMProgressBar
import pytorch_lightning as pl

import os, sys, warnings
warnings.filterwarnings('ignore')
from tqdm import tqdm
import getpass
import argparse, csv
import joblib
import contextlib


# Progress bar
class LitProgressBar(TQDMProgressBar):
	def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
		print ()
		return super().on_validation_epoch_end(trainer, pl_module)


# Allows to use tqdm when extracting features for multiple images in parallel.
@contextlib.contextmanager
def tqdm_joblib(tqdm_object):
	"""
	- Allows to use tqdm when extracting features for multiple images in parallel.
	- Useful when multi-threading on CPUs with joblib.
	- Context manager to patch joblib to report into tqdm progress bar given as argument
	"""
	class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
		def __call__(self, *args, **kwargs):
			tqdm_object.update(n=self.batch_size)
			return super().__call__(*args, **kwargs)

	old_batch_callback = joblib.parallel.BatchCompletionCallBack
	joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
	try:
		yield tqdm_object
	finally:
		joblib.parallel.BatchCompletionCallBack = old_batch_callback
		tqdm_object.close()
	

# Parsing Arguments for config files
def parser_args():
	"""
	Parsing Arguments for config files i.e .yaml files
	"""
	def nullable_str(s):
		if s.lower() in ['null', 'none', '']:
			return None
		return s
	
	parser = argparse.ArgumentParser()
	parser.add_argument('--config', '-c', type=nullable_str, help='config file path')
	return parser.parse_args()


# Options for GenImage dataset
def get_GenImage_options():
	"""
	Get Image Sources
	"""
	# Image Sources
	test_image_sources = ["midjourney", "sdv4", "sdv5", "adm", "glide", "wukong", "vqdm", "biggan"]
	train_image_sources = ["sdv4"]

	return train_image_sources, test_image_sources


# Options for UnivFD dataset
def get_UnivFD_options():
	"""
	Get Image Sources
	"""
	# Image Sources
	test_image_sources = ["progan", "cyclegan", "biggan", "stylegan", "gaugan", "stargan", "deepfake", "seeingdark", "san", "crn", "imle", "guided", "ldm_200", "ldm_200_cfg", "ldm_100", "glide_100_27", "glide_50_27", "glide_100_10", "dalle"]
	train_image_sources = ["progan"]

	return train_image_sources, test_image_sources


# Options for DRCT dataset
def get_DRCT_options():
	"""
	Get Image Sources
	"""
	# Image Sources
	test_image_sources = [
		'ldm-text2im-large-256', 'stable-diffusion-v1-4', 'stable-diffusion-v1-5', 'stable-diffusion-2-1', 'stable-diffusion-xl-base-1.0', 'stable-diffusion-xl-refiner-1.0',
		'sd-turbo', 'sdxl-turbo', 
		'lcm-lora-sdv1-5', 'lcm-lora-sdxl',
		'sd-controlnet-canny', 'sd21-controlnet-canny', 'controlnet-canny-sdxl-1.0',
		'stable-diffusion-inpainting', 'stable-diffusion-2-inpainting', 'stable-diffusion-xl-1.0-inpainting-0.1']
	train_image_sources = ["stable-diffusion-v1-4"]

	return train_image_sources, test_image_sources


# Write Results in a .csv file
def write_results_csv(test_set_metrics, test_image_sources, f_model_name, save_path):
	# Assertions
	assert len(test_set_metrics) == len(test_image_sources), "len(test_set_metrics) and len(test_image_sources), does not match."

	# Create a .csv file if it doesn't exist
	if os.path.exists(save_path) == False:
		os.makedirs(os.path.dirname(save_path), exist_ok=True)
		with open(save_path, mode='w') as filename:
			writer = csv.writer(filename, delimiter=',',  quotechar='"', quoting=csv.QUOTE_MINIMAL)
			writer.writerow(["model", "(mAP, mAcc, mAcc_Real, mAcc_fake, mcc)"] + test_image_sources)


	# Results
	Append_List = [f_model_name, (0,0,0,0,0)]

	metrics = []
	for i,_ in enumerate(test_image_sources):
		O = test_set_metrics[i][1:]
		ap = np.round(O[0]*100, decimals=2)
		acc = np.round(O[4]*100, decimals=2)
		r_acc = np.round(O[5]*100, decimals=2)
		f_acc = np.round(O[6]*100, decimals=2)
		mcc = np.round(O[8], decimals=2)

		Append_List.append(tuple([ap, acc, r_acc, f_acc, mcc]))
		metrics.append([ap, acc, r_acc, f_acc, mcc])

	metrics = np.round(np.mean(metrics, axis=0), decimals=2)
	Append_List[1] = tuple(metrics)
	

	# Appending
	assert len(Append_List) == len(test_image_sources) + 2, "len(Append_List) != len(test_image_sources) + 2"
	with open(save_path, mode='a') as filename:
		writer = csv.writer(filename, delimiter=',',  quotechar='"', quoting=csv.QUOTE_MINIMAL)
		writer.writerow(Append_List)