1import random
2import numpy as np
3from math import sqrt
4
5import mxnet as mx
6import mxnet.ndarray as nd
7
8from modules import *
9# pylint: disable-all
10
11class Generator(nn.HybridBlock):
12    def __init__(self, fused=True):
13        super().__init__()
14
15        self.progression = nn.HybridSequential()
16        with self.progression.name_scope():
17            self.progression.add(StyledConvBlock(512, 512, 3, 1, initial=True, blur=blur))  # 4
18            self.progression.add(StyledConvBlock(512, 512, 3, 1, upsample=True, blur=blur))  # 8
19            self.progression.add(StyledConvBlock(512, 512, 3, 1, upsample=True, blur=blur))  # 16
20            self.progression.add(StyledConvBlock(512, 512, 3, 1, upsample=True, blur=blur))  # 32
21            self.progression.add(StyledConvBlock(512, 256, 3, 1, upsample=True, blur=blur))  # 64
22            self.progression.add(StyledConvBlock(256, 128, 3, 1, upsample=True, fused=fused, blur=blur))  # 128
23            self.progression.add(StyledConvBlock(128, 64, 3, 1, upsample=True, fused=fused, blur=blur))  # 256
24            self.progression.add(StyledConvBlock(64, 32, 3, 1, upsample=True, fused=fused, blur=blur))  # 512
25            self.progression.add(StyledConvBlock(32, 16, 3, 1, upsample=True, fused=fused, blur=blur))  # 1024
26
27        self.to_rgb = nn.HybridSequential()
28        with self.to_rgb.name_scope():
29            self.to_rgb.add(EqualConv2d(512, 3, 1))
30            self.to_rgb.add(EqualConv2d(512, 3, 1))
31            self.to_rgb.add(EqualConv2d(512, 3, 1))
32            self.to_rgb.add(EqualConv2d(512, 3, 1))
33            self.to_rgb.add(EqualConv2d(256, 3, 1))
34            self.to_rgb.add(EqualConv2d(128, 3, 1))
35            self.to_rgb.add(EqualConv2d(64, 3, 1))
36            self.to_rgb.add(EqualConv2d(32, 3, 1))
37            self.to_rgb.add(EqualConv2d(16, 3, 1))
38
39    def hybrid_forward(self, F, style, noise, step=0, alpha=-1, mixing_range=(-1, -1)):
40
41        out = nd.array(noise[0], ctx=style[0].context)
42
43        if style.shape[0] < 2:
44            inject_index = [len(self.progression) + 1]
45
46        else:
47            inject_index = random.sample(list(range(step)), style.shape[0] - 1)
48
49        crossover = 0
50
51        for i, (conv, to_rgb) in enumerate(zip(self.progression, self.to_rgb)):
52            if mixing_range == (-1, -1):
53                if crossover < len(inject_index) and i > inject_index[crossover]:
54                    crossover = min(crossover + 1, len(style))
55
56                style_step = style[crossover]
57
58            else:
59                if mixing_range[0] <= i <= mixing_range[1]:
60                    style_step = style[1]
61
62                else:
63                    style_step = style[0]
64
65            if i > 0 and step > 0:
66                out_prev = out
67
68            out = conv(out, style_step, nd.array(noise[i], ctx=style[0].context))
69
70            if i == step:
71
72                out = to_rgb(out)
73
74                if i > 0 and 0 <= alpha < 1:
75                    skip_rgb = self.to_rgb[i - 1](out_prev)
76                    skip_rgb = F.UpSampling(skip_rgb, scale=2, sample_type='nearest')
77                    out = (1 - alpha) * skip_rgb + alpha * out
78
79                break
80
81        return out
82
83
84class StyledGenerator(nn.HybridBlock):
85    r"""Style-based GAN
86    Reference:
87
88        Tero Karras, Samuli Laine, Timo Aila. "A Style-Based Generator
89        Architecture for Generative Adversarial Networks." *CVPR*, 2019
90    """
91    def __init__(self, code_dim=512, n_mlp=8, blur=False):
92        super().__init__()
93
94        self.generator = Generator(code_dim, blur)
95
96        self.style = nn.HybridSequential()
97
98        with self.style.name_scope():
99
100            self.style.add(PixelNorm())
101
102            for i in range(n_mlp):
103                self.style.add(EqualLinear(code_dim, code_dim))
104                self.style.add(nn.LeakyReLU(0.2))
105
106
107    def hybrid_forward(self, F, x, step=0, alpha=-1, noise=None, mean_style=None,
108                       style_weight=0,  mixing_range=(-1, -1)):
109
110        styles = []
111
112        if type(x) not in (list, tuple):
113            x = [x]
114
115        for i in x:
116            styles.append(self.style(i))
117
118        batch = x[0].shape[0]
119
120        if noise is None:
121            noise = []
122
123            for i in range(step + 1):
124                size = 4 * 2 ** i
125                noise.append(nd.random.randn(batch, 1, size, size, ctx=x[0].context))
126
127        if mean_style is not None:
128            styles_norm = []
129
130            for style in styles:
131                styles_norm.append(mean_style + style_weight * (style - mean_style))
132
133            styles = styles_norm
134
135        nd_styles = nd.empty((len(styles), styles[0].shape[0], styles[0].shape[1]))
136
137        for i, style in enumerate(styles):
138            nd_styles[i] = style
139
140        return self.generator(nd_styles, noise, step, alpha, mixing_range)
141
142    def mean_style(self, x):
143
144        style = self.style(x).mean(axis=0, keepdims=True)
145
146        return style
147
148
149class Discriminator(nn.HybridBlock):
150    def __init__(self, fused=True, from_rgb_activate=False):
151        super().__init__()
152
153        self.progression = nn.HybridSequential()
154        with self.progression.name_scope():
155            self.progression.add(ConvBlock(16, 32, 3, 1, downsample=True, fused=fused))  # 512
156            self.progression.add(ConvBlock(32, 64, 3, 1, downsample=True, fused=fused))  # 256
157            self.progression.add(ConvBlock(64, 128, 3, 1, downsample=True, fused=fused))  # 128
158            self.progression.add(ConvBlock(128, 256, 3, 1, downsample=True, fused=fused))  # 64
159            self.progression.add(ConvBlock(256, 512, 3, 1, downsample=True))  # 32
160            self.progression.add(ConvBlock(512, 512, 3, 1, downsample=True))  # 16
161            self.progression.add(ConvBlock(512, 512, 3, 1, downsample=True))  # 8
162            self.progression.add(ConvBlock(512, 512, 3, 1, downsample=True))  # 4
163            self.progression.add(ConvBlock(513, 512, 3, 1, 4, 0))
164
165        def make_from_rgb(out_channel):
166            if from_rgb_activate:
167                module = nn.HybridSequential()
168                with module.name_scope():
169                    module.add(EqualConv2d(3, out_channel, 1))
170                    module.add(nn.LeakyReLU(0.2))
171                return module
172
173            else:
174                return EqualConv2d(3, out_channel, 1)
175
176        self.from_rgb = nn.HybridSequential()
177        with self.from_rgb.name_scope():
178            self.from_rgb.add(make_from_rgb(16))
179            self.from_rgb.add(make_from_rgb(32))
180            self.from_rgb.add(make_from_rgb(64))
181            self.from_rgb.add(make_from_rgb(128))
182            self.from_rgb.add(make_from_rgb(256))
183            self.from_rgb.add(make_from_rgb(512))
184            self.from_rgb.add(make_from_rgb(512))
185            self.from_rgb.add(make_from_rgb(512))
186            self.from_rgb.add(make_from_rgb(512))
187
188        self.n_layer = len(self.progression)
189
190        self.linear = EqualLinear(512, 1)
191
192    def hybrid_forward(self, F, x, step=0, alpha=-1):
193
194        for i in range(step, -1, -1):
195
196            index = self.n_layer - i - 1
197
198            if i == step:
199                out = self.from_rgb[index](x)
200
201            if i == 0:
202                out_mean = nd.mean(out, 0)
203                out_var = (out - out_mean) **2
204                out_std = F.sqrt(nd.mean(out_var,0) + 1e-8)
205                mean_std = out_std.mean()
206                mean_std = mean_std.broadcast_to([out.shape[0], 1, 4, 4])
207                out = F.Concat(out, mean_std, dim=1)
208
209            out = self.progression[index](out)
210
211            if i > 0:
212                if i == step and 0 <= alpha < 1:
213                    skip_rgb = F.Pooling(x, kernel=(2, 2), stride=(2,2), pool_type='avg')
214                    skip_rgb = self.from_rgb[index + 1](skip_rgb)
215                    out = (1 - alpha) * skip_rgb + alpha * out
216
217        out = F.squeeze(out, axis=2)
218        out = F.squeeze(out, axis=2)
219
220        out = self.linear(out)
221
222        return out
223
224