1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import print_function
19import mxnet as mx
20import copy
21from mxnet import gluon
22from mxnet.gluon import contrib
23from mxnet.gluon import nn
24from mxnet.gluon.contrib.nn import (
25    Concurrent, HybridConcurrent, Identity, SparseEmbedding, PixelShuffle1D,
26    PixelShuffle2D, PixelShuffle3D)
27from mxnet.test_utils import almost_equal, default_context, assert_almost_equal, assert_allclose
28from common import setup_module, with_seed, teardown
29import numpy as np
30
31
32def check_rnn_cell(cell, prefix, in_shape=(10, 50), out_shape=(10, 100), begin_state=None):
33    inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
34    outputs, _ = cell.unroll(3, inputs, begin_state=begin_state)
35    outputs = mx.sym.Group(outputs)
36    assert sorted(cell.collect_params().keys()) == [prefix+'h2h_bias', prefix+'h2h_weight',
37                                                    prefix+'i2h_bias', prefix+'i2h_weight']
38    assert outputs.list_outputs() == [prefix+'t0_out_output', prefix+'t1_out_output', prefix+'t2_out_output']
39
40    args, outs, auxs = outputs.infer_shape(rnn_t0_data=in_shape,
41                                           rnn_t1_data=in_shape,
42                                           rnn_t2_data=in_shape)
43    assert outs == [out_shape]*3
44
45
46def check_rnn_forward(layer, inputs):
47    inputs.attach_grad()
48    layer.collect_params().initialize()
49    with mx.autograd.record():
50        layer.unroll(3, inputs, merge_outputs=True)[0].backward()
51        mx.autograd.backward(layer.unroll(3, inputs, merge_outputs=False)[0])
52    mx.nd.waitall()
53
54
55@with_seed()
56def test_rnn_cells():
57    check_rnn_forward(contrib.rnn.Conv1DLSTMCell((5, 7), 10, (3,), (3,)),
58                      mx.nd.ones((8, 3, 5, 7)))
59    check_rnn_forward(contrib.rnn.Conv1DRNNCell((5, 7), 10, (3,), (3,)),
60                      mx.nd.ones((8, 3, 5, 7)))
61    check_rnn_forward(contrib.rnn.Conv1DGRUCell((5, 7), 10, (3,), (3,)),
62                      mx.nd.ones((8, 3, 5, 7)))
63
64    net = mx.gluon.rnn.SequentialRNNCell()
65    net.add(contrib.rnn.Conv1DLSTMCell((5, 7), 10, (3,), (3,)))
66    net.add(contrib.rnn.Conv1DRNNCell((10, 5), 11, (3,), (3,)))
67    net.add(contrib.rnn.Conv1DGRUCell((11, 3), 12, (3,), (3,)))
68    check_rnn_forward(net, mx.nd.ones((8, 3, 5, 7)))
69
70
71@with_seed()
72def test_convrnn():
73    cell = contrib.rnn.Conv1DRNNCell((10, 50), 100, 3, 3, prefix='rnn_')
74    check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 50), out_shape=(1, 100, 48))
75
76    cell = contrib.rnn.Conv2DRNNCell((10, 20, 50), 100, 3, 3, prefix='rnn_')
77    check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 50), out_shape=(1, 100, 18, 48))
78
79    cell = contrib.rnn.Conv3DRNNCell((10, 20, 30, 50), 100, 3, 3, prefix='rnn_')
80    check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48))
81
82
83@with_seed()
84def test_convlstm():
85    cell = contrib.rnn.Conv1DLSTMCell((10, 50), 100, 3, 3, prefix='rnn_')
86    check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 50), out_shape=(1, 100, 48))
87
88    cell = contrib.rnn.Conv2DLSTMCell((10, 20, 50), 100, 3, 3, prefix='rnn_')
89    check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 50), out_shape=(1, 100, 18, 48))
90
91    cell = contrib.rnn.Conv3DLSTMCell((10, 20, 30, 50), 100, 3, 3, prefix='rnn_')
92    check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48))
93
94
95@with_seed()
96def test_convgru():
97    cell = contrib.rnn.Conv1DGRUCell((10, 50), 100, 3, 3, prefix='rnn_')
98    check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 50), out_shape=(1, 100, 48))
99
100    cell = contrib.rnn.Conv2DGRUCell((10, 20, 50), 100, 3, 3, prefix='rnn_')
101    check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 50), out_shape=(1, 100, 18, 48))
102
103    cell = contrib.rnn.Conv3DGRUCell((10, 20, 30, 50), 100, 3, 3, prefix='rnn_')
104    check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48))
105
106
107@with_seed()
108def test_conv_fill_shape():
109    cell = contrib.rnn.Conv1DLSTMCell((0, 7), 10, (3,), (3,))
110    cell.hybridize()
111    check_rnn_forward(cell, mx.nd.ones((8, 3, 5, 7)))
112    assert cell.i2h_weight.shape[1] == 5, cell.i2h_weight.shape[1]
113
114
115@with_seed()
116def test_lstmp():
117    nhid = 100
118    nproj = 64
119    cell = contrib.rnn.LSTMPCell(nhid, nproj, prefix='rnn_')
120    inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
121    outputs, _ = cell.unroll(3, inputs)
122    outputs = mx.sym.Group(outputs)
123    expected_params = ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_h2r_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
124    expected_outputs = ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']
125    assert sorted(cell.collect_params().keys()) == expected_params
126    assert outputs.list_outputs() == expected_outputs, outputs.list_outputs()
127
128    args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50))
129    assert outs == [(10, nproj), (10, nproj), (10, nproj)]
130
131
132@with_seed()
133def test_vardrop():
134    def check_vardrop(drop_inputs, drop_states, drop_outputs):
135        cell = contrib.rnn.VariationalDropoutCell(mx.gluon.rnn.RNNCell(100, prefix='rnn_'),
136                                                  drop_outputs=drop_outputs,
137                                                  drop_states=drop_states,
138                                                  drop_inputs=drop_inputs)
139        cell.collect_params().initialize(init='xavier')
140        input_data = mx.nd.random_uniform(shape=(10, 3, 50), ctx=mx.context.current_context())
141        with mx.autograd.record():
142            outputs1, _ = cell.unroll(3, input_data, merge_outputs=True)
143            mx.nd.waitall()
144            outputs2, _ = cell.unroll(3, input_data, merge_outputs=True)
145        assert not almost_equal(outputs1.asnumpy(), outputs2.asnumpy())
146
147        inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
148        outputs, _ = cell.unroll(3, inputs, merge_outputs=False)
149        outputs = mx.sym.Group(outputs)
150
151        args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50))
152        assert outs == [(10, 100), (10, 100), (10, 100)]
153
154        cell.reset()
155        cell.hybridize()
156        with mx.autograd.record():
157            outputs3, _ = cell.unroll(3, input_data, merge_outputs=True)
158            mx.nd.waitall()
159            outputs4, _ = cell.unroll(3, input_data, merge_outputs=True)
160        assert not almost_equal(outputs3.asnumpy(), outputs4.asnumpy())
161        assert not almost_equal(outputs1.asnumpy(), outputs3.asnumpy())
162
163    check_vardrop(0.5, 0.5, 0.5)
164    check_vardrop(0.5, 0, 0.5)
165
166
167def test_concurrent():
168    model = HybridConcurrent(axis=1)
169    model.add(nn.Dense(128, activation='tanh', in_units=10))
170    model.add(nn.Dense(64, activation='tanh', in_units=10))
171    model.add(nn.Dense(32, in_units=10))
172    model2 = Concurrent(axis=1)
173    model2.add(nn.Dense(128, activation='tanh', in_units=10))
174    model2.add(nn.Dense(64, activation='tanh', in_units=10))
175    model2.add(nn.Dense(32, in_units=10))
176
177    # symbol
178    x = mx.sym.var('data')
179    y = model(x)
180    assert len(y.list_arguments()) == 7
181
182    # ndarray
183    model.initialize(mx.init.Xavier(magnitude=2.24))
184    model2.initialize(mx.init.Xavier(magnitude=2.24))
185    x = model(mx.nd.zeros((32, 10)))
186    x2 = model2(mx.nd.zeros((32, 10)))
187    assert x.shape == (32, 224)
188    assert x2.shape == (32, 224)
189    x.wait_to_read()
190    x2.wait_to_read()
191
192@with_seed()
193def test_identity():
194    model = Identity()
195    x = mx.nd.random.uniform(shape=(128, 33, 64))
196    assert_almost_equal(model(x), x)
197
198@with_seed()
199def test_sparse_embedding():
200    layer = SparseEmbedding(10, 100)
201    layer.initialize()
202    trainer = mx.gluon.Trainer(layer.collect_params(), 'sgd')
203    x = mx.nd.array([3,4,2,0,1])
204    with mx.autograd.record():
205        y = layer(x)
206        y.backward()
207    assert (layer.weight.grad().asnumpy()[:5] == 1).all()
208    assert (layer.weight.grad().asnumpy()[5:] == 0).all()
209
210def test_pixelshuffle1d():
211    nchan = 2
212    up_x = 2
213    nx = 3
214    shape_before = (1, nchan * up_x, nx)
215    shape_after = (1, nchan, nx * up_x)
216    layer = PixelShuffle1D(up_x)
217    x = mx.nd.arange(np.prod(shape_before)).reshape(shape_before)
218    y = layer(x)
219    assert y.shape == shape_after
220    assert_allclose(
221        y,
222        [[[0, 3, 1, 4, 2, 5],
223          [6, 9, 7, 10, 8, 11]]]
224    )
225
226def test_pixelshuffle2d():
227    nchan = 2
228    up_x = 2
229    up_y = 3
230    nx = 2
231    ny = 3
232    shape_before = (1, nchan * up_x * up_y, nx, ny)
233    shape_after = (1, nchan, nx * up_x, ny * up_y)
234    layer = PixelShuffle2D((up_x, up_y))
235    x = mx.nd.arange(np.prod(shape_before)).reshape(shape_before)
236    y = layer(x)
237    assert y.shape == shape_after
238    # - Channels are reshaped to form 2x3 blocks
239    # - Within each block, the increment is `nx * ny` when increasing the column
240    #   index by 1
241    # - Increasing the block index adds an offset of 1
242    # - Increasing the channel index adds an offset of `nx * up_x * ny * up_y`
243    assert_allclose(
244        y,
245        [[[[ 0,  6, 12,  1,  7, 13,  2,  8, 14],
246           [18, 24, 30, 19, 25, 31, 20, 26, 32],
247           [ 3,  9, 15,  4, 10, 16,  5, 11, 17],
248           [21, 27, 33, 22, 28, 34, 23, 29, 35]],
249
250          [[36, 42, 48, 37, 43, 49, 38, 44, 50],
251           [54, 60, 66, 55, 61, 67, 56, 62, 68],
252           [39, 45, 51, 40, 46, 52, 41, 47, 53],
253           [57, 63, 69, 58, 64, 70, 59, 65, 71]]]]
254    )
255
256def test_pixelshuffle3d():
257    nchan = 1
258    up_x = 2
259    up_y = 1
260    up_z = 2
261    nx = 2
262    ny = 3
263    nz = 4
264    shape_before = (1, nchan * up_x * up_y * up_z, nx, ny, nz)
265    shape_after = (1, nchan, nx * up_x, ny * up_y, nz * up_z)
266    layer = PixelShuffle3D((up_x, up_y, up_z))
267    x = mx.nd.arange(np.prod(shape_before)).reshape(shape_before)
268    y = layer(x)
269    assert y.shape == shape_after
270    # - Channels are reshaped to form 2x1x2 blocks
271    # - Within each block, the increment is `nx * ny * nz` when increasing the
272    #   column index by 1, e.g. the block [[[ 0, 24]], [[48, 72]]]
273    # - Increasing the block index adds an offset of 1
274    assert_allclose(
275        y,
276        [[[[[ 0, 24,  1, 25,  2, 26,  3, 27],
277            [ 4, 28,  5, 29,  6, 30,  7, 31],
278            [ 8, 32,  9, 33, 10, 34, 11, 35]],
279
280           [[48, 72, 49, 73, 50, 74, 51, 75],
281            [52, 76, 53, 77, 54, 78, 55, 79],
282            [56, 80, 57, 81, 58, 82, 59, 83]],
283
284           [[12, 36, 13, 37, 14, 38, 15, 39],
285            [16, 40, 17, 41, 18, 42, 19, 43],
286            [20, 44, 21, 45, 22, 46, 23, 47]],
287
288           [[60, 84, 61, 85, 62, 86, 63, 87],
289            [64, 88, 65, 89, 66, 90, 67, 91],
290            [68, 92, 69, 93, 70, 94, 71, 95]]]]]
291    )
292
293def test_datasets():
294    wikitext2_train = contrib.data.text.WikiText2(root='data/wikitext-2', segment='train')
295    wikitext2_val = contrib.data.text.WikiText2(root='data/wikitext-2', segment='validation',
296                                                vocab=wikitext2_train.vocabulary)
297    wikitext2_test = contrib.data.text.WikiText2(root='data/wikitext-2', segment='test')
298    assert len(wikitext2_train) == 59305,  len(wikitext2_train)
299    assert len(wikitext2_train.vocabulary) == 33278, len(wikitext2_train.vocabulary)
300    assert len(wikitext2_train.frequencies) == 33277, len(wikitext2_train.frequencies)
301    assert len(wikitext2_val) == 6181, len(wikitext2_val)
302    assert len(wikitext2_val.vocabulary) == 33278, len(wikitext2_val.vocabulary)
303    assert len(wikitext2_val.frequencies) == 13776, len(wikitext2_val.frequencies)
304    assert len(wikitext2_test) == 6974, len(wikitext2_test)
305    assert len(wikitext2_test.vocabulary) == 14143, len(wikitext2_test.vocabulary)
306    assert len(wikitext2_test.frequencies) == 14142, len(wikitext2_test.frequencies)
307    assert wikitext2_test.frequencies['English'] == 32
308
309
310def test_sampler():
311    interval_sampler = contrib.data.IntervalSampler(10, 3)
312    assert sorted(list(interval_sampler)) == list(range(10))
313    interval_sampler = contrib.data.IntervalSampler(10, 3, rollover=False)
314    assert list(interval_sampler) == [0, 3, 6, 9]
315
316
317class TestRNNLayer(gluon.HybridBlock):
318    def __init__(self, cell_type, hidden_size, layout, prefix=None, params=None):
319        super(TestRNNLayer, self).__init__(prefix=prefix, params=params)
320        self.cell = cell_type(hidden_size, prefix='rnn_')
321        self.layout = layout
322
323    def hybrid_forward(self, F, inputs, states, valid_length):
324        if isinstance(valid_length, list) and len(valid_length) == 0:
325            valid_length = None
326        return contrib.rnn.rnn_cell.dynamic_unroll(self.cell, inputs, states,
327                                                   valid_length=valid_length,
328                                                   layout=self.layout)
329
330def check_unroll(cell_type, num_states, layout):
331    batch_size = 20
332    input_size = 50
333    hidden_size = 30
334    seq_len = 10
335    if layout == 'TNC':
336        rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, input_size))
337    elif layout == 'NTC':
338        rnn_data = mx.nd.normal(loc=0, scale=1, shape=(batch_size, seq_len, input_size))
339    else:
340        print("Wrong layout")
341        return
342    valid_length = mx.nd.round(mx.nd.random.uniform(low=1, high=10, shape=(batch_size)))
343    state_shape = (batch_size, hidden_size)
344    states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(num_states)]
345
346    cell = cell_type(hidden_size, prefix='rnn_')
347    cell.initialize(ctx=default_context())
348    if layout == 'TNC':
349        cell(rnn_data[0], states)
350    else:
351        cell(rnn_data[:,0,:], states)
352    params1 = cell.collect_params()
353    orig_params1 = copy.deepcopy(params1)
354
355    trainer = gluon.Trainer(params1, 'sgd', {'learning_rate' : 0.03})
356    with mx.autograd.record():
357        res1, states1 = cell.unroll(seq_len, rnn_data, states, valid_length=valid_length,
358                                    layout=layout, merge_outputs=True)
359    res1.backward()
360    trainer.step(batch_size)
361
362    configs = [
363            lambda layer: None,
364            lambda layer: layer.hybridize(),
365            lambda layer: layer.hybridize({'inline_limit': 0}),
366            lambda layer: layer.hybridize({'static_alloc': True}),
367            lambda layer: layer.hybridize({'static_alloc': True, 'static_shape': True}) ]
368    # We can't pass None to a hybrid block, but it accepts an empty list.
369    # so we use an empty list to represent valid_length if it's None.
370    if valid_length is None:
371        valid_length = []
372    for config in configs:
373        layer = TestRNNLayer(cell_type, hidden_size, layout)
374        layer.initialize(ctx=default_context())
375        config(layer)
376        res2, states2 = layer(rnn_data, states, valid_length)
377        params2 = layer.collect_params()
378        for key, val in orig_params1.items():
379            params2[key].set_data(copy.deepcopy(val.data()))
380
381        trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03})
382        with mx.autograd.record():
383            res2, states2 = layer(rnn_data, states, valid_length)
384        assert_almost_equal(res1, res2, rtol=0.001, atol=0.0001)
385        assert len(states1) == len(states2)
386        for i in range(len(states1)):
387            assert_almost_equal(states1[i], states2[i], rtol=0.001, atol=0.0001)
388        res2.backward()
389        trainer.step(batch_size)
390
391        for key, val in params1.items():
392            weight1 = val.data()
393            weight2 = params2[key].data()
394            assert_almost_equal(weight1, weight2, rtol=0.001, atol=0.0001)
395
396
397@with_seed()
398def test_contrib_unroll():
399    cell_types = [(gluon.rnn.RNNCell, 1), (gluon.rnn.LSTMCell, 2),
400            (gluon.rnn.GRUCell, 1)]
401    for cell_type, num_states in cell_types:
402        check_unroll(cell_type, num_states, 'TNC')
403        check_unroll(cell_type, num_states, 'NTC')
404
405@with_seed()
406def test_ModulatedDeformableConvolution():
407    """test of the deformable convolution layer with possible combinations of arguments,
408    currently this layer only supports gpu
409    """
410    from mxnet.gluon.contrib.cnn import DeformableConvolution
411    net = nn.HybridSequential()
412    net.add(
413        DeformableConvolution(10, kernel_size=(3, 3), strides=1, padding=0),
414        DeformableConvolution(10, kernel_size=(1, 1), strides=1, padding=0),
415        DeformableConvolution(10, kernel_size=(5, 5), strides=1, padding=0),
416        DeformableConvolution(10, kernel_size=(3, 5), strides=1, padding=0),
417        DeformableConvolution(10, kernel_size=(5, 1), strides=1, padding=0, num_deformable_group=2),
418        DeformableConvolution(10, kernel_size=(3, 2), strides=1, padding=0, activation='relu',
419                               offset_use_bias=False, use_bias=False),
420        DeformableConvolution(10, kernel_size=(3, 2), strides=1, padding=0, activation='relu',
421                               offset_use_bias=False),
422        DeformableConvolution(10, kernel_size=(3, 2), strides=1, padding=0, activation='relu',
423                               use_bias=False),
424        DeformableConvolution(10, kernel_size=(3, 2), strides=1, padding=0, offset_use_bias=False, use_bias=False),
425        DeformableConvolution(10, kernel_size=(3, 2), strides=1, padding=0, offset_use_bias=False),
426        DeformableConvolution(12, kernel_size=(3, 2), strides=1, padding=0, use_bias=False),
427        DeformableConvolution(12, kernel_size=(3, 2), strides=1, padding=0, use_bias=False, num_deformable_group=4),
428    )
429
430    ctx = mx.cpu()
431
432    net.initialize(force_reinit=True, ctx=ctx)
433    net.hybridize()
434
435    x = mx.nd.random.uniform(shape=(8, 5, 30, 31), ctx=ctx)
436    with mx.autograd.record():
437        y = net(x)
438
439
440if __name__ == '__main__':
441    import nose
442    nose.runmodule()
443