File size: 2,052 Bytes
2cda712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Importing Libraries
import numpy as np
from PIL import Image

import torch
from torchvision import transforms
from torchinfo import summary

import os,sys,warnings
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
warnings.filterwarnings("ignore")
import functions.hyperiqa_functions as hyperiqa_functions
import defaults

class Compute_HyperIQA(torch.nn.Module):
	def __init__(self,
		model_path:str,
		device:str
	):
		"""
		Args:
			model_path (str): Path to weights of HyperIQA model.
			device (str): Device used while computing features.
		"""
		super().__init__()
		# Device
		if device is None:
			self.device = "cuda" if torch.cuda.is_available() else "cpu"
		else:
			self.device = device

		# Load HyperIQA Model
		self.HyperIQA = hyperiqa_functions.HyperNet(16, 112, 224, 112, 56, 28, 14, 7)
		self.HyperIQA = self.HyperIQA.to(self.device)
		self.HyperIQA.load_state_dict(torch.load(model_path, map_location=self.device))
		self.HyperIQA.eval()
		for param in self.HyperIQA.parameters():
			param.requires_grad = False

	
	def forward(self, img):
		# Parameters for Target Network
		parameters = self.HyperIQA(img)
		semantic_features = parameters["pooled_semantic_features"]
		
		# Target Network
		target_model = hyperiqa_functions.TargetNet(parameters).to(self.device)
		for param in target_model.parameters():
			param.requires_grad = False

		q,q1,q2,q3,q4 = target_model(parameters["target_in_vec"])

		feat_batch = torch.hstack((
			torch.flatten(semantic_features, start_dim=1),
			torch.flatten(parameters["target_in_vec"], start_dim=1),
			torch.flatten(q1, start_dim=1),
			torch.flatten(q2, start_dim=1),
			torch.flatten(q3, start_dim=1),
			torch.flatten(q4, start_dim=1),
		))

		return feat_batch
	

# Calling Main function
if __name__ == '__main__':
	F = Compute_HyperIQA(model_path=os.path.join("/mnt/LIVELAB2/Detecting-AI-Generated-Images", "feature_extractor_checkpoints/hyperiqa_model/koniq_pretrained.pkl"), device="cuda:0")
	O = F.forward(torch.randn(1,3,224,224).cuda())
	print (O.shape)