detectNanoBananaImage / functions /contrique_functions.py
krishnasrikard
Codes
2cda712
# Importing Libaries
import torch
import torch.nn as nn
import torchvision
def get_network(name, pretrained=False):
network = {
"VGG16": torchvision.models.vgg16(pretrained=pretrained),
"VGG16_bn": torchvision.models.vgg16_bn(pretrained=pretrained),
"resnet18": torchvision.models.resnet18(pretrained=pretrained),
"resnet34": torchvision.models.resnet34(pretrained=pretrained),
"resnet50": torchvision.models.resnet50(pretrained=pretrained),
"resnet101": torchvision.models.resnet101(pretrained=pretrained),
"resnet152": torchvision.models.resnet152(pretrained=pretrained),
}
if name not in network.keys():
raise KeyError(f"{name} is not a valid network architecture")
return network[name]
class CONTRIQUE_Model(nn.Module):
def __init__(self,
encoder,
n_features,
patch_dim = (2,2),
normalize = True,
projection_dim = 128
):
"""
ResNet50 model for Projector
"""
super(CONTRIQUE_Model, self).__init__()
self.normalize = normalize
self.encoder = nn.Sequential(*list(encoder.children())[:-2])
self.n_features = n_features
self.patch_dim = patch_dim
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
self.avgpool_patch = nn.AdaptiveAvgPool2d(patch_dim)
# MLP for Projector
self.projector = nn.Sequential(
nn.Linear(self.n_features, self.n_features, bias=False),
nn.BatchNorm1d(self.n_features),
nn.ReLU(),
nn.Linear(self.n_features, projection_dim, bias=False),
nn.BatchNorm1d(projection_dim),
)
def forward(self, x_i, x_j):
# global features
h_i = self.encoder(x_i)
h_j = self.encoder(x_j)
# local features
h_i_patch = self.avgpool_patch(h_i)
h_j_patch = self.avgpool_patch(h_j)
h_i_patch = h_i_patch.reshape(-1,self.n_features, self.patch_dim[0]*self.patch_dim[1])
h_j_patch = h_j_patch.reshape(-1,self.n_features, self.patch_dim[0]*self.patch_dim[1])
h_i_patch = torch.transpose(h_i_patch,2,1)
h_i_patch = h_i_patch.reshape(-1, self.n_features)
h_j_patch = torch.transpose(h_j_patch,2,1)
h_j_patch = h_j_patch.reshape(-1, self.n_features)
h_i = self.avgpool(h_i)
h_j = self.avgpool(h_j)
h_i = h_i.view(-1, self.n_features)
h_j = h_j.view(-1, self.n_features)
if self.normalize:
h_i = nn.functional.normalize(h_i, dim=1)
h_j = nn.functional.normalize(h_j, dim=1)
h_i_patch = nn.functional.normalize(h_i_patch, dim=1)
h_j_patch = nn.functional.normalize(h_j_patch, dim=1)
# Global Projections
z_i = self.projector(h_i)
z_j = self.projector(h_j)
# Local Projections
z_i_patch = self.projector(h_i_patch)
z_j_patch = self.projector(h_j_patch)
return z_i, z_j, z_i_patch, z_j_patch, h_i, h_j, h_i_patch, h_j_patch