File size: 1,932 Bytes
2cda712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Importing Libraries
import numpy as np
from PIL import Image

import torch
from torchinfo import summary
import clip

import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))


class Compute_CLIP(torch.nn.Module):
	def __init__(self,
		model_name:str,
		device:str
	):
		"""
		Args:
			model_name (str): CLIP backbone. Use `clip.available_models()` to see all backbones.
			device (str): Device used while computing CLIP features.
		"""
		super().__init__()
		# Device
		if device is None:
			self.device = "cuda" if torch.cuda.is_available() else "cpu"
		else:
			self.device = device

		# Model and Preprocessing Function
		self.clip_model, self.preprocess_fn = clip.load(model_name, device=self.device)
		for param in self.clip_model.parameters():
			param.requires_grad = False


	def forward(self, img):
		return self.clip_model.encode_image(image=img)	


class CLIP_Classifier(torch.nn.Module):
	def __init__(self,
		model_name:str,
		device:str
	):
		"""
		Args:
			model_name (str): CLIP backbone. Use `clip.available_models()` to see all backbones.
			device (str): Device used while computing CLIP features.
		"""
		super().__init__()
		# Device
		if device is None:
			self.device = "cuda" if torch.cuda.is_available() else "cpu"
		else:
			self.device = device

		# Model and Preprocessing Function
		self.clip_model, self.preprocess_fn = clip.load(model_name, device=self.device)
		for param in self.clip_model.parameters():
			param.requires_grad = False

		if model_name == "RN50":
			self.classifier = torch.nn.Linear(1024,1)
		else:
			self.classifier = torch.nn.Linear(768,1)


	def forward(self, img):
		f = self.clip_model.encode_image(image=img).to(torch.float32)
		return self.classifier(f)
	

# Calling Main function
if __name__ == '__main__':
	F = Compute_CLIP(model_name="ViT-L/14", device="cuda:0")
	O = F.forward(torch.randn(1,3,224,224).cuda())
	print (O.shape)