1import random 2import numpy as np 3from math import sqrt 4 5import mxnet as mx 6import mxnet.ndarray as nd 7 8from modules import * 9# pylint: disable-all 10 11class Generator(nn.HybridBlock): 12 def __init__(self, fused=True): 13 super().__init__() 14 15 self.progression = nn.HybridSequential() 16 with self.progression.name_scope(): 17 self.progression.add(StyledConvBlock(512, 512, 3, 1, initial=True, blur=blur)) # 4 18 self.progression.add(StyledConvBlock(512, 512, 3, 1, upsample=True, blur=blur)) # 8 19 self.progression.add(StyledConvBlock(512, 512, 3, 1, upsample=True, blur=blur)) # 16 20 self.progression.add(StyledConvBlock(512, 512, 3, 1, upsample=True, blur=blur)) # 32 21 self.progression.add(StyledConvBlock(512, 256, 3, 1, upsample=True, blur=blur)) # 64 22 self.progression.add(StyledConvBlock(256, 128, 3, 1, upsample=True, fused=fused, blur=blur)) # 128 23 self.progression.add(StyledConvBlock(128, 64, 3, 1, upsample=True, fused=fused, blur=blur)) # 256 24 self.progression.add(StyledConvBlock(64, 32, 3, 1, upsample=True, fused=fused, blur=blur)) # 512 25 self.progression.add(StyledConvBlock(32, 16, 3, 1, upsample=True, fused=fused, blur=blur)) # 1024 26 27 self.to_rgb = nn.HybridSequential() 28 with self.to_rgb.name_scope(): 29 self.to_rgb.add(EqualConv2d(512, 3, 1)) 30 self.to_rgb.add(EqualConv2d(512, 3, 1)) 31 self.to_rgb.add(EqualConv2d(512, 3, 1)) 32 self.to_rgb.add(EqualConv2d(512, 3, 1)) 33 self.to_rgb.add(EqualConv2d(256, 3, 1)) 34 self.to_rgb.add(EqualConv2d(128, 3, 1)) 35 self.to_rgb.add(EqualConv2d(64, 3, 1)) 36 self.to_rgb.add(EqualConv2d(32, 3, 1)) 37 self.to_rgb.add(EqualConv2d(16, 3, 1)) 38 39 def hybrid_forward(self, F, style, noise, step=0, alpha=-1, mixing_range=(-1, -1)): 40 41 out = nd.array(noise[0], ctx=style[0].context) 42 43 if style.shape[0] < 2: 44 inject_index = [len(self.progression) + 1] 45 46 else: 47 inject_index = random.sample(list(range(step)), style.shape[0] - 1) 48 49 crossover = 0 50 51 for i, (conv, to_rgb) in enumerate(zip(self.progression, self.to_rgb)): 52 if mixing_range == (-1, -1): 53 if crossover < len(inject_index) and i > inject_index[crossover]: 54 crossover = min(crossover + 1, len(style)) 55 56 style_step = style[crossover] 57 58 else: 59 if mixing_range[0] <= i <= mixing_range[1]: 60 style_step = style[1] 61 62 else: 63 style_step = style[0] 64 65 if i > 0 and step > 0: 66 out_prev = out 67 68 out = conv(out, style_step, nd.array(noise[i], ctx=style[0].context)) 69 70 if i == step: 71 72 out = to_rgb(out) 73 74 if i > 0 and 0 <= alpha < 1: 75 skip_rgb = self.to_rgb[i - 1](out_prev) 76 skip_rgb = F.UpSampling(skip_rgb, scale=2, sample_type='nearest') 77 out = (1 - alpha) * skip_rgb + alpha * out 78 79 break 80 81 return out 82 83 84class StyledGenerator(nn.HybridBlock): 85 r"""Style-based GAN 86 Reference: 87 88 Tero Karras, Samuli Laine, Timo Aila. "A Style-Based Generator 89 Architecture for Generative Adversarial Networks." *CVPR*, 2019 90 """ 91 def __init__(self, code_dim=512, n_mlp=8, blur=False): 92 super().__init__() 93 94 self.generator = Generator(code_dim, blur) 95 96 self.style = nn.HybridSequential() 97 98 with self.style.name_scope(): 99 100 self.style.add(PixelNorm()) 101 102 for i in range(n_mlp): 103 self.style.add(EqualLinear(code_dim, code_dim)) 104 self.style.add(nn.LeakyReLU(0.2)) 105 106 107 def hybrid_forward(self, F, x, step=0, alpha=-1, noise=None, mean_style=None, 108 style_weight=0, mixing_range=(-1, -1)): 109 110 styles = [] 111 112 if type(x) not in (list, tuple): 113 x = [x] 114 115 for i in x: 116 styles.append(self.style(i)) 117 118 batch = x[0].shape[0] 119 120 if noise is None: 121 noise = [] 122 123 for i in range(step + 1): 124 size = 4 * 2 ** i 125 noise.append(nd.random.randn(batch, 1, size, size, ctx=x[0].context)) 126 127 if mean_style is not None: 128 styles_norm = [] 129 130 for style in styles: 131 styles_norm.append(mean_style + style_weight * (style - mean_style)) 132 133 styles = styles_norm 134 135 nd_styles = nd.empty((len(styles), styles[0].shape[0], styles[0].shape[1])) 136 137 for i, style in enumerate(styles): 138 nd_styles[i] = style 139 140 return self.generator(nd_styles, noise, step, alpha, mixing_range) 141 142 def mean_style(self, x): 143 144 style = self.style(x).mean(axis=0, keepdims=True) 145 146 return style 147 148 149class Discriminator(nn.HybridBlock): 150 def __init__(self, fused=True, from_rgb_activate=False): 151 super().__init__() 152 153 self.progression = nn.HybridSequential() 154 with self.progression.name_scope(): 155 self.progression.add(ConvBlock(16, 32, 3, 1, downsample=True, fused=fused)) # 512 156 self.progression.add(ConvBlock(32, 64, 3, 1, downsample=True, fused=fused)) # 256 157 self.progression.add(ConvBlock(64, 128, 3, 1, downsample=True, fused=fused)) # 128 158 self.progression.add(ConvBlock(128, 256, 3, 1, downsample=True, fused=fused)) # 64 159 self.progression.add(ConvBlock(256, 512, 3, 1, downsample=True)) # 32 160 self.progression.add(ConvBlock(512, 512, 3, 1, downsample=True)) # 16 161 self.progression.add(ConvBlock(512, 512, 3, 1, downsample=True)) # 8 162 self.progression.add(ConvBlock(512, 512, 3, 1, downsample=True)) # 4 163 self.progression.add(ConvBlock(513, 512, 3, 1, 4, 0)) 164 165 def make_from_rgb(out_channel): 166 if from_rgb_activate: 167 module = nn.HybridSequential() 168 with module.name_scope(): 169 module.add(EqualConv2d(3, out_channel, 1)) 170 module.add(nn.LeakyReLU(0.2)) 171 return module 172 173 else: 174 return EqualConv2d(3, out_channel, 1) 175 176 self.from_rgb = nn.HybridSequential() 177 with self.from_rgb.name_scope(): 178 self.from_rgb.add(make_from_rgb(16)) 179 self.from_rgb.add(make_from_rgb(32)) 180 self.from_rgb.add(make_from_rgb(64)) 181 self.from_rgb.add(make_from_rgb(128)) 182 self.from_rgb.add(make_from_rgb(256)) 183 self.from_rgb.add(make_from_rgb(512)) 184 self.from_rgb.add(make_from_rgb(512)) 185 self.from_rgb.add(make_from_rgb(512)) 186 self.from_rgb.add(make_from_rgb(512)) 187 188 self.n_layer = len(self.progression) 189 190 self.linear = EqualLinear(512, 1) 191 192 def hybrid_forward(self, F, x, step=0, alpha=-1): 193 194 for i in range(step, -1, -1): 195 196 index = self.n_layer - i - 1 197 198 if i == step: 199 out = self.from_rgb[index](x) 200 201 if i == 0: 202 out_mean = nd.mean(out, 0) 203 out_var = (out - out_mean) **2 204 out_std = F.sqrt(nd.mean(out_var,0) + 1e-8) 205 mean_std = out_std.mean() 206 mean_std = mean_std.broadcast_to([out.shape[0], 1, 4, 4]) 207 out = F.Concat(out, mean_std, dim=1) 208 209 out = self.progression[index](out) 210 211 if i > 0: 212 if i == step and 0 <= alpha < 1: 213 skip_rgb = F.Pooling(x, kernel=(2, 2), stride=(2,2), pool_type='avg') 214 skip_rgb = self.from_rgb[index + 1](skip_rgb) 215 out = (1 - alpha) * skip_rgb + alpha * out 216 217 out = F.squeeze(out, axis=2) 218 out = F.squeeze(out, axis=2) 219 220 out = self.linear(out) 221 222 return out 223 224