| | import math |
| | from functools import reduce |
| | from operator import mul |
| | from ipdb import set_trace |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import torch.nn as nn |
| | from mmcls.models.backbones import VisionTransformer as _VisionTransformer |
| | from mmcls.models.utils import to_2tuple |
| | from mmcv.cnn.bricks.transformer import PatchEmbed |
| | from torch.nn.modules.batchnorm import _BatchNorm |
| |
|
| |
|
| | def build_2d_sincos_position_embedding(patches_resolution, |
| | embed_dims, |
| | temperature=10000., |
| | cls_token=False): |
| | """The function is to build position embedding for model to obtain the |
| | position information of the image patches.""" |
| |
|
| | if isinstance(patches_resolution, int): |
| | patches_resolution = (patches_resolution, patches_resolution) |
| |
|
| | h, w = patches_resolution |
| | grid_w = torch.arange(w, dtype=torch.float32) |
| | grid_h = torch.arange(h, dtype=torch.float32) |
| | grid_w, grid_h = torch.meshgrid(grid_w, grid_h) |
| | assert embed_dims % 4 == 0, \ |
| | 'Embed dimension must be divisible by 4.' |
| | pos_dim = embed_dims // 4 |
| |
|
| | omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim |
| | omega = 1. / (temperature**omega) |
| | out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) |
| | out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) |
| |
|
| | pos_emb = torch.cat( |
| | [ |
| | torch.sin(out_w), |
| | torch.cos(out_w), |
| | torch.sin(out_h), |
| | torch.cos(out_h) |
| | ], |
| | dim=1, |
| | )[None, :, :] |
| |
|
| | if cls_token: |
| | cls_token_pe = torch.zeros([1, 1, embed_dims], dtype=torch.float32) |
| | pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1) |
| |
|
| | return pos_emb |
| |
|
| |
|
| | class VisionTransformer(_VisionTransformer): |
| | """Vision Transformer. |
| | |
| | A pytorch implement of: `An Images is Worth 16x16 Words: Transformers for |
| | Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_. |
| | |
| | Part of the code is modified from: |
| | `<https://github.com/facebookresearch/moco-v3/blob/main/vits.py>`_. |
| | |
| | Args: |
| | stop_grad_conv1 (bool, optional): whether to stop the gradient of |
| | convolution layer in `PatchEmbed`. Defaults to False. |
| | frozen_stages (int): Stages to be frozen (stop grad and set eval mode). |
| | -1 means not freezing any parameters. Defaults to -1. |
| | norm_eval (bool): Whether to set norm layers to eval mode, namely, |
| | freeze running stats (mean and var). Note: Effect on Batch Norm |
| | and its variants only. Defaults to False. |
| | init_cfg (dict or list[dict], optional): Initialization config dict. |
| | Defaults to None. |
| | """ |
| |
|
| | arch_zoo = { |
| | **dict.fromkeys( |
| | ['mocov3-s', 'mocov3-small'], { |
| | 'embed_dims': 384, |
| | 'num_layers': 12, |
| | 'num_heads': 12, |
| | 'feedforward_channels': 1536, |
| | }), |
| | **dict.fromkeys( |
| | ['b', 'base'], { |
| | 'embed_dims': 768, |
| | 'num_layers': 12, |
| | 'num_heads': 12, |
| | 'feedforward_channels': 3072 |
| | }), |
| | } |
| |
|
| | def __init__(self, |
| | stop_grad_conv1=False, |
| | frozen_stages=-1, |
| | norm_eval=False, |
| | init_cfg=None, |
| | **kwargs): |
| | super(VisionTransformer, self).__init__(init_cfg=init_cfg,) |
| | self.patch_size = kwargs['patch_size'] |
| | self.frozen_stages = frozen_stages |
| | self.norm_eval = norm_eval |
| | self.init_cfg = init_cfg |
| | |
| | |
| | if isinstance(self.patch_embed, PatchEmbed): |
| | if stop_grad_conv1: |
| | self.patch_embed.projection.weight.requires_grad = False |
| | self.patch_embed.projection.bias.requires_grad = False |
| |
|
| | self._freeze_stages() |
| |
|
| | def init_weights(self): |
| | super(VisionTransformer, self).init_weights() |
| |
|
| | if not (isinstance(self.init_cfg, dict) |
| | and self.init_cfg['type'] == 'Pretrained'): |
| |
|
| | |
| | pos_emb = build_2d_sincos_position_embedding( |
| | patches_resolution=self.patch_resolution, |
| | embed_dims=self.embed_dims, |
| | cls_token=True) |
| | self.pos_embed.data.copy_(pos_emb) |
| | self.pos_embed.requires_grad = False |
| |
|
| | |
| | if isinstance(self.patch_embed, PatchEmbed): |
| | val = math.sqrt( |
| | 6. / float(3 * reduce(mul, to_2tuple(self.patch_size), 1) + |
| | self.embed_dims)) |
| | nn.init.uniform_(self.patch_embed.projection.weight, -val, val) |
| | nn.init.zeros_(self.patch_embed.projection.bias) |
| |
|
| | |
| | for name, m in self.named_modules(): |
| | if isinstance(m, nn.Linear): |
| | if 'qkv' in name: |
| | |
| | val = math.sqrt( |
| | 6. / |
| | float(m.weight.shape[0] // 3 + m.weight.shape[1])) |
| | nn.init.uniform_(m.weight, -val, val) |
| | else: |
| | nn.init.xavier_uniform_(m.weight) |
| | nn.init.zeros_(m.bias) |
| | nn.init.normal_(self.cls_token, std=1e-6) |
| |
|
| | def _freeze_stages(self): |
| | """Freeze patch_embed layer, some parameters and stages.""" |
| | if self.frozen_stages >= 0: |
| | self.patch_embed.eval() |
| | for param in self.patch_embed.parameters(): |
| | param.requires_grad = False |
| |
|
| | self.cls_token.requires_grad = False |
| | self.pos_embed.requires_grad = False |
| |
|
| | for i in range(1, self.frozen_stages + 1): |
| | m = self.layers[i - 1] |
| | m.eval() |
| | for param in m.parameters(): |
| | param.requires_grad = False |
| |
|
| | if i == (self.num_layers) and self.final_norm: |
| | for param in getattr(self, 'norm1').parameters(): |
| | param.requires_grad = False |
| |
|
| | def train(self, mode=True): |
| | super(VisionTransformer, self).train(mode) |
| | self._freeze_stages() |
| | if mode and self.norm_eval: |
| | for m in self.modules(): |
| | |
| | if isinstance(m, _BatchNorm): |
| | m.eval() |