Update vit_patch_generator.py
Browse files- vit_patch_generator.py +0 -19
vit_patch_generator.py
CHANGED
|
@@ -119,10 +119,6 @@ class ViTPatchGenerator(nn.Module):
|
|
| 119 |
'pos_embed',
|
| 120 |
]
|
| 121 |
|
| 122 |
-
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
| 123 |
-
if self.abs_pos:
|
| 124 |
-
self._load_embed(state_dict[f'{prefix}pos_embed'], self.pos_embed)
|
| 125 |
-
|
| 126 |
def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
|
| 127 |
if src_embed.shape != targ_embed.shape:
|
| 128 |
src_size = int(math.sqrt(src_embed.shape[1]))
|
|
@@ -285,18 +281,3 @@ class ViTPatchLinear(nn.Linear):
|
|
| 285 |
**factory
|
| 286 |
)
|
| 287 |
self.patch_size = patch_size
|
| 288 |
-
|
| 289 |
-
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
| 290 |
-
if self.bias is not None:
|
| 291 |
-
self.bias.data.copy_(state_dict[f'{prefix}bias'])
|
| 292 |
-
|
| 293 |
-
chk_weight = state_dict[f'{prefix}weight']
|
| 294 |
-
if chk_weight.shape != self.weight.shape:
|
| 295 |
-
src_patch_size = int(math.sqrt(chk_weight.shape[1] // 3))
|
| 296 |
-
|
| 297 |
-
assert (src_patch_size ** 2) * 3 == chk_weight.shape[1], 'Unable to interpolate non-square patch size'
|
| 298 |
-
|
| 299 |
-
chk_weight = rearrange(chk_weight, 'b (c h w) -> b c h w', c=3, h=src_patch_size, w=src_patch_size)
|
| 300 |
-
chk_weight = F.interpolate(chk_weight, size=(self.patch_size, self.patch_size), mode='bicubic', align_corners=True, antialias=False)
|
| 301 |
-
chk_weight = rearrange(chk_weight, 'b c h w -> b (c h w)')
|
| 302 |
-
self.weight.data.copy_(chk_weight)
|
|
|
|
| 119 |
'pos_embed',
|
| 120 |
]
|
| 121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
|
| 123 |
if src_embed.shape != targ_embed.shape:
|
| 124 |
src_size = int(math.sqrt(src_embed.shape[1]))
|
|
|
|
| 281 |
**factory
|
| 282 |
)
|
| 283 |
self.patch_size = patch_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|