| import flax.linen as nn |
| from jaxtyping import Array, ArrayLike |
|
|
|
|
| class ConvNeXtBlock(nn.Module): |
| """ConvNext block. See Fig.4 in "A ConvNet for the 2020s" by Liu et al. |
| |
| https://openaccess.thecvf.com/content/CVPR2022/papers/Liu_A_ConvNet_for_the_2020s_CVPR_2022_paper.pdf |
| """ |
| n_dims: int = 64 |
| kernel_size: int = 3 |
| group_features: bool = False |
|
|
| def setup(self) -> None: |
| self.residual = nn.Sequential([ |
| nn.Conv(self.n_dims, kernel_size=(self.kernel_size, self.kernel_size), use_bias=False, |
| feature_group_count=self.n_dims if self.group_features else 1), |
| nn.LayerNorm(), |
| nn.Conv(4 * self.n_dims, kernel_size=(1, 1)), |
| nn.gelu, |
| nn.Conv(self.n_dims, kernel_size=(1, 1)), |
| ]) |
|
|
| def __call__(self, x: ArrayLike) -> Array: |
| return x + self.residual(x) |
|
|
|
|
| class Projection(nn.Module): |
| n_dims: int |
|
|
| @nn.compact |
| def __call__(self, x: ArrayLike) -> Array: |
| x = nn.LayerNorm()(x) |
| x = nn.Conv(self.n_dims, (1, 1))(x) |
| return x |
|
|
|
|
| class ConvNeXt(nn.Module): |
| block_defs: list[tuple] |
|
|
| def setup(self) -> None: |
| layers = [] |
| current_size = self.block_defs[0][0] |
| for block_def in self.block_defs: |
| if block_def[0] != current_size: |
| layers.append(Projection(block_def[0])) |
| layers.append(ConvNeXtBlock(*block_def)) |
| current_size = block_def[0] |
| self.layers = layers |
|
|
| def __call__(self, x: ArrayLike, _: bool) -> Array: |
| for layer in self.layers: |
| x = layer(x) |
| return x |
|
|
|
|