""" PyTorch Lightning Module of training of deep-learning models Notes: - Using ".to(torch.float32)" to resolving precision issues while using different models. """ # Importing Libraries import numpy as np from sklearn.model_selection import train_test_split import torch torch.set_float32_matmul_precision('medium') import torch.nn as nn import torchvision from torch.utils.data import DataLoader import pytorch_lightning as pl import torchmetrics from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning import loggers as pl_loggers import os, sys, warnings warnings.filterwarnings("ignore") import random from functions.dataset import Image_Dataset import functions.dataset_utils as dataset_utils import functions.preprocess as preprocess from functions.loss_optimizers_metrics import * import functions.utils as utils import defaults # Lightning Module class Model_LightningModule(pl.LightningModule): def __init__(self, classifier, config): super().__init__() self.save_hyperparameters() self.config = config # Model as Manual Arguments self.classifier = classifier # Loss self.train_lossfn = get_loss_function(**self.config["train_loss_fn"]) self.val_lossfn = get_loss_function(**self.config["val_loss_fn"]) # Metrics self.train_accuracy_fn = torchmetrics.Accuracy(task="binary") self.val_accuracy_fn = torchmetrics.Accuracy(task="binary") # Training-Step def training_step(self, batch, batch_idx): if len(batch) == 2: X, y_true = batch # Extracting features using Backbone Feature Extractor with torch.no_grad(): X = feature_extractor_module(X) else: X1, X2, y_true = batch # Extracting features using Backbone Feature Extractor with torch.no_grad(): X = feature_extractor_module(X1, X2) X = torch.flatten(X, start_dim=1).to(torch.float32) X_input = preprocess.select_feature_indices(X, self.config["dataset"]["f_model_name"]) y_true_classes = torch.argmax(y_true, dim=1) latent_features, y_pred = self.classifier(X_input) y_pred_classes = torch.argmax(y_pred, dim=1) self.train_loss = self.train_lossfn(latent_features, y_pred, y_true_classes) self.train_acc = self.train_accuracy_fn(y_pred_classes, y_true_classes) self.log_dict( { "train_loss": self.train_loss, "train_acc": self.train_acc }, on_step=True, on_epoch=False, prog_bar=True, sync_dist=True ) return self.train_loss # Validation-Step def validation_step(self, batch, batch_idx, dataloader_idx=0): if len(batch) == 2: X, y_true = batch # Extracting features using Backbone Feature Extractor with torch.no_grad(): X = feature_extractor_module(X) else: X1, X2, y_true = batch # Extracting features using Backbone Feature Extractor with torch.no_grad(): X = feature_extractor_module(X1, X2) X = torch.flatten(X, start_dim=1).to(torch.float32) X_input = preprocess.select_feature_indices(X, self.config["dataset"]["f_model_name"]) y_true_classes = torch.argmax(y_true, dim=1) latent_features, y_pred = self.classifier(X_input) y_pred_classes = torch.argmax(y_pred, dim=1) self.val_loss = self.val_lossfn(latent_features, y_pred, y_true_classes) self.val_acc = self.val_accuracy_fn(y_pred_classes, y_true_classes) self.log_dict( { "val_loss": self.val_loss, "val_acc": self.val_acc }, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True ) # Prediction-Step def predict_step(self, batch, batch_idx, dataloader_idx=0): if len(batch) == 2: X, y_true = batch # Extracting features using Backbone Feature Extractor with torch.no_grad(): X = feature_extractor_module(X) else: X1, X2, y_true = batch # Extracting features using Backbone Feature Extractor with torch.no_grad(): X = feature_extractor_module(X1, X2) X = torch.flatten(X, start_dim=1).to(torch.float32) X_input = preprocess.select_feature_indices(X, self.config["dataset"]["f_model_name"]) y_true_classes = torch.argmax(y_true, dim=1) latent_features, y_pred = self.classifier(X_input) y_pred_classes = torch.argmax(y_pred, dim=1) return y_pred, y_true # Configure Optimizers def configure_optimizers(self): optimizer = get_optimizer( self.classifier.parameters(), **self.config["optimizer"] ) return [optimizer] # Main Function def run(feature_extractor, classifier, config, train_image_sources, test_image_sources, preprocess_settings, best_threshold, verbose=True): # Parameters dataset_type = config["dataset"]["dataset_type"] separateAugmentation = config["dataset"]["separateAugmentation"] model_name = config["dataset"]["model_name"] f_model_name = config["dataset"]["f_model_name"] # Paths main_dataset_dir = defaults.main_dataset_dir main_checkpoints_dir = defaults.main_checkpoints_dir # Checkpoints Paths # Resume Checkpoints if config["checkpoints"]["resume_dirname"] is not None and config["checkpoints"]["resume_filename"] is not None: resume_ckpt_path = os.path.join(main_checkpoints_dir, config["checkpoints"]["resume_dirname"], f_model_name, config["checkpoints"]["resume_filename"]) else: resume_ckpt_path = None print (resume_ckpt_path) # Save Checkpoints checkpoint_dirpath = os.path.join(main_checkpoints_dir, config["checkpoints"]["checkpoint_dirname"], f_model_name) os.makedirs(checkpoint_dirpath, exist_ok=True) # Resuming from checkpoint if resume_ckpt_path is not None: if os.path.exists(resume_ckpt_path): print ("Found the checkpoint at resume_ckpt_path provided.") else: assert False, "Resume checkpoint not found at resume_ckpt_path provided." else: if config["train_settings"]["train"]: # For Training. print ("No path is provided for resume checkpoint (resume_ckpt_path) provided. Starting training from the begining.") else: assert False, "No path is provided for resume checkpoint (resume_ckpt_path) provided. resume_ckpt_path is required for evaluation." # Checkpoint Callbacks best_checkpoint_callback = ModelCheckpoint( dirpath=checkpoint_dirpath, filename="best_model", monitor=config["train_settings"]["monitor"], mode=config["train_settings"]["mode"] ) # Pre-processing Functions preprocessfn, dual_scale = preprocess.get_preprocessfn(**preprocess_settings) # Logging print () print (preprocessfn) print () # Datasets # Images Train and Val Paths train_val_real_images_paths, train_val_fake_images_paths = dataset_utils.dataset_img_paths( dataset_type=dataset_type, status="train" ) # Train-Val Split train_val_real_images_paths.sort() train_val_fake_images_paths.sort() random.Random(0).shuffle(train_val_real_images_paths) random.Random(0).shuffle(train_val_fake_images_paths) train_real_images_paths, val_real_images_paths = train_val_real_images_paths[:int(0.8 * len(train_val_real_images_paths))], train_val_real_images_paths[int(0.8 * len(train_val_real_images_paths)):] train_fake_images_paths, val_fake_images_paths = train_val_fake_images_paths[:int(0.8 * len(train_val_fake_images_paths))], train_val_fake_images_paths[int(0.8 * len(train_val_fake_images_paths)):] # Images Train Dataset if config["train_settings"]["train"]: Train_Dataset = Image_Dataset( real_images_paths=train_real_images_paths, fake_images_paths=train_fake_images_paths, preprocessfn=preprocessfn, dual_scale=dual_scale, resize=preprocess_settings["resize"], separateAugmentation=separateAugmentation, ignore_reconstructed_images=False ) # Images Validation Dataset Val_Dataset = Image_Dataset( real_images_paths=val_real_images_paths, fake_images_paths=val_fake_images_paths, preprocessfn=preprocessfn, dual_scale=dual_scale, resize=preprocess_settings["resize"], separateAugmentation=separateAugmentation, ignore_reconstructed_images=False ) # Images Test Dataset if config["train_settings"]["train"] == False: Test_Datasets = [] for _,source in enumerate(test_image_sources): test_real_images_paths = dataset_utils.get_image_paths( dataset_type=dataset_type, status="val", image_sources=[source], label="real" ) test_fake_images_paths = dataset_utils.get_image_paths( dataset_type=dataset_type, status="val", image_sources=[source], label="fake" ) # For images smaller than preprocess_settings["input_image_dimensions"] which only occur for BigGAN fake images in GenImage dataset, we do the following: """ - During inference, we avoid Resizing to reduce the effect of resizing artifacts. - We process the images at (224,224) or their smaller resolution unless the feature extraction model requires (224,224) inputs. """ if model_name == "resnet50" or model_name == "hyperiqa" or model_name == "tres" or model_name == "clip-resnet50" or model_name == "clip-vit-l-14": # Updated Pre-Processing Settings Fixed_Input_preprocess_settings = preprocess_settings.copy() Fixed_Input_preprocess_settings["input_image_dimensions"] = (224,224) # Preprocessing Function Fixed_Input_preprocessfn, Fixed_Input_dual_scale = preprocess.get_preprocessfn(**Fixed_Input_preprocess_settings) Test_Datasets.append( Image_Dataset( real_images_paths=test_real_images_paths, fake_images_paths=test_fake_images_paths, preprocessfn=Fixed_Input_preprocessfn, dual_scale=Fixed_Input_dual_scale, resize=preprocess_settings["resize"], separateAugmentation=separateAugmentation ) ) elif (dataset_type == "GenImage" and source == "biggan") and (preprocess_settings["input_image_dimensions"][0] > 128 and preprocess_settings["input_image_dimensions"][1] > 128): # Updated Pre-Processing Settings GenImage_BigGAN_preprocess_settings = preprocess_settings.copy() GenImage_BigGAN_preprocess_settings["input_image_dimensions"] = (128,128) # Preprocessing Function print ("Using GenImage, BigGAN Preprocessing Function") GenImage_BigGAN_preprocessfn, GenImage_BigGAN_dual_scale = preprocess.get_preprocessfn(**GenImage_BigGAN_preprocess_settings) Test_Datasets.append( Image_Dataset( real_images_paths=test_real_images_paths, fake_images_paths=test_fake_images_paths, preprocessfn=GenImage_BigGAN_preprocessfn, dual_scale=GenImage_BigGAN_dual_scale, resize=preprocess_settings["resize"], separateAugmentation=separateAugmentation ) ) else: Test_Datasets.append( Image_Dataset( real_images_paths=test_real_images_paths, fake_images_paths=test_fake_images_paths, preprocessfn=preprocessfn, dual_scale=dual_scale, resize=preprocess_settings["resize"], separateAugmentation=separateAugmentation ) ) # DataLoaders # Train DataLoader if config["train_settings"]["train"]: Train_Dataloader = DataLoader( dataset=Train_Dataset, batch_size=config["train_settings"]["batch_size"], num_workers=config["train_settings"]["num_workers"], shuffle=True, ) # Val DataLoader Val_Dataloader = DataLoader( dataset=Val_Dataset, batch_size=config["train_settings"]["batch_size"], num_workers=config["train_settings"]["num_workers"], shuffle=False, ) # Test DataLoaders if config["train_settings"]["train"] == False: Test_Dataloaders = [] for i,_ in enumerate(test_image_sources): Test_Dataloaders.append( DataLoader( dataset=Test_Datasets[i], batch_size=config["train_settings"]["batch_size"], num_workers=config["train_settings"]["num_workers"], shuffle=False, ) ) print ("-"*25 + " Datasets and DataLoaders Ready " + "-"*25) # Global Variables: (feature_extractor) global feature_extractor_module feature_extractor_module = feature_extractor feature_extractor_module.to("cuda:{}".format(config["trainer"]["devices"][0])) feature_extractor_module.eval() for params in feature_extractor_module.parameters(): params.requires_grad = False # Assertions assert config["trainer"]["num_nodes"] == 1, "num_nodes should be 1 for single node training. num_nodes > 1 is not supported as our feature extractor is outside the Lightning Module." assert len(config["trainer"]["devices"]) == 1, "devices should be 1 for single GPU training. devices > 1 is not supported as our feature extractor is outside the Lightning Module." # Lightning Module Model = Model_LightningModule(classifier, config) # PyTorch Lightning Trainer trainer = pl.Trainer( **config["trainer"], callbacks=[best_checkpoint_callback, utils.LitProgressBar()], precision=32 ) # Training or Evaluating if config["train_settings"]["train"]: print ("-"*25 + " Starting Training " + "-"*25) trainer.fit( model=Model, train_dataloaders=Train_Dataloader, val_dataloaders=Val_Dataloader, ckpt_path=resume_ckpt_path ) # print ("Preliminary Evaluation of Training Dataset") # trainer.validate( # model=Model, # dataloaders=Train_Dataloader, # ckpt_path=resume_ckpt_path, # verbose=verbose # ) # print ("Preliminary Evaluation of Validation Dataset") # trainer.validate( # model=Model, # dataloaders=Val_Dataloader, # ckpt_path=resume_ckpt_path, # ) else: # Finding Best Threshold if best_threshold is None: print ("-"*10, "Calculating best_threshold", "-"*10) # Predictions on Validation Dataset val_y_pred_y_true = trainer.predict( model=Model, dataloaders=Val_Dataloader, ckpt_path=resume_ckpt_path ) val_y_pred, val_y_true = concatenate_predictions(y_pred_y_true=val_y_pred_y_true) # Calculating Threshold val_y_pred = val_y_pred[:, 1] val_y_true = np.argmax(val_y_true, axis=1) _, _, _, _, _, _, _, _, _, best_threshold = calculate_metrics(y_pred=val_y_pred, y_true=val_y_true, threshold=None) # Predictions on Test Dataset test_y_pred_y_true = trainer.predict( model=Model, dataloaders=Test_Dataloaders, ckpt_path=resume_ckpt_path ) if len(test_image_sources) == 1: test_set_metrics = [] y_pred, y_true = concatenate_predictions(y_pred_y_true=test_y_pred_y_true) y_pred = y_pred[:, 1] y_true = np.argmax(y_true, axis=1) ap, acc0, r_acc0, f_acc0, acc1, r_acc1, f_acc1, mcc0, mcc1, _ = calculate_metrics(y_pred=y_pred, y_true=y_true, threshold=best_threshold) test_set_metrics.append([0, ap, acc0, r_acc0, f_acc0, acc1, r_acc1, f_acc1, mcc0, mcc1]) return test_set_metrics, best_threshold test_set_metrics = [] for i, _ in enumerate(test_image_sources): y_pred, y_true = concatenate_predictions(y_pred_y_true=test_y_pred_y_true[i]) y_pred = y_pred[:, 1] y_true = np.argmax(y_true, axis=1) ap, acc0, r_acc0, f_acc0, acc1, r_acc1, f_acc1, mcc0, mcc1, _ = calculate_metrics(y_pred=y_pred, y_true=y_true, threshold=best_threshold) test_set_metrics.append([i, ap, acc0, r_acc0, f_acc0, acc1, r_acc1, f_acc1, mcc0, mcc1]) return test_set_metrics, best_threshold