1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# coding: utf-8 19# pylint: disable= arguments-differ 20"""Custom neural network layers in model_zoo.""" 21 22__all__ = ['Concurrent', 'HybridConcurrent', 'Identity', 'SparseEmbedding', 23 'SyncBatchNorm', 'PixelShuffle1D', 'PixelShuffle2D', 24 'PixelShuffle3D'] 25 26import warnings 27from .... import nd, context 28from ...block import HybridBlock, Block 29from ...nn import Sequential, HybridSequential, BatchNorm 30 31class Concurrent(Sequential): 32 """Lays `Block` s concurrently. 33 34 This block feeds its input to all children blocks, and 35 produce the output by concatenating all the children blocks' outputs 36 on the specified axis. 37 38 Example:: 39 40 net = Concurrent() 41 # use net's name_scope to give children blocks appropriate names. 42 with net.name_scope(): 43 net.add(nn.Dense(10, activation='relu')) 44 net.add(nn.Dense(20)) 45 net.add(Identity()) 46 47 Parameters 48 ---------- 49 axis : int, default -1 50 The axis on which to concatenate the outputs. 51 """ 52 def __init__(self, axis=-1, prefix=None, params=None): 53 super(Concurrent, self).__init__(prefix=prefix, params=params) 54 self.axis = axis 55 56 def forward(self, x): 57 out = [] 58 for block in self._children.values(): 59 out.append(block(x)) 60 out = nd.concat(*out, dim=self.axis) 61 return out 62 63 64class HybridConcurrent(HybridSequential): 65 """Lays `HybridBlock` s concurrently. 66 67 This block feeds its input to all children blocks, and 68 produce the output by concatenating all the children blocks' outputs 69 on the specified axis. 70 71 Example:: 72 73 net = HybridConcurrent() 74 # use net's name_scope to give children blocks appropriate names. 75 with net.name_scope(): 76 net.add(nn.Dense(10, activation='relu')) 77 net.add(nn.Dense(20)) 78 net.add(Identity()) 79 80 Parameters 81 ---------- 82 axis : int, default -1 83 The axis on which to concatenate the outputs. 84 """ 85 def __init__(self, axis=-1, prefix=None, params=None): 86 super(HybridConcurrent, self).__init__(prefix=prefix, params=params) 87 self.axis = axis 88 89 def hybrid_forward(self, F, x): 90 out = [] 91 for block in self._children.values(): 92 out.append(block(x)) 93 out = F.concat(*out, dim=self.axis) 94 return out 95 96 97class Identity(HybridBlock): 98 """Block that passes through the input directly. 99 100 This block can be used in conjunction with HybridConcurrent 101 block for residual connection. 102 103 Example:: 104 105 net = HybridConcurrent() 106 # use net's name_scope to give child Blocks appropriate names. 107 with net.name_scope(): 108 net.add(nn.Dense(10, activation='relu')) 109 net.add(nn.Dense(20)) 110 net.add(Identity()) 111 """ 112 def __init__(self, prefix=None, params=None): 113 super(Identity, self).__init__(prefix=prefix, params=params) 114 115 def hybrid_forward(self, F, x): 116 return x 117 118class SparseEmbedding(Block): 119 r"""Turns non-negative integers (indexes/tokens) into dense vectors 120 of fixed size. eg. [4, 20] -> [[0.25, 0.1], [0.6, -0.2]] 121 122 This SparseBlock is designed for distributed training with extremely large 123 input dimension. Both weight and gradient w.r.t. weight are `RowSparseNDArray`. 124 125 Note: if `sparse_grad` is set to True, the gradient w.r.t weight will be 126 sparse. Only a subset of optimizers support sparse gradients, including SGD, AdaGrad 127 and Adam. By default lazy updates is turned on, which may perform differently 128 from standard updates. For more details, please check the Optimization API at: 129 https://mxnet.incubator.apache.org/api/python/optimization/optimization.html 130 131 Parameters 132 ---------- 133 input_dim : int 134 Size of the vocabulary, i.e. maximum integer index + 1. 135 output_dim : int 136 Dimension of the dense embedding. 137 dtype : str or np.dtype, default 'float32' 138 Data type of output embeddings. 139 weight_initializer : Initializer 140 Initializer for the `embeddings` matrix. 141 142 Inputs: 143 - **data**: (N-1)-D tensor with shape: `(x1, x2, ..., xN-1)`. 144 Output: 145 - **out**: N-D tensor with shape: `(x1, x2, ..., xN-1, output_dim)`. 146 """ 147 def __init__(self, input_dim, output_dim, dtype='float32', 148 weight_initializer=None, **kwargs): 149 super(SparseEmbedding, self).__init__(**kwargs) 150 self._kwargs = {'input_dim': input_dim, 'output_dim': output_dim, 151 'dtype': dtype, 'sparse_grad': True} 152 self.weight = self.params.get('weight', shape=(input_dim, output_dim), 153 init=weight_initializer, dtype=dtype, 154 grad_stype='row_sparse', stype='row_sparse') 155 156 def forward(self, x): 157 weight = self.weight.row_sparse_data(x) 158 return nd.Embedding(x, weight, name='fwd', **self._kwargs) 159 160 def __repr__(self): 161 s = '{block_name}({input_dim} -> {output_dim}, {dtype})' 162 return s.format(block_name=self.__class__.__name__, 163 **self._kwargs) 164 165class SyncBatchNorm(BatchNorm): 166 """Cross-GPU Synchronized Batch normalization (SyncBN) 167 168 Standard BN [1]_ implementation only normalize the data within each device. 169 SyncBN normalizes the input within the whole mini-batch. 170 We follow the implementation described in the paper [2]_. 171 172 Note: Current implementation of SyncBN does not support FP16 training. 173 For FP16 inference, use standard nn.BatchNorm instead of SyncBN. 174 175 Parameters 176 ---------- 177 in_channels : int, default 0 178 Number of channels (feature maps) in input data. If not specified, 179 initialization will be deferred to the first time `forward` is called 180 and `in_channels` will be inferred from the shape of input data. 181 num_devices : int, default number of visible GPUs 182 momentum: float, default 0.9 183 Momentum for the moving average. 184 epsilon: float, default 1e-5 185 Small float added to variance to avoid dividing by zero. 186 center: bool, default True 187 If True, add offset of `beta` to normalized tensor. 188 If False, `beta` is ignored. 189 scale: bool, default True 190 If True, multiply by `gamma`. If False, `gamma` is not used. 191 When the next layer is linear (also e.g. `nn.relu`), 192 this can be disabled since the scaling 193 will be done by the next layer. 194 use_global_stats: bool, default False 195 If True, use global moving statistics instead of local batch-norm. This will force 196 change batch-norm into a scale shift operator. 197 If False, use local batch-norm. 198 beta_initializer: str or `Initializer`, default 'zeros' 199 Initializer for the beta weight. 200 gamma_initializer: str or `Initializer`, default 'ones' 201 Initializer for the gamma weight. 202 running_mean_initializer: str or `Initializer`, default 'zeros' 203 Initializer for the running mean. 204 running_variance_initializer: str or `Initializer`, default 'ones' 205 Initializer for the running variance. 206 207 208 Inputs: 209 - **data**: input tensor with arbitrary shape. 210 Outputs: 211 - **out**: output tensor with the same shape as `data`. 212 213 Reference: 214 .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating \ 215 deep network training by reducing internal covariate shift." *ICML 2015* 216 .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, \ 217 Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018* 218 """ 219 def __init__(self, in_channels=0, num_devices=None, momentum=0.9, epsilon=1e-5, 220 center=True, scale=True, use_global_stats=False, beta_initializer='zeros', 221 gamma_initializer='ones', running_mean_initializer='zeros', 222 running_variance_initializer='ones', **kwargs): 223 super(SyncBatchNorm, self).__init__( 224 axis=1, momentum=momentum, epsilon=epsilon, 225 center=center, scale=scale, 226 use_global_stats=use_global_stats, 227 beta_initializer=beta_initializer, 228 gamma_initializer=gamma_initializer, 229 running_mean_initializer=running_mean_initializer, 230 running_variance_initializer=running_variance_initializer, 231 in_channels=in_channels, **kwargs) 232 num_devices = self._get_num_devices() if num_devices is None else num_devices 233 self._kwargs = {'eps': epsilon, 'momentum': momentum, 234 'fix_gamma': not scale, 'use_global_stats': use_global_stats, 235 'ndev': num_devices, 'key': self.prefix} 236 237 def _get_num_devices(self): 238 warnings.warn("Caution using SyncBatchNorm: " 239 "if not using all the GPUs, please mannually set num_devices", 240 UserWarning) 241 num_devices = context.num_gpus() 242 num_devices = num_devices if num_devices > 0 else 1 243 return num_devices 244 245 def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var): 246 return F.contrib.SyncBatchNorm(x, gamma, beta, running_mean, running_var, 247 name='fwd', **self._kwargs) 248 249class PixelShuffle1D(HybridBlock): 250 251 r"""Pixel-shuffle layer for upsampling in 1 dimension. 252 253 Pixel-shuffling is the operation of taking groups of values along 254 the *channel* dimension and regrouping them into blocks of pixels 255 along the ``W`` dimension, thereby effectively multiplying that dimension 256 by a constant factor in size. 257 258 For example, a feature map of shape :math:`(fC, W)` is reshaped 259 into :math:`(C, fW)` by forming little value groups of size :math:`f` 260 and arranging them in a grid of size :math:`W`. 261 262 Parameters 263 ---------- 264 factor : int or 1-tuple of int 265 Upsampling factor, applied to the ``W`` dimension. 266 267 Inputs: 268 - **data**: Tensor of shape ``(N, f*C, W)``. 269 Outputs: 270 - **out**: Tensor of shape ``(N, C, W*f)``. 271 272 Examples 273 -------- 274 >>> pxshuf = PixelShuffle1D(2) 275 >>> x = mx.nd.zeros((1, 8, 3)) 276 >>> pxshuf(x).shape 277 (1, 4, 6) 278 """ 279 280 def __init__(self, factor): 281 super(PixelShuffle1D, self).__init__() 282 self._factor = int(factor) 283 284 def hybrid_forward(self, F, x): 285 """Perform pixel-shuffling on the input.""" 286 f = self._factor 287 # (N, C*f, W) 288 x = F.reshape(x, (0, -4, -1, f, 0)) # (N, C, f, W) 289 x = F.transpose(x, (0, 1, 3, 2)) # (N, C, W, f) 290 x = F.reshape(x, (0, 0, -3)) # (N, C, W*f) 291 return x 292 293 def __repr__(self): 294 return "{}({})".format(self.__class__.__name__, self._factor) 295 296 297class PixelShuffle2D(HybridBlock): 298 299 r"""Pixel-shuffle layer for upsampling in 2 dimensions. 300 301 Pixel-shuffling is the operation of taking groups of values along 302 the *channel* dimension and regrouping them into blocks of pixels 303 along the ``H`` and ``W`` dimensions, thereby effectively multiplying 304 those dimensions by a constant factor in size. 305 306 For example, a feature map of shape :math:`(f^2 C, H, W)` is reshaped 307 into :math:`(C, fH, fW)` by forming little :math:`f \times f` blocks 308 of pixels and arranging them in an :math:`H \times W` grid. 309 310 Pixel-shuffling together with regular convolution is an alternative, 311 learnable way of upsampling an image by arbitrary factors. It is reported 312 to help overcome checkerboard artifacts that are common in upsampling with 313 transposed convolutions (also called deconvolutions). See the paper 314 `Real-Time Single Image and Video Super-Resolution Using an Efficient 315 Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158>`_ 316 for further details. 317 318 Parameters 319 ---------- 320 factor : int or 2-tuple of int 321 Upsampling factors, applied to the ``H`` and ``W`` dimensions, 322 in that order. 323 324 Inputs: 325 - **data**: Tensor of shape ``(N, f1*f2*C, H, W)``. 326 Outputs: 327 - **out**: Tensor of shape ``(N, C, H*f1, W*f2)``. 328 329 Examples 330 -------- 331 >>> pxshuf = PixelShuffle2D((2, 3)) 332 >>> x = mx.nd.zeros((1, 12, 3, 5)) 333 >>> pxshuf(x).shape 334 (1, 2, 6, 15) 335 """ 336 337 def __init__(self, factor): 338 super(PixelShuffle2D, self).__init__() 339 try: 340 self._factors = (int(factor),) * 2 341 except TypeError: 342 self._factors = tuple(int(fac) for fac in factor) 343 assert len(self._factors) == 2, "wrong length {}".format(len(self._factors)) 344 345 def hybrid_forward(self, F, x): 346 """Perform pixel-shuffling on the input.""" 347 f1, f2 = self._factors 348 # (N, f1*f2*C, H, W) 349 x = F.reshape(x, (0, -4, -1, f1 * f2, 0, 0)) # (N, C, f1*f2, H, W) 350 x = F.reshape(x, (0, 0, -4, f1, f2, 0, 0)) # (N, C, f1, f2, H, W) 351 x = F.transpose(x, (0, 1, 4, 2, 5, 3)) # (N, C, H, f1, W, f2) 352 x = F.reshape(x, (0, 0, -3, -3)) # (N, C, H*f1, W*f2) 353 return x 354 355 def __repr__(self): 356 return "{}({})".format(self.__class__.__name__, self._factors) 357 358 359class PixelShuffle3D(HybridBlock): 360 361 r"""Pixel-shuffle layer for upsampling in 3 dimensions. 362 363 Pixel-shuffling (or voxel-shuffling in 3D) is the operation of taking 364 groups of values along the *channel* dimension and regrouping them into 365 blocks of voxels along the ``D``, ``H`` and ``W`` dimensions, thereby 366 effectively multiplying those dimensions by a constant factor in size. 367 368 For example, a feature map of shape :math:`(f^3 C, D, H, W)` is reshaped 369 into :math:`(C, fD, fH, fW)` by forming little :math:`f \times f \times f` 370 blocks of voxels and arranging them in a :math:`D \times H \times W` grid. 371 372 Pixel-shuffling together with regular convolution is an alternative, 373 learnable way of upsampling an image by arbitrary factors. It is reported 374 to help overcome checkerboard artifacts that are common in upsampling with 375 transposed convolutions (also called deconvolutions). See the paper 376 `Real-Time Single Image and Video Super-Resolution Using an Efficient 377 Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158>`_ 378 for further details. 379 380 Parameters 381 ---------- 382 factor : int or 3-tuple of int 383 Upsampling factors, applied to the ``D``, ``H`` and ``W`` 384 dimensions, in that order. 385 386 Inputs: 387 - **data**: Tensor of shape ``(N, f1*f2*f3*C, D, H, W)``. 388 Outputs: 389 - **out**: Tensor of shape ``(N, C, D*f1, H*f2, W*f3)``. 390 391 Examples 392 -------- 393 >>> pxshuf = PixelShuffle3D((2, 3, 4)) 394 >>> x = mx.nd.zeros((1, 48, 3, 5, 7)) 395 >>> pxshuf(x).shape 396 (1, 2, 6, 15, 28) 397 """ 398 399 def __init__(self, factor): 400 super(PixelShuffle3D, self).__init__() 401 try: 402 self._factors = (int(factor),) * 3 403 except TypeError: 404 self._factors = tuple(int(fac) for fac in factor) 405 assert len(self._factors) == 3, "wrong length {}".format(len(self._factors)) 406 407 def hybrid_forward(self, F, x): 408 """Perform pixel-shuffling on the input.""" 409 # `transpose` doesn't support 8D, need other implementation 410 f1, f2, f3 = self._factors 411 # (N, C*f1*f2*f3, D, H, W) 412 x = F.reshape(x, (0, -4, -1, f1 * f2 * f3, 0, 0, 0)) # (N, C, f1*f2*f3, D, H, W) 413 x = F.swapaxes(x, 2, 3) # (N, C, D, f1*f2*f3, H, W) 414 x = F.reshape(x, (0, 0, 0, -4, f1, f2*f3, 0, 0)) # (N, C, D, f1, f2*f3, H, W) 415 x = F.reshape(x, (0, 0, -3, 0, 0, 0)) # (N, C, D*f1, f2*f3, H, W) 416 x = F.swapaxes(x, 3, 4) # (N, C, D*f1, H, f2*f3, W) 417 x = F.reshape(x, (0, 0, 0, 0, -4, f2, f3, 0)) # (N, C, D*f1, H, f2, f3, W) 418 x = F.reshape(x, (0, 0, 0, -3, 0, 0)) # (N, C, D*f1, H*f2, f3, W) 419 x = F.swapaxes(x, 4, 5) # (N, C, D*f1, H*f2, W, f3) 420 x = F.reshape(x, (0, 0, 0, 0, -3)) # (N, C, D*f1, H*f2, W*f3) 421 return x 422 423 def __repr__(self): 424 return "{}({})".format(self.__class__.__name__, self._factors) 425