krishnasrikard
Codes
2cda712
"""
PyTorch Lightning Module of training of deep-learning models
Notes:
- Using ".to(torch.float32)" to resolving precision issues while using different models.
"""
# Importing Libraries
import numpy as np
from sklearn.model_selection import train_test_split
import torch
torch.set_float32_matmul_precision('medium')
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import torchmetrics
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import loggers as pl_loggers
import os, sys, warnings
warnings.filterwarnings("ignore")
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))))
import random
from functions.dataset import Image_Dataset
import functions.dataset_utils as dataset_utils
from functions.loss_optimizers_metrics import *
import functions.utils as utils
import prior_methods.prior_functions.prior_preprocess as prior_preprocess
import defaults
# Lightning Module
class Model_LightningModule(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.save_hyperparameters()
self.config = config
# Loss
self.train_lossfn = get_loss_function(**self.config["train_loss_fn"])
self.val_lossfn = get_loss_function(**self.config["val_loss_fn"])
# Metrics
self.train_accuracy_fn = torchmetrics.Accuracy(task="binary")
self.val_accuracy_fn = torchmetrics.Accuracy(task="binary")
# Training-Step
def training_step(self, batch, batch_idx):
with torch.no_grad():
X, y_true = batch
X = feature_extractor_module(X)
y_true_classes = torch.argmax(y_true, dim=1)
y_pred = classifier_module(X)
if y_pred.shape[1] == 1:
y_pred_classes = y_pred >= 0.5
else:
y_pred_classes = torch.argmax(y_pred, dim=1)
self.train_loss = self.train_lossfn(y_pred, y_true_classes)
self.train_acc = self.train_accuracy_fn(y_pred_classes, y_true_classes)
self.log_dict(
{
"train_loss": self.train_loss,
"train_acc": self.train_acc
},
on_step=True, on_epoch=False, prog_bar=True, sync_dist=True
)
return self.train_loss
# Validation-Step
def validation_step(self, batch, batch_idx, dataloader_idx=0):
with torch.no_grad():
X, y_true = batch
X = feature_extractor_module(X)
y_true_classes = torch.argmax(y_true, dim=1)
y_pred = classifier_module(X)
if y_pred.shape[1] == 1:
y_pred_classes = y_pred >= 0.5
else:
y_pred_classes = torch.argmax(y_pred, dim=1)
self.val_loss = self.val_lossfn(y_pred, y_true_classes)
self.val_acc = self.val_accuracy_fn(y_pred_classes, y_true_classes)
self.log_dict(
{
"val_loss": self.val_loss,
"val_acc": self.val_acc,
},
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True
)
# Prediction-Step
def predict_step(self, batch, batch_idx, dataloader_idx=0):
with torch.no_grad():
X, y_true = batch
X = feature_extractor_module(X)
y_true_classes = torch.argmax(y_true, dim=1)
y_pred = classifier_module(X)
if y_pred.shape[1] == 1:
y_pred_classes = y_pred >= 0.5
else:
y_pred_classes = torch.argmax(y_pred, dim=1)
return y_pred, y_true
# Configure Optimizers
def configure_optimizers(self):
optimizer = get_optimizer(
None
**self.config["optimizer"]
)
return [optimizer]
# Main Function
def run(feature_extractor, classifier, config, train_image_sources, test_image_sources, preprocess_settings, best_threshold, verbose=True):
# Parameters
dataset_type = config["dataset"]["dataset_type"]
separateAugmentation = config["dataset"]["separateAugmentation"]
model_name = config["dataset"]["model_name"]
f_model_name = config["dataset"]["f_model_name"]
# Paths
main_dataset_dir = defaults.main_dataset_dir
main_checkpoints_dir = defaults.main_prior_checkpoints_dir
# Checkpoints Paths
# Resume Checkpoints
if config["checkpoints"]["resume_dirname"] is not None and config["checkpoints"]["resume_filename"] is not None:
resume_ckpt_path = os.path.join(main_checkpoints_dir, config["checkpoints"]["resume_dirname"], f_model_name, config["checkpoints"]["resume_filename"])
else:
resume_ckpt_path = None
print (resume_ckpt_path)
# Save Checkpoints
checkpoint_dirpath = os.path.join(main_checkpoints_dir, dataset_type, config["checkpoints"]["checkpoint_dirname"], f_model_name)
os.makedirs(checkpoint_dirpath, exist_ok=True)
# Resuming from checkpoint
if resume_ckpt_path is not None:
if os.path.exists(resume_ckpt_path):
print ("Found the checkpoint at resume_ckpt_path provided.")
else:
assert False, "Resume checkpoint not found at resume_ckpt_path provided."
else:
if config["train_settings"]["train"]:
# For Training.
print ("No path is provided for resume checkpoint (resume_ckpt_path) provided. Starting training from the begining.")
else:
assert False, "No path is provided for resume checkpoint (resume_ckpt_path) provided. resume_ckpt_path is required for evaluation."
# Checkpoint Callbacks
best_checkpoint_callback = ModelCheckpoint(
dirpath=checkpoint_dirpath,
filename="best_model",
monitor=config["train_settings"]["monitor"],
mode=config["train_settings"]["mode"]
)
# Pre-processing Functions
preprocessfn, dual_scale = prior_preprocess.get_preprocessfn(**preprocess_settings)
# Logging
print ()
print (preprocessfn)
print ()
# Datasets
# Images Train and Val Paths
train_val_real_images_paths, train_val_fake_images_paths = dataset_utils.dataset_img_paths(
dataset_type=dataset_type,
status="train"
)
# Train-Val Split
train_val_real_images_paths.sort()
train_val_fake_images_paths.sort()
random.Random(0).shuffle(train_val_real_images_paths)
random.Random(0).shuffle(train_val_fake_images_paths)
train_real_images_paths, val_real_images_paths = train_val_real_images_paths[:int(0.8 * len(train_val_real_images_paths))], train_val_real_images_paths[int(0.8 * len(train_val_real_images_paths)):]
train_fake_images_paths, val_fake_images_paths = train_val_fake_images_paths[:int(0.8 * len(train_val_fake_images_paths))], train_val_fake_images_paths[int(0.8 * len(train_val_fake_images_paths)):]
# Images Validation Dataset
Val_Dataset = Image_Dataset(
real_images_paths=val_real_images_paths,
fake_images_paths=val_fake_images_paths,
preprocessfn=preprocessfn,
dual_scale=dual_scale,
resize=preprocess_settings["resize"],
separateAugmentation=separateAugmentation
)
# Images Test Dataset
if config["train_settings"]["train"] == False:
Test_Datasets = []
for _,source in enumerate(test_image_sources):
test_real_images_paths = dataset_utils.get_image_paths(
dataset_type=dataset_type,
status="val",
image_sources=[source],
label="real"
)
test_fake_images_paths = dataset_utils.get_image_paths(
dataset_type=dataset_type,
status="val",
image_sources=[source],
label="fake"
)
"""
- During inference, we avoid Resizing to reduce the effect of resizing artifacts.
- We process the images at (224,224) or their smaller resolution unless the feature extraction model requires (224,224) inputs.
"""
if model_name == "clip-resnet50" or model_name == "clip-vit-l-14" or model_name == "drct-clip-vit-l-14" or model_name == "drct-convnext-b":
# Updated Pre-Processing Settings
Fixed_Input_preprocess_settings = preprocess_settings.copy()
Fixed_Input_preprocess_settings["input_image_dimensions"] = (224,224)
# Preprocessing Function
Fixed_Input_preprocessfn, Fixed_Input_dual_scale = prior_preprocess.get_preprocessfn(**Fixed_Input_preprocess_settings)
Test_Datasets.append(
Image_Dataset(
real_images_paths=test_real_images_paths,
fake_images_paths=test_fake_images_paths,
preprocessfn=Fixed_Input_preprocessfn,
dual_scale=Fixed_Input_dual_scale,
resize=preprocess_settings["resize"],
separateAugmentation=separateAugmentation
)
)
else:
Test_Datasets.append(
Image_Dataset(
real_images_paths=test_real_images_paths,
fake_images_paths=test_fake_images_paths,
preprocessfn=preprocessfn,
dual_scale=dual_scale,
resize=preprocess_settings["resize"],
separateAugmentation=separateAugmentation
)
)
# DataLoaders
# Val DataLoader
Val_Dataloader = DataLoader(
dataset=Val_Dataset,
batch_size=config["train_settings"]["batch_size"],
num_workers=config["train_settings"]["num_workers"],
shuffle=False,
)
# Test DataLoaders
if config["train_settings"]["train"] == False:
Test_Dataloaders = []
for i,_ in enumerate(test_image_sources):
Test_Dataloaders.append(
DataLoader(
dataset=Test_Datasets[i],
batch_size=config["train_settings"]["batch_size"],
num_workers=config["train_settings"]["num_workers"],
shuffle=False,
)
)
print ("-"*25 + " Datasets and DataLoaders Ready " + "-"*25)
# Global Variables: (feature_extractor, classifier)
global feature_extractor_module
feature_extractor_module = feature_extractor
feature_extractor_module.to("cuda")
feature_extractor_module.eval()
for params in feature_extractor_module.parameters():
params.requires_grad = False
global classifier_module
classifier_module = classifier
classifier_module.to("cuda")
classifier_module.eval()
for params in classifier_module.parameters():
classifier_module.requires_grad = False
# Lightning Module
Model = Model_LightningModule(config)
# PyTorch Lightning Trainer
trainer = pl.Trainer(
**config["trainer"],
callbacks=[utils.LitProgressBar()],
precision=32
)
# Evaluating
# Finding Best Threshold
if best_threshold is None:
print ("-"*10, "Calculating best_threshold", "-"*10)
# Predictions on Validation Dataset
val_y_pred_y_true = trainer.predict(
model=Model,
dataloaders=Val_Dataloader,
)
val_y_pred, val_y_true = concatenate_predictions(y_pred_y_true=val_y_pred_y_true)
# Calculating Threshold
val_y_pred = val_y_pred[:, 1]
val_y_true = np.argmax(val_y_true, axis=1)
_, _, _, _, _, _, _, _, _, best_threshold = calculate_metrics(y_pred=val_y_pred, y_true=val_y_true, threshold=None)
# Predictions on Test Dataset
test_y_pred_y_true = trainer.predict(
model=Model,
dataloaders=Test_Dataloaders,
)
if len(test_image_sources) == 1:
test_set_metrics = []
y_pred, y_true = concatenate_predictions(y_pred_y_true=test_y_pred_y_true)
y_pred = y_pred[:, 1]
y_true = np.argmax(y_true, axis=1)
ap, acc0, r_acc0, f_acc0, acc1, r_acc1, f_acc1, mcc0, mcc1, _ = calculate_metrics(y_pred=y_pred, y_true=y_true, threshold=best_threshold)
test_set_metrics.append([0, ap, acc0, r_acc0, f_acc0, acc1, r_acc1, f_acc1, mcc0, mcc1])
return test_set_metrics, best_threshold
test_set_metrics = []
for i, _ in enumerate(test_image_sources):
y_pred, y_true = concatenate_predictions(y_pred_y_true=test_y_pred_y_true[i])
y_pred = y_pred[:, 1]
y_true = np.argmax(y_true, axis=1)
ap, acc0, r_acc0, f_acc0, acc1, r_acc1, f_acc1, mcc0, mcc1, _ = calculate_metrics(y_pred=y_pred, y_true=y_true, threshold=best_threshold)
test_set_metrics.append([i, ap, acc0, r_acc0, f_acc0, acc1, r_acc1, f_acc1, mcc0, mcc1])
return test_set_metrics, best_threshold