Spaces:
Sleeping
Sleeping
| from typing import Optional | |
| import torch.nn as nn | |
| def build_normalization(norm_type: str, dim: Optional[int] = None) -> nn.Module: | |
| """ | |
| Overview: | |
| Construct the corresponding normalization module. For beginners, | |
| refer to [this article](https://zhuanlan.zhihu.com/p/34879333) to learn more about batch normalization. | |
| Arguments: | |
| - norm_type (:obj:`str`): Type of the normalization. Currently supports ['BN', 'LN', 'IN', 'SyncBN']. | |
| - dim (:obj:`Optional[int]`): Dimension of the normalization, applicable when norm_type is in ['BN', 'IN']. | |
| Returns: | |
| - norm_func (:obj:`nn.Module`): The corresponding batch normalization function. | |
| """ | |
| if dim is None: | |
| key = norm_type | |
| else: | |
| if norm_type in ['BN', 'IN']: | |
| key = norm_type + str(dim) | |
| elif norm_type in ['LN', 'SyncBN']: | |
| key = norm_type | |
| else: | |
| raise NotImplementedError("not support indicated dim when creates {}".format(norm_type)) | |
| norm_func = { | |
| 'BN1': nn.BatchNorm1d, | |
| 'BN2': nn.BatchNorm2d, | |
| 'LN': nn.LayerNorm, | |
| 'IN1': nn.InstanceNorm1d, | |
| 'IN2': nn.InstanceNorm2d, | |
| 'SyncBN': nn.SyncBatchNorm, | |
| } | |
| if key in norm_func.keys(): | |
| return norm_func[key] | |
| else: | |
| raise KeyError("invalid norm type: {}".format(key)) | |