1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import mxnet as mx
19import numpy as np
20import gluoncv
21import onnxruntime
22
23from mxnet.test_utils import assert_almost_equal
24from common import with_seed
25
26import json
27import os
28import pytest
29import shutil
30
31
32class GluonModel():
33    def __init__(self, model_name, input_shape, input_dtype, tmpdir):
34        self.model_name = model_name
35        self.input_shape = input_shape
36        self.input_dtype = input_dtype
37        self.modelpath = os.path.join(tmpdir, model_name)
38        self.ctx = mx.cpu(0)
39        self.get_model()
40        self.export()
41
42    def get_model(self):
43        self.model = gluoncv.model_zoo.get_model(self.model_name, pretrained=True, ctx=self.ctx)
44        self.model.hybridize()
45
46    def export(self):
47        data = mx.nd.zeros(self.input_shape, dtype=self.input_dtype, ctx=self.ctx)
48        self.model.forward(data)
49        self.model.export(self.modelpath, 0)
50
51    def export_onnx(self):
52        onnx_file = self.modelpath + ".onnx"
53        mx.onnx.export_model(self.modelpath + "-symbol.json", self.modelpath + "-0000.params",
54                             [self.input_shape], self.input_dtype, onnx_file)
55        return onnx_file
56
57    def export_onnx_dynamic(self, dynamic_input_shapes):
58        onnx_file = self.modelpath + ".onnx"
59        mx.onnx.export_model(self.modelpath + "-symbol.json", self.modelpath + "-0000.params",
60                             [self.input_shape], self.input_dtype, onnx_file, dynamic=True,
61                             dynamic_input_shapes=dynamic_input_shapes)
62        return onnx_file
63
64    def export_onnx_argaux(self):
65        onnx_file = self.modelpath + ".onnx"
66        sym_file = self.modelpath + "-symbol.json"
67        params_file = self.modelpath + "-0000.params"
68        if not (os.path.isfile(sym_file) and os.path.isfile(params_file)):
69            raise ValueError("Symbol and params files provided are invalid")
70
71        try:
72            # reads symbol.json file from given path and
73            # retrieves model prefix and number of epochs
74            model_name = sym_file.rsplit('.', 1)[0].rsplit('-', 1)[0]
75            params_file_list = params_file.rsplit('.', 1)[0].rsplit('-', 1)
76            # Setting num_epochs to 0 if not present in filename
77            num_epochs = 0 if len(params_file_list) == 1 else int(params_file_list[1])
78        except IndexError:
79            logging.info("Model and params name should be in format: "
80                         "prefix-symbol.json, prefix-epoch.params")
81            raise
82
83        sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, num_epochs)
84        params = [arg_params, aux_params]
85        mx.onnx.export_model(sym, params, [self.input_shape], self.input_dtype, onnx_file)
86        return onnx_file
87
88    def predict(self, data):
89        return self.model(data)
90
91
92@pytest.fixture(scope="session")
93def obj_class_test_images(tmpdir_factory):
94    tmpdir = tmpdir_factory.mktemp("obj_class_data")
95    from urllib.parse import urlparse
96    test_image_urls = [
97        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/bikers.jpg',
98        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/car.jpg',
99        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/dancer.jpg',
100        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/duck.jpg',
101        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/fieldhockey.jpg',
102        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/flower.jpg',
103        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/runners.jpg',
104        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/shark.jpg',
105        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/soccer2.jpg',
106        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/tree.jpg',
107    ]
108    paths = []
109    for url in test_image_urls:
110        fn = os.path.join(tmpdir, os.path.basename(urlparse(url).path))
111        mx.test_utils.download(url, fname=fn)
112        paths.append(fn)
113    return paths
114
115@pytest.mark.onnx_cv_batch1
116@pytest.mark.parametrize('model', [
117    'alexnet',
118    'cifar_resnet20_v1',
119    'cifar_resnet56_v1',
120    'cifar_resnet110_v1',
121    'cifar_resnet20_v2',
122    'cifar_resnet56_v2',
123    'cifar_resnet110_v2',
124    'cifar_wideresnet16_10',
125    'cifar_wideresnet28_10',
126    'cifar_wideresnet40_8',
127    'cifar_resnext29_16x64d',
128    'darknet53',
129    'densenet121',
130    'densenet161',
131    'densenet169',
132    'densenet201',
133    'googlenet',
134    'mobilenet1.0',
135    'mobilenet0.75',
136    'mobilenet0.5',
137    'mobilenet0.25',
138    'mobilenetv2_1.0',
139    'mobilenetv2_0.75',
140    'mobilenetv2_0.5',
141    'mobilenetv2_0.25',
142    pytest.param('mobilenetv3_large', marks=pytest.mark.integration),
143    'mobilenetv3_small',
144    'resnest14',
145    'resnest26',
146    'resnest50',
147    'resnest101',
148    pytest.param('resnest200', marks=pytest.mark.integration),
149    'resnest269',
150    'resnet18_v1',
151    'resnet18_v1b_0.89',
152    'resnet18_v2',
153    'resnet34_v1',
154    'resnet34_v2',
155    'resnet50_v1',
156    'resnet50_v1d_0.86',
157    'resnet50_v1d_0.48',
158    'resnet50_v1d_0.37',
159    'resnet50_v1d_0.11',
160    pytest.param('resnet50_v2', marks=pytest.mark.integration),
161    'resnet101_v1',
162    'resnet101_v1d_0.76',
163    'resnet101_v1d_0.73',
164    'resnet101_v2',
165    'resnet152_v1',
166    'resnet152_v2',
167    'resnext50_32x4d',
168    'resnext101_32x4d',
169    'resnext101_64x4d',
170    'senet_154',
171    'se_resnext101_32x4d',
172    'se_resnext101_64x4d',
173    'se_resnext50_32x4d',
174    'squeezenet1.0',
175    'squeezenet1.1',
176    'vgg11',
177    'vgg11_bn',
178    'vgg13',
179    'vgg13_bn',
180    'vgg16',
181    'vgg16_bn',
182    'vgg19',
183    pytest.param('vgg19_bn', marks=pytest.mark.integration),
184    'xception',
185    'inceptionv3'
186])
187def test_obj_class_model_inference_onnxruntime(tmp_path, model, obj_class_test_images):
188    inlen = 299 if 'inceptionv3' == model else 224
189    def normalize_image(imgfile):
190        img_data = mx.image.imread(imgfile)
191        img_data = mx.image.imresize(img_data, inlen, inlen)
192        img_data = img_data.transpose([2, 0, 1]).astype('float32')
193        mean_vec = mx.nd.array([0.485, 0.456, 0.406])
194        stddev_vec = mx.nd.array([0.229, 0.224, 0.225])
195        norm_img_data = mx.nd.zeros(img_data.shape).astype('float32')
196        for i in range(img_data.shape[0]):
197            norm_img_data[i,:,:] = (img_data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]
198        return norm_img_data.reshape(1, 3, inlen, inlen).astype('float32')
199
200    try:
201        tmp_path = str(tmp_path)
202        M = GluonModel(model, (1,3,inlen,inlen), 'float32', tmp_path)
203        if model == 'resnet50_v2':
204            # testing export for arg/aux
205            onnx_file = M.export_onnx_argaux()
206        else:
207            onnx_file = M.export_onnx()
208
209        # create onnxruntime session using the generated onnx file
210        ses_opt = onnxruntime.SessionOptions()
211        ses_opt.log_severity_level = 3
212        session = onnxruntime.InferenceSession(onnx_file, ses_opt)
213        input_name = session.get_inputs()[0].name
214
215        for img in obj_class_test_images:
216            img_data = normalize_image(img)
217            mx_result = M.predict(img_data)
218            onnx_result = session.run([], {input_name: img_data.asnumpy()})[0]
219            assert_almost_equal(mx_result, onnx_result)
220
221    finally:
222        shutil.rmtree(tmp_path)
223
224
225@pytest.fixture(scope="session")
226def obj_detection_test_images(tmpdir_factory):
227    tmpdir = tmpdir_factory.mktemp("obj_det_data")
228    from urllib.parse import urlparse
229    test_image_urls = [
230        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/fieldhockey.jpg',
231        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/flower.jpg',
232        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/runners.jpg',
233        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/shark.jpg',
234        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/soccer2.jpg',
235        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/tree.jpg',
236    ]
237    paths = []
238    for url in test_image_urls:
239        fn = os.path.join(tmpdir, os.path.basename(urlparse(url).path))
240        mx.test_utils.download(url, fname=fn)
241        paths.append(fn)
242    return paths
243
244
245@pytest.mark.onnx_cv_batch2
246@pytest.mark.parametrize('model', [
247    'center_net_resnet18_v1b_voc',
248    'center_net_resnet50_v1b_voc',
249    pytest.param('center_net_resnet101_v1b_voc', marks=pytest.mark.integration),
250    'center_net_resnet18_v1b_coco',
251    'center_net_resnet50_v1b_coco',
252    'center_net_resnet101_v1b_coco',
253    'ssd_300_vgg16_atrous_voc',
254    'ssd_512_vgg16_atrous_voc',
255    'ssd_512_resnet50_v1_voc',
256    'ssd_512_mobilenet1.0_voc',
257    'faster_rcnn_resnet50_v1b_voc',
258    'yolo3_darknet53_voc',
259    'yolo3_mobilenet1.0_voc',
260    'ssd_300_vgg16_atrous_coco',
261    'ssd_512_vgg16_atrous_coco',
262    'ssd_300_resnet34_v1b_coco',
263    'ssd_512_resnet50_v1_coco',
264    'ssd_512_mobilenet1.0_coco',
265    'faster_rcnn_resnet50_v1b_coco',
266    'faster_rcnn_resnet101_v1d_coco',
267    'yolo3_darknet53_coco',
268    'yolo3_mobilenet1.0_coco',
269    'faster_rcnn_fpn_resnet50_v1b_coco',
270    # Those two models were failing in nightly ci due to anticipated mxnet onnx nms numerical differences.
271    # Model ouputs look good to human eyes. We will need to rewrite the bbox check
272    #'faster_rcnn_fpn_resnet101_v1d_coco',
273    #'mask_rcnn_fpn_resnet18_v1b_coco',
274    'mask_rcnn_resnet18_v1b_coco',
275    'mask_rcnn_resnet50_v1b_coco',
276    'mask_rcnn_resnet101_v1d_coco',
277    'mask_rcnn_fpn_resnet50_v1b_coco',
278    'mask_rcnn_fpn_resnet101_v1d_coco',
279])
280def test_obj_detection_model_inference_onnxruntime(tmp_path, model, obj_detection_test_images):
281    def assert_obj_detetion_result(mx_ids, mx_scores, mx_boxes,
282                                   onnx_ids, onnx_scores, onnx_boxes,
283                                   score_thresh=0.6, score_tol=0.0001, box_tol=0.01):
284        def assert_bbox(mx_boxe, onnx_boxe):
285            def assert_scalar(a, b):
286                return np.abs(a-b) <= box_tol
287            return assert_scalar(mx_boxe[0], onnx_boxe[0]) and assert_scalar(mx_boxe[1], onnx_boxe[1]) \
288                      and assert_scalar(mx_boxe[2], onnx_boxe[2]) and assert_scalar(mx_boxe[3], onnx_boxe[3])
289
290        found_match = False
291        for i in range(len(onnx_ids)):
292            onnx_id = onnx_ids[i][0]
293            onnx_score = onnx_scores[i][0]
294            onnx_boxe = onnx_boxes[i]
295            if onnx_score < score_thresh:
296                break
297            for j in range(len(mx_ids)):
298                mx_id = mx_ids[j].asnumpy()[0]
299                mx_score = mx_scores[j].asnumpy()[0]
300                mx_boxe = mx_boxes[j].asnumpy()
301                # check socre
302                if onnx_score < mx_score - score_tol:
303                    continue
304                if onnx_score > mx_score + score_tol:
305                    assert found_match, 'match not found'
306                # check id
307                if onnx_id != mx_id:
308                    continue
309                # check bounding box
310                if assert_bbox(mx_boxe, onnx_boxe):
311                    found_match = True
312                    break
313            assert found_match, 'match not found'
314            found_match = False
315
316    def normalize_image(imgfile):
317        img = mx.image.imread(imgfile)
318        img, _ = mx.image.center_crop(img, size=(512, 512))
319        img, _ = gluoncv.data.transforms.presets.center_net.transform_test(img, short=512)
320        return img
321
322    try:
323        tmp_path = str(tmp_path)
324        M = GluonModel(model, (1,3,512,512), 'float32', tmp_path)
325        onnx_file = M.export_onnx()
326        # create onnxruntime session using the generated onnx file
327        ses_opt = onnxruntime.SessionOptions()
328        ses_opt.log_severity_level = 3
329        session = onnxruntime.InferenceSession(onnx_file, ses_opt)
330        input_name = session.get_inputs()[0].name
331
332        for img in obj_detection_test_images:
333            img_data = normalize_image(img)
334            if model.startswith('mask_rcnn'):
335                mx_class_ids, mx_scores, mx_boxes, _ = M.predict(img_data)
336            else:
337                mx_class_ids, mx_scores, mx_boxes = M.predict(img_data)
338            # center_net_resnet models have different output format
339            if 'center_net_resnet' in model:
340                onnx_scores, onnx_class_ids, onnx_boxes = session.run([], {input_name: img_data.asnumpy()})
341                assert_almost_equal(mx_class_ids, onnx_class_ids)
342                assert_almost_equal(mx_scores, onnx_scores)
343                assert_almost_equal(mx_boxes, onnx_boxes)
344            else:
345                if model.startswith('mask_rcnn'):
346                    onnx_class_ids, onnx_scores, onnx_boxes, _ = session.run([], {input_name: img_data.asnumpy()})
347                    assert_obj_detetion_result(mx_class_ids[0], mx_scores[0], mx_boxes[0],
348                                               onnx_class_ids[0], onnx_scores[0], onnx_boxes[0],
349                                               score_thresh=0.8, score_tol=0.05, box_tol=15)
350                elif model.startswith('faster_rcnn_fpn'):
351                    onnx_class_ids, onnx_scores, onnx_boxes = session.run([], {input_name: img_data.asnumpy()})
352                    assert_obj_detetion_result(mx_class_ids[0], mx_scores[0], mx_boxes[0],
353                                               onnx_class_ids[0], onnx_scores[0], onnx_boxes[0],
354                                               score_thresh=0.8, score_tol=0.05, box_tol=30)
355                else:
356                    onnx_class_ids, onnx_scores, onnx_boxes = session.run([], {input_name: img_data.asnumpy()})
357                    assert_obj_detetion_result(mx_class_ids[0], mx_scores[0], mx_boxes[0],
358                                               onnx_class_ids[0], onnx_scores[0], onnx_boxes[0])
359
360    finally:
361        shutil.rmtree(tmp_path)
362
363@pytest.fixture(scope="session")
364def img_segmentation_test_images(tmpdir_factory):
365    tmpdir = tmpdir_factory.mktemp("img_seg_data")
366    from urllib.parse import urlparse
367    test_image_urls = [
368        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/bikers.jpg',
369        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/car.jpg',
370        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/dancer.jpg',
371        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/duck.jpg',
372        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/fieldhockey.jpg',
373        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/flower.jpg',
374        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/runners.jpg',
375        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/shark.jpg',
376        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/soccer2.jpg',
377        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/tree.jpg',
378    ]
379    paths = []
380    for url in test_image_urls:
381        fn = os.path.join(tmpdir, os.path.basename(urlparse(url).path))
382        mx.test_utils.download(url, fname=fn)
383        paths.append(fn)
384    return paths
385
386@pytest.mark.onnx_cv_batch2
387@pytest.mark.parametrize('model', [
388    'fcn_resnet50_ade',
389    'fcn_resnet101_ade',
390    'deeplab_resnet50_ade',
391    'deeplab_resnet101_ade',
392    'deeplab_resnest50_ade',
393    'deeplab_resnest101_ade',
394    # cannot download this model, skipping for now
395    # 'deeplab_resnest200_ade',
396    'deeplab_resnest269_ade',
397    'fcn_resnet101_coco',
398    'deeplab_resnet101_coco',
399    'fcn_resnet101_voc',
400    'deeplab_resnet101_voc',
401    'deeplab_resnet152_voc',
402    pytest.param('deeplab_resnet50_citys', marks=pytest.mark.integration),
403    'deeplab_resnet101_citys',
404    'deeplab_v3b_plus_wideresnet_citys',
405    'danet_resnet50_citys',
406    'danet_resnet101_citys'
407])
408def test_img_segmentation_model_inference_onnxruntime(tmp_path, model, img_segmentation_test_images):
409    def normalize_image(imgfile):
410        img = mx.image.imread(imgfile).astype('float32')
411        img, _ = mx.image.center_crop(img, size=(480, 480))
412        img = gluoncv.data.transforms.presets.segmentation.test_transform(img, mx.cpu(0))
413        return img
414
415
416    try:
417        tmp_path = str(tmp_path)
418        M = GluonModel(model, (1,3,480,480), 'float32', tmp_path)
419        onnx_file = M.export_onnx()
420        # create onnxruntime session using the generated onnx file
421        ses_opt = onnxruntime.SessionOptions()
422        ses_opt.log_severity_level = 3
423        session = onnxruntime.InferenceSession(onnx_file, ses_opt)
424        input_name = session.get_inputs()[0].name
425
426        for img in img_segmentation_test_images:
427            img_data = normalize_image(img)
428            mx_result = M.predict(img_data)
429            onnx_result = session.run([], {input_name: img_data.asnumpy()})
430            assert(len(mx_result) == len(onnx_result))
431            for i in range(len(mx_result)):
432                assert_almost_equal(mx_result[i], onnx_result[i])
433
434    finally:
435        shutil.rmtree(tmp_path)
436
437
438@pytest.fixture(scope="session")
439def pose_estimation_test_images(tmpdir_factory):
440    tmpdir = tmpdir_factory.mktemp("pose_est_data")
441    from urllib.parse import urlparse
442    test_image_urls = [
443        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/bikers.jpg',
444        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/dancer.jpg',
445        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/fieldhockey.jpg',
446        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/runners.jpg',
447        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/soccer2.jpg',
448    ]
449    paths = []
450    for url in test_image_urls:
451        fn = os.path.join(tmpdir, os.path.basename(urlparse(url).path))
452        mx.test_utils.download(url, fname=fn)
453        paths.append(fn)
454    return paths
455
456@pytest.mark.onnx_cv_batch1
457@pytest.mark.parametrize('model', [
458    'simple_pose_resnet18_v1b',
459    'simple_pose_resnet50_v1b',
460    'simple_pose_resnet50_v1d',
461    'simple_pose_resnet101_v1b',
462    'simple_pose_resnet101_v1d',
463    'simple_pose_resnet152_v1b',
464    'simple_pose_resnet152_v1d',
465    'alpha_pose_resnet101_v1b_coco',
466    'mobile_pose_resnet18_v1b',
467    'mobile_pose_resnet50_v1b',
468    pytest.param('mobile_pose_mobilenet1.0', marks=pytest.mark.integration),
469    'mobile_pose_mobilenetv2_1.0',
470    'mobile_pose_mobilenetv3_large',
471    'mobile_pose_mobilenetv3_small',
472])
473def test_pose_estimation_model_inference_onnxruntime(tmp_path, model, pose_estimation_test_images):
474    def normalize_image(imgfile):
475        img = mx.image.imread(imgfile).astype('float32')
476        img, _ = mx.image.center_crop(img, size=(512, 512))
477        img = gluoncv.data.transforms.presets.segmentation.test_transform(img, mx.cpu(0))
478        return img
479
480    try:
481        tmp_path = str(tmp_path)
482        M = GluonModel(model, (1,3,512,512), 'float32', tmp_path)
483        onnx_file = M.export_onnx()
484        # create onnxruntime session using the generated onnx file
485        ses_opt = onnxruntime.SessionOptions()
486        ses_opt.log_severity_level = 3
487        session = onnxruntime.InferenceSession(onnx_file, ses_opt)
488        input_name = session.get_inputs()[0].name
489
490        for img in pose_estimation_test_images:
491            img_data = normalize_image(img)
492            mx_result = M.predict(img_data)
493            onnx_result = session.run([], {input_name: img_data.asnumpy()})
494            assert(len(mx_result) == len(onnx_result))
495            for i in range(len(mx_result)):
496                assert_almost_equal(mx_result[i], onnx_result[i])
497
498    finally:
499        shutil.rmtree(tmp_path)
500
501@pytest.fixture(scope="session")
502def act_recognition_test_data(tmpdir_factory):
503    tmpdir = tmpdir_factory.mktemp("act_rec_data")
504    from urllib.parse import urlparse
505    test_image_urls = [
506        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/actions/biking.rec',
507        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/actions/diving.rec',
508        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/actions/golfing.rec',
509        'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/actions/sledding.rec',
510    ]
511    paths = []
512    for url in test_image_urls:
513        fn = os.path.join(tmpdir, os.path.basename(urlparse(url).path))
514        mx.test_utils.download(url, fname=fn)
515        paths.append(fn)
516    return paths
517
518@pytest.mark.onnx_cv_batch2
519@pytest.mark.parametrize('model', [
520    'inceptionv1_kinetics400',
521    'resnet18_v1b_kinetics400',
522    'resnet34_v1b_kinetics400',
523    'resnet50_v1b_kinetics400',
524    'resnet101_v1b_kinetics400',
525    'resnet152_v1b_kinetics400',
526    'resnet50_v1b_hmdb51',
527    'resnet50_v1b_sthsthv2',
528    'vgg16_ucf101',
529    pytest.param('inceptionv3_kinetics400', marks=pytest.mark.integration),
530    'inceptionv3_ucf101',
531])
532def test_action_recognition_model_inference_onnxruntime(tmp_path, model, act_recognition_test_data):
533    batch_size = 64
534    input_len = 224
535    if 'inceptionv3' in model:
536        input_len = 340
537
538    def load_video(filepath):
539        iterator = mx.image.ImageIter(batch_size=batch_size, data_shape=(3,input_len,input_len), path_imgrec=filepath)
540        for batch in iterator:
541            return batch.data[0]
542
543    try:
544        tmp_path = str(tmp_path)
545        M = GluonModel(model, (batch_size,3,input_len,input_len), 'float32', tmp_path)
546        onnx_file = M.export_onnx()
547        # create onnxruntime session using the generated onnx file
548        ses_opt = onnxruntime.SessionOptions()
549        ses_opt.log_severity_level = 3
550        session = onnxruntime.InferenceSession(onnx_file, ses_opt)
551        input_name = session.get_inputs()[0].name
552
553        for video in act_recognition_test_data:
554            data = load_video(video)
555            mx_result = M.predict(data)
556            onnx_result = session.run([], {input_name: data.asnumpy()})[0]
557            assert_almost_equal(mx_result, onnx_result, rtol=0.001, atol=0.01)
558
559    finally:
560        shutil.rmtree(tmp_path)
561
562
563@with_seed()
564@pytest.mark.onnx_cv_batch1
565@pytest.mark.integration
566@pytest.mark.parametrize('model_name', ['mobilenet1.0', 'inceptionv3', 'darknet53', 'resnest14'])
567def test_dynamic_shape_cv_inference_onnxruntime(tmp_path, model_name):
568    tmp_path = str(tmp_path)
569    try:
570        M = GluonModel(model_name, (1, 3, 512, 512), 'float32', tmp_path)
571        dynamic_input_shapes = [(None, 3, 512, 512)]
572        onnx_file = M.export_onnx_dynamic(dynamic_input_shapes)
573
574        # create onnxruntime session using the generated onnx file
575        ses_opt = onnxruntime.SessionOptions()
576        ses_opt.log_severity_level = 3
577        sess = onnxruntime.InferenceSession(onnx_file, ses_opt)
578
579        # test on a different batch size
580        x = mx.random.uniform(0, 10, (5, 3, 512, 512))
581        in_tensors = [x]
582        input_dict = dict((sess.get_inputs()[i].name, in_tensors[i].asnumpy()) for i in range(len(in_tensors)))
583        pred_on = sess.run(None, input_dict)
584
585        pred_mx = M.predict(x)
586
587        assert_almost_equal(pred_mx, pred_on[0])
588
589    finally:
590        shutil.rmtree(tmp_path)
591