# Importing Libraries import numpy as np from PIL import Image import torch from torchvision import transforms from torchinfo import summary import os,sys,warnings sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) warnings.filterwarnings("ignore") import functions.hyperiqa_functions as hyperiqa_functions import defaults class Compute_HyperIQA(torch.nn.Module): def __init__(self, model_path:str, device:str ): """ Args: model_path (str): Path to weights of HyperIQA model. device (str): Device used while computing features. """ super().__init__() # Device if device is None: self.device = "cuda" if torch.cuda.is_available() else "cpu" else: self.device = device # Load HyperIQA Model self.HyperIQA = hyperiqa_functions.HyperNet(16, 112, 224, 112, 56, 28, 14, 7) self.HyperIQA = self.HyperIQA.to(self.device) self.HyperIQA.load_state_dict(torch.load(model_path, map_location=self.device)) self.HyperIQA.eval() for param in self.HyperIQA.parameters(): param.requires_grad = False def forward(self, img): # Parameters for Target Network parameters = self.HyperIQA(img) semantic_features = parameters["pooled_semantic_features"] # Target Network target_model = hyperiqa_functions.TargetNet(parameters).to(self.device) for param in target_model.parameters(): param.requires_grad = False q,q1,q2,q3,q4 = target_model(parameters["target_in_vec"]) feat_batch = torch.hstack(( torch.flatten(semantic_features, start_dim=1), torch.flatten(parameters["target_in_vec"], start_dim=1), torch.flatten(q1, start_dim=1), torch.flatten(q2, start_dim=1), torch.flatten(q3, start_dim=1), torch.flatten(q4, start_dim=1), )) return feat_batch # Calling Main function if __name__ == '__main__': F = Compute_HyperIQA(model_path=os.path.join("/mnt/LIVELAB2/Detecting-AI-Generated-Images", "feature_extractor_checkpoints/hyperiqa_model/koniq_pretrained.pkl"), device="cuda:0") O = F.forward(torch.randn(1,3,224,224).cuda()) print (O.shape)