""" PyTorch Lightning Module of training of deep-learning models Notes: - Using ".to(torch.float32)" to resolving precision issues while using different models. """ # Importing Libraries import numpy as np from sklearn.model_selection import train_test_split import torch torch.set_float32_matmul_precision('medium') import torch.nn as nn import torchvision from torch.utils.data import DataLoader import pytorch_lightning as pl import torchmetrics from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning import loggers as pl_loggers import torch torch.set_float32_matmul_precision('medium') import os, sys, warnings warnings.filterwarnings("ignore") sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) from yaml import safe_load from functions.dataset import Image_Dataset import functions.preprocess as preprocess from functions.loss_optimizers_metrics import * import functions.utils as utils import functions.module as module import defaults # Lightning Module class Model_LightningModule(pl.LightningModule): def __init__(self, classifier, config): super().__init__() self.save_hyperparameters() self.config = config # Model as Manual Arguments self.classifier = classifier # Loss self.train_lossfn = get_loss_function(**self.config["train_loss_fn"]) self.val_lossfn = get_loss_function(**self.config["val_loss_fn"]) # Metrics self.train_accuracy_fn = torchmetrics.Accuracy(task="binary") self.val_accuracy_fn = torchmetrics.Accuracy(task="binary") # Training-Step def training_step(self, batch, batch_idx): if len(batch) == 2: X, y_true = batch # Extracting features using Backbone Feature Extractor with torch.no_grad(): X = feature_extractor_module(X) else: X1, X2, y_true = batch # Extracting features using Backbone Feature Extractor with torch.no_grad(): X = feature_extractor_module(X1, X2) X = torch.flatten(X, start_dim=1).to(torch.float32) X_input = preprocess.select_feature_indices(X, self.config["dataset"]["f_model_name"]) y_true_classes = torch.argmax(y_true, dim=1) latent_features, y_pred = self.classifier(X_input) y_pred_classes = torch.argmax(y_pred, dim=1) self.train_loss = self.train_lossfn(latent_features, y_pred, y_true_classes) self.train_acc = self.train_accuracy_fn(y_pred_classes, y_true_classes) self.log_dict( { "train_loss": self.train_loss, "train_acc": self.train_acc }, on_step=True, on_epoch=False, prog_bar=True, sync_dist=True ) return self.train_loss # Validation-Step def validation_step(self, batch, batch_idx, dataloader_idx=0): if len(batch) == 2: X, y_true = batch # Extracting features using Backbone Feature Extractor with torch.no_grad(): X = feature_extractor_module(X) else: X1, X2, y_true = batch # Extracting features using Backbone Feature Extractor with torch.no_grad(): X = feature_extractor_module(X1, X2) X = torch.flatten(X, start_dim=1).to(torch.float32) X_input = preprocess.select_feature_indices(X, self.config["dataset"]["f_model_name"]) y_true_classes = torch.argmax(y_true, dim=1) latent_features, y_pred = self.classifier(X_input) y_pred_classes = torch.argmax(y_pred, dim=1) self.val_loss = self.val_lossfn(latent_features, y_pred, y_true_classes) self.val_acc = self.val_accuracy_fn(y_pred_classes, y_true_classes) self.log_dict( { "val_loss": self.val_loss, "val_acc": self.val_acc, }, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True ) # Prediction-Step def predict_step(self, batch, batch_idx, dataloader_idx=0): if len(batch) == 2: X, y_true = batch # Extracting features using Backbone Feature Extractor with torch.no_grad(): X = feature_extractor_module(X) else: X1, X2, y_true = batch # Extracting features using Backbone Feature Extractor with torch.no_grad(): X = feature_extractor_module(X1, X2) X = torch.flatten(X, start_dim=1).to(torch.float32) X_input = preprocess.select_feature_indices(X, self.config["dataset"]["f_model_name"]) y_true_classes = torch.argmax(y_true, dim=1) latent_features, y_pred = self.classifier(X_input) y_pred_classes = torch.argmax(y_pred, dim=1) return y_pred, y_true # Configure Optimizers def configure_optimizers(self): optimizer = get_optimizer( self.classifier.parameters(), **self.config["optimizer"] ) return [optimizer] # Main Function def run_on_images(feature_extractor, classifier, config, test_real_images_paths, test_fake_images_paths, preprocess_settings, best_threshold, verbose=True): # Parameters dataset_type = config["dataset"]["dataset_type"] separateAugmentation = config["dataset"]["separateAugmentation"] model_name = config["dataset"]["model_name"] f_model_name = config["dataset"]["f_model_name"] # Paths main_dataset_dir = defaults.main_dataset_dir main_checkpoints_dir = defaults.main_checkpoints_dir # Checkpoints Paths # Resume Checkpoints if config["checkpoints"]["resume_dirname"] is not None and config["checkpoints"]["resume_filename"] is not None: resume_ckpt_path = os.path.join(main_checkpoints_dir, config["checkpoints"]["resume_dirname"], f_model_name, config["checkpoints"]["resume_filename"]) else: resume_ckpt_path = None print (resume_ckpt_path) # Save Checkpoints checkpoint_dirpath = os.path.join(main_checkpoints_dir, config["checkpoints"]["checkpoint_dirname"], f_model_name) os.makedirs(checkpoint_dirpath, exist_ok=True) # Resuming from checkpoint if resume_ckpt_path is not None: if os.path.exists(resume_ckpt_path): print ("Found the checkpoint at resume_ckpt_path provided.") else: assert False, "Resume checkpoint not found at resume_ckpt_path provided." else: if config["train_settings"]["train"]: # For Training. print ("No path is provided for resume checkpoint (resume_ckpt_path) provided. Starting training from the begining.") else: assert False, "No path is provided for resume checkpoint (resume_ckpt_path) provided. resume_ckpt_path is required for evaluation." # Checkpoint Callbacks best_checkpoint_callback = ModelCheckpoint( dirpath=checkpoint_dirpath, filename="best_model", monitor=config["train_settings"]["monitor"], mode=config["train_settings"]["mode"] ) # Pre-processing Functions preprocessfn, dual_scale = preprocess.get_preprocessfn(**preprocess_settings) # Logging print () print (preprocessfn) print () # Datasets # Images Test Dataset if config["train_settings"]["train"] == False: # For images smaller than preprocess_settings["input_image_dimensions"] which only occur for BigGAN fake images in GenImage dataset, we do the following: """ - During inference, we avoid Resizing to reduce the effect of resizing artifacts. - We process the images at (224,224) or their smaller resolution unless the feature extraction model requires (224,224) inputs. """ if model_name == "resnet50" or model_name == "hyperiqa" or model_name == "tres" or model_name == "clip-resnet50" or model_name == "clip-vit-l-14": # Updated Pre-Processing Settings Fixed_Input_preprocess_settings = preprocess_settings.copy() Fixed_Input_preprocess_settings["input_image_dimensions"] = (224,224) # Preprocessing Function Fixed_Input_preprocessfn, Fixed_Input_dual_scale = preprocess.get_preprocessfn(**Fixed_Input_preprocess_settings) Test_Dataset = Image_Dataset( real_images_paths=test_real_images_paths, fake_images_paths=test_fake_images_paths, preprocessfn=Fixed_Input_preprocessfn, dual_scale=Fixed_Input_dual_scale, resize=preprocess_settings["resize"], separateAugmentation=separateAugmentation ) else: Test_Dataset = Image_Dataset( real_images_paths=test_real_images_paths, fake_images_paths=test_fake_images_paths, preprocessfn=preprocessfn, dual_scale=dual_scale, resize=preprocess_settings["resize"], separateAugmentation=separateAugmentation ) # DataLoaders # Test DataLoaders if config["train_settings"]["train"] == False: Test_Dataloader = DataLoader( dataset=Test_Dataset, batch_size=config["train_settings"]["batch_size"], num_workers=config["train_settings"]["num_workers"], shuffle=False, ) print ("-"*25 + " Datasets and DataLoaders Ready " + "-"*25) # Global Variables: (feature_extractor) global feature_extractor_module feature_extractor_module = feature_extractor # Detect device from feature extractor (it's already on the correct device) device = next(feature_extractor_module.parameters()).device feature_extractor_module.eval() for params in feature_extractor_module.parameters(): params.requires_grad = False # Lightning Module Model = Model_LightningModule(classifier, config) # PyTorch Lightning Trainer # Override accelerator based on detected device trainer_config = config["trainer"].copy() if device.type == "mps": trainer_config["accelerator"] = "mps" trainer_config["devices"] = 1 elif device.type == "cuda": trainer_config["accelerator"] = "cuda" trainer_config["devices"] = [device.index] if device.index is not None else [0] else: # cpu trainer_config["accelerator"] = "cpu" trainer_config["devices"] = 1 trainer = pl.Trainer( **trainer_config, callbacks=[best_checkpoint_callback, utils.LitProgressBar()], precision=32 ) # Evaluating # Predictions on Test Dataset test_y_pred_y_true = trainer.predict( model=Model, dataloaders=Test_Dataloader, ckpt_path=resume_ckpt_path ) test_set_metrics = [] y_pred, y_true = concatenate_predictions(y_pred_y_true=test_y_pred_y_true) y_pred = y_pred[:, 1] y_true = np.argmax(y_true, axis=1) ap, acc0, r_acc0, f_acc0, acc1, r_acc1, f_acc1, mcc0, mcc1, _ = calculate_metrics(y_pred=y_pred, y_true=y_true, threshold=best_threshold) test_set_metrics.append([0, ap, acc0, r_acc0, f_acc0, acc1, r_acc1, f_acc1, mcc0, mcc1]) return test_set_metrics, best_threshold, y_pred, y_true