| |
| |
| |
| |
| |
| |
| |
|
|
| import numpy as np |
| import math |
| import functools |
| import os |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn import init |
| import torch.optim as optim |
| import torch.nn.functional as F |
|
|
| |
| import sys |
|
|
| sys.path.insert(1, os.path.join(sys.path[0], "..")) |
| import BigGAN_PyTorch.layers as layers |
|
|
| |
| from BigGAN_PyTorch.diffaugment_utils import DiffAugment |
|
|
| |
| |
| |
| def G_arch(ch=64, attention="64", ksize="333333", dilation="111111"): |
| arch = {} |
| arch[512] = { |
| "in_channels": [ch * item for item in [16, 16, 8, 8, 4, 2, 1]], |
| "out_channels": [ch * item for item in [16, 8, 8, 4, 2, 1, 1]], |
| "upsample": [True] * 7, |
| "resolution": [8, 16, 32, 64, 128, 256, 512], |
| "attention": { |
| 2 ** i: (2 ** i in [int(item) for item in attention.split("_")]) |
| for i in range(3, 10) |
| }, |
| } |
| arch[256] = { |
| "in_channels": [ch * item for item in [16, 16, 8, 8, 4, 2]], |
| "out_channels": [ch * item for item in [16, 8, 8, 4, 2, 1]], |
| "upsample": [True] * 6, |
| "resolution": [8, 16, 32, 64, 128, 256], |
| "attention": { |
| 2 ** i: (2 ** i in [int(item) for item in attention.split("_")]) |
| for i in range(3, 9) |
| }, |
| } |
| arch[128] = { |
| "in_channels": [ch * item for item in [16, 16, 8, 4, 2]], |
| "out_channels": [ch * item for item in [16, 8, 4, 2, 1]], |
| "upsample": [True] * 5, |
| "resolution": [8, 16, 32, 64, 128], |
| "attention": { |
| 2 ** i: (2 ** i in [int(item) for item in attention.split("_")]) |
| for i in range(3, 8) |
| }, |
| } |
| arch[64] = { |
| "in_channels": [ch * item for item in [16, 16, 8, 4]], |
| "out_channels": [ch * item for item in [16, 8, 4, 2]], |
| "upsample": [True] * 4, |
| "resolution": [8, 16, 32, 64], |
| "attention": { |
| 2 ** i: (2 ** i in [int(item) for item in attention.split("_")]) |
| for i in range(3, 7) |
| }, |
| } |
| arch[32] = { |
| "in_channels": [ch * item for item in [4, 4, 4]], |
| "out_channels": [ch * item for item in [4, 4, 4]], |
| "upsample": [True] * 3, |
| "resolution": [8, 16, 32], |
| "attention": { |
| 2 ** i: (2 ** i in [int(item) for item in attention.split("_")]) |
| for i in range(3, 6) |
| }, |
| } |
|
|
| return arch |
|
|
|
|
| class Generator(nn.Module): |
| def __init__( |
| self, |
| G_ch=64, |
| dim_z=128, |
| bottom_width=4, |
| resolution=128, |
| G_kernel_size=3, |
| G_attn="64", |
| n_classes=1000, |
| num_G_SVs=1, |
| num_G_SV_itrs=1, |
| G_shared=True, |
| shared_dim=0, |
| hier=False, |
| cross_replica=False, |
| mybn=False, |
| G_activation=nn.ReLU(inplace=False), |
| G_lr=5e-5, |
| G_B1=0.0, |
| G_B2=0.999, |
| adam_eps=1e-8, |
| BN_eps=1e-5, |
| SN_eps=1e-12, |
| G_mixed_precision=False, |
| G_fp16=False, |
| G_init="ortho", |
| skip_init=False, |
| no_optim=False, |
| G_param="SN", |
| norm_style="bn", |
| class_cond=True, |
| embedded_optimizer=True, |
| instance_cond=False, |
| G_shared_feat=True, |
| shared_dim_feat=2048, |
| **kwargs |
| ): |
| super(Generator, self).__init__() |
| |
| self.ch = G_ch |
| |
| self.dim_z = dim_z |
| |
| self.bottom_width = bottom_width |
| |
| self.resolution = resolution |
| |
| self.kernel_size = G_kernel_size |
| |
| self.attention = G_attn |
| |
| self.n_classes = n_classes |
| |
| self.G_shared = G_shared |
| |
| self.shared_dim = shared_dim if shared_dim > 0 else dim_z |
| |
| self.hier = hier |
| |
| self.cross_replica = cross_replica |
| |
| self.mybn = mybn |
| |
| self.activation = G_activation |
| |
| self.init = G_init |
| |
| self.G_param = G_param |
| |
| self.norm_style = norm_style |
| |
| self.BN_eps = BN_eps |
| |
| self.SN_eps = SN_eps |
| |
| self.fp16 = G_fp16 |
| |
| self.G_shared_feat = G_shared_feat |
| self.shared_dim_feat = shared_dim_feat |
| |
| self.arch = G_arch(self.ch, self.attention)[resolution] |
|
|
| |
| if self.hier: |
| |
| self.num_slots = len(self.arch["in_channels"]) + 1 |
| self.z_chunk_size = self.dim_z // self.num_slots |
| |
| self.dim_z = self.z_chunk_size * self.num_slots |
| else: |
| self.num_slots = 1 |
| self.z_chunk_size = 0 |
|
|
| |
| if self.G_param == "SN": |
| self.which_conv = functools.partial( |
| layers.SNConv2d, |
| kernel_size=3, |
| padding=1, |
| num_svs=num_G_SVs, |
| num_itrs=num_G_SV_itrs, |
| eps=self.SN_eps, |
| ) |
| self.which_linear = functools.partial( |
| layers.SNLinear, |
| num_svs=num_G_SVs, |
| num_itrs=num_G_SV_itrs, |
| eps=self.SN_eps, |
| ) |
| else: |
| self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) |
| self.which_linear = nn.Linear |
|
|
| |
| |
| self.which_embedding = nn.Embedding |
| bn_linear = ( |
| functools.partial(self.which_linear, bias=False) |
| if self.G_shared |
| else self.which_embedding |
| ) |
| if not class_cond and not instance_cond: |
| input_sz_bn = self.n_classes |
| else: |
| input_sz_bn = self.z_chunk_size |
| if class_cond: |
| input_sz_bn += self.shared_dim |
| if instance_cond: |
| input_sz_bn += self.shared_dim_feat |
| self.which_bn = functools.partial( |
| layers.ccbn, |
| which_linear=bn_linear, |
| cross_replica=self.cross_replica, |
| mybn=self.mybn, |
| input_size=input_sz_bn, |
| norm_style=self.norm_style, |
| eps=self.BN_eps, |
| ) |
|
|
| |
| |
| self.shared = ( |
| self.which_embedding(n_classes, self.shared_dim) |
| if G_shared |
| else layers.identity() |
| ) |
| self.shared_feat = ( |
| self.which_linear(2048, self.shared_dim_feat) |
| if G_shared_feat |
| else layers.identity() |
| ) |
| |
| self.linear = self.which_linear( |
| self.dim_z // self.num_slots, |
| self.arch["in_channels"][0] * (self.bottom_width ** 2), |
| ) |
|
|
| |
| |
| |
| self.blocks = [] |
| for index in range(len(self.arch["out_channels"])): |
| self.blocks += [ |
| [ |
| layers.GBlock( |
| in_channels=self.arch["in_channels"][index], |
| out_channels=self.arch["out_channels"][index], |
| which_conv=self.which_conv, |
| which_bn=self.which_bn, |
| activation=self.activation, |
| upsample=( |
| functools.partial(F.interpolate, scale_factor=2) |
| if self.arch["upsample"][index] |
| else None |
| ), |
| ) |
| ] |
| ] |
|
|
| |
| if self.arch["attention"][self.arch["resolution"][index]]: |
| print( |
| "Adding attention layer in G at resolution %d" |
| % self.arch["resolution"][index] |
| ) |
| self.blocks[-1] += [ |
| layers.Attention(self.arch["out_channels"][index], self.which_conv) |
| ] |
|
|
| |
| self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) |
|
|
| |
| |
| self.output_layer = nn.Sequential( |
| layers.bn( |
| self.arch["out_channels"][-1], |
| cross_replica=self.cross_replica, |
| mybn=self.mybn, |
| ), |
| self.activation, |
| self.which_conv(self.arch["out_channels"][-1], 3), |
| ) |
|
|
| |
| if not skip_init: |
| self.init_weights() |
|
|
| |
| |
| if no_optim or not embedded_optimizer: |
| return |
| self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps |
| if G_mixed_precision: |
| print("Using fp16 adam in G...") |
| import utils |
|
|
| self.optim = utils.Adam16( |
| params=self.parameters(), |
| lr=self.lr, |
| betas=(self.B1, self.B2), |
| weight_decay=0, |
| eps=self.adam_eps, |
| ) |
| else: |
| self.optim = optim.Adam( |
| params=self.parameters(), |
| lr=self.lr, |
| betas=(self.B1, self.B2), |
| weight_decay=0, |
| eps=self.adam_eps, |
| ) |
|
|
| |
| |
| |
|
|
| |
| def init_weights(self): |
| self.param_count = 0 |
| for module in self.modules(): |
| if ( |
| isinstance(module, nn.Conv2d) |
| or isinstance(module, nn.Linear) |
| or isinstance(module, nn.Embedding) |
| ): |
| if self.init == "ortho": |
| init.orthogonal_(module.weight) |
| elif self.init == "N02": |
| init.normal_(module.weight, 0, 0.02) |
| elif self.init in ["glorot", "xavier"]: |
| init.xavier_uniform_(module.weight) |
| else: |
| print("Init style not recognized...") |
| self.param_count += sum( |
| [p.data.nelement() for p in module.parameters()] |
| ) |
| print("Param count for G" "s initialized parameters: %d" % self.param_count) |
|
|
| |
|
|
| def get_condition_embeddings(self, cl=None, feat=None): |
| c_embed = [] |
| if cl is not None: |
| c_embed.append(self.shared(cl)) |
| if feat is not None: |
| c_embed.append(self.shared_feat(feat)) |
| if len(c_embed) > 0: |
| c_embed = torch.cat(c_embed, dim=-1) |
| return c_embed |
|
|
| |
| |
| |
| |
| def forward(self, z, label=None, feats=None): |
| y = self.get_condition_embeddings(label, feats) |
| |
| if self.hier: |
| zs = torch.split(z, self.z_chunk_size, 1) |
| z = zs[0] |
| ys = [torch.cat([y, item], 1) for item in zs[1:]] |
| else: |
| ys = [y] * len(self.blocks) |
|
|
| |
| h = self.linear(z) |
| |
| h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width) |
|
|
| |
| for index, blocklist in enumerate(self.blocks): |
| |
| for block in blocklist: |
| h = block(h, ys[index]) |
|
|
| |
| return torch.tanh(self.output_layer(h)) |
|
|
|
|
| |
| def D_arch(ch=64, attention="64", ksize="333333", dilation="111111"): |
| arch = {} |
| arch[256] = { |
| "in_channels": [3] + [ch * item for item in [1, 2, 4, 8, 8, 16]], |
| "out_channels": [item * ch for item in [1, 2, 4, 8, 8, 16, 16]], |
| "downsample": [True] * 6 + [False], |
| "resolution": [128, 64, 32, 16, 8, 4, 4], |
| "attention": { |
| 2 ** i: 2 ** i in [int(item) for item in attention.split("_")] |
| for i in range(2, 8) |
| }, |
| } |
| arch[128] = { |
| "in_channels": [3] + [ch * item for item in [1, 2, 4, 8, 16]], |
| "out_channels": [item * ch for item in [1, 2, 4, 8, 16, 16]], |
| "downsample": [True] * 5 + [False], |
| "resolution": [64, 32, 16, 8, 4, 4], |
| "attention": { |
| 2 ** i: 2 ** i in [int(item) for item in attention.split("_")] |
| for i in range(2, 8) |
| }, |
| } |
| arch[64] = { |
| "in_channels": [3] + [ch * item for item in [1, 2, 4, 8]], |
| "out_channels": [item * ch for item in [1, 2, 4, 8, 16]], |
| "downsample": [True] * 4 + [False], |
| "resolution": [32, 16, 8, 4, 4], |
| "attention": { |
| 2 ** i: 2 ** i in [int(item) for item in attention.split("_")] |
| for i in range(2, 7) |
| }, |
| } |
| arch[32] = { |
| "in_channels": [3] + [item * ch for item in [4, 4, 4]], |
| "out_channels": [item * ch for item in [4, 4, 4, 4]], |
| "downsample": [True, True, False, False], |
| "resolution": [16, 16, 16, 16], |
| "attention": { |
| 2 ** i: 2 ** i in [int(item) for item in attention.split("_")] |
| for i in range(2, 6) |
| }, |
| } |
| return arch |
|
|
|
|
| class Discriminator(nn.Module): |
| def __init__( |
| self, |
| D_ch=64, |
| D_wide=True, |
| resolution=128, |
| D_kernel_size=3, |
| D_attn="64", |
| n_classes=1000, |
| num_D_SVs=1, |
| num_D_SV_itrs=1, |
| D_activation=nn.ReLU(inplace=False), |
| D_lr=2e-4, |
| D_B1=0.0, |
| D_B2=0.999, |
| adam_eps=1e-8, |
| SN_eps=1e-12, |
| output_dim=1, |
| D_mixed_precision=False, |
| D_fp16=False, |
| D_init="ortho", |
| skip_init=False, |
| D_param="SN", |
| class_cond=True, |
| embedded_optimizer=True, |
| instance_cond=False, |
| instance_sz=2048, |
| **kwargs |
| ): |
| super(Discriminator, self).__init__() |
| |
| self.ch = D_ch |
| |
| self.D_wide = D_wide |
| |
| self.resolution = resolution |
| |
| self.kernel_size = D_kernel_size |
| |
| self.attention = D_attn |
| |
| self.n_classes = n_classes |
| |
| self.activation = D_activation |
| |
| self.init = D_init |
| |
| self.D_param = D_param |
| |
| self.SN_eps = SN_eps |
| |
| self.fp16 = D_fp16 |
| |
| self.arch = D_arch(self.ch, self.attention)[resolution] |
|
|
| |
| |
| if self.D_param == "SN": |
| self.which_conv = functools.partial( |
| layers.SNConv2d, |
| kernel_size=3, |
| padding=1, |
| num_svs=num_D_SVs, |
| num_itrs=num_D_SV_itrs, |
| eps=self.SN_eps, |
| ) |
| self.which_linear = functools.partial( |
| layers.SNLinear, |
| num_svs=num_D_SVs, |
| num_itrs=num_D_SV_itrs, |
| eps=self.SN_eps, |
| ) |
| self.which_embedding = functools.partial( |
| layers.SNEmbedding, |
| num_svs=num_D_SVs, |
| num_itrs=num_D_SV_itrs, |
| eps=self.SN_eps, |
| ) |
| |
| |
| |
| self.blocks = [] |
| for index in range(len(self.arch["out_channels"])): |
| self.blocks += [ |
| [ |
| layers.DBlock( |
| in_channels=self.arch["in_channels"][index], |
| out_channels=self.arch["out_channels"][index], |
| which_conv=self.which_conv, |
| wide=self.D_wide, |
| activation=self.activation, |
| preactivation=(index > 0), |
| downsample=( |
| nn.AvgPool2d(2) if self.arch["downsample"][index] else None |
| ), |
| ) |
| ] |
| ] |
| |
| if self.arch["attention"][self.arch["resolution"][index]]: |
| print( |
| "Adding attention layer in D at resolution %d" |
| % self.arch["resolution"][index] |
| ) |
| self.blocks[-1] += [ |
| layers.Attention(self.arch["out_channels"][index], self.which_conv) |
| ] |
| |
| self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) |
| |
| |
| self.linear = self.which_linear(self.arch["out_channels"][-1], output_dim) |
| |
| if class_cond and instance_cond: |
| self.linear_feat = self.which_linear( |
| instance_sz, self.arch["out_channels"][-1] // 2 |
| ) |
| self.embed = self.which_embedding( |
| self.n_classes, self.arch["out_channels"][-1] // 2 |
| ) |
| elif class_cond: |
| |
| self.embed = self.which_embedding( |
| self.n_classes, self.arch["out_channels"][-1] |
| ) |
| elif instance_cond: |
| self.linear_feat = self.which_linear( |
| instance_sz, self.arch["out_channels"][-1] |
| ) |
|
|
| |
| if not skip_init: |
| self.init_weights() |
|
|
| |
| if embedded_optimizer: |
| self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps |
| if D_mixed_precision: |
| print("Using fp16 adam in D...") |
| import utils |
|
|
| self.optim = utils.Adam16( |
| params=self.parameters(), |
| lr=self.lr, |
| betas=(self.B1, self.B2), |
| weight_decay=0, |
| eps=self.adam_eps, |
| ) |
| else: |
| self.optim = optim.Adam( |
| params=self.parameters(), |
| lr=self.lr, |
| betas=(self.B1, self.B2), |
| weight_decay=0, |
| eps=self.adam_eps, |
| ) |
| |
| |
| |
|
|
| |
| def init_weights(self): |
| self.param_count = 0 |
| for module in self.modules(): |
| if ( |
| isinstance(module, nn.Conv2d) |
| or isinstance(module, nn.Linear) |
| or isinstance(module, nn.Embedding) |
| ): |
| if self.init == "ortho": |
| init.orthogonal_(module.weight) |
| elif self.init == "N02": |
| init.normal_(module.weight, 0, 0.02) |
| elif self.init in ["glorot", "xavier"]: |
| init.xavier_uniform_(module.weight) |
| else: |
| print("Init style not recognized...") |
| self.param_count += sum( |
| [p.data.nelement() for p in module.parameters()] |
| ) |
| print("Param count for D" "s initialized parameters: %d" % self.param_count) |
|
|
| def forward(self, x, y=None, feat=None): |
| |
| h = x |
| |
| for index, blocklist in enumerate(self.blocks): |
| for block in blocklist: |
| h = block(h) |
| |
| h = torch.sum(self.activation(h), [2, 3]) |
| |
| out = self.linear(h) |
| |
| if y is not None and feat is not None: |
| out = out + torch.sum( |
| torch.cat([self.embed(y), self.linear_feat(feat)], dim=-1) * h, |
| 1, |
| keepdim=True, |
| ) |
| |
| elif y is not None: |
| |
| out = out + torch.sum(self.embed(y) * h, 1, keepdim=True) |
| |
| elif feat is not None: |
| out = out + torch.sum(self.linear_feat(feat) * h, 1, keepdim=True) |
| return out |
|
|
|
|
| |
| |
| class G_D(nn.Module): |
| def __init__(self, G, D, optimizer_G=None, optimizer_D=None): |
| super(G_D, self).__init__() |
| self.G = G |
| self.D = D |
| self.optimizer_G = optimizer_G |
| self.optimizer_D = optimizer_D |
|
|
| def forward( |
| self, |
| z, |
| gy, |
| feats_g=None, |
| x=None, |
| dy=None, |
| feats=None, |
| train_G=False, |
| return_G_z=False, |
| split_D=False, |
| policy=False, |
| DA=False, |
| ): |
| |
| with torch.set_grad_enabled(train_G): |
| |
| G_z = self.G(z, gy, feats_g) |
| |
| |
| |
| |
| |
| |
| |
| if split_D: |
| D_fake = self.D(G_z, gy, feats_g) |
| if x is not None: |
| D_real = self.D(x, dy, feats) |
| return D_fake, D_real |
| else: |
| if return_G_z: |
| return D_fake, G_z |
| else: |
| return D_fake |
| |
| |
| else: |
| D_input = torch.cat([G_z, x], 0) if x is not None else G_z |
| D_class = torch.cat([gy, dy], 0) if dy is not None else gy |
| if feats_g is not None: |
| D_feats = ( |
| torch.cat([feats_g, feats], 0) if feats is not None else feats_g |
| ) |
| else: |
| D_feats = None |
| if DA: |
| D_input = DiffAugment(D_input, policy=policy) |
| |
| D_out = self.D(D_input, D_class, D_feats) |
| if x is not None: |
| return torch.split(D_out, [G_z.shape[0], x.shape[0]]) |
| else: |
| if return_G_z: |
| return D_out, G_z |
| else: |
| return D_out |
|
|