Spaces:
Sleeping
Sleeping
| # Importing Libraries | |
| import numpy as np | |
| from pytorch_lightning.callbacks import TQDMProgressBar | |
| import pytorch_lightning as pl | |
| import os, sys, warnings | |
| warnings.filterwarnings('ignore') | |
| from tqdm import tqdm | |
| import getpass | |
| import argparse, csv | |
| import joblib | |
| import contextlib | |
| # Progress bar | |
| class LitProgressBar(TQDMProgressBar): | |
| def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
| print () | |
| return super().on_validation_epoch_end(trainer, pl_module) | |
| # Allows to use tqdm when extracting features for multiple images in parallel. | |
| def tqdm_joblib(tqdm_object): | |
| """ | |
| - Allows to use tqdm when extracting features for multiple images in parallel. | |
| - Useful when multi-threading on CPUs with joblib. | |
| - Context manager to patch joblib to report into tqdm progress bar given as argument | |
| """ | |
| class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): | |
| def __call__(self, *args, **kwargs): | |
| tqdm_object.update(n=self.batch_size) | |
| return super().__call__(*args, **kwargs) | |
| old_batch_callback = joblib.parallel.BatchCompletionCallBack | |
| joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback | |
| try: | |
| yield tqdm_object | |
| finally: | |
| joblib.parallel.BatchCompletionCallBack = old_batch_callback | |
| tqdm_object.close() | |
| # Parsing Arguments for config files | |
| def parser_args(): | |
| """ | |
| Parsing Arguments for config files i.e .yaml files | |
| """ | |
| def nullable_str(s): | |
| if s.lower() in ['null', 'none', '']: | |
| return None | |
| return s | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--config', '-c', type=nullable_str, help='config file path') | |
| return parser.parse_args() | |
| # Options for GenImage dataset | |
| def get_GenImage_options(): | |
| """ | |
| Get Image Sources | |
| """ | |
| # Image Sources | |
| test_image_sources = ["midjourney", "sdv4", "sdv5", "adm", "glide", "wukong", "vqdm", "biggan"] | |
| train_image_sources = ["sdv4"] | |
| return train_image_sources, test_image_sources | |
| # Options for UnivFD dataset | |
| def get_UnivFD_options(): | |
| """ | |
| Get Image Sources | |
| """ | |
| # Image Sources | |
| test_image_sources = ["progan", "cyclegan", "biggan", "stylegan", "gaugan", "stargan", "deepfake", "seeingdark", "san", "crn", "imle", "guided", "ldm_200", "ldm_200_cfg", "ldm_100", "glide_100_27", "glide_50_27", "glide_100_10", "dalle"] | |
| train_image_sources = ["progan"] | |
| return train_image_sources, test_image_sources | |
| # Options for DRCT dataset | |
| def get_DRCT_options(): | |
| """ | |
| Get Image Sources | |
| """ | |
| # Image Sources | |
| test_image_sources = [ | |
| 'ldm-text2im-large-256', 'stable-diffusion-v1-4', 'stable-diffusion-v1-5', 'stable-diffusion-2-1', 'stable-diffusion-xl-base-1.0', 'stable-diffusion-xl-refiner-1.0', | |
| 'sd-turbo', 'sdxl-turbo', | |
| 'lcm-lora-sdv1-5', 'lcm-lora-sdxl', | |
| 'sd-controlnet-canny', 'sd21-controlnet-canny', 'controlnet-canny-sdxl-1.0', | |
| 'stable-diffusion-inpainting', 'stable-diffusion-2-inpainting', 'stable-diffusion-xl-1.0-inpainting-0.1'] | |
| train_image_sources = ["stable-diffusion-v1-4"] | |
| return train_image_sources, test_image_sources | |
| # Write Results in a .csv file | |
| def write_results_csv(test_set_metrics, test_image_sources, f_model_name, save_path): | |
| # Assertions | |
| assert len(test_set_metrics) == len(test_image_sources), "len(test_set_metrics) and len(test_image_sources), does not match." | |
| # Create a .csv file if it doesn't exist | |
| if os.path.exists(save_path) == False: | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| with open(save_path, mode='w') as filename: | |
| writer = csv.writer(filename, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) | |
| writer.writerow(["model", "(mAP, mAcc, mAcc_Real, mAcc_fake, mcc)"] + test_image_sources) | |
| # Results | |
| Append_List = [f_model_name, (0,0,0,0,0)] | |
| metrics = [] | |
| for i,_ in enumerate(test_image_sources): | |
| O = test_set_metrics[i][1:] | |
| ap = np.round(O[0]*100, decimals=2) | |
| acc = np.round(O[4]*100, decimals=2) | |
| r_acc = np.round(O[5]*100, decimals=2) | |
| f_acc = np.round(O[6]*100, decimals=2) | |
| mcc = np.round(O[8], decimals=2) | |
| Append_List.append(tuple([ap, acc, r_acc, f_acc, mcc])) | |
| metrics.append([ap, acc, r_acc, f_acc, mcc]) | |
| metrics = np.round(np.mean(metrics, axis=0), decimals=2) | |
| Append_List[1] = tuple(metrics) | |
| # Appending | |
| assert len(Append_List) == len(test_image_sources) + 2, "len(Append_List) != len(test_image_sources) + 2" | |
| with open(save_path, mode='a') as filename: | |
| writer = csv.writer(filename, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) | |
| writer.writerow(Append_List) |