1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import mxnet as mx
19from mxnet.gluon import HybridBlock, nn
20import numpy as np
21import onnxruntime as rt
22from mxnet.test_utils import assert_almost_equal
23import pytest
24import tempfile
25
26def def_model(op_name, dummy_input=False, **params):
27    class Model(HybridBlock):
28        def __init__(self, **kwargs):
29            super(Model, self).__init__(**kwargs)
30
31        def hybrid_forward(self, F, *inputs):
32            names = op_name.split('.')
33            func = F
34            for name in names:
35                func = getattr(func, name)
36            if dummy_input:
37                return func(**params), inputs[0]
38            else:
39                return func(*inputs, **params)
40    return Model
41
42def def_model_from_func(func, dummy_input=False, **params):
43    class Model(HybridBlock):
44        def __init__(self, **kwargs):
45            super(Model, self).__init__(**kwargs)
46
47        def hybrid_forward(self, F, *inputs):
48            if dummy_input:
49                return func(**params), inputs[0]
50            else:
51                return func(*inputs, **params)
52    return Model
53
54def op_export_test(model_name, Model, inputs, tmp_path, dummy_input=False, onnx_map=None, mx_map=None, rtol=None, atol=None):
55    def export_to_onnx(model, model_name, inputs):
56        model_path = '{}/{}'.format(tmp_path, model_name)
57        model.export(model_path, epoch=0)
58        sym_file = '{}-symbol.json'.format(model_path)
59        params_file = '{}-0000.params'.format(model_path)
60        onnx_file = '{}/{}.onnx'.format(tmp_path, model_name)
61        mx.onnx.export_model(sym_file, params_file, [inp.shape for inp in inputs],
62                             [inp.dtype for inp in inputs], onnx_file)
63        return onnx_file
64
65    def onnx_rt(onnx_file, inputs):
66        sess = rt.InferenceSession(onnx_file)
67        dtype_0 = inputs[0].asnumpy().dtype
68        input_dict = dict((sess.get_inputs()[i].name, inputs[i].asnumpy()) for i in range(len(inputs)))
69        pred = sess.run(None, input_dict)
70        return pred
71
72    # create a new model
73    model = Model()
74    model.initialize(ctx=mx.cpu(0))
75    model.hybridize()
76    pred_mx = model(*inputs)
77    onnx_file = export_to_onnx(model, model_name, inputs)
78    pred_onx = onnx_rt(onnx_file, inputs)
79    if dummy_input:
80        pred_mx = pred_mx[0]
81    if isinstance(pred_mx, list):
82        for i in range(len(pred_mx)):
83            pred_onx_i = onnx_map(pred_onx[i]) if onnx_map else pred_onx[i]
84            pred_mx_i = mx_map(pred_mx[i]) if mx_map else pred_mx[i]
85            assert_almost_equal(pred_onx_i, pred_mx_i, equal_nan=True, rtol=rtol, atol=atol)
86    else:
87        pred_onx = onnx_map(pred_onx[0]) if onnx_map else pred_onx[0]
88        pred_mx = mx_map(pred_mx) if mx_map else pred_mx
89        assert_almost_equal(pred_onx, pred_mx, equal_nan=True, rtol=rtol, atol=atol)
90
91
92def test_onnx_export_abs(tmp_path):
93    M = def_model('abs')
94    x = mx.nd.array([[-2, -1], [0, 99]], dtype='float32')
95    op_export_test('abs', M, [x], tmp_path)
96
97
98@pytest.mark.parametrize('dtype', ['float32', 'float64', 'float16', 'int32', 'int64'])
99@pytest.mark.parametrize('params', [[(0, 1), (2,3), (1, 1)],
100                                    [(None, 1), (2, None), None],
101                                    [(0, 0, 0), (None, 4, 5), (None, 1, 2)]])
102def test_onnx_export_slice(tmp_path, dtype, params):
103    M = def_model('slice', begin=params[0], end=params[1], step=params[2])
104    x = mx.nd.arange(start=0, stop=60, dtype=dtype).reshape((3, 4, 5))
105    op_export_test('slice', M, [x], tmp_path)
106
107
108def test_onnx_export_stack(tmp_path):
109    M = def_model('stack')
110    x = mx.nd.array([1, 2], dtype='float32')
111    y = mx.nd.array([3, 4], dtype='float32')
112    op_export_test('stack', M, [x, y], tmp_path)
113
114@pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64"])
115@pytest.mark.parametrize("shape", [(1), (1,2), (2,3,4), (5,6,7)])
116def test_onnx_export_zeros(tmp_path, dtype, shape):
117    M = def_model('zeros', shape=shape, dtype=dtype, dummy_input=True)
118    x = mx.nd.array([1])
119    op_export_test('zeros', M, [x], tmp_path, dummy_input=True)
120
121@pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64"])
122@pytest.mark.parametrize("shape", [(1), (1,2), (2,3,4), (5,6,7)])
123def test_onnx_export_ones(tmp_path, dtype, shape):
124    M = def_model('ones', shape=shape, dtype=dtype, dummy_input=True)
125    x = mx.nd.array([0])
126    op_export_test('ones', M, [x], tmp_path, dummy_input=True)
127
128
129@pytest.mark.parametrize('dtype', [None, 'float32', 'float64', 'int32', 'int64'])
130@pytest.mark.parametrize('shape', [(1), (1,2), (2,3,4), (5,6,7)])
131def test_onnx_export_zeros_like(tmp_path, dtype, shape):
132    M = def_model('zeros_like', dtype=dtype)
133    x = mx.random.uniform(0, 1, shape, dtype='float32')
134    op_export_test('zeros_like', M, [x], tmp_path)
135
136
137@pytest.mark.parametrize('dtype', [None, 'float32', 'float64', 'int32', 'int64'])
138@pytest.mark.parametrize('shape', [(1), (1,2), (2,3,4), (5,6,7)])
139def test_onnx_export_ones_like(tmp_path, dtype, shape):
140    M = def_model('ones_like', dtype=dtype)
141    x = mx.random.uniform(0, 1, shape, dtype='float32')
142    op_export_test('ones_like', M, [x], tmp_path)
143
144
145@pytest.mark.parametrize("dtype", ["float32", "float64"])
146@pytest.mark.parametrize("axis", [None,0,1])
147@pytest.mark.parametrize("start", [0, 0.5, 1])
148@pytest.mark.parametrize("step", [0.01, 0.1, 0.5, 1])
149@pytest.mark.parametrize("test_data", [ mx.random.uniform(0, 1, (10,20)), [[0,1,2,3,4,5],[4,5,6,7,8,9],[8,9,10,11,12,13]]])
150def test_onnx_export_arange_like(tmp_path, dtype, axis, start, step, test_data):
151    M = def_model('contrib.arange_like', axis=axis, start=start, step=step)
152    x = mx.nd.array(test_data, dtype=dtype)
153    op_export_test('arange_like', M, [x], tmp_path)
154
155
156@pytest.mark.parametrize("params", [[0, 2, 1], [0, 50, 0.25], [-100, 100, 0.5], [5, None, 1], [-5, None, -1]])
157@pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64"])
158def test_onnx_export_arange(tmp_path, dtype, params):
159    start, stop, step = params[0], params[1], params[2]
160    if "int" in dtype:
161        start = int(start)
162        stop = int(stop) if stop != None else None
163        step = int(step)
164        if step == 0:
165            step = 1
166    M = def_model('arange', dummy_input=True, start=start, stop=stop, step=step, dtype=dtype)
167    x = mx.nd.array([1], dtype='float32')
168    op_export_test('arange', M, [x], tmp_path, dummy_input=True)
169
170
171@pytest.mark.parametrize('dtype', ['float32'])
172def test_onnx_export_layernorm(tmp_path, dtype):
173    x = mx.nd.random.uniform(1, 2, (3, 4, 5), dtype=dtype)
174    axes = list(range(np.shape(np.shape(x))[0]))
175    axes.append(-1)
176    for axis in axes:
177        M = def_model('LayerNorm', axis=axis)
178        gamma = mx.random.uniform(0, 1, [np.shape(x)[axis]], dtype=dtype)
179        beta = mx.random.uniform(0, 1, [np.shape(x)[axis]], dtype=dtype)
180        op_export_test('LayerNorm', M, [x, gamma, beta], tmp_path)
181
182
183@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32'])
184def test_onnx_export_broadcast_axis(tmp_path, dtype):
185    M1 = def_model('broadcast_axis', axis=(0, 2), size=(3, 4))
186    M2 = def_model('broadcast_axis', axis=(0, 2), size=(1, 5))
187    x1 = mx.nd.array([[[1], [2]]], dtype=dtype)
188    op_export_test('broadcast_axis_1', M1, [x1], tmp_path)
189    op_export_test('broadcast_axis_2', M2, [x1], tmp_path)
190    M3 = def_model('broadcast_axis', axis=(1, 4), size=(3, 5))
191    x2 = mx.nd.ones((1, 1, 3, 1, 1, 1), dtype=dtype)
192    op_export_test('broadcast_axis_3', M3, [x2], tmp_path)
193
194
195#TODO: onnxruntime does not support float64 for Where
196@pytest.mark.parametrize('dtype', ['float32'])
197def test_onnx_export_SequenceMask(tmp_path, dtype):
198    M1 = def_model('SequenceMask', use_sequence_length=True, axis=1, value=-5)
199    M2 = def_model('SequenceMask', use_sequence_length=True, axis=0, value=-99)
200    x = mx.nd.array([[[[  1.,   2.,   3.,  3.5]],
201                      [[  4.,   5.,   6.,  6.5]]],
202                     [[[  7.,   8.,   9.,  9.5]],
203                      [[ 10.,  11.,  12., 12.5]]],
204                     [[[ 13.,  14.,  15., 15.5]],
205                      [[ 16.,  17.,  18., 18.5]]]], dtype=dtype)
206    seq_len1 = mx.nd.array([1, 2, 1], dtype=dtype)
207    seq_len2 = mx.nd.array([1, 2], dtype=dtype)
208    op_export_test('SequenceMask_1', M1, [x, seq_len1], tmp_path)
209    op_export_test('SequenceMask_2', M2, [x, seq_len2], tmp_path)
210
211
212@pytest.mark.parametrize('dtype', ['float32'])
213def test_onnx_export_contrib_interleaved_matmul_selfatt_qk(tmp_path, dtype):
214    M1 = def_model('contrib.interleaved_matmul_selfatt_qk', heads=3)
215    x1 = mx.nd.random.uniform(0, 1, (3, 3, 3*3*3), dtype=dtype)
216    op_export_test('contrib_interleaved_matmul_selfatt_qk_1', M1, [x1], tmp_path)
217    M2 = def_model('contrib.interleaved_matmul_selfatt_qk', heads=5)
218    x2 = mx.nd.random.uniform(0, 1, (7, 5, 4*5*6), dtype=dtype)
219    op_export_test('contrib_interleaved_matmul_selfatt_qk_2', M2, [x2], tmp_path)
220
221@pytest.mark.parametrize('dtype', ['float32'])
222def test_onnx_export_contrib_interleaved_matmul_selfatt_valatt(tmp_path, dtype):
223    M = def_model('contrib.interleaved_matmul_selfatt_valatt', heads=6)
224    x = mx.nd.random.uniform(0, 1, (4, 5, 6*7*3), dtype=dtype)
225    att = mx.nd.random.uniform(0, 1, (5*6, 4, 4), dtype=dtype)
226    op_export_test('contrib_interleaved_matmul_selfatt_valatt', M, [x, att], tmp_path)
227
228
229@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32'])
230def test_onnx_export_slice_axis(tmp_path, dtype):
231    x = mx.nd.array([[  1.,   2.,   3.,   4.],
232                     [  5.,   6.,   7.,   8.],
233                     [  9.,  10.,  11.,  12.]], dtype=dtype)
234    M1 = def_model('slice_axis', axis=0, begin=1, end=3)
235    M2 = def_model('slice_axis', axis=0, begin=1, end=None)
236    M3 = def_model('slice_axis', axis=1, begin=-3, end=-1)
237    M4 = def_model('slice_axis', axis=-1, begin=-3, end=None)
238    op_export_test('slice_axis_1', M1, [x], tmp_path)
239    op_export_test('slice_axis_2', M2, [x], tmp_path)
240    op_export_test('slice_axis_3', M3, [x], tmp_path)
241    op_export_test('slice_axis_4', M4, [x], tmp_path)
242
243
244@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64'])
245def test_onnx_export_reshape(tmp_path, dtype):
246    x = mx.nd.ones((2, 3, 4, 5, 6), dtype=dtype)
247    M1 = def_model('reshape', shape=(6, 1, 0, -1))
248    op_export_test('reshape_1', M1, [x], tmp_path)
249    M2 = def_model('reshape', shape=(3, -1, 0, 0), reverse=True)
250    op_export_test('reshape_2', M2, [x], tmp_path)
251    M3 = def_model('reshape', shape=(5, 1, 1, 1, 1, 0 -1, 0), reverse=True)
252    op_export_test('reshape_3', M3, [x], tmp_path)
253    M4 = def_model('reshape', shape=(-3, -1))
254    op_export_test('reshape_4', M4, [x], tmp_path)
255
256
257@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64'])
258def test_onnx_export_reshape_special_cases(tmp_path, dtype):
259    x1 = mx.nd.ones((8, 9), dtype=dtype)
260    M1 = def_model('reshape', shape=(0, -4, 1, -1))
261    op_export_test('reshape_spec_1', M1, [x1], tmp_path)
262
263    x2 = mx.nd.ones((8, 9, 10), dtype=dtype)
264
265    M2 = def_model('reshape', shape=(0, -4, 3, -1, 10))
266    op_export_test('reshape_spec_2', M2, [x2], tmp_path)
267    M3 = def_model('reshape', shape=(-4, 2, -1, 10, 9))
268    op_export_test('reshape_spec_3', M3, [x2], tmp_path)
269
270    M4 = def_model('reshape', shape=(-3, 0))
271    op_export_test('reshape_spec_4', M4, [x2], tmp_path)
272
273    x3 = mx.nd.ones((1, 2, 3, 4, 5, 6), dtype=dtype)
274    M5 = def_model('reshape', shape=(0, 0, -3, -3))
275    op_export_test('reshape_spec_5', M5, [x3], tmp_path)
276
277    x4 = mx.nd.ones((5, 8, 6, 7), dtype=dtype)
278    M6 = def_model('reshape', shape=(0, -4, -1, 4, 0, 0))
279    op_export_test('reshape_spec_6', M6, [x4], tmp_path)
280
281    x5 = mx.nd.ones((2, 3, 4, 5, 6), dtype=dtype)
282    M7 = def_model('reshape', shape=(0, 0, -4, 2, 2, 0, 0))
283    op_export_test('reshape_spec_7', M7, [x5], tmp_path)
284
285    x6 = mx.nd.ones((8, 7, 6, 5), dtype=dtype)
286    M8 = def_model('reshape', shape=(-4, 1, -1, 0, 0, 0))
287    op_export_test('reshape_spec_8', M8, [x6], tmp_path)
288
289    x7 = mx.nd.ones((1000, 2, 3), dtype=dtype)
290    M9 = def_model('reshape', shape=(-4, 1, 1000, 0, 0))
291    op_export_test('reshape_spec_9', M9, [x7], tmp_path)
292
293    x8 = mx.nd.ones((3, 96, 5), dtype=dtype)
294    M10 = def_model('reshape', shape=(0, -4, 12, -1, 0))
295    op_export_test('reshape_spec_10', M10, [x8], tmp_path)
296
297    x9 = mx.nd.ones((3, 96, 5), dtype=dtype)
298    M11 = def_model('reshape', shape=(0, -4, 16, -1, 0))
299    op_export_test('reshape_spec_11', M11, [x9], tmp_path)
300
301
302@pytest.mark.parametrize('dtype', ['int32', 'int64'])
303def test_onnx_export_embedding(tmp_path, dtype):
304    x = mx.nd.array([[ 1.,  3.],
305                     [ 0.,  2.]], dtype=dtype)
306    y = mx.nd.array([[  0.,   1.,   2.,   3.,   4.],
307                     [  5.,   6.,   7.,   8.,   9.],
308                     [ 10.,  11.,  12.,  13.,  14.],
309                     [ 15.,  16.,  17.,  18.,  19.]], dtype=dtype)
310    M = def_model('Embedding', input_dim=4, output_dim=5)
311    op_export_test('Embedding', M, [x, y], tmp_path)
312
313
314@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64'])
315@pytest.mark.parametrize('num_hidden', [1, 2, 7, 10, 20])
316@pytest.mark.parametrize('no_bias', [True, False])
317@pytest.mark.parametrize('flatten', [True, False])
318def test_onnx_export_fully_connected(tmp_path, dtype, num_hidden, no_bias, flatten):
319    M = def_model('FullyConnected', num_hidden=num_hidden, no_bias=no_bias, flatten=flatten)
320    x = mx.nd.random.uniform(-0.5, 0.5, (3, 4, 5))
321    if (flatten):
322        weight = mx.nd.random.uniform(0, 1, (num_hidden, 4*5))
323    else:
324        weight = mx.nd.random.uniform(0, 1, (num_hidden, 5))
325    args = [x, weight]
326    if not no_bias:
327        args.append(mx.nd.random.uniform(0,1,(num_hidden,)))
328    op_export_test('FullyConnected', M, args, tmp_path)
329
330
331#TODO: onnxruntime does not support float64 for the relu opertors
332@pytest.mark.parametrize('dtype', ['float32', 'float16'])
333@pytest.mark.parametrize('shape', [(1,), (3,), (4, 5), (3, 4, 5)])
334@pytest.mark.parametrize('act_type', ['elu', 'leaky', 'prelu', 'selu', 'gelu'])
335def test_onnx_export_LeakyReLU(tmp_path, dtype, shape, act_type):
336    M = def_model('LeakyReLU', act_type='leaky')
337    x = mx.nd.random.uniform(-0.5, 0.5, shape=shape, dtype=dtype)
338    op_export_test('LeakyReLU', M, [x], tmp_path)
339
340
341@pytest.mark.parametrize('dtype', ['float32', 'float64', 'float16', 'int32', 'int64'])
342def test_onnx_export_Concat(tmp_path, dtype):
343    x = mx.nd.array([[1,1],[2,2]], dtype=dtype)
344    y = mx.nd.array([[3,3],[4,4],[5,5]], dtype=dtype)
345    z = mx.nd.array([[6,6],[7,7],[8,8]], dtype=dtype)
346    M1 = def_model('Concat', dim=0)
347    M2 = def_model('Concat', dim=1)
348    op_export_test('Concat_1', M1, [x, y, z], tmp_path)
349    op_export_test('Concat_2', M2, [y, z], tmp_path)
350
351
352@pytest.mark.parametrize('dtype', ['float32', 'float16'])
353@pytest.mark.parametrize('shape', [(1,), (3,), (4, 5), (3, 4, 5)])
354@pytest.mark.parametrize('act_type', ['tanh', 'relu', 'sigmoid', 'softrelu', 'softsign'])
355def test_onnx_export_Activation(tmp_path, dtype, shape, act_type):
356    M = def_model('Activation', act_type=act_type)
357    x = mx.nd.random.uniform(-0.5, 0.5, shape=shape, dtype=dtype)
358    op_export_test('Activation', M, [x], tmp_path)
359
360
361@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64'])
362@pytest.mark.parametrize('axes', [None, [1,0,2]])
363def test_onnx_export_transpose(tmp_path, dtype, axes):
364    if axes != None:
365        M = def_model('transpose', axes=axes)
366    else:
367        M = def_model('transpose')
368    x = mx.nd.array([[[1,2],[3,4]],[[5,6],[7,8]]], dtype=dtype)
369    op_export_test('transpose', M, [x], tmp_path)
370
371
372@pytest.mark.parametrize('dtype', ['float32', 'float64'])
373@pytest.mark.parametrize('axis', [0, 1, 2])
374def test_onnx_export_expand_dims(tmp_path, dtype, axis):
375    M = def_model('expand_dims', axis=axis)
376    x = mx.nd.random.uniform(0, 1, (2,3,4), dtype=dtype)
377    op_export_test('expand_dims', M, [x], tmp_path)
378
379
380@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64'])
381def test_onnx_export_broadcast_add(tmp_path, dtype):
382    M = def_model('broadcast_add')
383    x = mx.nd.array([[1,1,1],[1,1,1]], dtype=dtype)
384    y = mx.nd.array([[0],[1]], dtype=dtype)
385    op_export_test('broadcast_add', M, [x, y], tmp_path)
386
387
388@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64'])
389def test_onnx_export_broadcast_equal(tmp_path, dtype):
390    M = def_model('broadcast_equal')
391    x = mx.nd.zeros((4,5,6), dtype=dtype)
392    y = mx.nd.ones((4,5,6), dtype=dtype)
393    op_export_test('broadcast_equal', M, [x, y], tmp_path)
394
395
396@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64'])
397def test_onnx_export_broadcast_not_equal(tmp_path, dtype):
398    M = def_model('broadcast_not_equal')
399    x = mx.nd.zeros((4,5,6), dtype=dtype)
400    y = mx.nd.ones((4,5,6), dtype=dtype)
401    op_export_test('broadcast_not_equal', M, [x, y], tmp_path)
402    x1 = mx.nd.ones((4,5,6), dtype=dtype)
403    y1 = mx.nd.ones((5,6), dtype=dtype)
404    op_export_test('broadcast_not_equal', M, [x1, y1], tmp_path)
405
406
407@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
408def test_onnx_export_broadcast_minimum(tmp_path, dtype):
409    M = def_model('broadcast_minimum')
410    if 'int' in dtype:
411        x = mx.nd.random.randint(0, 1000, (4, 5, 6), dtype=dtype)
412        y = mx.nd.random.randint(0, 1000, (4, 5, 6), dtype=dtype)
413    else:
414        x = mx.nd.random.uniform(0, 1000, (4, 5, 6), dtype=dtype)
415        y = mx.nd.random.uniform(0, 1000, (4, 5, 6), dtype=dtype)
416    op_export_test('broadcast_minimum', M, [x, y], tmp_path)
417
418@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64'])
419@pytest.mark.parametrize('axis', [0, 1, 2, -1])
420def test_onnx_export_stack(tmp_path, dtype, axis):
421    M = def_model('stack', axis=axis)
422    if 'int' in dtype:
423        x = mx.nd.random.randint(0, 10*9, (3,4,5), dtype=dtype)
424        y = mx.nd.random.randint(0, 10*9, (3,4,5), dtype=dtype)
425    else:
426        x = mx.nd.random.normal(0, 10*9, (3,4,5), dtype=dtype)
427        y = mx.nd.random.normal(0, 10*9, (3,4,5), dtype=dtype)
428    op_export_test('stack', M, [x, y], tmp_path)
429
430
431@pytest.mark.parametrize('dtype', ['float32', 'float64'])
432@pytest.mark.parametrize('p', [0.1, 0.2, 0.5, 0.8])
433def test_onnx_export_dropout(tmp_path, dtype, p):
434    M = def_model('Dropout', p=p)
435    x = mx.nd.array([[3,0.5,-0.5,2,7],[2,-0.4,7,3,0.2]], dtype=dtype)
436    op_export_test('Dropout', M, [x], tmp_path)
437
438
439@pytest.mark.parametrize('src_dtype', ['float16', 'float32', 'float64'])
440@pytest.mark.parametrize('dst_dtype', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8'])
441@pytest.mark.parametrize('shape', [(2,3), (4,5,6)])
442def test_onnx_export_cast(tmp_path, src_dtype, dst_dtype, shape):
443    M = def_model('Cast', dtype=dst_dtype)
444    x = mx.nd.ones(shape, dtype=src_dtype)
445    op_export_test('Cast', M, [x], tmp_path)
446
447
448@pytest.mark.parametrize('dtype', ['float16', 'float32'])
449@pytest.mark.parametrize('temperature', [None, .1, 1., 10.])
450def test_onnx_export_softmax(tmp_path, dtype, temperature):
451    x = mx.nd.random.uniform(0, 1, (4, 5, 6), dtype=dtype)
452    M1 = def_model('softmax')
453    op_export_test('softmax_1', M1, [x], tmp_path)
454    M2 = def_model('softmax', use_length=True, axis=0, temperature=temperature)
455    l2 = mx.random.uniform(0, 4, (5, 6)).astype('int32')
456    op_export_test('softmax_2', M2, [x, l2], tmp_path)
457    M3 = def_model('softmax', use_length=True, axis=-1, temperature=temperature)
458    # note that the axis==-1 case uses negative value masking + ONNX softmax
459    # when valid_len==0 the masked values will NOT be 0
460    l3 = mx.random.uniform(1, 6, (4, 5)).astype('int32')
461    op_export_test('softmax_3', M3, [x, l3], tmp_path)
462    M4 = def_model('softmax', use_length=True, axis=1, temperature=temperature)
463    l4 = mx.random.uniform(0, 5, (4, 6)).astype('int32')
464    op_export_test('softmax_4', M4, [x, l4], tmp_path)
465
466
467@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
468@pytest.mark.parametrize('axis', [0, 1, 2, 3])
469def test_onnx_export_reverse(tmp_path, dtype, axis):
470    x = mx.nd.arange(0, 120, dtype=dtype).reshape((2, 3, 4, 5))
471    M = def_model('reverse', axis=axis)
472    op_export_test('reverse', M, [x], tmp_path)
473
474
475@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
476@pytest.mark.parametrize('axis', [None, 0, 1, 2, -1, -2, -3])
477@pytest.mark.parametrize('repeats', [2, 1, 3])
478def test_onnx_export_repeat(tmp_path, dtype, axis, repeats):
479    x = mx.nd.arange(0, 27, dtype=dtype).reshape((3, 3, 3))
480    M = def_model('repeat', axis=axis, repeats=repeats)
481    op_export_test('repeat', M, [x], tmp_path)
482
483
484@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
485@pytest.mark.parametrize('shape', [(1, 3, 224, 224), (2, 2, 5, 8), (2, 4, 17, 23)])
486@pytest.mark.parametrize('params', [{'height': 7, 'width': 13},
487                                    {'height': 10, 'width': 16},
488                                    {'height': 3, 'width': 5},
489                                    {'height': 2, 'width': 4},
490                                    {'scale_height': 3, 'scale_width': 2},
491                                    {'scale_height': 1.7, 'scale_width': 2.3},
492                                    {'scale_height': 0.5, 'scale_width': 0.6},
493                                    {'scale_height': 0.8, 'scale_width': 0.13},
494                                    {'scale_height': 2.5, 'scale_width': 0.5},
495                                    {'scale_height': 3, 'scale_width': 0.2},
496                                    ])
497def test_onnx_export_contrib_BilinearResize2D(tmp_path, dtype, shape, params):
498    x = mx.random.uniform(0, 1, shape)
499    M = def_model('contrib.BilinearResize2D', **params)
500    op_export_test('contrib_BilinearResize2D', M, [x], tmp_path)
501
502
503@pytest.mark.parametrize('topk', [-1, 2, 3, 4])
504@pytest.mark.parametrize('valid_thresh', [0.3, 0.4, 0.8])
505@pytest.mark.parametrize('overlap_thresh', [0.4, 0.7, 1.0])
506def test_onnx_export_contrib_box_nms(tmp_path, topk, valid_thresh, overlap_thresh):
507    # Note that ONNX NMS op only supports float32
508
509    # Also note that onnxruntime's nms has slightly different implementation in handling
510    # overlaps and score ordering when certain boxes are suppressed than that of mxnet
511    # the following test tensors are manually tweaked to avoid such diferences
512    # The purpose of theses tests cases are to show that the high level conversion logic is
513    # laid out correctly
514
515    A = mx.nd.array([[
516                    [[[[0.5, 0.1, 0.1, 0.2, 0.2],
517                    [0.4, 0.1, 0.1, 0.2, 0.2],
518                    [0.7, 0.5, 0.5, 0.9, 0.9],
519                    [0.8, 0.1, 0.9, 0.11, 0.91],
520                    [0.001, 0.01, 0.01, 0.02, 0.02]]]],
521
522                    [[[[0.5, 0.1, 0.1, 0.2, 0.2],
523                    [0.4, 0.1, 0.1, 0.2, 0.2],
524                    [0.7, 0.5, 0.5, 0.9, 0.9],
525                    [0.8, 0.1, 0.9, 0.11, 0.91],
526                    [0.001, 0.01, 0.01, 0.02, 0.02]]]],
527
528                    [[[[0.4, 0.1, 0.1, 0.2, 0.2],
529                    [0.3, 0.1, 0.1, 0.2, 0.2],
530                    [0.7, 0.5, 0.5, 0.9, 0.9],
531                    [0.8, 0.1, 0.9, 0.11, 0.91],
532                    [0.001, 0.01, 0.01, 0.02, 0.02]]]],
533                    ]])
534    M = def_model('contrib.box_nms', coord_start=1, force_suppress=True,
535                  overlap_thresh=overlap_thresh, valid_thresh=valid_thresh, score_index=0,
536                  topk=topk, in_format='corner', out_format='corner')
537    op_export_test('contrib_nms_manual_coner', M, [A], tmp_path)
538
539    B = mx.nd.array([
540                    [[[[0.7, 0.5, 0.5, 0.2, 0.2],
541                    [0.6, 0.48, 0.48, 0.2, 0.2],
542                    [0.8, 0.76, 0.76, 0.2, 0.2],
543                    [0.9, 0.7, 0.7, 0.2, 0.2],
544                    [0.001, 0.5, 0.1, 0.02, 0.02]]]],
545
546                    [[[[0.5, 0.2, 0.2, 0.2, 0.2],
547                    [0.6, 0.4, 0.4, 0.21, 0.21],
548                    [0.7, 0.5, 0.5, 0.9, 0.9],
549                    [0.8, 0.1, 0.9, 0.01, 0.01],
550                    [0.001, 0.6, 0.1, 0.02, 0.02]]]],
551                    ])
552    M = def_model('contrib.box_nms', coord_start=1, force_suppress=True,
553                  overlap_thresh=overlap_thresh, valid_thresh=valid_thresh, score_index=0,
554                  topk=topk, in_format='center', out_format='center')
555    op_export_test('contrib_nms_manual_center', M, [B], tmp_path)
556
557
558@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
559@pytest.mark.parametrize("scalar", [0., 0.1, 0.5, 1., 5, 555.])
560def test_onnx_export_greater_scalar(tmp_path, dtype, scalar):
561    if 'int' in dtype:
562        scalar = int(scalar)
563        x = mx.nd.arange(0, 12, dtype=dtype).reshape((3, 4))
564    else:
565        x = mx.random.uniform(0, 9999, (5,10), dtype=dtype)
566    M = def_model('_internal._greater_scalar', scalar=scalar)
567    op_export_test('_internal._greater_scalar', M, [x], tmp_path)
568
569
570@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
571@pytest.mark.parametrize("scalar", [0., 0.1, 0.5, 1., 5, 555.])
572def test_onnx_export_lesser_scalar(tmp_path, dtype, scalar):
573    if 'int' in dtype:
574        scalar = int(scalar)
575        x = mx.nd.arange(0, 12, dtype=dtype).reshape((3, 4))
576    else:
577        x = mx.random.uniform(0, 9999, (5,10), dtype=dtype)
578    M = def_model('_internal._lesser_scalar', scalar=scalar)
579    op_export_test('_internal._lesser_scalar', M, [x], tmp_path)
580
581
582@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
583@pytest.mark.parametrize("scalar", [0., 0.1, 0.5, 1., 5, 555.])
584def test_onnx_export_equal_scalar(tmp_path, dtype, scalar):
585    if 'int' in dtype:
586        scalar = int(scalar)
587        x = mx.nd.arange(0, 12, dtype=dtype).reshape((3, 4))
588    else:
589        x = mx.random.uniform(0, 9999, (5,10), dtype=dtype)
590    M = def_model('_internal._equal_scalar', scalar=scalar)
591    op_export_test('_internal._equal_scalar', M, [x], tmp_path)
592
593
594@pytest.mark.parametrize('dtype', ["float16", "float32", "int32", "int64"])
595@pytest.mark.parametrize('shape', [(5,), (3,3), (10,2), (20,30,40)])
596@pytest.mark.parametrize('broadcast', [True, False])
597def test_onnx_export_where(tmp_path, dtype, shape, broadcast):
598    M = def_model('where')
599    x = mx.nd.zeros(shape, dtype=dtype)
600    y = mx.nd.ones(shape, dtype=dtype)
601    if broadcast:
602        shape = shape[0:1]
603    cond = mx.nd.random.randint(low=0, high=1, shape=shape, dtype='int32')
604    op_export_test('where', M, [cond, x, y], tmp_path)
605
606
607# onnxruntime does not seem to support float64 and int32
608@pytest.mark.parametrize('dtype', ['float16', 'float32', 'int64'])
609@pytest.mark.parametrize('axis', [0, 2, -1, -2, -3])
610@pytest.mark.parametrize('is_ascend', [True, False, 0, 1, None])
611@pytest.mark.parametrize('k', [1, 4])
612@pytest.mark.parametrize('dtype_i', ['float32', 'int32', 'int64'])
613@pytest.mark.parametrize('ret_typ', ['value', 'indices', 'both'])
614def test_onnx_export_topk(tmp_path, dtype, axis, is_ascend, k, dtype_i, ret_typ):
615    A = mx.random.uniform(0, 100, (4, 5, 6)).astype(dtype)
616    kwargs = {}
617    if is_ascend is not None:
618        kwargs['is_ascend'] = is_ascend
619    M = def_model('topk', axis=axis, k=k, dtype=dtype_i, ret_typ=ret_typ, **kwargs)
620    op_export_test('topk', M, [A], tmp_path)
621
622
623def test_onnx_link_op_with_multiple_outputs(tmp_path):
624    A = mx.random.uniform(0, 100, (4, 5, 6))
625    class Model1(HybridBlock):
626        def __init__(self, **kwargs):
627            super(Model1, self).__init__(**kwargs)
628
629        def hybrid_forward(self, F, x):
630            out1, out2 = F.topk(x, k=3, ret_typ='both')
631            out11 = out1 ** 2
632            out22 = out2 ** 3
633            return out11, out22
634    op_export_test('link_op_with_multiple_outputs_case1', Model1, [A], tmp_path)
635
636    class Model2(HybridBlock):
637        def __init__(self, **kwargs):
638            super(Model2, self).__init__(**kwargs)
639
640        def hybrid_forward(self, F, x):
641            out_ = F.topk(x, k=3, ret_typ='value')
642            out = out_ ** 3
643            return out
644    op_export_test('link_op_with_multiple_outputs_case2', Model2, [A], tmp_path)
645
646    class Model3(HybridBlock):
647        def __init__(self, **kwargs):
648            super(Model3, self).__init__(**kwargs)
649
650        def hybrid_forward(self, F, x):
651            out_ = F.topk(x, k=3, ret_typ='indices')
652            out = out_ ** 3
653            return out
654    op_export_test('link_op_with_multiple_outputs_case3', Model3, [A], tmp_path)
655
656
657# opset 8 MAX only supports float types
658# opset 12 and up suppots float and int
659@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64'])
660@pytest.mark.parametrize('shape', [(3, 4, 5), (1, 4, 1, 7)])
661def test_onnx_maximum_scalar(tmp_path, dtype, shape):
662    x = mx.random.uniform(0, 10, shape).astype(dtype)
663    M = def_model('maximum', right=5)
664    op_export_test('_maximum_scalar', M, [x], tmp_path)
665
666
667# opset 8 Min only supports float types
668# opset 12 and up suppots float and int
669@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64'])
670@pytest.mark.parametrize('shape', [(3, 4, 5), (1, 4, 1, 7)])
671def test_onnx_minimum_scalar(tmp_path, dtype, shape):
672    x = mx.random.uniform(0, 10, shape).astype(dtype)
673    M = def_model('minimum', right=5)
674    op_export_test('_minimum_scalar', M, [x], tmp_path)
675
676
677@pytest.mark.parametrize('dtype', ['float16', 'float32'])
678@pytest.mark.parametrize('fmt', ['corner', 'center'])
679@pytest.mark.parametrize('clip', [-1., 0., .5, 5.])
680def test_onnx_export_contrib_box_decode(tmp_path, dtype, fmt, clip):
681    # ensure data[0] < data[2] and data[1] < data[3] for corner format
682    mul = mx.nd.array([-1, -1, 1, 1], dtype=dtype)
683    data = mx.nd.random.uniform(0, 1, (2, 3, 4), dtype=dtype) * mul
684    anchors = mx.nd.random.uniform(0, 1, (1, 3, 4), dtype=dtype) * mul
685    M1 = def_model('contrib.box_decode', format=fmt, clip=clip)
686    op_export_test('contrib_box_decode', M1, [data, anchors], tmp_path)
687    M2 = def_model('contrib.box_decode', format=fmt, clip=clip, std0=0.3, std1=1.4, std2=0.5, std3=1.6)
688    op_export_test('contrib_box_decode', M1, [data, anchors], tmp_path)
689
690
691@pytest.mark.parametrize('dtype', ['float16', 'float32'])
692def test_onnx_export_contrib_AdaptiveAvgPooling2D(tmp_path, dtype):
693    x = mx.nd.random.uniform(0, 1, (1, 2, 3, 4), dtype=dtype)
694    M1 = def_model('contrib.AdaptiveAvgPooling2D')
695    op_export_test('contrib_AdaptiveAvgPooling2D', M1, [x], tmp_path)
696    M2 = def_model('contrib.AdaptiveAvgPooling2D', output_size=1)
697    op_export_test('contrib_AdaptiveAvgPooling2D', M2, [x], tmp_path)
698    M3 = def_model('contrib.AdaptiveAvgPooling2D', output_size=[1])
699    op_export_test('contrib_AdaptiveAvgPooling2D', M3, [x], tmp_path)
700    M4 = def_model('contrib.AdaptiveAvgPooling2D', output_size=[1,1])
701    op_export_test('contrib_AdaptiveAvgPooling2D', M4, [x], tmp_path)
702
703
704@pytest.mark.parametrize('dtype', ['float16', 'float32', 'int32', 'int64'])
705@pytest.mark.parametrize('shapes', [((3, 3, 3), (1, 3)), ((4, 5, 6, 7), (6, 7))])
706def test_onnx_export_broadcast_mod(tmp_path, dtype, shapes):
707    A = mx.nd.random.uniform(-300, 300, shapes[0]).astype(dtype)
708    B = mx.nd.random.uniform(-30, 30, shapes[1]).astype(dtype)
709    # test when dividend is zero
710    B[-1] = 0
711    M = def_model('broadcast_mod')
712    op_export_test('broadcast_mod', M, [A, B], tmp_path)
713
714
715@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
716def test_onnx_export_reshape_like(tmp_path, dtype):
717    if 'int' in dtype:
718        x = mx.nd.random.randint(0, 10, (2, 2, 3, 2), dtype=dtype)
719        y = mx.nd.random.randint(0, 10, (1, 4, 3, 2), dtype=dtype)
720    else:
721        x = mx.nd.random.normal(0, 10, (2, 2, 3, 2), dtype=dtype)
722        y = mx.nd.random.normal(0, 10, (1, 4, 3, 2), dtype=dtype)
723    M1 = def_model('reshape_like')
724    op_export_test('reshape_like1', M1, [x, y], tmp_path)
725    M2 = def_model('reshape_like', lhs_begin=0, lhs_end=2, rhs_begin=1, rhs_end=2)
726    op_export_test('reshape_like2', M2, [x, y], tmp_path)
727    M3 = def_model('reshape_like', lhs_begin=-4, lhs_end=-2, rhs_begin=-3, rhs_end=-2)
728    op_export_test('reshape_like3', M3, [x, y], tmp_path)
729    M4 = def_model('reshape_like', lhs_begin=0, lhs_end=None, rhs_begin=1, rhs_end=None)
730    op_export_test('reshape_like4', M4, [x, y], tmp_path)
731
732
733@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
734def test_onnx_export_gather_nd(tmp_path, dtype):
735    # y[0] == dim(x)
736    x1 = mx.random.uniform(-100, 100, (4, 5, 6, 7)).astype(dtype)
737    y1 = mx.random.randint(-4, 4, (4, 4, 4)).astype(dtype)
738    M1 = def_model('gather_nd')
739    op_export_test('gather_nd1', M1, [x1, y1], tmp_path)
740    # y[0] < dim(x)
741    x2 = mx.random.uniform(-100, 100, (4, 5, 6, 7)).astype(dtype)
742    y2 = mx.random.randint(-4, 4, (2,3,4)).astype(dtype)
743    M2 = def_model('gather_nd')
744    op_export_test('gather_nd2', M2, [x2, y2], tmp_path)
745
746
747@pytest.mark.parametrize('dtype', ['float16', 'float32'])
748@pytest.mark.parametrize('shape', [(3, 4, 5, 6), (1, 1, 1, 1)])
749@pytest.mark.parametrize('scale', [1, 2, 3])
750def test_onnx_export_upsampling(tmp_path, dtype, shape, scale):
751    A = mx.random.uniform(0, 1, shape).astype(dtype)
752    M = def_model('UpSampling', scale=scale, sample_type='nearest', num_args=1)
753    op_export_test('UpSampling', M, [A], tmp_path)
754
755
756@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
757@pytest.mark.parametrize('params', [((4, 5, 6), (0, 2)), ((4, 5, 6), (0, 1)),
758                                    ((1, 2, 3, 4, 1), (0, 4)),
759                                    ((4, 5, 1, 6), (0, 2))])
760def test_onnx_export_swap_axis(tmp_path, dtype, params):
761    shape = params[0]
762    dim1, dim2 = params[1]
763    x = mx.random.uniform(-100, 100, shape).astype(dtype)
764    M = def_model('SwapAxis', dim1=dim1, dim2=dim2)
765    op_export_test('SwapAxis', M, [x], tmp_path)
766
767
768@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
769@pytest.mark.parametrize('axes', [None, (0, 1, 2), (-2, -3), (-2, 0)])
770def test_onnx_export_slice_like(tmp_path, dtype, axes):
771    x = mx.nd.random.uniform(0, 1, (4, 5, 6, 7)).astype(dtype)
772    if axes is None:
773        M = def_model('slice_like')
774        y = mx.nd.zeros((2, 3, 4, 5), dtype=dtype)
775        op_export_test('slice_like', M, [x, y], tmp_path)
776    else:
777        M = def_model('slice_like', axes=axes)
778        y1 = mx.nd.zeros((2, 3, 4), dtype=dtype)
779        y2 = mx.nd.zeros((2, 3, 4, 5), dtype=dtype)
780        y3 = mx.nd.zeros((2, 3, 4, 5, 6), dtype=dtype)
781        op_export_test('slice_like_1', M, [x, y1], tmp_path)
782        op_export_test('slice_like_2', M, [x, y2], tmp_path)
783        op_export_test('slice_like_3', M, [x, y3], tmp_path)
784
785
786@pytest.mark.parametrize('dtype', ['float16', 'float32', 'int32', 'int64'])
787@pytest.mark.parametrize('axis', [None, 0, 2, -1])
788@pytest.mark.parametrize('num_outputs', [2, 5])
789def test_onnx_export_slice_channel(tmp_path, dtype, axis, num_outputs):
790    x = mx.nd.zeros((10,20,30,40), dtype=dtype)
791    if axis is None:
792        M = def_model('SliceChannel', num_outputs=num_outputs)
793    else:
794        M = def_model('SliceChannel', axis=axis, num_outputs=num_outputs)
795    op_export_test('SliceChannel', M, [x], tmp_path)
796
797
798@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
799@pytest.mark.parametrize('lhs_axes', [[1, 3], [3, 1], [-2, -4], [-4, -2]])
800@pytest.mark.parametrize('rhs_axes', [[1, 3], [3, 1], [-2, -4], [-4, -2]])
801def test_onnx_export_broadcast_like(tmp_path, dtype, lhs_axes, rhs_axes):
802    x = mx.random.normal(0, 10, (2, 1, 1, 1, 6)).astype(dtype)
803    y = mx.random.normal(0, 10, (2, 3, 4, 5, 6)).astype(dtype)
804    M1 = def_model('broadcast_like')
805    op_export_test('broadcast_like1', M1, [x, y], tmp_path)
806    M2 = def_model('broadcast_like', lhs_axes=lhs_axes, rhs_axes=rhs_axes)
807    op_export_test('broadcast_like2', M2, [x, y], tmp_path)
808
809
810@pytest.mark.parametrize('dtype', ['float32'])
811@pytest.mark.parametrize('pooled_size', [(1, 1), (3, 3), (14, 14), (5, 7)])
812@pytest.mark.parametrize('spatial_scale', [1, 0.5, 0.0625])
813@pytest.mark.parametrize('spatial_ratio', [1, 2, 3, 5])
814def test_onnx_export_contrib_ROIAlign(tmp_path, dtype, pooled_size, spatial_scale, spatial_ratio):
815    data = mx.random.uniform(0, 1, (5, 3, 512, 512)).astype(dtype)
816    rois = mx.nd.array([[-1, 0, 0, 0, 0],
817                        [0, 0, 0, 63, 63],
818                        [1, 34, 52, 25, 85],
819                        [2, 50, 50, 100, 100],
820                        [3, 0, 0, 127, 127],
821                        [4, 12, 84, 22, 94],
822                        [0, 0, 0, 1, 1]]).astype(dtype)
823    M = def_model('contrib.ROIAlign', pooled_size=pooled_size, spatial_scale=spatial_scale,
824                  sample_ratio=spatial_ratio)
825    # according to https://mxnet.apache.org/versions/1.7.0/api/python/docs/api/contrib/symbol/index.html#mxnet.contrib.symbol.ROIAlign
826    # the returned value for when batch_id < 0 should be all 0's
827    # however mxnet 1.8 does always behave this way so we set the first roi to 0's manually
828    def mx_map(x):
829        x[0] = 0
830        return x
831    op_export_test('_contrib_ROIAlign', M, [data, rois], tmp_path, mx_map=mx_map)
832
833
834@pytest.mark.parametrize('dtype', ['float32', 'float64'])
835@pytest.mark.parametrize('transpose_a', [True, False])
836@pytest.mark.parametrize('transpose_b', [True, False])
837def test_onnx_export_batch_dot(tmp_path, dtype, transpose_a, transpose_b):
838    x1 = mx.nd.random.normal(0, 10, (2, 3, 4, 5, 6), dtype=dtype)
839    y1 = mx.nd.random.normal(0, 10, (2, 3, 4, 6, 5), dtype=dtype)
840    M1 = def_model('batch_dot')
841    op_export_test('batch_dot1', M1, [x1, y1], tmp_path)
842    x2 = mx.nd.random.normal(0, 10, (2, 3, 4, 5, 5), dtype=dtype)
843    y2 = mx.nd.random.normal(0, 10, (2, 3, 4, 5, 5), dtype=dtype)
844    M2 = def_model('batch_dot', transpose_a=transpose_a, transpose_b=transpose_b)
845    op_export_test('batch_dot2', M2, [x2, y2], tmp_path)
846
847
848@pytest.mark.parametrize('dtype', ['float32'])
849@pytest.mark.parametrize('shape', [(1, 3, 64, 64), (2, 1, 60, 60)])
850@pytest.mark.parametrize('count_include_pad', [True, False])
851@pytest.mark.parametrize('pooling_convention', ['full', 'valid'])
852@pytest.mark.parametrize('kernel', [(3, 3), (4, 5), (14, 14)])
853@pytest.mark.parametrize('stride', [None, (1, 1), (2, 2), (3, 4), (4, 5)])
854@pytest.mark.parametrize('pad', [None, (1, 1), (3, 4), (4, 5)])
855def test_onnx_export_pooling_avg(tmp_path, dtype, shape, count_include_pad, pooling_convention,
856                                 kernel, stride, pad):
857    # mxnet and onnxruntime has different implementation of count_include_pad on the left column
858    # and bottom row
859    if pooling_convention == 'full' and count_include_pad == True:
860        return
861    # onnxruntime requires that pad is smaller than kernel
862    if pad and (pad[0] >= kernel[0] or pad[1] >= kernel[1]):
863        return
864    x = mx.random.uniform(0, 1, shape, dtype=dtype)
865    kwargs = {}
866    if kernel:
867        kwargs['kernel'] = kernel
868    if stride:
869        kwargs['stride'] = stride
870    if pad:
871        kwargs['pad'] = pad
872    M = def_model('Pooling', count_include_pad=count_include_pad, pool_type='avg',
873                  pooling_convention=pooling_convention, layout='NCHW', **kwargs)
874    # Note here we use np.nan_to_num to map the onnx output because onnxruntime AveragePool will
875    # output NaN in some edge cases where mxnet outputs 0
876    op_export_test('pooling_avg', M, [x], tmp_path, onnx_map=np.nan_to_num)
877
878
879@pytest.mark.parametrize('dtype', ['float32'])
880@pytest.mark.parametrize('shape', [(1, 3, 16, 16, 16), (1, 1, 10, 18, 18)])
881@pytest.mark.parametrize('count_include_pad', [True, False])
882@pytest.mark.parametrize('pooling_convention', ['full', 'valid'])
883@pytest.mark.parametrize('kernel', [(1, 1, 1), (3, 3, 3), (1, 7, 7)])
884@pytest.mark.parametrize('stride', [None, (1, 1, 1), (1, 2, 3)])
885@pytest.mark.parametrize('pad', [None, (0, 1, 1), (1, 2, 3)])
886def test_onnx_export_pooling_avg_3d(tmp_path, dtype, shape, count_include_pad, pooling_convention,
887                                    kernel, stride, pad):
888    # mxnet and onnxruntime has different implementation of count_include_pad on the left column
889    # and bottom row
890    if pooling_convention == 'full' and count_include_pad == True:
891        return
892    # onnxruntime requires that pad is smaller than kernel
893    if pad and (pad[0] >= kernel[0] or pad[1] >= kernel[1] or pad[2] >= kernel[2]):
894        return
895    x = mx.random.uniform(0, 1, shape, dtype=dtype)
896    kwargs = {}
897    if kernel:
898        kwargs['kernel'] = kernel
899    if stride:
900        kwargs['stride'] = stride
901    if pad:
902        kwargs['pad'] = pad
903    M = def_model('Pooling', count_include_pad=count_include_pad, pool_type='avg',
904                  pooling_convention=pooling_convention, layout='NCDHW', **kwargs)
905    # Note here we use np.nan_to_num to map the onnx output because onnxruntime AveragePool will
906    # output NaN in some edge cases where mxnet outputs 0
907    def mx_nan_to_num(a):
908        return np.nan_to_num(a.asnumpy())
909    op_export_test('pooling_avg_3d', M, [x], tmp_path, onnx_map=np.nan_to_num, mx_map=mx_nan_to_num)
910
911
912
913@pytest.mark.parametrize('dtype', ['float32'])
914@pytest.mark.parametrize('shape', [(1, 3, 64, 64), (2, 1, 60, 60)])
915@pytest.mark.parametrize('pooling_convention', ['full', 'valid'])
916@pytest.mark.parametrize('kernel', [(3, 3), (4, 5), (14, 14)])
917@pytest.mark.parametrize('stride', [None, (1, 1), (2, 2), (3, 4), (4, 5)])
918@pytest.mark.parametrize('pad', [None, (1, 1), (3, 4), (4, 5)])
919def test_onnx_export_pooling_max(tmp_path, dtype, shape, pooling_convention, kernel, stride, pad):
920    # onnxruntime requires that pad is smaller than kernel
921    if pad and (pad[0] >= kernel[0] or pad[1] >= kernel[1]):
922        return
923    x = mx.random.uniform(0, 1, shape, dtype=dtype)
924    kwargs = {}
925    if kernel:
926        kwargs['kernel'] = kernel
927    if stride:
928        kwargs['stride'] = stride
929    if pad:
930        kwargs['pad'] = pad
931    M = def_model('Pooling', pool_type='max', pooling_convention=pooling_convention,
932                  layout='NCHW', **kwargs)
933    op_export_test('pooling_max', M, [x], tmp_path)
934
935
936@pytest.mark.parametrize('dtype', ['float32'])
937@pytest.mark.parametrize('shape', [(1, 3, 16, 16, 16), (1, 1, 10, 18, 18)])
938@pytest.mark.parametrize('pooling_convention', ['full', 'valid'])
939@pytest.mark.parametrize('kernel', [(1, 1, 1), (3, 3, 3), (1, 7, 7)])
940@pytest.mark.parametrize('stride', [None, (1, 1, 1), (1, 2, 3)])
941@pytest.mark.parametrize('pad', [None, (0, 1, 1), (1, 2, 3)])
942def test_onnx_export_pooling_max_3d(tmp_path, dtype, shape, pooling_convention, kernel, stride, pad):
943    # onnxruntime requires that pad is smaller than kernel
944    if pad and (pad[0] >= kernel[0] or pad[1] >= kernel[1] or pad[2] >= kernel[2]):
945        return
946    x = mx.random.uniform(0, 1, shape, dtype=dtype)
947    kwargs = {}
948    if kernel:
949        kwargs['kernel'] = kernel
950    if stride:
951        kwargs['stride'] = stride
952    if pad:
953        kwargs['pad'] = pad
954    M = def_model('Pooling', pool_type='max', pooling_convention=pooling_convention,
955                  layout='NCDHW', **kwargs)
956    op_export_test('pooling_max_3d', M, [x], tmp_path)
957
958
959@pytest.mark.parametrize('dtype', ['float32'])
960@pytest.mark.parametrize('shape', [(1, 3, 64, 64), (2, 1, 60, 60)])
961@pytest.mark.parametrize('p_value', [1, 2])
962@pytest.mark.parametrize('kernel', [(3, 3), (4, 5), (14, 14)])
963@pytest.mark.parametrize('stride', [None, (1, 1), (2, 2), (3, 4), (4, 5)])
964@pytest.mark.parametrize('pad', [None, (1, 1), (3, 4), (4, 5)])
965def test_onnx_export_pooling_lp(tmp_path, dtype, shape, p_value, kernel, stride, pad):
966    # onnxruntime requires that pad is smaller than kernel
967    if pad and (pad[0] >= kernel[0] or pad[1] >= kernel[1]):
968        return
969    x = mx.random.uniform(0, 1, shape, dtype=dtype)
970    kwargs = {}
971    if kernel:
972        kwargs['kernel'] = kernel
973    if stride:
974        kwargs['stride'] = stride
975    if pad:
976        kwargs['pad'] = pad
977    M = def_model('Pooling', pool_type='lp', pooling_convention='valid',
978                  p_value=p_value, layout='NCHW', **kwargs)
979    op_export_test('pooling_lp', M, [x], tmp_path)
980
981
982@pytest.mark.parametrize('dtype', ['float32'])
983@pytest.mark.parametrize('shape', [(1, 3, 16, 16, 16), (1, 1, 10, 18, 18)])
984@pytest.mark.parametrize('p_value', [1, 2])
985@pytest.mark.parametrize('kernel', [(1, 1, 1), (3, 3, 3), (1, 7, 7)])
986@pytest.mark.parametrize('stride', [None, (1, 1, 1), (1, 2, 3)])
987@pytest.mark.parametrize('pad', [None, (0, 1, 1), (1, 2, 3)])
988def test_onnx_export_pooling_lp_3d(tmp_path, dtype, shape, p_value, kernel, stride, pad):
989    # onnxruntime requires that pad is smaller than kernel
990    if pad and (pad[0] >= kernel[0] or pad[1] >= kernel[1] or pad[2] >= kernel[2]):
991        return
992    x = mx.random.uniform(0, 1, shape, dtype=dtype)
993    kwargs = {}
994    if kernel:
995        kwargs['kernel'] = kernel
996    if stride:
997        kwargs['stride'] = stride
998    if pad:
999        kwargs['pad'] = pad
1000    M = def_model('Pooling', pool_type='lp', pooling_convention='valid',
1001                  p_value=p_value, layout='NCDHW', **kwargs)
1002    op_export_test('pooling_lp_3d', M, [x], tmp_path)
1003
1004
1005@pytest.mark.parametrize('dtype', ['float32'])
1006@pytest.mark.parametrize('shape', [(1, 3, 64, 64), (2, 1, 60, 60)])
1007@pytest.mark.parametrize('pool_type', ['avg', 'max', 'lp'])
1008@pytest.mark.parametrize('p_value', [1, 2])
1009@pytest.mark.parametrize('kernel', [(3, 3), (14, 14)])
1010@pytest.mark.parametrize('stride', [None, (3, 4)])
1011@pytest.mark.parametrize('pad', [None, (3, 4)])
1012def test_onnx_export_pooling_global(tmp_path, dtype, shape, pool_type, p_value, kernel, stride, pad):
1013    # onnxruntime requires that pad is smaller than kernel
1014    if pad and (pad[0] >= kernel[0] or pad[1] >= kernel[1]):
1015        return
1016    x = mx.random.uniform(0, 1, shape, dtype=dtype)
1017    kwargs = {}
1018    if kernel:
1019        kwargs['kernel'] = kernel
1020    if stride:
1021        kwargs['stride'] = stride
1022    if pad:
1023        kwargs['pad'] = pad
1024    # kernel, stride, and pad should have no effect on the results
1025    M = def_model('Pooling', global_pool=True, pool_type=pool_type, pooling_convention='valid',
1026                  p_value=p_value, layout='NCHW', **kwargs)
1027    op_export_test('pooling_global', M, [x], tmp_path)
1028
1029
1030@pytest.mark.parametrize('dtype', ['float32'])
1031@pytest.mark.parametrize('shape', [(1, 3, 16, 16, 16), (1, 1, 10, 18, 18)])
1032@pytest.mark.parametrize('pool_type', ['avg', 'max', 'lp'])
1033@pytest.mark.parametrize('p_value', [1, 2])
1034@pytest.mark.parametrize('kernel', [(1, 1, 1), (3, 3, 3)])
1035@pytest.mark.parametrize('stride', [None, (1, 1, 1)])
1036@pytest.mark.parametrize('pad', [None, (0, 1, 1)])
1037def test_onnx_export_pooling_global_3d(tmp_path, dtype, shape, pool_type, p_value, kernel, stride, pad):
1038    # onnxruntime requires that pad is smaller than kernel
1039    if pad and (pad[0] >= kernel[0] or pad[1] >= kernel[1] or pad[2] >= kernel[2]):
1040        return
1041    x = mx.random.uniform(0, 1, shape, dtype=dtype)
1042    kwargs = {}
1043    if kernel:
1044        kwargs['kernel'] = kernel
1045    if stride:
1046        kwargs['stride'] = stride
1047    if pad:
1048        kwargs['pad'] = pad
1049    # kernel, stride, and pad should have no effect on the results
1050    M = def_model('Pooling', global_pool=True, pool_type=pool_type, pooling_convention='valid',
1051                  p_value=p_value, layout='NCDHW', **kwargs)
1052    op_export_test('pooling_global_3d', M, [x], tmp_path)
1053
1054
1055@pytest.mark.parametrize('dtype', ['float16', 'float32'])
1056def test_onnx_export_log2(tmp_path, dtype):
1057    x = mx.random.normal(0, 10, (2, 3, 4, 5)).astype(dtype)
1058    M = def_model('log2')
1059    op_export_test('log2', M, [x], tmp_path)
1060
1061
1062@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
1063@pytest.mark.parametrize('axis', [None, 1, [1,2], -1])
1064@pytest.mark.parametrize('operator', ['sum', 'sum_axis'])
1065def test_onnx_export_sum(tmp_path, dtype, axis, operator):
1066    if 'int' in dtype:
1067        x = mx.nd.random.randint(0, 10, (5, 6, 7, 8), dtype=dtype)
1068    else:
1069        x = mx.nd.random.normal(0, 10, (5, 6, 7, 8), dtype=dtype)
1070    if axis is not None:
1071        M = def_model(operator, axis=axis)
1072    else:
1073        M = def_model(operator)
1074    op_export_test(operator, M, [x], tmp_path)
1075
1076
1077@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
1078def test_onnx_export_broadcast_mul(tmp_path, dtype):
1079    M = def_model('broadcast_mul')
1080    x = mx.nd.array([[1,2,3],[4,5,6]], dtype=dtype)
1081    y = mx.nd.array([[0],[3]], dtype=dtype)
1082    op_export_test('broadcast_mul', M, [x, y], tmp_path)
1083
1084
1085@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64'])
1086@pytest.mark.parametrize('shape', [(3, 4, 5), (1, 2, 3, 2, 1)])
1087@pytest.mark.parametrize('p', [0, 0.1, 0.5, 1])
1088def test_onnx_export_dropout(tmp_path, dtype, shape, p):
1089    x = mx.random.uniform(-100, 100, shape=shape).astype(dtype)
1090    M = def_model('Dropout', p=p)
1091    op_export_test('Dropuout', M, [x], tmp_path)
1092
1093
1094@pytest.mark.parametrize('dtype', ['float32'])
1095@pytest.mark.parametrize('shape', [(1, 3, 64, 64), (2, 6, 60, 60)])
1096@pytest.mark.parametrize('num_filter', [2, 4, 32])
1097@pytest.mark.parametrize('num_group', [1, 2])
1098@pytest.mark.parametrize('no_bias', [True, False])
1099@pytest.mark.parametrize('kernel', [(3, 3), (4, 5), (14, 14)])
1100@pytest.mark.parametrize('stride', [None, (1, 1), (2, 2), (3, 4), (4, 5)])
1101@pytest.mark.parametrize('pad', [None, (1, 1), (3, 4), (4, 5)])
1102@pytest.mark.parametrize('dilate', [None, (1, 1)])
1103def test_onnx_export_convolution(tmp_path, dtype, shape, num_filter, num_group, no_bias,
1104                                 kernel, stride, pad, dilate):
1105    if shape[1] % num_group:
1106        return
1107    x = mx.random.uniform(0, 1, shape, dtype=dtype)
1108    w_shape = (num_filter,) + (shape[1] // num_group,) + kernel
1109    w = mx.random.uniform(0, 1, w_shape, dtype=dtype)
1110    b_shape = (num_filter)
1111    b = mx.random.uniform(0, 1, b_shape, dtype=dtype)
1112    kwargs = {}
1113    if kernel:
1114        kwargs['kernel'] = kernel
1115    if stride:
1116        kwargs['stride'] = stride
1117    if pad:
1118        kwargs['pad'] = pad
1119    if dilate:
1120        kwargs['dilate'] = dilate
1121    M = def_model('Convolution', num_filter=num_filter, num_group=num_group,  no_bias=no_bias,
1122                  layout='NCHW', **kwargs)
1123    inputs = [x, w] if no_bias else [x, w, b]
1124    op_export_test('convolution', M, inputs, tmp_path)
1125
1126
1127@pytest.mark.parametrize('dtype', ['float32'])
1128@pytest.mark.parametrize('shape', [(1, 4, 16, 16, 16), (1, 3, 10, 18, 18)])
1129@pytest.mark.parametrize('num_filter', [2, 4, 32])
1130@pytest.mark.parametrize('num_group', [1, 2])
1131@pytest.mark.parametrize('no_bias', [True, False])
1132@pytest.mark.parametrize('kernel', [(3, 3, 3), (1, 1, 1), (1, 7, 7)])
1133@pytest.mark.parametrize('stride', [None, (1, 1, 1), (1, 2, 3)])
1134@pytest.mark.parametrize('pad', [None, (0, 1, 1), (1, 2, 3)])
1135@pytest.mark.parametrize('dilate', [None, [2, 2, 2]])
1136def test_onnx_export_convolution_3D(tmp_path, dtype, shape, num_filter, num_group, no_bias,
1137                                 kernel, stride, pad, dilate):
1138    if shape[1] % num_group:
1139        return
1140    x = mx.random.uniform(0, 1, shape, dtype=dtype)
1141    w_shape = (num_filter,) + (shape[1] // num_group,) + kernel
1142    w = mx.random.uniform(0, 1, w_shape, dtype=dtype)
1143    b_shape = (num_filter)
1144    b = mx.random.uniform(0, 1, b_shape, dtype=dtype)
1145    kwargs = {}
1146    if kernel:
1147        kwargs['kernel'] = kernel
1148    if stride:
1149        kwargs['stride'] = stride
1150    if pad:
1151        kwargs['pad'] = pad
1152    if dilate:
1153        kwargs['dilate'] = dilate
1154    M = def_model('Convolution', num_filter=num_filter, num_group=num_group,  no_bias=no_bias,
1155                  layout='NCDHW', **kwargs)
1156    inputs = [x, w] if no_bias else [x, w, b]
1157    op_export_test('convolution', M, inputs, tmp_path)
1158
1159
1160@pytest.mark.parametrize('dtype', ['float16', 'float32'])
1161@pytest.mark.parametrize('num_outputs', [1, 3, 9])
1162@pytest.mark.parametrize('axis', [1, 2, -1, -2])
1163@pytest.mark.parametrize('squeeze_axis', [True, False, 0, 1])
1164def test_onnx_export_slice_channel(tmp_path, dtype, num_outputs, axis, squeeze_axis):
1165    shape = (3, 9, 18)
1166    if squeeze_axis and shape[axis] != num_outputs:
1167        return
1168    M = def_model('SliceChannel', num_outputs=num_outputs, axis=axis, squeeze_axis=squeeze_axis)
1169    x = mx.random.uniform(0, 1, shape, dtype=dtype)
1170    op_export_test('slice_channel', M, [x], tmp_path)
1171
1172
1173@pytest.mark.parametrize('dtype', ['float32', 'float64'])
1174@pytest.mark.parametrize('momentum', [0.9, 0.5, 0.1])
1175def test_onnx_export_batchnorm(tmp_path, dtype, momentum):
1176    x = mx.nd.random.normal(0, 10, (2, 3, 4, 5)).astype(dtype)
1177    gamma = mx.nd.random.normal(0, 10, (3)).astype(dtype)
1178    beta = mx.nd.random.normal(0, 10, (3)).astype(dtype)
1179    moving_mean = mx.nd.random.normal(0, 10, (3)).astype(dtype)
1180    moving_var = mx.nd.abs(mx.nd.random.normal(0, 10, (3))).astype(dtype)
1181    M = def_model('BatchNorm', eps=1e-5, momentum=momentum, fix_gamma=False, use_global_stats=False)
1182    op_export_test('BatchNorm1', M, [x, gamma, beta, moving_mean, moving_var], tmp_path)
1183
1184
1185# onnxruntime does not seem to support float64 and int32
1186@pytest.mark.parametrize('dtype', ['float32', 'int64'])
1187@pytest.mark.parametrize('axis', [0, 2, -1, -2, -3])
1188@pytest.mark.parametrize('is_ascend', [True, False, 0, 1, None])
1189@pytest.mark.parametrize('dtype_i', ['float32', 'int32', 'int64'])
1190def test_onnx_export_argsort(tmp_path, dtype, axis, is_ascend, dtype_i):
1191    A = mx.random.uniform(0, 100, (4, 5, 6)).astype(dtype)
1192    kwargs = {}
1193    if is_ascend is not None:
1194        kwargs['is_ascend'] = is_ascend
1195    M = def_model('argsort', axis=axis, dtype=dtype_i, **kwargs)
1196    op_export_test('argsort', M, [A], tmp_path)
1197
1198
1199@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
1200@pytest.mark.parametrize('reps', [(2, 3), (2, ), (2, 3, 4)])
1201def test_onnx_export_tile(tmp_path, dtype, reps):
1202    x = mx.nd.random.normal(0, 100, (5, 6)).astype(dtype)
1203    M = def_model('tile', reps=reps)
1204    op_export_test('tile', M, [x], tmp_path)
1205
1206
1207@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
1208@pytest.mark.parametrize('axis', [-3, -2, -1, 0, 1, 2])
1209@pytest.mark.parametrize('mode', ['clip', 'wrap'])
1210def test_onnx_export_take(tmp_path, dtype, axis, mode):
1211    x = mx.nd.random.normal(0, 10, (3, 4, 5)).astype(dtype)
1212    y = mx.random.randint(-100, 100, (6, 7)).astype(dtype)
1213    M1 = def_model('take')
1214    op_export_test('take1', M1, [x, y], tmp_path)
1215    M2 = def_model('take', axis=axis, mode=mode)
1216    op_export_test('take2', M2, [x, y], tmp_path)
1217
1218
1219@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
1220@pytest.mark.parametrize('axis', [-3, -2, -1, 0, 1, 2])
1221def test_onnx_export_take_raise(tmp_path, dtype, axis):
1222    x = mx.nd.random.normal(0, 10, (3, 4, 5)).astype(dtype)
1223    y = mx.random.randint(0, 3, (6, 7)).astype(dtype)
1224    M = def_model('take', axis=axis, mode='raise')
1225    op_export_test('take', M, [x, y], tmp_path)
1226
1227
1228# onnxruntime currently does not support int32
1229@pytest.mark.parametrize("dtype", ["float16", "float32", "int64"])
1230@pytest.mark.parametrize("depth", [1, 3, 5, 10])
1231@pytest.mark.parametrize("shape", [(1,1), (1,5), (5,5), (3,4,5)])
1232def test_onnx_export_one_hot(tmp_path, dtype, depth, shape):
1233    M = def_model('one_hot', depth=depth, dtype=dtype)
1234    x = mx.random.randint(0, 10, shape).astype('int64')
1235    op_export_test('one_hot', M, [x], tmp_path)
1236
1237
1238@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
1239@pytest.mark.parametrize('params', [((6, 5, 4), [1, 2, 4, 5, 6]),
1240                                     ((7, 3, 5), [1, 7, 4]),
1241                                     ((3, 2, 1), [1, 2])])
1242def test_onnx_export_sequence_reverse(tmp_path, dtype, params):
1243    x = mx.nd.random.uniform(0, 10, params[0]).astype(dtype)
1244    M1 = def_model('SequenceReverse')
1245    op_export_test('SequenceReverse1', M1, [x], tmp_path)
1246    seq_len = mx.nd.array(params[1])
1247    M1 = def_model('SequenceReverse', use_sequence_length=True)
1248    op_export_test('SequenceReverse1', M1, [x, seq_len], tmp_path)
1249
1250
1251@pytest.mark.parametrize('mode', ['lstm', 'gru', 'rnn_tanh', 'rnn_relu'])
1252@pytest.mark.parametrize('dtype', ['float32'])
1253@pytest.mark.parametrize('state_size', [16, 32])
1254@pytest.mark.parametrize('input_size', [16, 32, 64])
1255@pytest.mark.parametrize('num_layers', [1, 2])
1256@pytest.mark.parametrize('batch_size', [1, 2, 4])
1257@pytest.mark.parametrize('seq_length', [16])
1258@pytest.mark.parametrize('bidirectional', [True, False])
1259def test_onnx_export_RNN(tmp_path, mode, dtype, state_size, input_size, num_layers, batch_size, seq_length, bidirectional):
1260    # TODO: The current implementation fails assertion checks for large parm/state_size.
1261    # for num_layers >= 2, input_size must equal to state_size
1262    if num_layers >= 2 and input_size != state_size:
1263        return
1264    # Currently only bidirectional supports lstm with num_layers = 1
1265    if bidirectional and (mode != 'lstm' or num_layers != 1):
1266        return
1267
1268    b = 1
1269    if bidirectional:
1270        b = 2
1271
1272    factor = 1
1273    if mode == 'gru':
1274        factor = 3
1275    elif mode == 'lstm':
1276        factor = 4
1277
1278    M = def_model('RNN', mode=mode, state_size=state_size, state_outputs=True,  num_layers=num_layers, p=0, bidirectional=bidirectional)
1279    x = mx.nd.random.normal(0, 10, (seq_length, batch_size, input_size)).astype(dtype)
1280    param = mx.nd.random.normal(0, 1, [b*num_layers*factor*state_size*input_size +
1281                                       b*num_layers*factor*state_size*state_size +
1282                                       b*num_layers*2*factor*state_size]).astype(dtype)
1283    state = mx.nd.random.uniform(-1, 1, [b*num_layers, batch_size, state_size]).astype(dtype)
1284    if mode == 'lstm':
1285        cell = mx.nd.random.uniform(-1, 1, [b*num_layers, batch_size, state_size]).astype(dtype)
1286        op_export_test('rnn', M, [x, param, state, cell], tmp_path)
1287    elif mode == 'rnn_relu':
1288        # set large atol as relu can outputs big numbers
1289        op_export_test('rnn', M, [x, param, state], tmp_path, atol=1e20)
1290    else:
1291        op_export_test('rnn', M, [x, param, state], tmp_path, atol=1e-2)
1292
1293
1294@pytest.mark.parametrize('dtype', ['float16', 'float32', 'int32', 'int64'])
1295@pytest.mark.parametrize('shapes', [((3, 3, 3), (1, 3)), ((4, 5, 6, 7), (6, 7))])
1296def test_onnx_export_broadcast_lesser_equal(tmp_path, dtype, shapes):
1297    A = mx.nd.random.uniform(0, 5, shapes[0]).astype('int32').astype(dtype)
1298    B = mx.nd.random.uniform(0, 5, shapes[1]).astype('int32').astype(dtype)
1299    M = def_model('broadcast_lesser_equal')
1300    op_export_test('broadcast_lesser_equal', M, [A, B], tmp_path)
1301
1302
1303@pytest.mark.parametrize('dtype', ['float16', 'float32', 'int32', 'int64'])
1304@pytest.mark.parametrize('shapes', [((3, 3, 3), (1, 3)), ((4, 5, 6, 7), (6, 7))])
1305def test_onnx_export_broadcast_greater_equal(tmp_path, dtype, shapes):
1306    A = mx.nd.random.uniform(0, 5, shapes[0]).astype('int32').astype(dtype)
1307    B = mx.nd.random.uniform(0, 5, shapes[1]).astype('int32').astype(dtype)
1308    M = def_model('broadcast_greater_equal')
1309    op_export_test('broadcast_greater_equal', M, [A, B], tmp_path)
1310
1311
1312@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64'])
1313@pytest.mark.parametrize('shape', [(3, 4, 5), (6, 7), (8,)])
1314def test_onnx_export_contrib_div_sqrt_dim(tmp_path, dtype, shape):
1315    A = mx.nd.random.uniform(-100, 100, shape).astype(dtype)
1316    M = def_model('contrib.div_sqrt_dim')
1317    op_export_test('contrib_div_sqrt_dim', M, [A], tmp_path)
1318
1319
1320@pytest.mark.parametrize('dtype', ['float16', 'float32'])
1321@pytest.mark.parametrize('shape', [(3, 4, 5), (6, 7), (8,)])
1322@pytest.mark.parametrize('operator', ['sin', 'cos', 'tan', 'tanh', 'arcsin', 'arccos', 'arctan',
1323                                      'sigmoid', 'relu', 'exp', 'identity', 'BlockGrad', 'MakeLoss'])
1324def test_onnx_export_ufunc(tmp_path, dtype, shape, operator):
1325    A = mx.nd.random.uniform(-100, 100, shape).astype(dtype)
1326    M = def_model(operator)
1327    op_export_test('ufunc', M, [A], tmp_path)
1328
1329
1330@pytest.mark.parametrize('dtype', ['float32'])
1331@pytest.mark.parametrize('shape', [(1, 3, 64, 64), (2, 6, 60, 60)])
1332@pytest.mark.parametrize('num_filter', [4, 16, 256])
1333@pytest.mark.parametrize('num_group', [1, 2])
1334@pytest.mark.parametrize('no_bias', [False, True])
1335@pytest.mark.parametrize('kernel', [(2, 2), (3, 4)])
1336@pytest.mark.parametrize('stride', [(1, 1), (2, 2)])
1337@pytest.mark.parametrize('pad', [None, (0, 0), (1, 1)])
1338@pytest.mark.parametrize('dilate', [None, (1, 1)])
1339@pytest.mark.parametrize('adj', [(0, 0), (1, 1)])
1340def test_onnx_export_deconvolution(tmp_path, dtype, shape, num_filter, num_group, no_bias,
1341                                 kernel, stride, pad, dilate, adj):
1342    for i in range(len(stride)):
1343        if stride[i] <= adj[i]:
1344            return
1345    if shape[1] % num_group:
1346        return
1347    x = mx.random.uniform(0, 1, shape, dtype=dtype)
1348    w_shape = (shape[1],) + (num_filter // num_group,) + kernel
1349    w = mx.random.uniform(0, 1, w_shape, dtype=dtype)
1350    b_shape = (num_filter)
1351    b = mx.random.uniform(0, 1, b_shape, dtype=dtype)
1352    kwargs = {}
1353    if kernel:
1354        kwargs['kernel'] = kernel
1355    if stride:
1356        kwargs['stride'] = stride
1357    if pad:
1358        kwargs['pad'] = pad
1359    if dilate:
1360        kwargs['dilate'] = dilate
1361    if adj:
1362        kwargs['adj'] = adj
1363    M = def_model('Deconvolution', num_filter=num_filter, num_group=num_group,  no_bias=no_bias,
1364                  layout='NCHW', **kwargs)
1365    inputs = [x, w] if no_bias else [x, w, b]
1366    op_export_test('deconvolution', M, inputs, tmp_path)
1367
1368
1369@pytest.mark.parametrize('dtype', ['float32', 'float16', 'float64'])
1370@pytest.mark.parametrize('mode', ['edge', 'constant', 'reflect'])
1371@pytest.mark.parametrize('params', [((3, 4, 5, 6), (0, 0, 0, 0, 2, 3, 4, 5)),
1372                                    ((7, 6, 5, 4, 3), (0, 0, 0, 0, 4, 4, 3, 3, 2, 1))])
1373def test_onnx_export_pad(tmp_path, dtype, mode, params):
1374     kwargs = {}
1375     kwargs['constant_value'] = 9999.55
1376     kwargs['pad_width'] = params[1]
1377     x = mx.random.uniform(0, 1, shape=params[0], dtype=dtype)
1378     M = def_model('pad', mode=mode, **kwargs)
1379     op_export_test('pad', M, [x], tmp_path)
1380
1381
1382# Note that due to ONNX limitation, the behavior for when inputs > 2-D is different from that of
1383# MXNet
1384@pytest.mark.parametrize('dtype', ['float32', 'float64'])
1385@pytest.mark.parametrize('params', [((4, 5), (5, 6), False, False),
1386                                    ((5, 4), (5, 6), True, False),
1387                                    ((5, 4), (6, 5), True, True),
1388                                    ((4, 5), (6, 5), False, True),
1389                                    ((4, 5), (5), False, False),
1390                                    ((4,), (4, 5), False, False),
1391                                    ((4, 5), (5,), False, False)])
1392def test_onnx_export_dot(tmp_path, dtype, params):
1393    A = mx.random.uniform(0, 1, params[0], dtype=dtype)
1394    B = mx.random.uniform(0, 1, params[1], dtype=dtype)
1395    M = def_model('dot', transpose_a=params[2], transpose_b=params[3])
1396    op_export_test('dot', M, [A, B], tmp_path)
1397
1398
1399@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64'])
1400@pytest.mark.parametrize('shape', [(3, 4, 5, 6), (7, 8)])
1401def test_onnx_export_flatten(tmp_path, dtype, shape):
1402    x = mx.random.uniform(0, 1, shape, dtype='float32').astype(dtype)
1403    M = def_model('flatten')
1404    op_export_test('flatten', M, [x], tmp_path)
1405
1406
1407# Note that due to ONNX limitation, the behavior for when inputs > 2-D is different from that of
1408# MXNet
1409@pytest.mark.parametrize('dtype', ['float32', 'float64'])
1410@pytest.mark.parametrize('alpha', [1, 1.5])
1411@pytest.mark.parametrize('params', [((4, 5), (5, 4), False, False),
1412                                    ((4, 5, 6), (4, 6, 5), False, False),
1413                                    ((4, 5, 6, 7), (4, 5, 6, 7), True, False),
1414                                    ((4, 5, 6, 7), (4, 5, 6, 7), False, True),
1415                                    ((4, 5, 9, 7), (4, 5, 6, 9), True, True)])
1416def test_onnx_export_linalg_gemm2(tmp_path, dtype, alpha, params):
1417    A = mx.random.uniform(0, 1, params[0], dtype=dtype)
1418    B = mx.random.uniform(0, 1, params[1], dtype=dtype)
1419    M = def_model('linalg.gemm2', alpha=alpha, transpose_a=params[2], transpose_b=params[3])
1420    op_export_test('_linalg_gemm2', M, [A, B], tmp_path)
1421
1422
1423@pytest.mark.parametrize('dtype', ['float32'])
1424@pytest.mark.parametrize('shape', [(3, 4, 5), (6, 7), (8,)])
1425def test_onnx_export_LogisticRegressionOutput(tmp_path, dtype, shape):
1426    x = mx.random.uniform(0, 1, shape, dtype=dtype)
1427    y = mx.nd.zeros(shape, dtype=dtype)
1428    M = def_model('LogisticRegressionOutput')
1429    op_export_test('LogisticRegressionOutput', M, [x, y], tmp_path)
1430
1431
1432@pytest.mark.parametrize('dtype', ['float32', 'float64'])
1433@pytest.mark.parametrize('shape', [(4, 5, 6), (6, 7), (3, 4, 5, 6, 7)])
1434def test_onnx_export_SoftmaxOutput(tmp_path, dtype, shape):
1435    x = mx.random.uniform(0, 1, shape, dtype=dtype)
1436    y = mx.nd.zeros(shape[:-1], dtype=dtype)
1437    M = def_model('SoftmaxOutput')
1438    op_export_test('SoftmaxOutput', M, [x, y], tmp_path)
1439
1440
1441# Due to ONNX limitation, L2Normalization only supports channel mode for now
1442@pytest.mark.parametrize('dtype', ['float32'])
1443@pytest.mark.parametrize('shape', [(3, 4, 5), (3, 4, 5, 6, 7)])
1444def test_onnx_export_L2Normalization(tmp_path, dtype, shape):
1445    x = mx.random.uniform(0, 1, shape, dtype=dtype)
1446    M = def_model('L2Normalization', mode='channel')
1447    op_export_test('L2Normalization', M, [x], tmp_path)
1448
1449
1450@pytest.mark.parametrize('dtype', ['float32'])
1451@pytest.mark.parametrize('shape', [(3, 4, 5), (3, 4, 5, 6, 7)])
1452@pytest.mark.parametrize('eps', [0.001, 0.00001])
1453def test_onnx_export_InstanceNorm(tmp_path, dtype, shape, eps):
1454    x = mx.random.uniform(0, 1, shape, dtype=dtype)
1455    gamma = mx.random.uniform(0, 1, shape[1:2], dtype=dtype)
1456    beta = mx.random.uniform(0, 1, shape[1:2], dtype=dtype)
1457    M = def_model('InstanceNorm', eps=eps)
1458    op_export_test('InstanceNorm', M, [x, gamma, beta], tmp_path)
1459
1460
1461# ONNXRuntime only supports 4-D inputs
1462@pytest.mark.parametrize('dtype', ['float32'])
1463@pytest.mark.parametrize('shape', [(4, 5, 6, 7)])
1464@pytest.mark.parametrize('alpha', [0.001, 0.00001])
1465@pytest.mark.parametrize('beta', [0.75, 0.8])
1466@pytest.mark.parametrize('knorm', [1, 2])
1467@pytest.mark.parametrize('nsize', [3, 5])
1468def test_onnx_export_LRN(tmp_path, dtype, shape, alpha, beta, knorm, nsize):
1469    x = mx.random.uniform(0, 1, shape, dtype=dtype)
1470    M = def_model('LRN', alpha=alpha, beta=beta, knorm=knorm, nsize=nsize)
1471    op_export_test('LRN', M, [x], tmp_path)
1472
1473
1474@pytest.mark.parametrize('dtype', ['float32'])
1475@pytest.mark.parametrize('shape', [(1, 3, 224, 224), (5, 6, 64, 64)])
1476@pytest.mark.parametrize('h_w', [(10, 10), (7, 11)])
1477@pytest.mark.parametrize('offset', [(7, 13), (10, 10)])
1478@pytest.mark.parametrize('shape2', [None, (10, 10, 16, 16)])
1479def test_onnx_export_Crop(tmp_path, dtype, shape, h_w, offset, shape2):
1480    x = mx.random.uniform(0, 1, shape, dtype=dtype)
1481    M = def_model('Crop', h_w=h_w, offset=offset, center_crop=False)
1482    if shape2 is not None:
1483        y = mx.random.uniform(0, 1, shape2, dtype=dtype)
1484        op_export_test('Crop', M, [x, y], tmp_path)
1485    else:
1486        op_export_test('Crop', M, [x], tmp_path)
1487
1488
1489@pytest.mark.parametrize('dtype', ['float16', 'float32'])
1490@pytest.mark.parametrize('shape', [(100,), (3, 4, 5), (6, 7)])
1491def test_onnx_export_reciprocal(tmp_path, dtype, shape):
1492    A = mx.nd.random.uniform(-100, 100, shape).astype(dtype)
1493    M = def_model('reciprocal')
1494    op_export_test('reciprocal', M, [A], tmp_path)
1495
1496
1497@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
1498@pytest.mark.parametrize('shape', [(1, 3), (3, 4, 5)])
1499def test_onnx_export_power(tmp_path, shape, dtype):
1500    x = mx.nd.random.uniform(-5, 5, shape).astype(dtype)
1501    y = mx.nd.random.uniform(-10, 10, shape).astype(dtype)
1502    M = def_model('_internal._power')
1503    op_export_test('_internal._power', M, [x, y], tmp_path)
1504
1505@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
1506@pytest.mark.parametrize('shape', [(1, 3), (3, 4, 5)])
1507def test_onnx_export_broadcast_power(tmp_path, shape, dtype):
1508    x = mx.nd.random.uniform(-5, 5, shape).astype(dtype)
1509    y = mx.nd.random.uniform(-10, 10, shape).astype(dtype)
1510    M = def_model('broadcast_power')
1511    op_export_test('broadcast_power', M, [x, y], tmp_path)
1512
1513
1514@pytest.mark.parametrize("dtype", ["float16", "float32", "float64"])
1515@pytest.mark.parametrize('shape', [(3, 4, 5), (6, 7), (8,)])
1516def test_onnx_export_sqrt(tmp_path, dtype, shape):
1517    A = mx.nd.random.uniform(-100, 100, shape).astype(dtype)
1518    M = def_model('sqrt')
1519    op_export_test('sqrt', M, [A], tmp_path)
1520
1521
1522@pytest.mark.parametrize("dtype", ["float16", "float32"])
1523@pytest.mark.parametrize("params", [[(1,4,2,3), 1], [(1,4,2,3), 2]])
1524def test_onnx_export_depth_to_space(tmp_path, dtype, params):
1525    shape, block_size = params
1526    M = def_model('depth_to_space', block_size=block_size)
1527    x = mx.nd.arange(0, np.prod(shape)).reshape(shape).astype(dtype)
1528    op_export_test('depth_to_space', M, [x], tmp_path)
1529
1530
1531@pytest.mark.parametrize("dtype", ["float16", "float32"])
1532@pytest.mark.parametrize("params", [[(1,4,2,3), 1], [(1,1,4,6),2]])
1533def test_onnx_export_space_to_depth(tmp_path, dtype, params):
1534    shape, block_size = params
1535    M = def_model('space_to_depth', block_size=block_size)
1536    x = mx.nd.arange(0, np.prod(shape)).reshape(shape).astype(dtype)
1537    op_export_test('space_to_depth', M, [x], tmp_path)
1538
1539
1540@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
1541@pytest.mark.parametrize("shape", [(10,), (1,2,3), (4,5,6)])
1542def test_onnx_export_square(tmp_path, dtype, shape):
1543    M = def_model('square')
1544    x = mx.nd.arange(0, np.prod(shape)).reshape(shape).astype(dtype)
1545    op_export_test('square', M, [x], tmp_path)
1546
1547
1548@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
1549@pytest.mark.parametrize("shape", [(10,), (1,2,3), (4,5,6)])
1550def test_onnx_export_shape_array(tmp_path, dtype, shape):
1551    M = def_model('shape_array')
1552    x = mx.nd.arange(0, np.prod(shape)).reshape(shape).astype(dtype)
1553    op_export_test('shape_array', M, [x], tmp_path)
1554
1555
1556@pytest.mark.parametrize("dtype", ["float16", "float32"])
1557@pytest.mark.parametrize("shape", [(10,), (1,2,3), (4,5,6)])
1558@pytest.mark.parametrize("alpha", [None, 0.1, 0.4567, 0.9])
1559@pytest.mark.parametrize("beta", [None, 0.1, 0.4567, 0.5, 0.9])
1560def test_onnx_export_hard_sigmoid(tmp_path, dtype, shape, alpha, beta):
1561    kwargs = { }
1562    if alpha is not None:
1563        kwargs['alpha'] = alpha
1564    if beta is not None:
1565        kwargs['beta'] = beta
1566    M = def_model('hard_sigmoid', **kwargs)
1567    x = mx.nd.arange(0, np.prod(shape)).reshape(shape).astype(dtype)
1568    op_export_test('hard_sigmoid', M, [x], tmp_path)
1569
1570
1571@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
1572@pytest.mark.parametrize("shape", [(10,), (1,2,3), (4,5,6)])
1573def test_onnx_export_broadcast_lesser(tmp_path, dtype, shape):
1574    M = def_model('broadcast_lesser')
1575    x = mx.nd.random.uniform(-100, 100, shape).astype(dtype)
1576    y = mx.nd.random.uniform(-100, 100, shape).astype(dtype)
1577    op_export_test('broadcast_lesser', M, [x, y], tmp_path)
1578
1579
1580@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
1581@pytest.mark.parametrize("shape", [(10,), (1,2,3), (4,5,6)])
1582def test_onnx_export_broadcast_greater(tmp_path, dtype, shape):
1583    M = def_model('broadcast_greater')
1584    x = mx.nd.random.uniform(-100, 100, shape).astype(dtype)
1585    y = mx.nd.random.uniform(-100, 100, shape).astype(dtype)
1586    op_export_test('broadcast_greater', M, [x, y], tmp_path)
1587
1588
1589@pytest.mark.parametrize('dtype', ['float16', 'float32'])
1590@pytest.mark.parametrize("shape", [(10,5), (1,2,3), (4,5,6)])
1591@pytest.mark.parametrize('axis', [None, 1])
1592def test_onnx_export_log_softmax(tmp_path, dtype, shape, axis):
1593    x = mx.nd.random.uniform(0, 1, shape, dtype=dtype)
1594    kwargs = {}
1595    if axis is not None:
1596        kwargs['axis'] = axis
1597    M = def_model('log_softmax', **kwargs)
1598    op_export_test('log_softmax', M, [x], tmp_path)
1599
1600
1601@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
1602@pytest.mark.parametrize("shape", [(10,), (2,3), (4,5,6)])
1603def test_onnx_export_broadcast_logical_and(tmp_path, dtype, shape):
1604    M = def_model('broadcast_logical_and')
1605    x = mx.nd.random.uniform(-1, 1, shape).astype(dtype)
1606    y = mx.nd.random.uniform(-1, 1, shape).astype(dtype)
1607    op_export_test('broadcast_logical_and', M, [x, y], tmp_path)
1608
1609
1610@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
1611@pytest.mark.parametrize("shape", [(10,), (2,3), (4,5,6)])
1612def test_onnx_export_broadcast_logical_or(tmp_path, dtype, shape):
1613    M = def_model('broadcast_logical_or')
1614    x = mx.nd.random.uniform(-1, 1, shape).astype(dtype)
1615    y = mx.nd.random.uniform(-1, 1, shape).astype(dtype)
1616    op_export_test('broadcast_logical_or', M, [x, y], tmp_path)
1617
1618
1619@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
1620@pytest.mark.parametrize("shape", [(10,), (2,3), (4,5,6)])
1621def test_onnx_export_broadcast_logical_xor(tmp_path, dtype, shape):
1622    M = def_model('broadcast_logical_xor')
1623    x = mx.nd.random.uniform(-1, 1, shape).astype(dtype)
1624    y = mx.nd.random.uniform(-1, 1, shape).astype(dtype)
1625    op_export_test('broadcast_logical_xor', M, [x, y], tmp_path)
1626
1627
1628@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
1629@pytest.mark.parametrize("shapes", [[(1,3),(2,3)], [(2,1,3,1),(2,8,3,9)], [(1,3,6),(5,3,6)]])
1630def test_onnx_export_broadcast_to(tmp_path, dtype, shapes):
1631    in_shape, to_shape = shapes
1632    M = def_model('broadcast_to', shape=to_shape)
1633    x = mx.nd.random.uniform(-100, 100, in_shape).astype(dtype)
1634    op_export_test('broadcast_to', M, [x], tmp_path)
1635
1636
1637# onnxruntime currently does not support int32
1638@pytest.mark.parametrize('dtype', ['float16', 'float32', 'int64'])
1639@pytest.mark.parametrize('shape', [(1,), (2, 3), (4, 5, 6)])
1640def test_onnx_export_clip(tmp_path, dtype, shape):
1641    A = mx.nd.random.uniform(-100, 100, shape).astype(dtype)
1642    a_min = mx.nd.min(A).astype('float32').asnumpy()[0] + 5
1643    a_max = mx.nd.max(A).astype('float32').asnumpy()[0] - 5
1644    print(a_min)
1645    M = def_model('clip', a_min=a_min, a_max=a_max)
1646    op_export_test('clip', M, [A], tmp_path)
1647
1648
1649@pytest.mark.parametrize('dtype', ['float16', 'float32', 'int32', 'int64'])
1650@pytest.mark.parametrize('shape', [(3, 4, 5), (6, 7), (8,)])
1651@pytest.mark.parametrize('func', [lambda x : x + np.random.rand(1)[0]*100,
1652                                  lambda x : x * np.random.rand(1)[0]*100,
1653                                  lambda x : x - np.random.rand(1)[0]*100,
1654                                  lambda x : np.random.rand(1)[0]*100 - x,
1655                                  lambda x : x / (np.random.rand(1)[0]*100 + 1),
1656                                  lambda x : np.random.rand(1)[0]*100 / x,
1657                                  lambda x : x ** np.random.rand(1)[0]*10,
1658                                 ])
1659def test_onnx_export_scalar_op(tmp_path, dtype, shape, func):
1660    A = mx.nd.random.uniform(1, 100, shape).astype(dtype)
1661    M = def_model_from_func(func)
1662    op_export_test('_scalar', M, [A], tmp_path)
1663
1664
1665@pytest.mark.parametrize('dtype', ['float16', 'float32', 'int32'])
1666@pytest.mark.parametrize('shape', [(1, 1, 1), (2, 3, 4), (5, 6, 7, 8)])
1667@pytest.mark.parametrize('axis', ['None', 0, 1, 2, -1, -2])
1668@pytest.mark.parametrize('keepdims', [True, False])
1669@pytest.mark.parametrize('op_name', ['argmax', 'argmin'])
1670def test_onnx_export_arg_max_min(tmp_path, dtype, shape, axis, keepdims, op_name):
1671    A = mx.nd.random.uniform(-100, 100, shape).astype(dtype)
1672    M = def_model(op_name, axis=axis, keepdims=keepdims)
1673    op_export_test(op_name, M, [A], tmp_path)
1674
1675
1676# onnx max and min have issue comparing negative float16 values
1677@pytest.mark.parametrize('dtype', ['float16', 'float32', 'int32', 'int64'])
1678@pytest.mark.parametrize('shape', [[(2, 3), (2, 3)], [(5, 4), (5, 4)]])
1679@pytest.mark.parametrize('op_name', ['maximum', 'minimum'])
1680def test_onnx_export_maximum_minimum(tmp_path, dtype, shape, op_name):
1681    lhs = mx.nd.random.uniform(1, 100, shape[0]).astype(dtype)
1682    rhs = mx.nd.random.uniform(1, 100, shape[1]).astype(dtype)
1683    M = def_model(op_name)
1684    op_export_test(op_name, M, [lhs, rhs], tmp_path)
1685
1686
1687# onnx reduce ops do not support float64
1688@pytest.mark.parametrize('dtype', ['float16', 'float32','int32', 'int64'])
1689@pytest.mark.parametrize('shape', [(2, 3), (4, 5, 6)])
1690@pytest.mark.parametrize('axis', [None, 0, 1, -1, (0, 1)])
1691@pytest.mark.parametrize('keepdims', [True, False])
1692@pytest.mark.parametrize('op_name', ['max', 'min', 'mean', 'prod'])
1693def test_onnx_export_reduce_op(tmp_path, dtype, shape, axis, keepdims, op_name):
1694    if dtype != 'int64' or op_name != 'mean':
1695        # onnx ReduceMean does not support int 64
1696        x = mx.nd.random.uniform(1, 100, shape=shape).astype(dtype)
1697        M = def_model(op_name, axis=axis, keepdims=keepdims)
1698        op_export_test(op_name, M, [x], tmp_path)
1699
1700
1701@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
1702@pytest.mark.parametrize('shape', [(1,), (3, ), (4, 5), (3, 4, 5)])
1703@pytest.mark.parametrize('op_name', ['elemwise_add', 'elemwise_sub', 'elemwise_mul', 'elemwise_div'])
1704def test_onnx_export_elemwise_op(tmp_path, dtype, shape, op_name):
1705    x = mx.nd.random.uniform(1, 100, shape=shape).astype(dtype)
1706    y = mx.nd.random.uniform(1, 100, shape=shape).astype(dtype)
1707    M = def_model(op_name)
1708    op_export_test(op_name, M, [x, y], tmp_path)
1709
1710
1711@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
1712@pytest.mark.parametrize('shape', [[(3, 4), (3, 4)], [(3, 4), (3, 1)], [(3, 4), (4)]])
1713@pytest.mark.parametrize('op_name', ['broadcast_sub', 'broadcast_div'])
1714def test_onnx_export_broadcast_op(tmp_path, dtype, shape, op_name):
1715    x = mx.nd.random.uniform(1, 100, shape=shape[0]).astype(dtype)
1716    y = mx.nd.random.uniform(1, 100, shape=shape[1]).astype(dtype)
1717    M = def_model(op_name)
1718    op_export_test(op_name, M, [x, y], tmp_path)
1719
1720
1721@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
1722@pytest.mark.parametrize('shape', [(1,), (3, ), (4, 5), (3, 4, 5)])
1723def test_onnx_export_negative(tmp_path, dtype, shape):
1724    x = mx.nd.random.uniform(-100, 100, shape=shape).astype(dtype)
1725    M = def_model('negative')
1726    op_export_test('negative', M, [x], tmp_path)
1727
1728
1729@pytest.mark.parametrize('dtype', ['float16', 'float32'])
1730@pytest.mark.parametrize('shape', [(1,), (3, ), (4, 5), (3, 4, 5)])
1731def test_onnx_export_addn(tmp_path, dtype, shape):
1732    x = mx.nd.random.uniform(-100, 100, shape=shape).astype(dtype)
1733    M = def_model('add_n')
1734    op_export_test('add_n', M, [x], tmp_path)
1735
1736
1737@pytest.mark.parametrize('dtype', ['float16', 'float32'])
1738@pytest.mark.parametrize('shape', [(1,), (3, ), (4, 5), (3, 4, 5)])
1739@pytest.mark.parametrize('op_name', ['ceil', 'floor', 'log'])
1740def test_onnx_export_ufunc(tmp_path, dtype, shape, op_name):
1741    x = mx.nd.random.uniform(-100, 100, shape=shape).astype(dtype)
1742    M = def_model(op_name)
1743    op_export_test(op_name, M, [x], tmp_path)
1744
1745
1746@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
1747@pytest.mark.parametrize('shape_axis', [[(1, 1), None], [(3, 1, 2, 1), (None)], [(3, 1, 2, 1), (1)],
1748                            [(3, 1, 2, 1), (1, 3)]])
1749def test_onnx_export_squeeze(tmp_path, dtype, shape_axis):
1750    x = mx.nd.random.uniform(1, 100, shape=shape_axis[0]).astype(dtype)
1751    M = def_model('squeeze', axis=shape_axis[1])
1752    op_export_test('squeeze', M, [x], tmp_path)
1753
1754
1755@pytest.mark.parametrize("dtype", ["float16", "float32"])
1756@pytest.mark.parametrize("order", [1, 2])
1757@pytest.mark.parametrize("keepdims", [0, 1])
1758@pytest.mark.parametrize("axis", [None, 0, 1, 2, -1, (0, 2), (0, 1, 2)])
1759@pytest.mark.parametrize("shape", [(4, 5, 6), (3, 4, 5, 6)])
1760def test_onnx_export_norm(tmp_path, dtype, order, axis, shape, keepdims):
1761    kwargs = {}
1762    if order is not None:
1763        kwargs['ord'] = order
1764    if axis is not None:
1765        kwargs['axis'] = axis
1766    if keepdims is not None:
1767        kwargs['keepdims'] = keepdims
1768    M = def_model('norm', **kwargs)
1769    x = mx.random.normal(0, 10, shape).astype(dtype)
1770    op_export_test('norm', M, [x], tmp_path)
1771
1772
1773@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
1774@pytest.mark.parametrize("shape", [(10,), (2,3), (4,5,6)])
1775def test_onnx_export_logical_not(tmp_path, dtype, shape):
1776    M = def_model('logical_not')
1777    x = mx.nd.random.uniform(-1, 1, shape).astype(dtype)
1778    op_export_test('logical_not', M, [x], tmp_path)
1779
1780
1781@pytest.mark.parametrize("dtype", ["float16", "float32", "float64"])
1782@pytest.mark.parametrize("shape", [(10,), (1,2,3), (4,5,6)])
1783def test_onnx_export_random_uniform_like(tmp_path, dtype, shape):
1784    M = def_model('random.uniform_like')
1785    low = -10
1786    high = 10
1787    x = mx.nd.zeros(shape=shape).astype(dtype)
1788    def rand_check(out):
1789        for i in out:
1790            if i.any() < low or i.any() >= high:
1791                raise Exception("Invalid value")
1792        return np.zeros_like(out)
1793    def rand_check_nd(out):
1794        return rand_check(out.asnumpy())
1795    op_export_test('random.uniform_like', M, [x], tmp_path, mx_map=rand_check_nd, onnx_map=rand_check)
1796
1797
1798@pytest.mark.parametrize("dtype", ["float32", "float64"])
1799@pytest.mark.parametrize("shape", [(10,), (1,2,3), (4,5,6)])
1800def test_onnx_export_random_uniform(tmp_path, dtype, shape):
1801    low = -10
1802    high = 10
1803    M = def_model('random_uniform', low=low, high=high, shape=shape, dtype=dtype, dummy_input=True)
1804    x = mx.nd.array([1], dtype='float32')
1805    def rand_check(out):
1806        for i in out:
1807            if i.any() < low or i.any() >= high:
1808                raise Exception("Invalid value")
1809        return np.zeros_like(out)
1810    def rand_check_nd(out):
1811        return rand_check(out.asnumpy())
1812    op_export_test('random_uniform', M, [x], tmp_path, mx_map=rand_check_nd, onnx_map=rand_check, dummy_input=True)
1813
1814
1815@pytest.mark.parametrize("dtype", ["float32", "float64"])
1816@pytest.mark.parametrize("shape", [(10,), (1,2,3), (4,5,6)])
1817@pytest.mark.parametrize("loc", [None, 0, 1, 2])
1818@pytest.mark.parametrize("scale", [None, 1, 2])
1819def test_onnx_export_random_normal(tmp_path, dtype, loc, scale, shape):
1820    kwargs = {
1821        'dtype': dtype,
1822        'shape': shape,
1823        'dummy_input': True
1824    }
1825    if loc is not None:
1826        kwargs['loc'] = loc
1827    if scale is not None:
1828        kwargs['scale'] = scale
1829    M = def_model('random_normal', **kwargs)
1830    x = mx.nd.array([1], dtype='float32')
1831    def rand_check(out):
1832        return np.zeros_like(out)
1833    def rand_check_nd(out):
1834        return rand_check(out.asnumpy())
1835    op_export_test('random_normal', M, [x], tmp_path, mx_map=rand_check_nd, onnx_map=rand_check, dummy_input=True)
1836
1837
1838@pytest.mark.parametrize("dtype", ["float16", "float32"])
1839@pytest.mark.parametrize("spatial_scale", [0.7, 1.0])
1840def test_onnx_export_roi_pooling(tmp_path, dtype, spatial_scale):
1841    M = def_model('ROIPooling', pooled_size=(2,2), spatial_scale=spatial_scale)
1842    x = mx.nd.arange(start=0, stop=48, dtype=dtype).reshape((1,1,8,6))
1843    y = mx.nd.array([[0,0,0,4,4]], dtype=dtype)
1844    op_export_test('ROIPooling', M, [x, y], tmp_path)
1845
1846
1847@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
1848@pytest.mark.parametrize("shape", [(1,2,3), (1,10)])
1849@pytest.mark.parametrize("axis", [None, 0, 1])
1850def test_onnx_export_rnn_param_concat(tmp_path, dtype, shape, axis):
1851    kwargs = {}
1852    if axis is not None:
1853        kwargs['dim'] = axis
1854    M = def_model('_internal._rnn_param_concat', **kwargs)
1855    x = mx.nd.random.uniform(-1, 1, shape).astype(dtype)
1856    y = mx.nd.random.uniform(-1, 1, shape).astype(dtype)
1857    op_export_test('_internal._rnn_param_concat', M, [x, y], tmp_path)
1858
1859
1860@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
1861@pytest.mark.parametrize("shape", [(10,), (1,2,3), (4,5,6)])
1862def test_onnx_export_size_array(tmp_path, dtype, shape):
1863    M = def_model('size_array')
1864    x = mx.nd.random.uniform(-1, 1, shape).astype(dtype)
1865    op_export_test('size_array', M, [x], tmp_path)
1866
1867
1868@pytest.mark.parametrize("dtype", ["float16", "float32"])
1869@pytest.mark.parametrize("shape", [(1,5), (2,10), (4,5)])
1870@pytest.mark.parametrize("sample_shape", [(1), (2)])
1871def test_onnx_export_sample_multinomial(tmp_path, dtype, shape, sample_shape):
1872    kwargs = {}
1873    if sample_shape is not None:
1874        kwargs['shape'] = sample_shape
1875    M = def_model('sample_multinomial', **kwargs)
1876    a = mx.nd.random.uniform(0, 1, shape).astype(dtype)
1877    x = a/a.sum(axis=1, keepdims=1)
1878    def rand_check(out):
1879        return np.zeros_like(out)
1880    def rand_check_nd(out):
1881        return rand_check(out.asnumpy())
1882    op_export_test('sample_multinomial', M, [x], tmp_path, mx_map=rand_check_nd, onnx_map=rand_check)
1883
1884
1885@pytest.mark.parametrize("dtype", ['float32', 'int32', 'int64'])
1886@pytest.mark.parametrize('params', [((2, 4, 6), (1, ), 0, True),
1887                                    ((4, 5, 6), (2, 4), 1, False),
1888                                    ((4, 5, 6, 7), (0, 2, 4), 2, False),
1889                                    ((4, 5, 6, 7), 3, -2, False),
1890                                    ((2, 6, 8), 8, -1, True)])
1891def test_onnx_export_split_v2(tmp_path, dtype, params):
1892    from onnx.defs import onnx_opset_version
1893    if onnx_opset_version() < 13 and not isinstance(params[1], int):
1894        # opset12 only supports sections. indices is supported since opset13
1895        return
1896    M = def_model('split_v2', indices_or_sections=params[1], axis=params[2], squeeze_axis=params[3])
1897    x = mx.nd.random.uniform(0, 10, params[0]).astype(dtype)
1898    op_export_test('split_v2', M, [x], tmp_path)
1899