Spaces:
Runtime error
Runtime error
File size: 4,494 Bytes
2cda712 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# Importing Libraries
import numpy as np
from pytorch_lightning.callbacks import TQDMProgressBar
import pytorch_lightning as pl
import os, sys, warnings
warnings.filterwarnings('ignore')
from tqdm import tqdm
import getpass
import argparse, csv
import joblib
import contextlib
# Progress bar
class LitProgressBar(TQDMProgressBar):
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
print ()
return super().on_validation_epoch_end(trainer, pl_module)
# Allows to use tqdm when extracting features for multiple images in parallel.
@contextlib.contextmanager
def tqdm_joblib(tqdm_object):
"""
- Allows to use tqdm when extracting features for multiple images in parallel.
- Useful when multi-threading on CPUs with joblib.
- Context manager to patch joblib to report into tqdm progress bar given as argument
"""
class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
def __call__(self, *args, **kwargs):
tqdm_object.update(n=self.batch_size)
return super().__call__(*args, **kwargs)
old_batch_callback = joblib.parallel.BatchCompletionCallBack
joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
try:
yield tqdm_object
finally:
joblib.parallel.BatchCompletionCallBack = old_batch_callback
tqdm_object.close()
# Parsing Arguments for config files
def parser_args():
"""
Parsing Arguments for config files i.e .yaml files
"""
def nullable_str(s):
if s.lower() in ['null', 'none', '']:
return None
return s
parser = argparse.ArgumentParser()
parser.add_argument('--config', '-c', type=nullable_str, help='config file path')
return parser.parse_args()
# Options for GenImage dataset
def get_GenImage_options():
"""
Get Image Sources
"""
# Image Sources
test_image_sources = ["midjourney", "sdv4", "sdv5", "adm", "glide", "wukong", "vqdm", "biggan"]
train_image_sources = ["sdv4"]
return train_image_sources, test_image_sources
# Options for UnivFD dataset
def get_UnivFD_options():
"""
Get Image Sources
"""
# Image Sources
test_image_sources = ["progan", "cyclegan", "biggan", "stylegan", "gaugan", "stargan", "deepfake", "seeingdark", "san", "crn", "imle", "guided", "ldm_200", "ldm_200_cfg", "ldm_100", "glide_100_27", "glide_50_27", "glide_100_10", "dalle"]
train_image_sources = ["progan"]
return train_image_sources, test_image_sources
# Options for DRCT dataset
def get_DRCT_options():
"""
Get Image Sources
"""
# Image Sources
test_image_sources = [
'ldm-text2im-large-256', 'stable-diffusion-v1-4', 'stable-diffusion-v1-5', 'stable-diffusion-2-1', 'stable-diffusion-xl-base-1.0', 'stable-diffusion-xl-refiner-1.0',
'sd-turbo', 'sdxl-turbo',
'lcm-lora-sdv1-5', 'lcm-lora-sdxl',
'sd-controlnet-canny', 'sd21-controlnet-canny', 'controlnet-canny-sdxl-1.0',
'stable-diffusion-inpainting', 'stable-diffusion-2-inpainting', 'stable-diffusion-xl-1.0-inpainting-0.1']
train_image_sources = ["stable-diffusion-v1-4"]
return train_image_sources, test_image_sources
# Write Results in a .csv file
def write_results_csv(test_set_metrics, test_image_sources, f_model_name, save_path):
# Assertions
assert len(test_set_metrics) == len(test_image_sources), "len(test_set_metrics) and len(test_image_sources), does not match."
# Create a .csv file if it doesn't exist
if os.path.exists(save_path) == False:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, mode='w') as filename:
writer = csv.writer(filename, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
writer.writerow(["model", "(mAP, mAcc, mAcc_Real, mAcc_fake, mcc)"] + test_image_sources)
# Results
Append_List = [f_model_name, (0,0,0,0,0)]
metrics = []
for i,_ in enumerate(test_image_sources):
O = test_set_metrics[i][1:]
ap = np.round(O[0]*100, decimals=2)
acc = np.round(O[4]*100, decimals=2)
r_acc = np.round(O[5]*100, decimals=2)
f_acc = np.round(O[6]*100, decimals=2)
mcc = np.round(O[8], decimals=2)
Append_List.append(tuple([ap, acc, r_acc, f_acc, mcc]))
metrics.append([ap, acc, r_acc, f_acc, mcc])
metrics = np.round(np.mean(metrics, axis=0), decimals=2)
Append_List[1] = tuple(metrics)
# Appending
assert len(Append_List) == len(test_image_sources) + 2, "len(Append_List) != len(test_image_sources) + 2"
with open(save_path, mode='a') as filename:
writer = csv.writer(filename, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
writer.writerow(Append_List) |