1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""
19Tests for individual operators
20This module contains operator tests which currently do not exist on
21ONNX backend test framework. Once we have PRs on the ONNX repo and get
22those PRs merged, this file will get EOL'ed.
23"""
24# pylint: disable=too-many-locals,wrong-import-position,import-error
25from __future__ import absolute_import
26import sys
27import os
28import unittest
29import logging
30import tarfile
31from collections import namedtuple
32import numpy as np
33import numpy.testing as npt
34from onnx import checker, numpy_helper, helper, load_model
35from onnx import TensorProto
36from mxnet.test_utils import download
37from mxnet.contrib import onnx as onnx_mxnet
38import mxnet as mx
39import backend
40
41CURR_PATH = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
42sys.path.insert(0, os.path.join(CURR_PATH, '../../python/unittest'))
43
44logger = logging.getLogger()
45logger.setLevel(logging.DEBUG)
46
47
48def get_rnd(shape, low=-1.0, high=1.0, dtype=np.float32):
49    if dtype == np.float32:
50        return (np.random.uniform(low, high,
51                                  np.prod(shape)).reshape(shape).astype(np.float32))
52    elif dtype == np.int32:
53        return (np.random.randint(low, high,
54                                  np.prod(shape)).reshape(shape).astype(np.float32))
55    elif dtype == np.bool_:
56        return np.random.choice(a=[False, True], size=shape).astype(np.float32)
57
58
59def _fix_attributes(attrs, attribute_mapping):
60    new_attrs = attrs
61    attr_modify = attribute_mapping.get('modify', {})
62    for k, v in attr_modify.items():
63        new_attrs[v] = new_attrs.pop(k, None)
64
65    attr_add = attribute_mapping.get('add', {})
66    for k, v in attr_add.items():
67        new_attrs[k] = v
68
69    attr_remove = attribute_mapping.get('remove', [])
70    for k in attr_remove:
71        if k in new_attrs:
72            del new_attrs[k]
73
74    return new_attrs
75
76
77def forward_pass(sym, arg, aux, data_names, input_data):
78    """ Perform forward pass on given data
79    :param sym: Symbol
80    :param arg: Arg params
81    :param aux: Aux params
82    :param data_names: Input names (list)
83    :param input_data: Input data (list). If there is only one input,
84                        pass it as a list. For example, if input is [1, 2],
85                        pass input_data=[[1, 2]]
86    :return: result of forward pass
87    """
88    data_shapes = []
89    data_forward = []
90    for idx in range(len(data_names)):
91        val = input_data[idx]
92        data_shapes.append((data_names[idx], np.shape(val)))
93        data_forward.append(mx.nd.array(val))
94    # create module
95    mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None)
96    mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)
97    if not arg and not aux:
98        mod.init_params()
99    else:
100        mod.set_params(arg_params=arg, aux_params=aux,
101                       allow_missing=True, allow_extra=True)
102    # run inference
103    batch = namedtuple('Batch', ['data'])
104    mod.forward(batch(data_forward), is_train=False)
105
106    return mod.get_outputs()[0].asnumpy()
107
108
109def get_input_tensors(input_data):
110    input_tensor = []
111    input_names = []
112    input_sym = []
113    for idx, ip in enumerate(input_data):
114        name = "input" + str(idx + 1)
115        input_sym.append(mx.sym.Variable(name))
116        input_names.append(name)
117        input_tensor.append(helper.make_tensor_value_info(name,
118                                                          TensorProto.FLOAT, shape=np.shape(ip)))
119    return input_names, input_tensor, input_sym
120
121
122def get_onnx_graph(testname, input_names, inputs, output_name, output_shape, attr):
123    outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=output_shape)]
124
125    nodes = [helper.make_node(output_name, input_names, ["output"], **attr)]
126
127    graph = helper.make_graph(nodes, testname, inputs, outputs)
128
129    model = helper.make_model(graph)
130    return model
131
132class TestNode(unittest.TestCase):
133    """ Tests for models.
134    Tests are dynamically added.
135    Therefore edit test_models to add more tests.
136    """
137    def test_import_export(self):
138        for test in test_cases:
139            test_name, mxnet_op, onnx_name, inputs, attrs, mxnet_specific, fix_attrs, check_value, check_shape = test
140            with self.subTest(test_name):
141                names, input_tensors, inputsym = get_input_tensors(inputs)
142                if inputs:
143                    test_op = mxnet_op(*inputsym, **attrs)
144                    mxnet_output = forward_pass(test_op, None, None, names, inputs)
145                    outputshape = np.shape(mxnet_output)
146                else:
147                    test_op = mxnet_op(**attrs)
148                    shape = attrs.get('shape', (1,))
149                    x = mx.nd.zeros(shape, dtype='float32')
150                    xgrad = mx.nd.zeros(shape, dtype='float32')
151                    exe = test_op.bind(ctx=mx.cpu(), args={'x': x}, args_grad={'x': xgrad})
152                    mxnet_output = exe.forward(is_train=False)[0].asnumpy()
153                    outputshape = np.shape(mxnet_output)
154
155                if mxnet_specific:
156                    onnxmodelfile = onnx_mxnet.export_model(test_op, {}, [np.shape(ip) for ip in inputs],
157                                                            [ip.dtype for ip in inputs],
158                                                            onnx_name + ".onnx")
159                    onnxmodel = load_model(onnxmodelfile)
160                else:
161                    onnx_attrs = _fix_attributes(attrs, fix_attrs)
162                    onnxmodel = get_onnx_graph(test_name, names, input_tensors, onnx_name, outputshape, onnx_attrs)
163
164                bkd_rep = backend.prepare(onnxmodel, operation='export')
165                output = bkd_rep.run(inputs)
166
167                if check_value:
168                    npt.assert_almost_equal(output[0], mxnet_output)
169
170                if check_shape:
171                    npt.assert_equal(output[0].shape, outputshape)
172
173        input1 = get_rnd((1, 10, 2, 3))
174        ipsym = mx.sym.Variable("input1")
175        for test in test_scalar_ops:
176            if test == 'Add':
177                outsym = 2 + ipsym
178            if test == "Sub":
179                outsym = ipsym - 2
180            if test == "rSub":
181                outsym = ipsym.__rsub__(2)
182            if test == "Mul":
183                outsym = 2 * ipsym
184            if test == "Div":
185                outsym = ipsym / 2
186            if test == "Pow":
187                outsym = ipsym ** 2
188            forward_op = forward_pass(outsym, None, None, ['input1'], input1)
189            converted_model = onnx_mxnet.export_model(outsym, {}, [np.shape(input1)], np.float32,
190                                                      onnx_file_path=outsym.name + ".onnx")
191
192            sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
193            result = forward_pass(sym, arg_params, aux_params, ['input1'], input1)
194
195            npt.assert_almost_equal(result, forward_op)
196
197    def test_imports(self):
198        for test in import_test_cases:
199            test_name, onnx_name, inputs, np_op, attrs = test
200            with self.subTest(test_name):
201                names, input_tensors, inputsym = get_input_tensors(inputs)
202                np_out = [np_op(*inputs, **attrs)]
203                output_shape = np.shape(np_out)
204                onnx_model = get_onnx_graph(test_name, names, input_tensors, onnx_name, output_shape, attrs)
205                bkd_rep = backend.prepare(onnx_model, operation='import')
206                mxnet_out = bkd_rep.run(inputs)
207                npt.assert_almost_equal(np_out, mxnet_out, decimal=4)
208
209    def test_exports(self):
210        input_shape = (2,1,3,1)
211        for test in export_test_cases:
212            test_name, onnx_name, mx_op, attrs = test
213            input_sym = mx.sym.var('data')
214            outsym = mx_op(input_sym, **attrs)
215            converted_model = onnx_mxnet.export_model(outsym, {}, [input_shape], np.float32,
216                                                      onnx_file_path=outsym.name + ".onnx")
217            model = load_model(converted_model)
218            checker.check_model(model)
219
220
221# test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list], attribute map, MXNet_specific=True/False,
222# fix_attributes = {'modify': {mxnet_attr_name: onnx_attr_name},
223#                   'remove': [attr_name],
224#                   'add': {attr_name: value},
225# check_value=True/False, check_shape=True/False)
226test_cases = [
227    ("test_equal", mx.sym.broadcast_equal, "Equal", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False, {}, True,
228     False),
229    ("test_greater", mx.sym.broadcast_greater, "Greater", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False, {}, True,
230     False),
231    ("test_less", mx.sym.broadcast_lesser, "Less", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False, {}, True,
232     False),
233    ("test_and", mx.sym.broadcast_logical_and, "And",
234     [get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, False),
235    ("test_xor", mx.sym.broadcast_logical_xor, "Xor",
236     [get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, False),
237    ("test_or", mx.sym.broadcast_logical_or, "Or",
238     [get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, False),
239    ("test_not", mx.sym.logical_not, "Not", [get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, False),
240    ("test_square", mx.sym.square, "Pow", [get_rnd((2, 3), dtype=np.int32)], {}, True, {}, True, False),
241    ("test_spacetodepth", mx.sym.space_to_depth, "SpaceToDepth", [get_rnd((1, 1, 4, 6))],
242     {'block_size': 2}, False, {}, True, False),
243    ("test_softmax", mx.sym.SoftmaxOutput, "Softmax", [get_rnd((1000, 1000)), get_rnd(1000)],
244     {'ignore_label': 0, 'use_ignore': False}, True, {}, True, False),
245    ("test_logistic_regression", mx.sym.LogisticRegressionOutput, "Sigmoid",
246     [get_rnd((1000, 1000)), get_rnd((1000, 1000))], {}, True, {}, True, False),
247    # TODO: After rewrite, FC would fail this testcase. Commenting this out for now
248    # ("test_fullyconnected", mx.sym.FullyConnected, "Gemm", [get_rnd((4, 3)), get_rnd((4, 3)), get_rnd(4)],
249    #  {'num_hidden': 4, 'name': 'FC'}, True, {}, True, False),
250    ("test_lppool1", mx.sym.Pooling, "LpPool", [get_rnd((2, 3, 20, 20))],
251     {'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 1, 'pool_type': 'lp'}, False,
252     {'modify': {'kernel': 'kernel_shape', 'pad': 'pads', 'stride': 'strides', 'p_value': 'p'},
253      'remove': ['pool_type']}, True, False),
254    ("test_lppool2", mx.sym.Pooling, "LpPool", [get_rnd((2, 3, 20, 20))],
255     {'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 2, 'pool_type': 'lp'}, False,
256     {'modify': {'kernel': 'kernel_shape', 'pad': 'pads', 'stride': 'strides', 'p_value': 'p'},
257      'remove': ['pool_type']}, True, False),
258    ("test_globallppool1", mx.sym.Pooling, "GlobalLpPool", [get_rnd((2, 3, 20, 20))],
259     {'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 1, 'pool_type': 'lp', 'global_pool': True}, False,
260     {'modify': {'p_value': 'p'},
261      'remove': ['pool_type', 'kernel', 'pad', 'stride', 'global_pool']}, True, False),
262    ("test_globallppool2", mx.sym.Pooling, "GlobalLpPool", [get_rnd((2, 3, 20, 20))],
263     {'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 2, 'pool_type': 'lp', 'global_pool': True}, False,
264     {'modify': {'p_value': 'p'},
265      'remove': ['pool_type', 'kernel', 'pad', 'stride', 'global_pool']}, True, False),
266    ("test_roipool", mx.sym.ROIPooling, "MaxRoiPool",
267     [[[get_rnd(shape=(8, 6), low=1, high=100, dtype=np.int32)]], [[0, 0, 0, 4, 4]]],
268     {'pooled_size': (2, 2), 'spatial_scale': 0.7}, False,
269     {'modify': {'pooled_size': 'pooled_shape'}}, True, False),
270
271    # since results would be random, checking for shape alone
272    ("test_multinomial", mx.sym.sample_multinomial, "Multinomial",
273     [np.array([0, 0.1, 0.2, 0.3, 0.4]).astype("float32")],
274     {'shape': (10,)}, False, {'modify': {'shape': 'sample_size'}}, False, True),
275    ("test_random_normal", mx.sym.random_normal, "RandomNormal", [],
276     {'shape': (2, 2), 'loc': 0, 'scale': 1}, False, {'modify': {'loc': 'mean'}}, False, True),
277    ("test_random_uniform", mx.sym.random_uniform, "RandomUniform", [],
278     {'shape': (2, 2), 'low': 0.5, 'high': 1.0}, False, {}, False, True)
279]
280
281test_scalar_ops = ['Add', 'Sub', 'rSub' 'Mul', 'Div', 'Pow']
282
283# test_case = ("test_case_name", "ONNX_op_name", [input_list], np_op, attribute map)
284import_test_cases = [
285    ("test_lpnormalization_default", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':2, 'axis':-1}),
286    ("test_lpnormalization_ord1", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':1, 'axis':-1}),
287    ("test_lpnormalization_ord2", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':2, 'axis':1})
288]
289
290# test_case = ("test_case_name", "ONNX_op_name", mxnet_op, attribute map)
291export_test_cases = [
292    ("test_expand", "Expand", mx.sym.broadcast_to, {'shape': (2,1,3,1)}),
293    ("test_tile", "Tile", mx.sym.tile, {'reps': (2,3)})
294]
295
296if __name__ == '__main__':
297    unittest.main()
298