detectNanoBananaImage2 / functions /run_on_images_fn.py
gh-rgupta
Add Git LFS configuration and update test files for Mac CPU compatibility
94421ed
"""
PyTorch Lightning Module of training of deep-learning models
Notes:
- Using ".to(torch.float32)" to resolving precision issues while using different models.
"""
# Importing Libraries
import numpy as np
from sklearn.model_selection import train_test_split
import torch
torch.set_float32_matmul_precision('medium')
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import torchmetrics
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import loggers as pl_loggers
import torch
torch.set_float32_matmul_precision('medium')
import os, sys, warnings
warnings.filterwarnings("ignore")
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
from yaml import safe_load
from functions.dataset import Image_Dataset
import functions.preprocess as preprocess
from functions.loss_optimizers_metrics import *
import functions.utils as utils
import functions.module as module
import defaults
# Lightning Module
class Model_LightningModule(pl.LightningModule):
def __init__(self, classifier, config):
super().__init__()
self.save_hyperparameters()
self.config = config
# Model as Manual Arguments
self.classifier = classifier
# Loss
self.train_lossfn = get_loss_function(**self.config["train_loss_fn"])
self.val_lossfn = get_loss_function(**self.config["val_loss_fn"])
# Metrics
self.train_accuracy_fn = torchmetrics.Accuracy(task="binary")
self.val_accuracy_fn = torchmetrics.Accuracy(task="binary")
# Training-Step
def training_step(self, batch, batch_idx):
if len(batch) == 2:
X, y_true = batch
# Extracting features using Backbone Feature Extractor
with torch.no_grad():
X = feature_extractor_module(X)
else:
X1, X2, y_true = batch
# Extracting features using Backbone Feature Extractor
with torch.no_grad():
X = feature_extractor_module(X1, X2)
X = torch.flatten(X, start_dim=1).to(torch.float32)
X_input = preprocess.select_feature_indices(X, self.config["dataset"]["f_model_name"])
y_true_classes = torch.argmax(y_true, dim=1)
latent_features, y_pred = self.classifier(X_input)
y_pred_classes = torch.argmax(y_pred, dim=1)
self.train_loss = self.train_lossfn(latent_features, y_pred, y_true_classes)
self.train_acc = self.train_accuracy_fn(y_pred_classes, y_true_classes)
self.log_dict(
{
"train_loss": self.train_loss,
"train_acc": self.train_acc
},
on_step=True, on_epoch=False, prog_bar=True, sync_dist=True
)
return self.train_loss
# Validation-Step
def validation_step(self, batch, batch_idx, dataloader_idx=0):
if len(batch) == 2:
X, y_true = batch
# Extracting features using Backbone Feature Extractor
with torch.no_grad():
X = feature_extractor_module(X)
else:
X1, X2, y_true = batch
# Extracting features using Backbone Feature Extractor
with torch.no_grad():
X = feature_extractor_module(X1, X2)
X = torch.flatten(X, start_dim=1).to(torch.float32)
X_input = preprocess.select_feature_indices(X, self.config["dataset"]["f_model_name"])
y_true_classes = torch.argmax(y_true, dim=1)
latent_features, y_pred = self.classifier(X_input)
y_pred_classes = torch.argmax(y_pred, dim=1)
self.val_loss = self.val_lossfn(latent_features, y_pred, y_true_classes)
self.val_acc = self.val_accuracy_fn(y_pred_classes, y_true_classes)
self.log_dict(
{
"val_loss": self.val_loss,
"val_acc": self.val_acc,
},
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True
)
# Prediction-Step
def predict_step(self, batch, batch_idx, dataloader_idx=0):
if len(batch) == 2:
X, y_true = batch
# Extracting features using Backbone Feature Extractor
with torch.no_grad():
X = feature_extractor_module(X)
else:
X1, X2, y_true = batch
# Extracting features using Backbone Feature Extractor
with torch.no_grad():
X = feature_extractor_module(X1, X2)
X = torch.flatten(X, start_dim=1).to(torch.float32)
X_input = preprocess.select_feature_indices(X, self.config["dataset"]["f_model_name"])
y_true_classes = torch.argmax(y_true, dim=1)
latent_features, y_pred = self.classifier(X_input)
y_pred_classes = torch.argmax(y_pred, dim=1)
return y_pred, y_true
# Configure Optimizers
def configure_optimizers(self):
optimizer = get_optimizer(
self.classifier.parameters(),
**self.config["optimizer"]
)
return [optimizer]
# Main Function
def run_on_images(feature_extractor, classifier, config, test_real_images_paths, test_fake_images_paths, preprocess_settings, best_threshold, verbose=True):
# Parameters
dataset_type = config["dataset"]["dataset_type"]
separateAugmentation = config["dataset"]["separateAugmentation"]
model_name = config["dataset"]["model_name"]
f_model_name = config["dataset"]["f_model_name"]
# Paths
main_dataset_dir = defaults.main_dataset_dir
main_checkpoints_dir = defaults.main_checkpoints_dir
# Checkpoints Paths
# Resume Checkpoints
if config["checkpoints"]["resume_dirname"] is not None and config["checkpoints"]["resume_filename"] is not None:
resume_ckpt_path = os.path.join(main_checkpoints_dir, config["checkpoints"]["resume_dirname"], f_model_name, config["checkpoints"]["resume_filename"])
else:
resume_ckpt_path = None
print (resume_ckpt_path)
# Save Checkpoints
checkpoint_dirpath = os.path.join(main_checkpoints_dir, config["checkpoints"]["checkpoint_dirname"], f_model_name)
os.makedirs(checkpoint_dirpath, exist_ok=True)
# Resuming from checkpoint
if resume_ckpt_path is not None:
if os.path.exists(resume_ckpt_path):
print ("Found the checkpoint at resume_ckpt_path provided.")
else:
assert False, "Resume checkpoint not found at resume_ckpt_path provided."
else:
if config["train_settings"]["train"]:
# For Training.
print ("No path is provided for resume checkpoint (resume_ckpt_path) provided. Starting training from the begining.")
else:
assert False, "No path is provided for resume checkpoint (resume_ckpt_path) provided. resume_ckpt_path is required for evaluation."
# Checkpoint Callbacks
best_checkpoint_callback = ModelCheckpoint(
dirpath=checkpoint_dirpath,
filename="best_model",
monitor=config["train_settings"]["monitor"],
mode=config["train_settings"]["mode"]
)
# Pre-processing Functions
preprocessfn, dual_scale = preprocess.get_preprocessfn(**preprocess_settings)
# Logging
print ()
print (preprocessfn)
print ()
# Datasets
# Images Test Dataset
if config["train_settings"]["train"] == False:
# For images smaller than preprocess_settings["input_image_dimensions"] which only occur for BigGAN fake images in GenImage dataset, we do the following:
"""
- During inference, we avoid Resizing to reduce the effect of resizing artifacts.
- We process the images at (224,224) or their smaller resolution unless the feature extraction model requires (224,224) inputs.
"""
if model_name == "resnet50" or model_name == "hyperiqa" or model_name == "tres" or model_name == "clip-resnet50" or model_name == "clip-vit-l-14":
# Updated Pre-Processing Settings
Fixed_Input_preprocess_settings = preprocess_settings.copy()
Fixed_Input_preprocess_settings["input_image_dimensions"] = (224,224)
# Preprocessing Function
Fixed_Input_preprocessfn, Fixed_Input_dual_scale = preprocess.get_preprocessfn(**Fixed_Input_preprocess_settings)
Test_Dataset = Image_Dataset(
real_images_paths=test_real_images_paths,
fake_images_paths=test_fake_images_paths,
preprocessfn=Fixed_Input_preprocessfn,
dual_scale=Fixed_Input_dual_scale,
resize=preprocess_settings["resize"],
separateAugmentation=separateAugmentation
)
else:
Test_Dataset = Image_Dataset(
real_images_paths=test_real_images_paths,
fake_images_paths=test_fake_images_paths,
preprocessfn=preprocessfn,
dual_scale=dual_scale,
resize=preprocess_settings["resize"],
separateAugmentation=separateAugmentation
)
# DataLoaders
# Test DataLoaders
if config["train_settings"]["train"] == False:
Test_Dataloader = DataLoader(
dataset=Test_Dataset,
batch_size=config["train_settings"]["batch_size"],
num_workers=config["train_settings"]["num_workers"],
shuffle=False,
)
print ("-"*25 + " Datasets and DataLoaders Ready " + "-"*25)
# Global Variables: (feature_extractor)
global feature_extractor_module
feature_extractor_module = feature_extractor
# Detect device from feature extractor (it's already on the correct device)
device = next(feature_extractor_module.parameters()).device
feature_extractor_module.eval()
for params in feature_extractor_module.parameters():
params.requires_grad = False
# Lightning Module
Model = Model_LightningModule(classifier, config)
# PyTorch Lightning Trainer
# Override accelerator based on detected device
trainer_config = config["trainer"].copy()
if device.type == "mps":
trainer_config["accelerator"] = "mps"
trainer_config["devices"] = 1
elif device.type == "cuda":
trainer_config["accelerator"] = "cuda"
trainer_config["devices"] = [device.index] if device.index is not None else [0]
else: # cpu
trainer_config["accelerator"] = "cpu"
trainer_config["devices"] = 1
trainer = pl.Trainer(
**trainer_config,
callbacks=[best_checkpoint_callback, utils.LitProgressBar()],
precision=32
)
# Evaluating
# Predictions on Test Dataset
test_y_pred_y_true = trainer.predict(
model=Model,
dataloaders=Test_Dataloader,
ckpt_path=resume_ckpt_path
)
test_set_metrics = []
y_pred, y_true = concatenate_predictions(y_pred_y_true=test_y_pred_y_true)
y_pred = y_pred[:, 1]
y_true = np.argmax(y_true, axis=1)
ap, acc0, r_acc0, f_acc0, acc1, r_acc1, f_acc1, mcc0, mcc1, _ = calculate_metrics(y_pred=y_pred, y_true=y_true, threshold=best_threshold)
test_set_metrics.append([0, ap, acc0, r_acc0, f_acc0, acc1, r_acc1, f_acc1, mcc0, mcc1])
return test_set_metrics, best_threshold, y_pred, y_true