1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18from __future__ import print_function 19import mxnet as mx 20import copy 21from mxnet import gluon 22from mxnet.gluon import contrib 23from mxnet.gluon import nn 24from mxnet.gluon.contrib.nn import ( 25 Concurrent, HybridConcurrent, Identity, SparseEmbedding, PixelShuffle1D, 26 PixelShuffle2D, PixelShuffle3D) 27from mxnet.test_utils import almost_equal, default_context, assert_almost_equal, assert_allclose 28from common import setup_module, with_seed, teardown 29import numpy as np 30 31 32def check_rnn_cell(cell, prefix, in_shape=(10, 50), out_shape=(10, 100), begin_state=None): 33 inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] 34 outputs, _ = cell.unroll(3, inputs, begin_state=begin_state) 35 outputs = mx.sym.Group(outputs) 36 assert sorted(cell.collect_params().keys()) == [prefix+'h2h_bias', prefix+'h2h_weight', 37 prefix+'i2h_bias', prefix+'i2h_weight'] 38 assert outputs.list_outputs() == [prefix+'t0_out_output', prefix+'t1_out_output', prefix+'t2_out_output'] 39 40 args, outs, auxs = outputs.infer_shape(rnn_t0_data=in_shape, 41 rnn_t1_data=in_shape, 42 rnn_t2_data=in_shape) 43 assert outs == [out_shape]*3 44 45 46def check_rnn_forward(layer, inputs): 47 inputs.attach_grad() 48 layer.collect_params().initialize() 49 with mx.autograd.record(): 50 layer.unroll(3, inputs, merge_outputs=True)[0].backward() 51 mx.autograd.backward(layer.unroll(3, inputs, merge_outputs=False)[0]) 52 mx.nd.waitall() 53 54 55@with_seed() 56def test_rnn_cells(): 57 check_rnn_forward(contrib.rnn.Conv1DLSTMCell((5, 7), 10, (3,), (3,)), 58 mx.nd.ones((8, 3, 5, 7))) 59 check_rnn_forward(contrib.rnn.Conv1DRNNCell((5, 7), 10, (3,), (3,)), 60 mx.nd.ones((8, 3, 5, 7))) 61 check_rnn_forward(contrib.rnn.Conv1DGRUCell((5, 7), 10, (3,), (3,)), 62 mx.nd.ones((8, 3, 5, 7))) 63 64 net = mx.gluon.rnn.SequentialRNNCell() 65 net.add(contrib.rnn.Conv1DLSTMCell((5, 7), 10, (3,), (3,))) 66 net.add(contrib.rnn.Conv1DRNNCell((10, 5), 11, (3,), (3,))) 67 net.add(contrib.rnn.Conv1DGRUCell((11, 3), 12, (3,), (3,))) 68 check_rnn_forward(net, mx.nd.ones((8, 3, 5, 7))) 69 70 71@with_seed() 72def test_convrnn(): 73 cell = contrib.rnn.Conv1DRNNCell((10, 50), 100, 3, 3, prefix='rnn_') 74 check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 50), out_shape=(1, 100, 48)) 75 76 cell = contrib.rnn.Conv2DRNNCell((10, 20, 50), 100, 3, 3, prefix='rnn_') 77 check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 50), out_shape=(1, 100, 18, 48)) 78 79 cell = contrib.rnn.Conv3DRNNCell((10, 20, 30, 50), 100, 3, 3, prefix='rnn_') 80 check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48)) 81 82 83@with_seed() 84def test_convlstm(): 85 cell = contrib.rnn.Conv1DLSTMCell((10, 50), 100, 3, 3, prefix='rnn_') 86 check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 50), out_shape=(1, 100, 48)) 87 88 cell = contrib.rnn.Conv2DLSTMCell((10, 20, 50), 100, 3, 3, prefix='rnn_') 89 check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 50), out_shape=(1, 100, 18, 48)) 90 91 cell = contrib.rnn.Conv3DLSTMCell((10, 20, 30, 50), 100, 3, 3, prefix='rnn_') 92 check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48)) 93 94 95@with_seed() 96def test_convgru(): 97 cell = contrib.rnn.Conv1DGRUCell((10, 50), 100, 3, 3, prefix='rnn_') 98 check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 50), out_shape=(1, 100, 48)) 99 100 cell = contrib.rnn.Conv2DGRUCell((10, 20, 50), 100, 3, 3, prefix='rnn_') 101 check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 50), out_shape=(1, 100, 18, 48)) 102 103 cell = contrib.rnn.Conv3DGRUCell((10, 20, 30, 50), 100, 3, 3, prefix='rnn_') 104 check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48)) 105 106 107@with_seed() 108def test_conv_fill_shape(): 109 cell = contrib.rnn.Conv1DLSTMCell((0, 7), 10, (3,), (3,)) 110 cell.hybridize() 111 check_rnn_forward(cell, mx.nd.ones((8, 3, 5, 7))) 112 assert cell.i2h_weight.shape[1] == 5, cell.i2h_weight.shape[1] 113 114 115@with_seed() 116def test_lstmp(): 117 nhid = 100 118 nproj = 64 119 cell = contrib.rnn.LSTMPCell(nhid, nproj, prefix='rnn_') 120 inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] 121 outputs, _ = cell.unroll(3, inputs) 122 outputs = mx.sym.Group(outputs) 123 expected_params = ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_h2r_weight', 'rnn_i2h_bias', 'rnn_i2h_weight'] 124 expected_outputs = ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output'] 125 assert sorted(cell.collect_params().keys()) == expected_params 126 assert outputs.list_outputs() == expected_outputs, outputs.list_outputs() 127 128 args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50)) 129 assert outs == [(10, nproj), (10, nproj), (10, nproj)] 130 131 132@with_seed() 133def test_vardrop(): 134 def check_vardrop(drop_inputs, drop_states, drop_outputs): 135 cell = contrib.rnn.VariationalDropoutCell(mx.gluon.rnn.RNNCell(100, prefix='rnn_'), 136 drop_outputs=drop_outputs, 137 drop_states=drop_states, 138 drop_inputs=drop_inputs) 139 cell.collect_params().initialize(init='xavier') 140 input_data = mx.nd.random_uniform(shape=(10, 3, 50), ctx=mx.context.current_context()) 141 with mx.autograd.record(): 142 outputs1, _ = cell.unroll(3, input_data, merge_outputs=True) 143 mx.nd.waitall() 144 outputs2, _ = cell.unroll(3, input_data, merge_outputs=True) 145 assert not almost_equal(outputs1.asnumpy(), outputs2.asnumpy()) 146 147 inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] 148 outputs, _ = cell.unroll(3, inputs, merge_outputs=False) 149 outputs = mx.sym.Group(outputs) 150 151 args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50)) 152 assert outs == [(10, 100), (10, 100), (10, 100)] 153 154 cell.reset() 155 cell.hybridize() 156 with mx.autograd.record(): 157 outputs3, _ = cell.unroll(3, input_data, merge_outputs=True) 158 mx.nd.waitall() 159 outputs4, _ = cell.unroll(3, input_data, merge_outputs=True) 160 assert not almost_equal(outputs3.asnumpy(), outputs4.asnumpy()) 161 assert not almost_equal(outputs1.asnumpy(), outputs3.asnumpy()) 162 163 check_vardrop(0.5, 0.5, 0.5) 164 check_vardrop(0.5, 0, 0.5) 165 166 167def test_concurrent(): 168 model = HybridConcurrent(axis=1) 169 model.add(nn.Dense(128, activation='tanh', in_units=10)) 170 model.add(nn.Dense(64, activation='tanh', in_units=10)) 171 model.add(nn.Dense(32, in_units=10)) 172 model2 = Concurrent(axis=1) 173 model2.add(nn.Dense(128, activation='tanh', in_units=10)) 174 model2.add(nn.Dense(64, activation='tanh', in_units=10)) 175 model2.add(nn.Dense(32, in_units=10)) 176 177 # symbol 178 x = mx.sym.var('data') 179 y = model(x) 180 assert len(y.list_arguments()) == 7 181 182 # ndarray 183 model.initialize(mx.init.Xavier(magnitude=2.24)) 184 model2.initialize(mx.init.Xavier(magnitude=2.24)) 185 x = model(mx.nd.zeros((32, 10))) 186 x2 = model2(mx.nd.zeros((32, 10))) 187 assert x.shape == (32, 224) 188 assert x2.shape == (32, 224) 189 x.wait_to_read() 190 x2.wait_to_read() 191 192@with_seed() 193def test_identity(): 194 model = Identity() 195 x = mx.nd.random.uniform(shape=(128, 33, 64)) 196 assert_almost_equal(model(x), x) 197 198@with_seed() 199def test_sparse_embedding(): 200 layer = SparseEmbedding(10, 100) 201 layer.initialize() 202 trainer = mx.gluon.Trainer(layer.collect_params(), 'sgd') 203 x = mx.nd.array([3,4,2,0,1]) 204 with mx.autograd.record(): 205 y = layer(x) 206 y.backward() 207 assert (layer.weight.grad().asnumpy()[:5] == 1).all() 208 assert (layer.weight.grad().asnumpy()[5:] == 0).all() 209 210def test_pixelshuffle1d(): 211 nchan = 2 212 up_x = 2 213 nx = 3 214 shape_before = (1, nchan * up_x, nx) 215 shape_after = (1, nchan, nx * up_x) 216 layer = PixelShuffle1D(up_x) 217 x = mx.nd.arange(np.prod(shape_before)).reshape(shape_before) 218 y = layer(x) 219 assert y.shape == shape_after 220 assert_allclose( 221 y, 222 [[[0, 3, 1, 4, 2, 5], 223 [6, 9, 7, 10, 8, 11]]] 224 ) 225 226def test_pixelshuffle2d(): 227 nchan = 2 228 up_x = 2 229 up_y = 3 230 nx = 2 231 ny = 3 232 shape_before = (1, nchan * up_x * up_y, nx, ny) 233 shape_after = (1, nchan, nx * up_x, ny * up_y) 234 layer = PixelShuffle2D((up_x, up_y)) 235 x = mx.nd.arange(np.prod(shape_before)).reshape(shape_before) 236 y = layer(x) 237 assert y.shape == shape_after 238 # - Channels are reshaped to form 2x3 blocks 239 # - Within each block, the increment is `nx * ny` when increasing the column 240 # index by 1 241 # - Increasing the block index adds an offset of 1 242 # - Increasing the channel index adds an offset of `nx * up_x * ny * up_y` 243 assert_allclose( 244 y, 245 [[[[ 0, 6, 12, 1, 7, 13, 2, 8, 14], 246 [18, 24, 30, 19, 25, 31, 20, 26, 32], 247 [ 3, 9, 15, 4, 10, 16, 5, 11, 17], 248 [21, 27, 33, 22, 28, 34, 23, 29, 35]], 249 250 [[36, 42, 48, 37, 43, 49, 38, 44, 50], 251 [54, 60, 66, 55, 61, 67, 56, 62, 68], 252 [39, 45, 51, 40, 46, 52, 41, 47, 53], 253 [57, 63, 69, 58, 64, 70, 59, 65, 71]]]] 254 ) 255 256def test_pixelshuffle3d(): 257 nchan = 1 258 up_x = 2 259 up_y = 1 260 up_z = 2 261 nx = 2 262 ny = 3 263 nz = 4 264 shape_before = (1, nchan * up_x * up_y * up_z, nx, ny, nz) 265 shape_after = (1, nchan, nx * up_x, ny * up_y, nz * up_z) 266 layer = PixelShuffle3D((up_x, up_y, up_z)) 267 x = mx.nd.arange(np.prod(shape_before)).reshape(shape_before) 268 y = layer(x) 269 assert y.shape == shape_after 270 # - Channels are reshaped to form 2x1x2 blocks 271 # - Within each block, the increment is `nx * ny * nz` when increasing the 272 # column index by 1, e.g. the block [[[ 0, 24]], [[48, 72]]] 273 # - Increasing the block index adds an offset of 1 274 assert_allclose( 275 y, 276 [[[[[ 0, 24, 1, 25, 2, 26, 3, 27], 277 [ 4, 28, 5, 29, 6, 30, 7, 31], 278 [ 8, 32, 9, 33, 10, 34, 11, 35]], 279 280 [[48, 72, 49, 73, 50, 74, 51, 75], 281 [52, 76, 53, 77, 54, 78, 55, 79], 282 [56, 80, 57, 81, 58, 82, 59, 83]], 283 284 [[12, 36, 13, 37, 14, 38, 15, 39], 285 [16, 40, 17, 41, 18, 42, 19, 43], 286 [20, 44, 21, 45, 22, 46, 23, 47]], 287 288 [[60, 84, 61, 85, 62, 86, 63, 87], 289 [64, 88, 65, 89, 66, 90, 67, 91], 290 [68, 92, 69, 93, 70, 94, 71, 95]]]]] 291 ) 292 293def test_datasets(): 294 wikitext2_train = contrib.data.text.WikiText2(root='data/wikitext-2', segment='train') 295 wikitext2_val = contrib.data.text.WikiText2(root='data/wikitext-2', segment='validation', 296 vocab=wikitext2_train.vocabulary) 297 wikitext2_test = contrib.data.text.WikiText2(root='data/wikitext-2', segment='test') 298 assert len(wikitext2_train) == 59305, len(wikitext2_train) 299 assert len(wikitext2_train.vocabulary) == 33278, len(wikitext2_train.vocabulary) 300 assert len(wikitext2_train.frequencies) == 33277, len(wikitext2_train.frequencies) 301 assert len(wikitext2_val) == 6181, len(wikitext2_val) 302 assert len(wikitext2_val.vocabulary) == 33278, len(wikitext2_val.vocabulary) 303 assert len(wikitext2_val.frequencies) == 13776, len(wikitext2_val.frequencies) 304 assert len(wikitext2_test) == 6974, len(wikitext2_test) 305 assert len(wikitext2_test.vocabulary) == 14143, len(wikitext2_test.vocabulary) 306 assert len(wikitext2_test.frequencies) == 14142, len(wikitext2_test.frequencies) 307 assert wikitext2_test.frequencies['English'] == 32 308 309 310def test_sampler(): 311 interval_sampler = contrib.data.IntervalSampler(10, 3) 312 assert sorted(list(interval_sampler)) == list(range(10)) 313 interval_sampler = contrib.data.IntervalSampler(10, 3, rollover=False) 314 assert list(interval_sampler) == [0, 3, 6, 9] 315 316 317class TestRNNLayer(gluon.HybridBlock): 318 def __init__(self, cell_type, hidden_size, layout, prefix=None, params=None): 319 super(TestRNNLayer, self).__init__(prefix=prefix, params=params) 320 self.cell = cell_type(hidden_size, prefix='rnn_') 321 self.layout = layout 322 323 def hybrid_forward(self, F, inputs, states, valid_length): 324 if isinstance(valid_length, list) and len(valid_length) == 0: 325 valid_length = None 326 return contrib.rnn.rnn_cell.dynamic_unroll(self.cell, inputs, states, 327 valid_length=valid_length, 328 layout=self.layout) 329 330def check_unroll(cell_type, num_states, layout): 331 batch_size = 20 332 input_size = 50 333 hidden_size = 30 334 seq_len = 10 335 if layout == 'TNC': 336 rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, input_size)) 337 elif layout == 'NTC': 338 rnn_data = mx.nd.normal(loc=0, scale=1, shape=(batch_size, seq_len, input_size)) 339 else: 340 print("Wrong layout") 341 return 342 valid_length = mx.nd.round(mx.nd.random.uniform(low=1, high=10, shape=(batch_size))) 343 state_shape = (batch_size, hidden_size) 344 states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(num_states)] 345 346 cell = cell_type(hidden_size, prefix='rnn_') 347 cell.initialize(ctx=default_context()) 348 if layout == 'TNC': 349 cell(rnn_data[0], states) 350 else: 351 cell(rnn_data[:,0,:], states) 352 params1 = cell.collect_params() 353 orig_params1 = copy.deepcopy(params1) 354 355 trainer = gluon.Trainer(params1, 'sgd', {'learning_rate' : 0.03}) 356 with mx.autograd.record(): 357 res1, states1 = cell.unroll(seq_len, rnn_data, states, valid_length=valid_length, 358 layout=layout, merge_outputs=True) 359 res1.backward() 360 trainer.step(batch_size) 361 362 configs = [ 363 lambda layer: None, 364 lambda layer: layer.hybridize(), 365 lambda layer: layer.hybridize({'inline_limit': 0}), 366 lambda layer: layer.hybridize({'static_alloc': True}), 367 lambda layer: layer.hybridize({'static_alloc': True, 'static_shape': True}) ] 368 # We can't pass None to a hybrid block, but it accepts an empty list. 369 # so we use an empty list to represent valid_length if it's None. 370 if valid_length is None: 371 valid_length = [] 372 for config in configs: 373 layer = TestRNNLayer(cell_type, hidden_size, layout) 374 layer.initialize(ctx=default_context()) 375 config(layer) 376 res2, states2 = layer(rnn_data, states, valid_length) 377 params2 = layer.collect_params() 378 for key, val in orig_params1.items(): 379 params2[key].set_data(copy.deepcopy(val.data())) 380 381 trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03}) 382 with mx.autograd.record(): 383 res2, states2 = layer(rnn_data, states, valid_length) 384 assert_almost_equal(res1, res2, rtol=0.001, atol=0.0001) 385 assert len(states1) == len(states2) 386 for i in range(len(states1)): 387 assert_almost_equal(states1[i], states2[i], rtol=0.001, atol=0.0001) 388 res2.backward() 389 trainer.step(batch_size) 390 391 for key, val in params1.items(): 392 weight1 = val.data() 393 weight2 = params2[key].data() 394 assert_almost_equal(weight1, weight2, rtol=0.001, atol=0.0001) 395 396 397@with_seed() 398def test_contrib_unroll(): 399 cell_types = [(gluon.rnn.RNNCell, 1), (gluon.rnn.LSTMCell, 2), 400 (gluon.rnn.GRUCell, 1)] 401 for cell_type, num_states in cell_types: 402 check_unroll(cell_type, num_states, 'TNC') 403 check_unroll(cell_type, num_states, 'NTC') 404 405@with_seed() 406def test_ModulatedDeformableConvolution(): 407 """test of the deformable convolution layer with possible combinations of arguments, 408 currently this layer only supports gpu 409 """ 410 from mxnet.gluon.contrib.cnn import DeformableConvolution 411 net = nn.HybridSequential() 412 net.add( 413 DeformableConvolution(10, kernel_size=(3, 3), strides=1, padding=0), 414 DeformableConvolution(10, kernel_size=(1, 1), strides=1, padding=0), 415 DeformableConvolution(10, kernel_size=(5, 5), strides=1, padding=0), 416 DeformableConvolution(10, kernel_size=(3, 5), strides=1, padding=0), 417 DeformableConvolution(10, kernel_size=(5, 1), strides=1, padding=0, num_deformable_group=2), 418 DeformableConvolution(10, kernel_size=(3, 2), strides=1, padding=0, activation='relu', 419 offset_use_bias=False, use_bias=False), 420 DeformableConvolution(10, kernel_size=(3, 2), strides=1, padding=0, activation='relu', 421 offset_use_bias=False), 422 DeformableConvolution(10, kernel_size=(3, 2), strides=1, padding=0, activation='relu', 423 use_bias=False), 424 DeformableConvolution(10, kernel_size=(3, 2), strides=1, padding=0, offset_use_bias=False, use_bias=False), 425 DeformableConvolution(10, kernel_size=(3, 2), strides=1, padding=0, offset_use_bias=False), 426 DeformableConvolution(12, kernel_size=(3, 2), strides=1, padding=0, use_bias=False), 427 DeformableConvolution(12, kernel_size=(3, 2), strides=1, padding=0, use_bias=False, num_deformable_group=4), 428 ) 429 430 ctx = mx.cpu() 431 432 net.initialize(force_reinit=True, ctx=ctx) 433 net.hybridize() 434 435 x = mx.nd.random.uniform(shape=(8, 5, 30, 31), ctx=ctx) 436 with mx.autograd.record(): 437 y = net(x) 438 439 440if __name__ == '__main__': 441 import nose 442 nose.runmodule() 443