Spaces:
Running
on
Zero
Running
on
Zero
| # Importing Libraries | |
| import torch | |
| import torch.nn as nn | |
| import timm | |
| from torchinfo import summary | |
| import os, sys, warnings | |
| warnings.filterwarnings('ignore') | |
| from tqdm import tqdm | |
| import joblib, time | |
| import features.ResNet50 as ResNet50 | |
| import features.CLIP as CLIP | |
| import features.CONTRIQUE as CONTRIQUE | |
| import features.ReIQA as ReIQA | |
| import features.ARNIQA as ARNIQA | |
| import features.TReS as TReS | |
| import features.HyperIQA as HyperIQA | |
| import defaults | |
| class Classifier_Arch1(nn.Module): | |
| def __init__(self, | |
| input_dim:int, | |
| hidden_layers:list | |
| ) -> None: | |
| super().__init__() | |
| layers = [] | |
| for i,_ in enumerate(hidden_layers): | |
| if i==0: | |
| layers.append(nn.Linear(in_features=input_dim, out_features=hidden_layers[i])) | |
| layers.append(nn.BatchNorm1d(hidden_layers[i])) | |
| layers.append(nn.GELU()) | |
| else: | |
| layers.append(nn.Linear(in_features=hidden_layers[i-1], out_features=hidden_layers[i])) | |
| layers.append(nn.BatchNorm1d(hidden_layers[i])) | |
| layers.append(nn.GELU()) | |
| layers.append(nn.Linear(in_features=hidden_layers[i], out_features=2)) | |
| self.layers = nn.Sequential(*layers) | |
| def forward(self, x): | |
| x = self.layers(x) | |
| return x | |
| class Classifier_Arch2(nn.Module): | |
| def __init__(self, | |
| input_dim:int, | |
| hidden_layers:list, | |
| ) -> None: | |
| super().__init__() | |
| assert len(hidden_layers) == 1, "Invalid Hidden Size" | |
| self.layer1 = nn.Linear(in_features=input_dim, out_features=hidden_layers[0]) | |
| self.act_layer1 = nn.ReLU() | |
| self.layer2 = nn.Linear(in_features=hidden_layers[0], out_features=2) | |
| def forward(self, x): | |
| features = self.act_layer1(self.layer1(x)) | |
| preds = self.layer2(features) | |
| return features, preds | |
| # Get Model for feature extraction | |
| def get_model(model_name, device): | |
| # Get model | |
| if model_name == "resnet50": | |
| model = ResNet50.Compute_ResNet50( | |
| device=device | |
| ) | |
| elif model_name == "clip-resnet50": | |
| model = CLIP.Compute_CLIP( | |
| model_name="RN50", | |
| device=device | |
| ) | |
| elif model_name == "clip-vit-l-14": | |
| model = CLIP.Compute_CLIP( | |
| model_name="ViT-L/14", | |
| device=device | |
| ) | |
| elif model_name == "contrique": | |
| model = CONTRIQUE.Compute_CONTRIQUE( | |
| model_path=os.path.join(defaults.main_feature_ckpts_dir, "contrique_feature_extractor", "CONTRIQUE_checkpoint25.tar"), | |
| device=device | |
| ) | |
| elif model_name == "reiqa": | |
| model = ReIQA.Compute_ReIQA( | |
| model_path=os.path.join(defaults.main_feature_ckpts_dir, "reiqa_feature_extractors"), | |
| device=device | |
| ) | |
| elif model_name == "arniqa": | |
| model = ARNIQA.Compute_ARNIQA( | |
| device=device | |
| ) | |
| elif model_name == "tres": | |
| model = TReS.Compute_TReS( | |
| model_path=os.path.join(defaults.main_feature_ckpts_dir, "tres_model/bestmodel_1_2021.zip"), | |
| device=device | |
| ) | |
| elif model_name == "hyperiqa": | |
| model = HyperIQA.Compute_HyperIQA( | |
| model_path=os.path.join(defaults.main_feature_ckpts_dir, "hyperiqa_model/koniq_pretrained.pkl"), | |
| device=device | |
| ) | |
| else: | |
| assert False, f"Update get_model() in extract_features.py to work with the model name: '{model_name}'" | |
| return model | |
| # Calling Main function | |
| if __name__ == '__main__': | |
| None |