Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| PyTorch Lightning Module of training of deep-learning models | |
| Notes: | |
| - Using ".to(torch.float32)" to resolving precision issues while using different models. | |
| """ | |
| # Importing Libraries | |
| import numpy as np | |
| from sklearn.model_selection import train_test_split | |
| import torch | |
| torch.set_float32_matmul_precision('medium') | |
| import torch.nn as nn | |
| import torchvision | |
| from torch.utils.data import DataLoader | |
| import pytorch_lightning as pl | |
| import torchmetrics | |
| from pytorch_lightning.callbacks import ModelCheckpoint | |
| from pytorch_lightning import loggers as pl_loggers | |
| import os, sys, warnings | |
| warnings.filterwarnings("ignore") | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))) | |
| import random | |
| from functions.dataset import Image_Dataset | |
| import functions.dataset_utils as dataset_utils | |
| from functions.loss_optimizers_metrics import * | |
| import functions.utils as utils | |
| import prior_methods.prior_functions.prior_preprocess as prior_preprocess | |
| import defaults | |
| # Lightning Module | |
| class Model_LightningModule(pl.LightningModule): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.config = config | |
| # Loss | |
| self.train_lossfn = get_loss_function(**self.config["train_loss_fn"]) | |
| self.val_lossfn = get_loss_function(**self.config["val_loss_fn"]) | |
| # Metrics | |
| self.train_accuracy_fn = torchmetrics.Accuracy(task="binary") | |
| self.val_accuracy_fn = torchmetrics.Accuracy(task="binary") | |
| # Training-Step | |
| def training_step(self, batch, batch_idx): | |
| with torch.no_grad(): | |
| X, y_true = batch | |
| X = feature_extractor_module(X) | |
| y_true_classes = torch.argmax(y_true, dim=1) | |
| y_pred = classifier_module(X) | |
| if y_pred.shape[1] == 1: | |
| y_pred_classes = y_pred >= 0.5 | |
| else: | |
| y_pred_classes = torch.argmax(y_pred, dim=1) | |
| self.train_loss = self.train_lossfn(y_pred, y_true_classes) | |
| self.train_acc = self.train_accuracy_fn(y_pred_classes, y_true_classes) | |
| self.log_dict( | |
| { | |
| "train_loss": self.train_loss, | |
| "train_acc": self.train_acc | |
| }, | |
| on_step=True, on_epoch=False, prog_bar=True, sync_dist=True | |
| ) | |
| return self.train_loss | |
| # Validation-Step | |
| def validation_step(self, batch, batch_idx, dataloader_idx=0): | |
| with torch.no_grad(): | |
| X, y_true = batch | |
| X = feature_extractor_module(X) | |
| y_true_classes = torch.argmax(y_true, dim=1) | |
| y_pred = classifier_module(X) | |
| if y_pred.shape[1] == 1: | |
| y_pred_classes = y_pred >= 0.5 | |
| else: | |
| y_pred_classes = torch.argmax(y_pred, dim=1) | |
| self.val_loss = self.val_lossfn(y_pred, y_true_classes) | |
| self.val_acc = self.val_accuracy_fn(y_pred_classes, y_true_classes) | |
| self.log_dict( | |
| { | |
| "val_loss": self.val_loss, | |
| "val_acc": self.val_acc, | |
| }, | |
| on_step=False, on_epoch=True, prog_bar=True, sync_dist=True | |
| ) | |
| # Prediction-Step | |
| def predict_step(self, batch, batch_idx, dataloader_idx=0): | |
| with torch.no_grad(): | |
| X, y_true = batch | |
| X = feature_extractor_module(X) | |
| y_true_classes = torch.argmax(y_true, dim=1) | |
| y_pred = classifier_module(X) | |
| if y_pred.shape[1] == 1: | |
| y_pred_classes = y_pred >= 0.5 | |
| else: | |
| y_pred_classes = torch.argmax(y_pred, dim=1) | |
| return y_pred, y_true | |
| # Configure Optimizers | |
| def configure_optimizers(self): | |
| optimizer = get_optimizer( | |
| None | |
| **self.config["optimizer"] | |
| ) | |
| return [optimizer] | |
| # Main Function | |
| def run(feature_extractor, classifier, config, train_image_sources, test_image_sources, preprocess_settings, best_threshold, verbose=True): | |
| # Parameters | |
| dataset_type = config["dataset"]["dataset_type"] | |
| separateAugmentation = config["dataset"]["separateAugmentation"] | |
| model_name = config["dataset"]["model_name"] | |
| f_model_name = config["dataset"]["f_model_name"] | |
| # Paths | |
| main_dataset_dir = defaults.main_dataset_dir | |
| main_checkpoints_dir = defaults.main_prior_checkpoints_dir | |
| # Checkpoints Paths | |
| # Resume Checkpoints | |
| if config["checkpoints"]["resume_dirname"] is not None and config["checkpoints"]["resume_filename"] is not None: | |
| resume_ckpt_path = os.path.join(main_checkpoints_dir, config["checkpoints"]["resume_dirname"], f_model_name, config["checkpoints"]["resume_filename"]) | |
| else: | |
| resume_ckpt_path = None | |
| print (resume_ckpt_path) | |
| # Save Checkpoints | |
| checkpoint_dirpath = os.path.join(main_checkpoints_dir, dataset_type, config["checkpoints"]["checkpoint_dirname"], f_model_name) | |
| os.makedirs(checkpoint_dirpath, exist_ok=True) | |
| # Resuming from checkpoint | |
| if resume_ckpt_path is not None: | |
| if os.path.exists(resume_ckpt_path): | |
| print ("Found the checkpoint at resume_ckpt_path provided.") | |
| else: | |
| assert False, "Resume checkpoint not found at resume_ckpt_path provided." | |
| else: | |
| if config["train_settings"]["train"]: | |
| # For Training. | |
| print ("No path is provided for resume checkpoint (resume_ckpt_path) provided. Starting training from the begining.") | |
| else: | |
| assert False, "No path is provided for resume checkpoint (resume_ckpt_path) provided. resume_ckpt_path is required for evaluation." | |
| # Checkpoint Callbacks | |
| best_checkpoint_callback = ModelCheckpoint( | |
| dirpath=checkpoint_dirpath, | |
| filename="best_model", | |
| monitor=config["train_settings"]["monitor"], | |
| mode=config["train_settings"]["mode"] | |
| ) | |
| # Pre-processing Functions | |
| preprocessfn, dual_scale = prior_preprocess.get_preprocessfn(**preprocess_settings) | |
| # Logging | |
| print () | |
| print (preprocessfn) | |
| print () | |
| # Datasets | |
| # Images Train and Val Paths | |
| train_val_real_images_paths, train_val_fake_images_paths = dataset_utils.dataset_img_paths( | |
| dataset_type=dataset_type, | |
| status="train" | |
| ) | |
| # Train-Val Split | |
| train_val_real_images_paths.sort() | |
| train_val_fake_images_paths.sort() | |
| random.Random(0).shuffle(train_val_real_images_paths) | |
| random.Random(0).shuffle(train_val_fake_images_paths) | |
| train_real_images_paths, val_real_images_paths = train_val_real_images_paths[:int(0.8 * len(train_val_real_images_paths))], train_val_real_images_paths[int(0.8 * len(train_val_real_images_paths)):] | |
| train_fake_images_paths, val_fake_images_paths = train_val_fake_images_paths[:int(0.8 * len(train_val_fake_images_paths))], train_val_fake_images_paths[int(0.8 * len(train_val_fake_images_paths)):] | |
| # Images Validation Dataset | |
| Val_Dataset = Image_Dataset( | |
| real_images_paths=val_real_images_paths, | |
| fake_images_paths=val_fake_images_paths, | |
| preprocessfn=preprocessfn, | |
| dual_scale=dual_scale, | |
| resize=preprocess_settings["resize"], | |
| separateAugmentation=separateAugmentation | |
| ) | |
| # Images Test Dataset | |
| if config["train_settings"]["train"] == False: | |
| Test_Datasets = [] | |
| for _,source in enumerate(test_image_sources): | |
| test_real_images_paths = dataset_utils.get_image_paths( | |
| dataset_type=dataset_type, | |
| status="val", | |
| image_sources=[source], | |
| label="real" | |
| ) | |
| test_fake_images_paths = dataset_utils.get_image_paths( | |
| dataset_type=dataset_type, | |
| status="val", | |
| image_sources=[source], | |
| label="fake" | |
| ) | |
| """ | |
| - During inference, we avoid Resizing to reduce the effect of resizing artifacts. | |
| - We process the images at (224,224) or their smaller resolution unless the feature extraction model requires (224,224) inputs. | |
| """ | |
| if model_name == "clip-resnet50" or model_name == "clip-vit-l-14" or model_name == "drct-clip-vit-l-14" or model_name == "drct-convnext-b": | |
| # Updated Pre-Processing Settings | |
| Fixed_Input_preprocess_settings = preprocess_settings.copy() | |
| Fixed_Input_preprocess_settings["input_image_dimensions"] = (224,224) | |
| # Preprocessing Function | |
| Fixed_Input_preprocessfn, Fixed_Input_dual_scale = prior_preprocess.get_preprocessfn(**Fixed_Input_preprocess_settings) | |
| Test_Datasets.append( | |
| Image_Dataset( | |
| real_images_paths=test_real_images_paths, | |
| fake_images_paths=test_fake_images_paths, | |
| preprocessfn=Fixed_Input_preprocessfn, | |
| dual_scale=Fixed_Input_dual_scale, | |
| resize=preprocess_settings["resize"], | |
| separateAugmentation=separateAugmentation | |
| ) | |
| ) | |
| else: | |
| Test_Datasets.append( | |
| Image_Dataset( | |
| real_images_paths=test_real_images_paths, | |
| fake_images_paths=test_fake_images_paths, | |
| preprocessfn=preprocessfn, | |
| dual_scale=dual_scale, | |
| resize=preprocess_settings["resize"], | |
| separateAugmentation=separateAugmentation | |
| ) | |
| ) | |
| # DataLoaders | |
| # Val DataLoader | |
| Val_Dataloader = DataLoader( | |
| dataset=Val_Dataset, | |
| batch_size=config["train_settings"]["batch_size"], | |
| num_workers=config["train_settings"]["num_workers"], | |
| shuffle=False, | |
| ) | |
| # Test DataLoaders | |
| if config["train_settings"]["train"] == False: | |
| Test_Dataloaders = [] | |
| for i,_ in enumerate(test_image_sources): | |
| Test_Dataloaders.append( | |
| DataLoader( | |
| dataset=Test_Datasets[i], | |
| batch_size=config["train_settings"]["batch_size"], | |
| num_workers=config["train_settings"]["num_workers"], | |
| shuffle=False, | |
| ) | |
| ) | |
| print ("-"*25 + " Datasets and DataLoaders Ready " + "-"*25) | |
| # Global Variables: (feature_extractor, classifier) | |
| global feature_extractor_module | |
| feature_extractor_module = feature_extractor | |
| feature_extractor_module.to("cuda") | |
| feature_extractor_module.eval() | |
| for params in feature_extractor_module.parameters(): | |
| params.requires_grad = False | |
| global classifier_module | |
| classifier_module = classifier | |
| classifier_module.to("cuda") | |
| classifier_module.eval() | |
| for params in classifier_module.parameters(): | |
| classifier_module.requires_grad = False | |
| # Lightning Module | |
| Model = Model_LightningModule(config) | |
| # PyTorch Lightning Trainer | |
| trainer = pl.Trainer( | |
| **config["trainer"], | |
| callbacks=[utils.LitProgressBar()], | |
| precision=32 | |
| ) | |
| # Evaluating | |
| # Finding Best Threshold | |
| if best_threshold is None: | |
| print ("-"*10, "Calculating best_threshold", "-"*10) | |
| # Predictions on Validation Dataset | |
| val_y_pred_y_true = trainer.predict( | |
| model=Model, | |
| dataloaders=Val_Dataloader, | |
| ) | |
| val_y_pred, val_y_true = concatenate_predictions(y_pred_y_true=val_y_pred_y_true) | |
| # Calculating Threshold | |
| val_y_pred = val_y_pred[:, 1] | |
| val_y_true = np.argmax(val_y_true, axis=1) | |
| _, _, _, _, _, _, _, _, _, best_threshold = calculate_metrics(y_pred=val_y_pred, y_true=val_y_true, threshold=None) | |
| # Predictions on Test Dataset | |
| test_y_pred_y_true = trainer.predict( | |
| model=Model, | |
| dataloaders=Test_Dataloaders, | |
| ) | |
| if len(test_image_sources) == 1: | |
| test_set_metrics = [] | |
| y_pred, y_true = concatenate_predictions(y_pred_y_true=test_y_pred_y_true) | |
| y_pred = y_pred[:, 1] | |
| y_true = np.argmax(y_true, axis=1) | |
| ap, acc0, r_acc0, f_acc0, acc1, r_acc1, f_acc1, mcc0, mcc1, _ = calculate_metrics(y_pred=y_pred, y_true=y_true, threshold=best_threshold) | |
| test_set_metrics.append([0, ap, acc0, r_acc0, f_acc0, acc1, r_acc1, f_acc1, mcc0, mcc1]) | |
| return test_set_metrics, best_threshold | |
| test_set_metrics = [] | |
| for i, _ in enumerate(test_image_sources): | |
| y_pred, y_true = concatenate_predictions(y_pred_y_true=test_y_pred_y_true[i]) | |
| y_pred = y_pred[:, 1] | |
| y_true = np.argmax(y_true, axis=1) | |
| ap, acc0, r_acc0, f_acc0, acc1, r_acc1, f_acc1, mcc0, mcc1, _ = calculate_metrics(y_pred=y_pred, y_true=y_true, threshold=best_threshold) | |
| test_set_metrics.append([i, ap, acc0, r_acc0, f_acc0, acc1, r_acc1, f_acc1, mcc0, mcc1]) | |
| return test_set_metrics, best_threshold |