1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# coding: utf-8
19# pylint: disable= arguments-differ
20"""Custom neural network layers in model_zoo."""
21
22__all__ = ['Concurrent', 'HybridConcurrent', 'Identity', 'SparseEmbedding',
23           'SyncBatchNorm', 'PixelShuffle1D', 'PixelShuffle2D',
24           'PixelShuffle3D']
25
26import warnings
27from .... import nd, context
28from ...block import HybridBlock, Block
29from ...nn import Sequential, HybridSequential, BatchNorm
30
31class Concurrent(Sequential):
32    """Lays `Block` s concurrently.
33
34    This block feeds its input to all children blocks, and
35    produce the output by concatenating all the children blocks' outputs
36    on the specified axis.
37
38    Example::
39
40        net = Concurrent()
41        # use net's name_scope to give children blocks appropriate names.
42        with net.name_scope():
43            net.add(nn.Dense(10, activation='relu'))
44            net.add(nn.Dense(20))
45            net.add(Identity())
46
47    Parameters
48    ----------
49    axis : int, default -1
50        The axis on which to concatenate the outputs.
51    """
52    def __init__(self, axis=-1, prefix=None, params=None):
53        super(Concurrent, self).__init__(prefix=prefix, params=params)
54        self.axis = axis
55
56    def forward(self, x):
57        out = []
58        for block in self._children.values():
59            out.append(block(x))
60        out = nd.concat(*out, dim=self.axis)
61        return out
62
63
64class HybridConcurrent(HybridSequential):
65    """Lays `HybridBlock` s concurrently.
66
67    This block feeds its input to all children blocks, and
68    produce the output by concatenating all the children blocks' outputs
69    on the specified axis.
70
71    Example::
72
73        net = HybridConcurrent()
74        # use net's name_scope to give children blocks appropriate names.
75        with net.name_scope():
76            net.add(nn.Dense(10, activation='relu'))
77            net.add(nn.Dense(20))
78            net.add(Identity())
79
80    Parameters
81    ----------
82    axis : int, default -1
83        The axis on which to concatenate the outputs.
84    """
85    def __init__(self, axis=-1, prefix=None, params=None):
86        super(HybridConcurrent, self).__init__(prefix=prefix, params=params)
87        self.axis = axis
88
89    def hybrid_forward(self, F, x):
90        out = []
91        for block in self._children.values():
92            out.append(block(x))
93        out = F.concat(*out, dim=self.axis)
94        return out
95
96
97class Identity(HybridBlock):
98    """Block that passes through the input directly.
99
100    This block can be used in conjunction with HybridConcurrent
101    block for residual connection.
102
103    Example::
104
105        net = HybridConcurrent()
106        # use net's name_scope to give child Blocks appropriate names.
107        with net.name_scope():
108            net.add(nn.Dense(10, activation='relu'))
109            net.add(nn.Dense(20))
110            net.add(Identity())
111    """
112    def __init__(self, prefix=None, params=None):
113        super(Identity, self).__init__(prefix=prefix, params=params)
114
115    def hybrid_forward(self, F, x):
116        return x
117
118class SparseEmbedding(Block):
119    r"""Turns non-negative integers (indexes/tokens) into dense vectors
120    of fixed size. eg. [4, 20] -> [[0.25, 0.1], [0.6, -0.2]]
121
122    This SparseBlock is designed for distributed training with extremely large
123    input dimension. Both weight and gradient w.r.t. weight are `RowSparseNDArray`.
124
125    Note: if `sparse_grad` is set to True, the gradient w.r.t weight will be
126    sparse. Only a subset of optimizers support sparse gradients, including SGD, AdaGrad
127    and Adam. By default lazy updates is turned on, which may perform differently
128    from standard updates. For more details, please check the Optimization API at:
129    https://mxnet.incubator.apache.org/api/python/optimization/optimization.html
130
131    Parameters
132    ----------
133    input_dim : int
134        Size of the vocabulary, i.e. maximum integer index + 1.
135    output_dim : int
136        Dimension of the dense embedding.
137    dtype : str or np.dtype, default 'float32'
138        Data type of output embeddings.
139    weight_initializer : Initializer
140        Initializer for the `embeddings` matrix.
141
142    Inputs:
143        - **data**: (N-1)-D tensor with shape: `(x1, x2, ..., xN-1)`.
144    Output:
145        - **out**: N-D tensor with shape: `(x1, x2, ..., xN-1, output_dim)`.
146    """
147    def __init__(self, input_dim, output_dim, dtype='float32',
148                 weight_initializer=None, **kwargs):
149        super(SparseEmbedding, self).__init__(**kwargs)
150        self._kwargs = {'input_dim': input_dim, 'output_dim': output_dim,
151                        'dtype': dtype, 'sparse_grad': True}
152        self.weight = self.params.get('weight', shape=(input_dim, output_dim),
153                                      init=weight_initializer, dtype=dtype,
154                                      grad_stype='row_sparse', stype='row_sparse')
155
156    def forward(self, x):
157        weight = self.weight.row_sparse_data(x)
158        return nd.Embedding(x, weight, name='fwd', **self._kwargs)
159
160    def __repr__(self):
161        s = '{block_name}({input_dim} -> {output_dim}, {dtype})'
162        return s.format(block_name=self.__class__.__name__,
163                        **self._kwargs)
164
165class SyncBatchNorm(BatchNorm):
166    """Cross-GPU Synchronized Batch normalization (SyncBN)
167
168    Standard BN [1]_ implementation only normalize the data within each device.
169    SyncBN normalizes the input within the whole mini-batch.
170    We follow the implementation described in the paper [2]_.
171
172    Note: Current implementation of SyncBN does not support FP16 training.
173    For FP16 inference, use standard nn.BatchNorm instead of SyncBN.
174
175    Parameters
176    ----------
177    in_channels : int, default 0
178        Number of channels (feature maps) in input data. If not specified,
179        initialization will be deferred to the first time `forward` is called
180        and `in_channels` will be inferred from the shape of input data.
181    num_devices : int, default number of visible GPUs
182    momentum: float, default 0.9
183        Momentum for the moving average.
184    epsilon: float, default 1e-5
185        Small float added to variance to avoid dividing by zero.
186    center: bool, default True
187        If True, add offset of `beta` to normalized tensor.
188        If False, `beta` is ignored.
189    scale: bool, default True
190        If True, multiply by `gamma`. If False, `gamma` is not used.
191        When the next layer is linear (also e.g. `nn.relu`),
192        this can be disabled since the scaling
193        will be done by the next layer.
194    use_global_stats: bool, default False
195        If True, use global moving statistics instead of local batch-norm. This will force
196        change batch-norm into a scale shift operator.
197        If False, use local batch-norm.
198    beta_initializer: str or `Initializer`, default 'zeros'
199        Initializer for the beta weight.
200    gamma_initializer: str or `Initializer`, default 'ones'
201        Initializer for the gamma weight.
202    running_mean_initializer: str or `Initializer`, default 'zeros'
203        Initializer for the running mean.
204    running_variance_initializer: str or `Initializer`, default 'ones'
205        Initializer for the running variance.
206
207
208    Inputs:
209        - **data**: input tensor with arbitrary shape.
210    Outputs:
211        - **out**: output tensor with the same shape as `data`.
212
213    Reference:
214        .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating \
215          deep network training by reducing internal covariate shift." *ICML 2015*
216        .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, \
217          Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018*
218    """
219    def __init__(self, in_channels=0, num_devices=None, momentum=0.9, epsilon=1e-5,
220                 center=True, scale=True, use_global_stats=False, beta_initializer='zeros',
221                 gamma_initializer='ones', running_mean_initializer='zeros',
222                 running_variance_initializer='ones', **kwargs):
223        super(SyncBatchNorm, self).__init__(
224            axis=1, momentum=momentum, epsilon=epsilon,
225            center=center, scale=scale,
226            use_global_stats=use_global_stats,
227            beta_initializer=beta_initializer,
228            gamma_initializer=gamma_initializer,
229            running_mean_initializer=running_mean_initializer,
230            running_variance_initializer=running_variance_initializer,
231            in_channels=in_channels, **kwargs)
232        num_devices = self._get_num_devices() if num_devices is None else num_devices
233        self._kwargs = {'eps': epsilon, 'momentum': momentum,
234                        'fix_gamma': not scale, 'use_global_stats': use_global_stats,
235                        'ndev': num_devices, 'key': self.prefix}
236
237    def _get_num_devices(self):
238        warnings.warn("Caution using SyncBatchNorm: "
239                      "if not using all the GPUs, please mannually set num_devices",
240                      UserWarning)
241        num_devices = context.num_gpus()
242        num_devices = num_devices if num_devices > 0 else 1
243        return num_devices
244
245    def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):
246        return F.contrib.SyncBatchNorm(x, gamma, beta, running_mean, running_var,
247                                       name='fwd', **self._kwargs)
248
249class PixelShuffle1D(HybridBlock):
250
251    r"""Pixel-shuffle layer for upsampling in 1 dimension.
252
253    Pixel-shuffling is the operation of taking groups of values along
254    the *channel* dimension and regrouping them into blocks of pixels
255    along the ``W`` dimension, thereby effectively multiplying that dimension
256    by a constant factor in size.
257
258    For example, a feature map of shape :math:`(fC, W)` is reshaped
259    into :math:`(C, fW)` by forming little value groups of size :math:`f`
260    and arranging them in a grid of size :math:`W`.
261
262    Parameters
263    ----------
264    factor : int or 1-tuple of int
265        Upsampling factor, applied to the ``W`` dimension.
266
267    Inputs:
268        - **data**: Tensor of shape ``(N, f*C, W)``.
269    Outputs:
270        - **out**: Tensor of shape ``(N, C, W*f)``.
271
272    Examples
273    --------
274    >>> pxshuf = PixelShuffle1D(2)
275    >>> x = mx.nd.zeros((1, 8, 3))
276    >>> pxshuf(x).shape
277    (1, 4, 6)
278    """
279
280    def __init__(self, factor):
281        super(PixelShuffle1D, self).__init__()
282        self._factor = int(factor)
283
284    def hybrid_forward(self, F, x):
285        """Perform pixel-shuffling on the input."""
286        f = self._factor
287                                             # (N, C*f, W)
288        x = F.reshape(x, (0, -4, -1, f, 0))  # (N, C, f, W)
289        x = F.transpose(x, (0, 1, 3, 2))     # (N, C, W, f)
290        x = F.reshape(x, (0, 0, -3))         # (N, C, W*f)
291        return x
292
293    def __repr__(self):
294        return "{}({})".format(self.__class__.__name__, self._factor)
295
296
297class PixelShuffle2D(HybridBlock):
298
299    r"""Pixel-shuffle layer for upsampling in 2 dimensions.
300
301    Pixel-shuffling is the operation of taking groups of values along
302    the *channel* dimension and regrouping them into blocks of pixels
303    along the ``H`` and ``W`` dimensions, thereby effectively multiplying
304    those dimensions by a constant factor in size.
305
306    For example, a feature map of shape :math:`(f^2 C, H, W)` is reshaped
307    into :math:`(C, fH, fW)` by forming little :math:`f \times f` blocks
308    of pixels and arranging them in an :math:`H \times W` grid.
309
310    Pixel-shuffling together with regular convolution is an alternative,
311    learnable way of upsampling an image by arbitrary factors. It is reported
312    to help overcome checkerboard artifacts that are common in upsampling with
313    transposed convolutions (also called deconvolutions). See the paper
314    `Real-Time Single Image and Video Super-Resolution Using an Efficient
315    Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158>`_
316    for further details.
317
318    Parameters
319    ----------
320    factor : int or 2-tuple of int
321        Upsampling factors, applied to the ``H`` and ``W`` dimensions,
322        in that order.
323
324    Inputs:
325        - **data**: Tensor of shape ``(N, f1*f2*C, H, W)``.
326    Outputs:
327        - **out**: Tensor of shape ``(N, C, H*f1, W*f2)``.
328
329    Examples
330    --------
331    >>> pxshuf = PixelShuffle2D((2, 3))
332    >>> x = mx.nd.zeros((1, 12, 3, 5))
333    >>> pxshuf(x).shape
334    (1, 2, 6, 15)
335    """
336
337    def __init__(self, factor):
338        super(PixelShuffle2D, self).__init__()
339        try:
340            self._factors = (int(factor),) * 2
341        except TypeError:
342            self._factors = tuple(int(fac) for fac in factor)
343            assert len(self._factors) == 2, "wrong length {}".format(len(self._factors))
344
345    def hybrid_forward(self, F, x):
346        """Perform pixel-shuffling on the input."""
347        f1, f2 = self._factors
348                                                      # (N, f1*f2*C, H, W)
349        x = F.reshape(x, (0, -4, -1, f1 * f2, 0, 0))  # (N, C, f1*f2, H, W)
350        x = F.reshape(x, (0, 0, -4, f1, f2, 0, 0))    # (N, C, f1, f2, H, W)
351        x = F.transpose(x, (0, 1, 4, 2, 5, 3))        # (N, C, H, f1, W, f2)
352        x = F.reshape(x, (0, 0, -3, -3))              # (N, C, H*f1, W*f2)
353        return x
354
355    def __repr__(self):
356        return "{}({})".format(self.__class__.__name__, self._factors)
357
358
359class PixelShuffle3D(HybridBlock):
360
361    r"""Pixel-shuffle layer for upsampling in 3 dimensions.
362
363    Pixel-shuffling (or voxel-shuffling in 3D) is the operation of taking
364    groups of values along the *channel* dimension and regrouping them into
365    blocks of voxels along the ``D``, ``H`` and ``W`` dimensions, thereby
366    effectively multiplying those dimensions by a constant factor in size.
367
368    For example, a feature map of shape :math:`(f^3 C, D, H, W)` is reshaped
369    into :math:`(C, fD, fH, fW)` by forming little :math:`f \times f \times f`
370    blocks of voxels and arranging them in a :math:`D \times H \times W` grid.
371
372    Pixel-shuffling together with regular convolution is an alternative,
373    learnable way of upsampling an image by arbitrary factors. It is reported
374    to help overcome checkerboard artifacts that are common in upsampling with
375    transposed convolutions (also called deconvolutions). See the paper
376    `Real-Time Single Image and Video Super-Resolution Using an Efficient
377    Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158>`_
378    for further details.
379
380    Parameters
381    ----------
382    factor : int or 3-tuple of int
383        Upsampling factors, applied to the ``D``, ``H`` and ``W``
384        dimensions, in that order.
385
386    Inputs:
387        - **data**: Tensor of shape ``(N, f1*f2*f3*C, D, H, W)``.
388    Outputs:
389        - **out**: Tensor of shape ``(N, C, D*f1, H*f2, W*f3)``.
390
391    Examples
392    --------
393    >>> pxshuf = PixelShuffle3D((2, 3, 4))
394    >>> x = mx.nd.zeros((1, 48, 3, 5, 7))
395    >>> pxshuf(x).shape
396    (1, 2, 6, 15, 28)
397    """
398
399    def __init__(self, factor):
400        super(PixelShuffle3D, self).__init__()
401        try:
402            self._factors = (int(factor),) * 3
403        except TypeError:
404            self._factors = tuple(int(fac) for fac in factor)
405            assert len(self._factors) == 3, "wrong length {}".format(len(self._factors))
406
407    def hybrid_forward(self, F, x):
408        """Perform pixel-shuffling on the input."""
409        # `transpose` doesn't support 8D, need other implementation
410        f1, f2, f3 = self._factors
411                                                              # (N, C*f1*f2*f3, D, H, W)
412        x = F.reshape(x, (0, -4, -1, f1 * f2 * f3, 0, 0, 0))  # (N, C, f1*f2*f3, D, H, W)
413        x = F.swapaxes(x, 2, 3)                               # (N, C, D, f1*f2*f3, H, W)
414        x = F.reshape(x, (0, 0, 0, -4, f1, f2*f3, 0, 0))      # (N, C, D, f1, f2*f3, H, W)
415        x = F.reshape(x, (0, 0, -3, 0, 0, 0))                 # (N, C, D*f1, f2*f3, H, W)
416        x = F.swapaxes(x, 3, 4)                               # (N, C, D*f1, H, f2*f3, W)
417        x = F.reshape(x, (0, 0, 0, 0, -4, f2, f3, 0))         # (N, C, D*f1, H, f2, f3, W)
418        x = F.reshape(x, (0, 0, 0, -3, 0, 0))                 # (N, C, D*f1, H*f2, f3, W)
419        x = F.swapaxes(x, 4, 5)                               # (N, C, D*f1, H*f2, W, f3)
420        x = F.reshape(x, (0, 0, 0, 0, -3))                    # (N, C, D*f1, H*f2, W*f3)
421        return x
422
423    def __repr__(self):
424        return "{}({})".format(self.__class__.__name__, self._factors)
425