1import numpy as np 2from math import sqrt 3 4import mxnet as mx 5import mxnet.ndarray as nd 6import mxnet.gluon.nn as nn 7from numpy import prod 8# pylint: disable-all 9 10def get_weight_key(module): 11 for k in module.params.keys(): 12 if 'weight' in k: 13 weight_key = k 14 15 return weight_key 16 17 18def compute_weight(weight_orig): 19 fan_in = weight_orig.shape[1] * weight_orig[0][0].size 20 21 return weight_orig * sqrt(2 / (fan_in + 1e-8)) 22 23 24class FusedUpsample(nn.HybridBlock): 25 def __init__(self, in_channel, out_channel, kernel_size, padding=0): 26 super().__init__() 27 28 fan_in = in_channel * kernel_size * kernel_size 29 self.multiplier = sqrt(2 / (fan_in)) 30 self.weight = self.params.get('weight', shape=(in_channel, out_channel, kernel_size, kernel_size), 31 init=mx.init.Normal(sigma=1)) 32 self.bias = self.params.get('bias', shape=(out_channel), init=mx.init.Zero()) 33 self.pad = (padding, padding) 34 35 def hybrid_forward(self, F, x, **kwargs): 36 weight = F.pad(kwargs['weight'] * self.multiplier, mode='constant', 37 constant_value=0, pad_width=(0, 0, 0, 0, 1, 1, 1, 1)) 38 weight = ( 39 weight[:, :, 1:, 1:] 40 + weight[:, :, :-1, 1:] 41 + weight[:, :, 1:, :-1] 42 + weight[:, :, :-1, :-1] 43 ) / 4 44 45 out = F.Deconvolution(x, weight, kwargs['bias'], kernel=weight.shape[-2:], stride=(2, 2), 46 pad=self.pad, num_filter=weight.shape[1], no_bias=False) 47 return out 48 49 50class FusedDownsample(nn.HybridBlock): 51 def __init__(self, in_channel, out_channel, kernel_size, padding=0): 52 super().__init__() 53 54 self.weight = self.params.get('weight', shape=(in_channel, out_channel, kernel_size, kernel_size), 55 grad_req='write', init=mx.init.Normal(sigma=1)) 56 self.bias = self.params.get('bias', shape=(out_channel), grad_req='write', init=mx.init.Zero()) 57 fan_in = in_channel * kernel_size * kernel_size 58 self.multiplier = sqrt(2 / (fan_in + 1e-8)) 59 self.pad = (padding, padding) 60 61 62 def hybrid_forward(self, F, x, **kwargs): 63 weight = F.pad(kwargs['weight'] * self.multiplier, mode='constant', 64 constant_value=0, pad_width=(0, 0, 0, 0, 1, 1, 1, 1)) 65 weight = ( 66 weight[:, :, 1:, 1:] 67 + weight[:, :, :-1, 1:] 68 + weight[:, :, 1:, :-1] 69 + weight[:, :, :-1, :-1] 70 ) / 4 71 72 out = F.Convolution(x, weight, kwargs['bias'], kernel=weight.shape[-2:], stride=(2, 2), 73 pad=self.pad, num_filter=weight.shape[1], no_bias=False) 74 return out 75 76 77class PixelNorm(nn.HybridBlock): 78 def __init__(self): 79 super().__init__() 80 81 def hybrid_forward(self, F, x): 82 return x / F.sqrt(F.mean(x ** 2, axis=1, keepdims=True) + 1e-8) 83 84 85class Blur(nn.HybridBlock): 86 def __init__(self, channel): 87 super().__init__() 88 89 self.channel = channel 90 91 weight = np.array([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=np.float32) 92 weight = weight.reshape((1, 1, 3, 3)) 93 weight = weight / (weight.sum()+1e-8) 94 95 self. weight = nd.array(weight).tile((channel, 1, 1, 1)) 96 weight_flip = np.flip(weight, 2) 97 self.weight_flip = nd.array(weight_flip).tile((channel, 1, 1, 1)) 98 99 def hybrid_forward(self, F, x, **kwargs): 100 101 weight = nd.array(self.weight, ctx=x.context) 102 output = F.Convolution(x, weight, kernel=self.weight.shape[-2:], pad=(1, 1), 103 num_filter=self.channel, num_group=x.shape[1], no_bias=True) 104 return output 105 106 107class EqualConv2d(nn.HybridBlock): 108 def __init__(self, in_dim, out_dim, kernel, padding=0): 109 super().__init__() 110 111 with self.name_scope(): 112 self.weight = self.params.get('weight_orig', shape=(out_dim, in_dim, kernel, kernel), grad_req='write', 113 init=mx.init.Normal(1)) 114 self.bias = self.params.get('bias', shape=(out_dim), grad_req='write', init=mx.init.Zero()) 115 self.kernel = (kernel, kernel) 116 self.channel = out_dim 117 self.padding = (padding, padding) 118 119 def hybrid_forward(self, F, x, **kwargs): 120 121 size = kwargs['weight'].shape 122 fan_in = prod(size[1:]) 123 multiplier = sqrt(2.0 / fan_in) 124 125 out = F.Convolution(x, kwargs['weight']*multiplier, kwargs['bias'], kernel=self.kernel, pad=self.padding, 126 num_filter=self.channel) 127 return out 128 129 130class EqualLinear(nn.HybridBlock): 131 def __init__(self, in_dim, out_dim): 132 super().__init__() 133 134 self.weight = self.params.get('weight_orig', shape=(out_dim, in_dim), grad_req='write', init=mx.init.Normal(1)) 135 self.bias = self.params.get('bias', shape=(out_dim), grad_req='write', init=mx.init.Zero()) 136 self.num_hidden = out_dim 137 138 def hybrid_forward(self, F, x, **kwargs): 139 140 size = kwargs['weight'].shape 141 fan_in = prod(size[1:]) 142 multiplier = sqrt(2.0 / fan_in) 143 144 out = F.FullyConnected(x, kwargs['weight']*multiplier, kwargs['bias'], num_hidden=self.num_hidden) 145 146 return out 147 148 149class AdaptiveInstanceNorm(nn.HybridBlock): 150 def __init__(self, in_channel, style_dim): 151 super().__init__() 152 153 self.norm = nn.InstanceNorm(in_channels=in_channel) 154 self.style = EqualLinear(style_dim, in_channel * 2) 155 self.style.initialize() 156 157 mx_params = self.style.collect_params() 158 for k in mx_params.keys(): 159 if 'bias' in k: 160 mx_params[k].data()[:in_channel] = 1 161 mx_params[k].data()[in_channel:] = 0 162 163 def hybrid_forward(self, F, x, style, **kwargs): 164 style = self.style(style).expand_dims(2).expand_dims(3) 165 gamma, beta = style.split(2, 1) 166 out = self.norm(x) 167 out = gamma * out + beta 168 169 return out 170 171 172class NoiseInjection(nn.HybridBlock): 173 def __init__(self, channel): 174 super().__init__() 175 176 self.weight = self.params.get('weight_orig', shape=(1, channel, 1, 1), init=mx.init.Zero()) 177 178 def hybrid_forward(self, F, image, noise, **kwargs): 179 new_weight = compute_weight(kwargs['weight']) 180 181 return image + new_weight * noise 182 183 184class ConstantInput(nn.HybridBlock): 185 def __init__(self, channel, size=4): 186 super().__init__() 187 188 self.input = self.params.get('input', shape=(1, channel, size, size), init=mx.init.Normal(sigma=1)) 189 190 def hybrid_forward(self, F, x, **kwargs): 191 batch = x.shape[0] 192 out = kwargs['input'].tile((batch, 1, 1, 1)) 193 194 return out 195 196class ConvBlock(nn.HybridBlock): 197 def __init__(self, in_channel, out_channel, kernel_size, padding, kernel_size2=None, 198 padding2=None, downsample=False, fused=False): 199 super().__init__() 200 201 pad1 = padding 202 pad2 = padding 203 if padding2 is not None: 204 pad2 = padding2 205 206 kernel1 = kernel_size 207 kernel2 = kernel_size 208 if kernel_size2 is not None: 209 kernel2 = kernel_size2 210 211 self.conv1 = nn.HybridSequential() 212 with self.conv1.name_scope(): 213 self.conv1.add(EqualConv2d(in_channel, out_channel, kernel1, padding=pad1)) 214 self.conv1.add(nn.LeakyReLU(0.2)) 215 216 if downsample: 217 if fused: 218 self.conv2 = nn.HybridSequential() 219 with self.conv2.name_scope(): 220 self.conv2.add(FusedDownsample(out_channel, out_channel, kernel2, padding=pad2)) 221 self.conv2.add(nn.LeakyReLU(0.2)) 222 223 else: 224 self.conv2 = nn.HybridSequential() 225 with self.conv2.name_scope(): 226 self.conv2.add(EqualConv2d(out_channel, out_channel, kernel2, padding=pad2)) 227 self.conv2.add(nn.AvgPool2D(pool_size=(2, 2))) 228 self.conv2.add(nn.LeakyReLU(0.2)) 229 230 else: 231 self.conv2 = nn.HybridSequential() 232 with self.conv2.name_scope(): 233 self.conv2.add(EqualConv2d(out_channel, out_channel, kernel2, padding=pad2)) 234 self.conv2.add(nn.LeakyReLU(0.2)) 235 236 def hybrid_forward(self, F, x): 237 out = self.conv1(x) 238 out = self.conv2(out) 239 240 return out 241 242 243class StyledConvBlock(nn.HybridBlock): 244 def __init__(self, in_channel, out_channel, kernel_size=3, padding=1, style_dim=512, 245 initial=False, upsample=False, fused=False, blur=False): 246 247 super().__init__() 248 249 self.upsample = None 250 if initial: 251 self.conv1 = ConstantInput(in_channel) 252 else: 253 if upsample: 254 if fused: 255 self.upsample = 'fused' 256 self.conv1 = nn.HybridSequential() 257 with self.conv1.name_scope(): 258 self.conv1.add(FusedUpsample(in_channel, out_channel, kernel_size, padding=padding)) 259 if blur: 260 self.conv1.add(Blur(out_channel)) 261 else: 262 self.upsample = 'nearest' 263 self.conv1 = nn.HybridSequential() 264 with self.conv1.name_scope(): 265 self.conv1.add(EqualConv2d(in_dim=in_channel, out_dim=out_channel, 266 kernel=kernel_size, padding=padding)) 267 if blur: 268 self.conv1.add(Blur(out_channel)) 269 else: 270 self.conv1 = EqualConv2d(in_dim=in_channel, out_dim=out_channel, 271 272 kernel=kernel_size, padding=padding) 273 274 self.noise1 = NoiseInjection(out_channel) 275 self.adain1 = AdaptiveInstanceNorm(out_channel, style_dim) 276 self.lrelu1 = nn.LeakyReLU(0.2) 277 278 self.conv2 = EqualConv2d(in_dim=out_channel, out_dim=out_channel, kernel=kernel_size, padding=padding) 279 self.noise2 = NoiseInjection(out_channel) 280 self.adain2 = AdaptiveInstanceNorm(out_channel, style_dim) 281 self.lrelu2 = nn.LeakyReLU(0.2) 282 283 def hybrid_forward(self, F, x, style, noise): 284 # Upsample 285 if self.upsample == 'nearest': 286 x = F.UpSampling(x, scale=2, sample_type='nearest') 287 out = self.conv1(x) 288 out = self.noise1(out, noise) 289 out = self.lrelu1(out) 290 out = self.adain1(out, style) 291 292 out = self.conv2(out) 293 out = self.noise2(out, noise) 294 out = self.lrelu2(out) 295 out = self.adain2(out, style) 296 297 return out 298