1import numpy as np
2from math import sqrt
3
4import mxnet as mx
5import mxnet.ndarray as nd
6import mxnet.gluon.nn as nn
7from numpy import prod
8# pylint: disable-all
9
10def get_weight_key(module):
11    for k in module.params.keys():
12        if 'weight' in k:
13            weight_key = k
14
15    return weight_key
16
17
18def compute_weight(weight_orig):
19    fan_in = weight_orig.shape[1] * weight_orig[0][0].size
20
21    return weight_orig * sqrt(2 / (fan_in + 1e-8))
22
23
24class FusedUpsample(nn.HybridBlock):
25    def __init__(self, in_channel, out_channel, kernel_size, padding=0):
26        super().__init__()
27
28        fan_in = in_channel * kernel_size * kernel_size
29        self.multiplier = sqrt(2 / (fan_in))
30        self.weight = self.params.get('weight', shape=(in_channel, out_channel, kernel_size, kernel_size),
31                                      init=mx.init.Normal(sigma=1))
32        self.bias = self.params.get('bias', shape=(out_channel), init=mx.init.Zero())
33        self.pad = (padding, padding)
34
35    def hybrid_forward(self, F, x, **kwargs):
36        weight = F.pad(kwargs['weight'] * self.multiplier, mode='constant',
37                       constant_value=0, pad_width=(0, 0, 0, 0, 1, 1, 1, 1))
38        weight = (
39            weight[:, :, 1:, 1:]
40            + weight[:, :, :-1, 1:]
41            + weight[:, :, 1:, :-1]
42            + weight[:, :, :-1, :-1]
43        ) / 4
44
45        out = F.Deconvolution(x, weight, kwargs['bias'], kernel=weight.shape[-2:], stride=(2, 2),
46                              pad=self.pad, num_filter=weight.shape[1], no_bias=False)
47        return out
48
49
50class FusedDownsample(nn.HybridBlock):
51    def __init__(self, in_channel, out_channel, kernel_size, padding=0):
52        super().__init__()
53
54        self.weight = self.params.get('weight', shape=(in_channel, out_channel, kernel_size, kernel_size),
55                                      grad_req='write', init=mx.init.Normal(sigma=1))
56        self.bias = self.params.get('bias', shape=(out_channel), grad_req='write', init=mx.init.Zero())
57        fan_in = in_channel * kernel_size * kernel_size
58        self.multiplier = sqrt(2 / (fan_in + 1e-8))
59        self.pad = (padding, padding)
60
61
62    def hybrid_forward(self, F, x, **kwargs):
63        weight = F.pad(kwargs['weight'] * self.multiplier, mode='constant',
64                       constant_value=0, pad_width=(0, 0, 0, 0, 1, 1, 1, 1))
65        weight = (
66            weight[:, :, 1:, 1:]
67            + weight[:, :, :-1, 1:]
68            + weight[:, :, 1:, :-1]
69            + weight[:, :, :-1, :-1]
70        ) / 4
71
72        out = F.Convolution(x, weight, kwargs['bias'], kernel=weight.shape[-2:], stride=(2, 2),
73                            pad=self.pad, num_filter=weight.shape[1], no_bias=False)
74        return out
75
76
77class PixelNorm(nn.HybridBlock):
78    def __init__(self):
79        super().__init__()
80
81    def hybrid_forward(self, F, x):
82        return x / F.sqrt(F.mean(x ** 2, axis=1, keepdims=True) + 1e-8)
83
84
85class Blur(nn.HybridBlock):
86    def __init__(self, channel):
87        super().__init__()
88
89        self.channel = channel
90
91        weight = np.array([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=np.float32)
92        weight = weight.reshape((1, 1, 3, 3))
93        weight = weight / (weight.sum()+1e-8)
94
95        self. weight = nd.array(weight).tile((channel, 1, 1, 1))
96        weight_flip = np.flip(weight, 2)
97        self.weight_flip = nd.array(weight_flip).tile((channel, 1, 1, 1))
98
99    def hybrid_forward(self, F, x, **kwargs):
100
101        weight = nd.array(self.weight, ctx=x.context)
102        output = F.Convolution(x, weight, kernel=self.weight.shape[-2:], pad=(1, 1),
103                               num_filter=self.channel, num_group=x.shape[1], no_bias=True)
104        return output
105
106
107class EqualConv2d(nn.HybridBlock):
108    def __init__(self, in_dim, out_dim, kernel, padding=0):
109        super().__init__()
110
111        with self.name_scope():
112            self.weight = self.params.get('weight_orig', shape=(out_dim, in_dim, kernel, kernel), grad_req='write',
113                                          init=mx.init.Normal(1))
114            self.bias = self.params.get('bias', shape=(out_dim), grad_req='write', init=mx.init.Zero())
115            self.kernel = (kernel, kernel)
116            self.channel = out_dim
117            self.padding = (padding, padding)
118
119    def hybrid_forward(self, F, x, **kwargs):
120
121        size = kwargs['weight'].shape
122        fan_in = prod(size[1:])
123        multiplier = sqrt(2.0 / fan_in)
124
125        out = F.Convolution(x, kwargs['weight']*multiplier, kwargs['bias'], kernel=self.kernel, pad=self.padding,
126                               num_filter=self.channel)
127        return out
128
129
130class EqualLinear(nn.HybridBlock):
131    def __init__(self, in_dim, out_dim):
132        super().__init__()
133
134        self.weight = self.params.get('weight_orig', shape=(out_dim, in_dim), grad_req='write', init=mx.init.Normal(1))
135        self.bias = self.params.get('bias', shape=(out_dim), grad_req='write', init=mx.init.Zero())
136        self.num_hidden = out_dim
137
138    def hybrid_forward(self, F, x, **kwargs):
139
140        size = kwargs['weight'].shape
141        fan_in = prod(size[1:])
142        multiplier = sqrt(2.0 / fan_in)
143
144        out = F.FullyConnected(x, kwargs['weight']*multiplier, kwargs['bias'], num_hidden=self.num_hidden)
145
146        return out
147
148
149class AdaptiveInstanceNorm(nn.HybridBlock):
150    def __init__(self, in_channel, style_dim):
151        super().__init__()
152
153        self.norm = nn.InstanceNorm(in_channels=in_channel)
154        self.style = EqualLinear(style_dim, in_channel * 2)
155        self.style.initialize()
156
157        mx_params = self.style.collect_params()
158        for k in mx_params.keys():
159            if 'bias' in k:
160                mx_params[k].data()[:in_channel] = 1
161                mx_params[k].data()[in_channel:] = 0
162
163    def hybrid_forward(self, F, x, style, **kwargs):
164        style = self.style(style).expand_dims(2).expand_dims(3)
165        gamma, beta = style.split(2, 1)
166        out = self.norm(x)
167        out = gamma * out + beta
168
169        return out
170
171
172class NoiseInjection(nn.HybridBlock):
173    def __init__(self, channel):
174        super().__init__()
175
176        self.weight = self.params.get('weight_orig', shape=(1, channel, 1, 1), init=mx.init.Zero())
177
178    def hybrid_forward(self, F, image, noise, **kwargs):
179        new_weight = compute_weight(kwargs['weight'])
180
181        return image + new_weight * noise
182
183
184class ConstantInput(nn.HybridBlock):
185    def __init__(self, channel, size=4):
186        super().__init__()
187
188        self.input = self.params.get('input', shape=(1, channel, size, size), init=mx.init.Normal(sigma=1))
189
190    def hybrid_forward(self, F, x, **kwargs):
191        batch = x.shape[0]
192        out = kwargs['input'].tile((batch, 1, 1, 1))
193
194        return out
195
196class ConvBlock(nn.HybridBlock):
197    def __init__(self, in_channel, out_channel, kernel_size, padding, kernel_size2=None,
198                 padding2=None, downsample=False, fused=False):
199        super().__init__()
200
201        pad1 = padding
202        pad2 = padding
203        if padding2 is not None:
204            pad2 = padding2
205
206        kernel1 = kernel_size
207        kernel2 = kernel_size
208        if kernel_size2 is not None:
209            kernel2 = kernel_size2
210
211        self.conv1 = nn.HybridSequential()
212        with self.conv1.name_scope():
213            self.conv1.add(EqualConv2d(in_channel, out_channel, kernel1, padding=pad1))
214            self.conv1.add(nn.LeakyReLU(0.2))
215
216        if downsample:
217            if fused:
218                self.conv2 = nn.HybridSequential()
219                with self.conv2.name_scope():
220                    self.conv2.add(FusedDownsample(out_channel, out_channel, kernel2, padding=pad2))
221                    self.conv2.add(nn.LeakyReLU(0.2))
222
223            else:
224                self.conv2 = nn.HybridSequential()
225                with self.conv2.name_scope():
226                    self.conv2.add(EqualConv2d(out_channel, out_channel, kernel2, padding=pad2))
227                    self.conv2.add(nn.AvgPool2D(pool_size=(2, 2)))
228                    self.conv2.add(nn.LeakyReLU(0.2))
229
230        else:
231            self.conv2 = nn.HybridSequential()
232            with self.conv2.name_scope():
233                self.conv2.add(EqualConv2d(out_channel, out_channel, kernel2, padding=pad2))
234                self.conv2.add(nn.LeakyReLU(0.2))
235
236    def hybrid_forward(self, F, x):
237        out = self.conv1(x)
238        out = self.conv2(out)
239
240        return out
241
242
243class StyledConvBlock(nn.HybridBlock):
244    def __init__(self, in_channel, out_channel, kernel_size=3, padding=1, style_dim=512,
245                 initial=False, upsample=False, fused=False, blur=False):
246
247        super().__init__()
248
249        self.upsample = None
250        if initial:
251            self.conv1 = ConstantInput(in_channel)
252        else:
253            if upsample:
254                if fused:
255                    self.upsample = 'fused'
256                    self.conv1 = nn.HybridSequential()
257                    with self.conv1.name_scope():
258                        self.conv1.add(FusedUpsample(in_channel, out_channel, kernel_size, padding=padding))
259                        if blur:
260                            self.conv1.add(Blur(out_channel))
261                else:
262                    self.upsample = 'nearest'
263                    self.conv1 = nn.HybridSequential()
264                    with self.conv1.name_scope():
265                        self.conv1.add(EqualConv2d(in_dim=in_channel, out_dim=out_channel,
266                                                   kernel=kernel_size, padding=padding))
267                        if blur:
268                            self.conv1.add(Blur(out_channel))
269            else:
270                self.conv1 = EqualConv2d(in_dim=in_channel, out_dim=out_channel,
271
272                                         kernel=kernel_size, padding=padding)
273
274        self.noise1 = NoiseInjection(out_channel)
275        self.adain1 = AdaptiveInstanceNorm(out_channel, style_dim)
276        self.lrelu1 = nn.LeakyReLU(0.2)
277
278        self.conv2 = EqualConv2d(in_dim=out_channel, out_dim=out_channel, kernel=kernel_size, padding=padding)
279        self.noise2 = NoiseInjection(out_channel)
280        self.adain2 = AdaptiveInstanceNorm(out_channel, style_dim)
281        self.lrelu2 = nn.LeakyReLU(0.2)
282
283    def hybrid_forward(self, F, x, style, noise):
284        #  Upsample
285        if self.upsample == 'nearest':
286            x = F.UpSampling(x, scale=2, sample_type='nearest')
287        out = self.conv1(x)
288        out = self.noise1(out, noise)
289        out = self.lrelu1(out)
290        out = self.adain1(out, style)
291
292        out = self.conv2(out)
293        out = self.noise2(out, noise)
294        out = self.lrelu2(out)
295        out = self.adain2(out, style)
296
297        return out
298