| import torch |
| import torch.nn as nn |
|
|
| from typing import Any, Tuple, Union |
|
|
| from utils import ( |
| ImageType, |
| crop_image_part, |
| ) |
|
|
| from layers import ( |
| SpectralConv2d, |
| InitLayer, |
| SLEBlock, |
| UpsampleBlockT1, |
| UpsampleBlockT2, |
| DownsampleBlockT1, |
| DownsampleBlockT2, |
| Decoder, |
| ) |
|
|
| from huggan.pytorch.huggan_mixin import HugGANModelHubMixin |
|
|
|
|
| class Generator(nn.Module, HugGANModelHubMixin): |
|
|
| def __init__(self, in_channels: int, |
| out_channels: int): |
| super().__init__() |
|
|
| self._channels = { |
| 4: 1024, |
| 8: 512, |
| 16: 256, |
| 32: 128, |
| 64: 128, |
| 128: 64, |
| 256: 32, |
| 512: 16, |
| 1024: 8, |
| } |
|
|
| self._init = InitLayer( |
| in_channels=in_channels, |
| out_channels=self._channels[4], |
| ) |
|
|
| self._upsample_8 = UpsampleBlockT2(in_channels=self._channels[4], out_channels=self._channels[8] ) |
| self._upsample_16 = UpsampleBlockT1(in_channels=self._channels[8], out_channels=self._channels[16] ) |
| self._upsample_32 = UpsampleBlockT2(in_channels=self._channels[16], out_channels=self._channels[32] ) |
| self._upsample_64 = UpsampleBlockT1(in_channels=self._channels[32], out_channels=self._channels[64] ) |
| self._upsample_128 = UpsampleBlockT2(in_channels=self._channels[64], out_channels=self._channels[128] ) |
| self._upsample_256 = UpsampleBlockT1(in_channels=self._channels[128], out_channels=self._channels[256] ) |
| self._upsample_512 = UpsampleBlockT2(in_channels=self._channels[256], out_channels=self._channels[512] ) |
| self._upsample_1024 = UpsampleBlockT1(in_channels=self._channels[512], out_channels=self._channels[1024]) |
|
|
| self._sle_64 = SLEBlock(in_channels=self._channels[4], out_channels=self._channels[64] ) |
| self._sle_128 = SLEBlock(in_channels=self._channels[8], out_channels=self._channels[128]) |
| self._sle_256 = SLEBlock(in_channels=self._channels[16], out_channels=self._channels[256]) |
| self._sle_512 = SLEBlock(in_channels=self._channels[32], out_channels=self._channels[512]) |
|
|
| self._out_128 = nn.Sequential( |
| SpectralConv2d( |
| in_channels=self._channels[128], |
| out_channels=out_channels, |
| kernel_size=1, |
| stride=1, |
| padding='same', |
| bias=False, |
| ), |
| nn.Tanh(), |
| ) |
|
|
| self._out_1024 = nn.Sequential( |
| SpectralConv2d( |
| in_channels=self._channels[1024], |
| out_channels=out_channels, |
| kernel_size=3, |
| stride=1, |
| padding='same', |
| bias=False, |
| ), |
| nn.Tanh(), |
| ) |
|
|
| def forward(self, input: torch.Tensor) -> \ |
| Tuple[torch.Tensor, torch.Tensor]: |
| size_4 = self._init(input) |
| size_8 = self._upsample_8(size_4) |
| size_16 = self._upsample_16(size_8) |
| size_32 = self._upsample_32(size_16) |
|
|
| size_64 = self._sle_64 (size_4, self._upsample_64 (size_32) ) |
| size_128 = self._sle_128(size_8, self._upsample_128(size_64) ) |
| size_256 = self._sle_256(size_16, self._upsample_256(size_128)) |
| size_512 = self._sle_512(size_32, self._upsample_512(size_256)) |
|
|
| size_1024 = self._upsample_1024(size_512) |
|
|
| out_128 = self._out_128 (size_128) |
| out_1024 = self._out_1024(size_1024) |
| return out_1024, out_128 |
|
|
|
|
| class Discriminrator(nn.Module, HugGANModelHubMixin): |
|
|
| def __init__(self, in_channels: int): |
| super().__init__() |
|
|
| self._channels = { |
| 4: 1024, |
| 8: 512, |
| 16: 256, |
| 32: 128, |
| 64: 128, |
| 128: 64, |
| 256: 32, |
| 512: 16, |
| 1024: 8, |
| } |
|
|
| self._init = nn.Sequential( |
| SpectralConv2d( |
| in_channels=in_channels, |
| out_channels=self._channels[1024], |
| kernel_size=4, |
| stride=2, |
| padding=1, |
| bias=False, |
| ), |
| nn.LeakyReLU(negative_slope=0.2), |
| SpectralConv2d( |
| in_channels=self._channels[1024], |
| out_channels=self._channels[512], |
| kernel_size=4, |
| stride=2, |
| padding=1, |
| bias=False, |
| ), |
| nn.BatchNorm2d(num_features=self._channels[512]), |
| nn.LeakyReLU(negative_slope=0.2), |
| ) |
|
|
| self._downsample_256 = DownsampleBlockT2(in_channels=self._channels[512], out_channels=self._channels[256]) |
| self._downsample_128 = DownsampleBlockT2(in_channels=self._channels[256], out_channels=self._channels[128]) |
| self._downsample_64 = DownsampleBlockT2(in_channels=self._channels[128], out_channels=self._channels[64] ) |
| self._downsample_32 = DownsampleBlockT2(in_channels=self._channels[64], out_channels=self._channels[32] ) |
| self._downsample_16 = DownsampleBlockT2(in_channels=self._channels[32], out_channels=self._channels[16] ) |
|
|
| self._sle_64 = SLEBlock(in_channels=self._channels[512], out_channels=self._channels[64]) |
| self._sle_32 = SLEBlock(in_channels=self._channels[256], out_channels=self._channels[32]) |
| self._sle_16 = SLEBlock(in_channels=self._channels[128], out_channels=self._channels[16]) |
|
|
| self._small_track = nn.Sequential( |
| SpectralConv2d( |
| in_channels=in_channels, |
| out_channels=self._channels[256], |
| kernel_size=4, |
| stride=2, |
| padding=1, |
| bias=False, |
| ), |
| nn.LeakyReLU(negative_slope=0.2), |
| DownsampleBlockT1(in_channels=self._channels[256], out_channels=self._channels[128]), |
| DownsampleBlockT1(in_channels=self._channels[128], out_channels=self._channels[64] ), |
| DownsampleBlockT1(in_channels=self._channels[64], out_channels=self._channels[32] ), |
| ) |
|
|
| self._features_large = nn.Sequential( |
| SpectralConv2d( |
| in_channels=self._channels[16] , |
| out_channels=self._channels[8], |
| kernel_size=1, |
| stride=1, |
| padding=0, |
| bias=False, |
| ), |
| nn.BatchNorm2d(num_features=self._channels[8]), |
| nn.LeakyReLU(negative_slope=0.2), |
| SpectralConv2d( |
| in_channels=self._channels[8], |
| out_channels=1, |
| kernel_size=4, |
| stride=1, |
| padding=0, |
| bias=False, |
| ) |
| ) |
|
|
| self._features_small = nn.Sequential( |
| SpectralConv2d( |
| in_channels=self._channels[32], |
| out_channels=1, |
| kernel_size=4, |
| stride=1, |
| padding=0, |
| bias=False, |
| ), |
| ) |
|
|
| self._decoder_large = Decoder(in_channels=self._channels[16], out_channels=3) |
| self._decoder_small = Decoder(in_channels=self._channels[32], out_channels=3) |
| self._decoder_piece = Decoder(in_channels=self._channels[32], out_channels=3) |
|
|
| def forward(self, images_1024: torch.Tensor, |
| images_128: torch.Tensor, |
| image_type: ImageType) -> \ |
| Union[ |
| torch.Tensor, |
| Tuple[torch.Tensor, Tuple[Any, Any, Any]] |
| ]: |
| |
|
|
| down_512 = self._init(images_1024) |
| down_256 = self._downsample_256(down_512) |
| down_128 = self._downsample_128(down_256) |
|
|
| down_64 = self._downsample_64(down_128) |
| down_64 = self._sle_64(down_512, down_64) |
|
|
| down_32 = self._downsample_32(down_64) |
| down_32 = self._sle_32(down_256, down_32) |
|
|
| down_16 = self._downsample_16(down_32) |
| down_16 = self._sle_16(down_128, down_16) |
|
|
| |
|
|
| down_small = self._small_track(images_128) |
|
|
| |
|
|
| features_large = self._features_large(down_16).view(-1) |
| features_small = self._features_small(down_small).view(-1) |
| features = torch.cat([features_large, features_small], dim=0) |
|
|
| |
|
|
| if image_type != ImageType.FAKE: |
| dec_large = self._decoder_large(down_16) |
| dec_small = self._decoder_small(down_small) |
| dec_piece = self._decoder_piece(crop_image_part(down_32, image_type)) |
| return features, (dec_large, dec_small, dec_piece) |
|
|
| return features |
|
|