krishnasrikard
Codes
2cda712
# Importing Libraries
import torch
import torch.nn as nn
import timm
from torchinfo import summary
import os, sys, warnings
warnings.filterwarnings('ignore')
from tqdm import tqdm
import joblib, time
import features.ResNet50 as ResNet50
import features.CLIP as CLIP
import features.CONTRIQUE as CONTRIQUE
import features.ReIQA as ReIQA
import features.ARNIQA as ARNIQA
import features.TReS as TReS
import features.HyperIQA as HyperIQA
import defaults
class Classifier_Arch1(nn.Module):
def __init__(self,
input_dim:int,
hidden_layers:list
) -> None:
super().__init__()
layers = []
for i,_ in enumerate(hidden_layers):
if i==0:
layers.append(nn.Linear(in_features=input_dim, out_features=hidden_layers[i]))
layers.append(nn.BatchNorm1d(hidden_layers[i]))
layers.append(nn.GELU())
else:
layers.append(nn.Linear(in_features=hidden_layers[i-1], out_features=hidden_layers[i]))
layers.append(nn.BatchNorm1d(hidden_layers[i]))
layers.append(nn.GELU())
layers.append(nn.Linear(in_features=hidden_layers[i], out_features=2))
self.layers = nn.Sequential(*layers)
def forward(self, x):
x = self.layers(x)
return x
class Classifier_Arch2(nn.Module):
def __init__(self,
input_dim:int,
hidden_layers:list,
) -> None:
super().__init__()
assert len(hidden_layers) == 1, "Invalid Hidden Size"
self.layer1 = nn.Linear(in_features=input_dim, out_features=hidden_layers[0])
self.act_layer1 = nn.ReLU()
self.layer2 = nn.Linear(in_features=hidden_layers[0], out_features=2)
def forward(self, x):
features = self.act_layer1(self.layer1(x))
preds = self.layer2(features)
return features, preds
# Get Model for feature extraction
def get_model(model_name, device):
# Get model
if model_name == "resnet50":
model = ResNet50.Compute_ResNet50(
device=device
)
elif model_name == "clip-resnet50":
model = CLIP.Compute_CLIP(
model_name="RN50",
device=device
)
elif model_name == "clip-vit-l-14":
model = CLIP.Compute_CLIP(
model_name="ViT-L/14",
device=device
)
elif model_name == "contrique":
model = CONTRIQUE.Compute_CONTRIQUE(
model_path=os.path.join(defaults.main_feature_ckpts_dir, "contrique_feature_extractor", "CONTRIQUE_checkpoint25.tar"),
device=device
)
elif model_name == "reiqa":
model = ReIQA.Compute_ReIQA(
model_path=os.path.join(defaults.main_feature_ckpts_dir, "reiqa_feature_extractors"),
device=device
)
elif model_name == "arniqa":
model = ARNIQA.Compute_ARNIQA(
device=device
)
elif model_name == "tres":
model = TReS.Compute_TReS(
model_path=os.path.join(defaults.main_feature_ckpts_dir, "tres_model/bestmodel_1_2021.zip"),
device=device
)
elif model_name == "hyperiqa":
model = HyperIQA.Compute_HyperIQA(
model_path=os.path.join(defaults.main_feature_ckpts_dir, "hyperiqa_model/koniq_pretrained.pkl"),
device=device
)
else:
assert False, f"Update get_model() in extract_features.py to work with the model name: '{model_name}'"
return model
# Calling Main function
if __name__ == '__main__':
None