1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import mxnet as mx
19import numpy as np
20import onnxruntime
21
22from mxnet.test_utils import assert_almost_equal
23from common import with_seed
24
25import json
26import os
27import pytest
28import shutil
29
30
31@with_seed()
32@pytest.mark.parametrize('model_name', ['roberta_24_1024_16', 'roberta_12_768_12'])
33def test_roberta_inference_onnxruntime(tmp_path, model_name):
34    tmp_path = str(tmp_path)
35    try:
36        import gluonnlp as nlp
37        ctx = mx.cpu(0)
38
39        dataset= 'openwebtext_ccnews_stories_books_cased'#'book_corpus_wiki_en_uncased'
40        model, _ = nlp.model.get_model(
41        name=model_name,
42        ctx=ctx,
43        pretrained=True,
44        use_decoder=True,
45        dataset_name=dataset)
46
47        model.hybridize(static_alloc=False)
48
49        batch = 2
50        seq_length = 32
51        num_masked_positions = 1
52        inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32', ctx=ctx)
53        valid_length = mx.nd.array([seq_length] * batch, dtype='float32', ctx=ctx)
54        masked_positions = mx.nd.random.uniform(0, 32, shape=(batch, num_masked_positions),
55            dtype='float32', ctx=ctx).astype('int32')
56
57        sequence_outputs, attention_outputs= model(inputs, valid_length, masked_positions)
58
59        prefix = "%s/roberta" % tmp_path
60        model.export(prefix)
61
62        sym_file = "%s-symbol.json" % prefix
63        params_file = "%s-0000.params" % prefix
64        onnx_file = "%s.onnx" % prefix
65        input_shapes = [(batch, seq_length), (batch,), (batch, num_masked_positions)]
66        converted_model_path = mx.onnx.export_model(sym_file, params_file, input_shapes,
67                                                    [np.float32, np.float32, np.int32],
68                                                    onnx_file, verbose=True)
69
70        sess_options = onnxruntime.SessionOptions()
71        sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
72        sess = onnxruntime.InferenceSession(onnx_file, sess_options)
73
74        in_tensors = [inputs, valid_length, masked_positions]
75        input_dict = dict((sess.get_inputs()[i].name, in_tensors[i].asnumpy()) for i in range(len(in_tensors)))
76        pred = sess.run(None, input_dict)
77
78        assert_almost_equal(sequence_outputs, pred[0])
79        assert_almost_equal(attention_outputs, pred[1])
80
81    finally:
82        shutil.rmtree(tmp_path)
83
84
85@with_seed()
86@pytest.mark.integration
87@pytest.mark.parametrize('model', ['bert_12_768_12', 'bert_24_1024_16'])
88def test_bert_inference_onnxruntime(tmp_path, model):
89    tmp_path = str(tmp_path)
90    try:
91        import gluonnlp as nlp
92        dataset = 'book_corpus_wiki_en_uncased'
93        ctx = mx.cpu(0)
94        model, vocab = nlp.model.get_model(
95            name=model,
96            ctx=ctx,
97            dataset_name=dataset,
98            pretrained=True,
99            use_pooler=True,
100            use_decoder=False,
101            use_classifier=False)
102
103        model.hybridize(static_alloc=True)
104
105        batch = 5
106        seq_length = 16
107        # create synthetic test data
108        inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32')
109        token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32')
110        valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
111
112        seq_encoding, cls_encoding = model(inputs, token_types, valid_length)
113
114        prefix = "%s/bert" % tmp_path
115        model.export(prefix)
116        sym_file = "%s-symbol.json" % prefix
117        params_file = "%s-0000.params" % prefix
118        onnx_file = "%s.onnx" % prefix
119
120
121        input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)]
122        input_types = [np.float32, np.float32, np.float32]
123        converted_model_path = mx.onnx.export_model(sym_file, params_file, input_shapes, input_types, onnx_file)
124
125
126        # create onnxruntime session using the generated onnx file
127        ses_opt = onnxruntime.SessionOptions()
128        ses_opt.log_severity_level = 3
129        session = onnxruntime.InferenceSession(onnx_file, ses_opt)
130        onnx_inputs = [inputs, token_types, valid_length]
131        input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs)))
132        pred_onx, cls_onx = session.run(None, input_dict)
133
134        assert_almost_equal(seq_encoding, pred_onx, rtol = 0.0001, atol = 0.0025)
135        assert_almost_equal(cls_encoding, cls_onx, rtol = 0.0001, atol = 0.0025)
136
137    finally:
138        shutil.rmtree(tmp_path)
139
140
141@with_seed()
142@pytest.mark.parametrize('model_name', ['distilbert_6_768_12'])
143def test_distilbert_inference_onnxruntime(tmp_path, model_name):
144    tmp_path = str(tmp_path)
145    try:
146        import gluonnlp as nlp
147        dataset = 'distilbert_book_corpus_wiki_en_uncased'
148        ctx = mx.cpu(0)
149        model, _ = nlp.model.get_model(
150            name=model_name,
151            ctx=ctx,
152            pretrained=True,
153            dataset_name=dataset)
154
155        model.hybridize(static_alloc=True)
156
157        batch = 2
158        seq_length = 32
159        num_masked_positions = 1
160        inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32', ctx=ctx)
161        valid_length = mx.nd.array([seq_length] * batch, dtype='float32', ctx=ctx)
162
163        sequence_outputs = model(inputs, valid_length)
164
165        prefix = "%s/distilbert" % tmp_path
166        model.export(prefix)
167        sym_file = "%s-symbol.json" % prefix
168        params_file = "%s-0000.params" % prefix
169        onnx_file = "%s.onnx" % prefix
170
171        input_shapes = [(batch, seq_length), (batch,)]
172        converted_model_path = mx.onnx.export_model(sym_file, params_file, input_shapes,
173                                                    [np.float32, np.float32],
174                                                    onnx_file, verbose=True)
175        sess_options = onnxruntime.SessionOptions()
176        sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
177        sess = onnxruntime.InferenceSession(onnx_file, sess_options)
178
179        in_tensors = [inputs, valid_length]
180        input_dict = dict((sess.get_inputs()[i].name, in_tensors[i].asnumpy()) for i in range(len(in_tensors)))
181        pred = sess.run(None, input_dict)
182
183        assert_almost_equal(sequence_outputs, pred[0])
184
185    finally:
186        shutil.rmtree(tmp_path)
187
188
189@with_seed()
190@pytest.mark.parametrize('model_name', [('standard_lstm_lm_200', 200), ('standard_lstm_lm_650', 650),
191                                        ('standard_lstm_lm_1500', 1500)])
192@pytest.mark.parametrize('seq_length', [64, 128])
193def test_standard_rnn_lstm_pretrained_inference_onnxruntime(tmp_path, model_name, seq_length):
194    try:
195        import gluonnlp as nlp
196        ctx = mx.cpu()
197        dataset= 'wikitext-2'
198        model, _ = nlp.model.get_model(
199            name=model_name[0],
200            ctx=ctx,
201            pretrained=True,
202            dataset_name=dataset,
203            dropout=0)
204        model.hybridize()
205
206        batch = 2
207        num_hidden = model_name[1]
208        num_layers = 2
209        inputs = mx.nd.random.randint(0, 33278, shape=(seq_length, batch),
210                                      ctx=ctx).astype('float32')
211        begin_state = model.begin_state(func=mx.nd.random.uniform, low=0, high=1,
212                                        batch_size=batch, dtype='float32', ctx=ctx)
213        out, out_state= model(inputs, begin_state)
214
215        prefix = "%s/standard_rnn_lstm" % tmp_path
216        model.export(prefix)
217        sym_file = "%s-symbol.json" % prefix
218        params_file = "%s-0000.params" % prefix
219        onnx_file = "%s.onnx" % prefix
220
221        input_shapes = [(seq_length, batch), np.shape(begin_state[0]), np.shape(begin_state[1])]
222        converted_model_path = mx.onnx.export_model(sym_file, params_file, input_shapes,
223                                                    [np.float32, np.float32, np.float32],
224                                                    onnx_file, verbose=True)
225        sess_options = onnxruntime.SessionOptions()
226        sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
227        sess = onnxruntime.InferenceSession(onnx_file, sess_options)
228
229        in_tensors = [inputs, begin_state[0], begin_state[1]]
230        input_dict = dict((sess.get_inputs()[i].name, in_tensors[i].asnumpy()) for i in range(len(in_tensors)))
231        pred = sess.run(None, input_dict)
232
233        assert_almost_equal(out, pred[2])
234        assert_almost_equal(out_state[0], pred[0])
235        assert_almost_equal(out_state[1], pred[1])
236
237    finally:
238        shutil.rmtree(tmp_path)
239
240
241@with_seed()
242@pytest.mark.integration
243@pytest.mark.parametrize('model', ['bert_12_768_12'])
244def test_dynamic_shape_bert_inference_onnxruntime(tmp_path, model):
245    tmp_path = str(tmp_path)
246    try:
247        import gluonnlp as nlp
248        dataset = 'book_corpus_wiki_en_uncased'
249        ctx = mx.cpu(0)
250        model, vocab = nlp.model.get_model(
251            name=model,
252            ctx=ctx,
253            dataset_name=dataset,
254            pretrained=True,
255            use_pooler=True,
256            use_decoder=False,
257            num_layers = 3,
258            hparam_allow_override = True,
259            use_classifier=False)
260
261        model.hybridize(static_alloc=True)
262
263        batch = 5
264        seq_length = 16
265        # create synthetic test data
266        inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32')
267        token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32')
268        valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
269
270        seq_encoding, cls_encoding = model(inputs, token_types, valid_length)
271
272        prefix = "%s/bert" % tmp_path
273        model.export(prefix)
274        sym_file = "%s-symbol.json" % prefix
275        params_file = "%s-0000.params" % prefix
276        onnx_file = "%s.onnx" % prefix
277
278        dynamic_input_shapes = [(None, seq_length), (None, seq_length), (None,)]
279        input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)]
280        input_types = [np.float32, np.float32, np.float32]
281        converted_model_path = mx.onnx.export_model(sym_file, params_file, input_shapes,
282                                                    input_types, onnx_file,
283                                                    dynamic=True,
284                                                    dynamic_input_shapes=dynamic_input_shapes)
285
286        # create onnxruntime session using the generated onnx file
287        ses_opt = onnxruntime.SessionOptions()
288        ses_opt.log_severity_level = 3
289        session = onnxruntime.InferenceSession(onnx_file, ses_opt)
290
291        # test on a different batch size
292        batch = 7
293        seq_length = 16
294        inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32')
295        token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32')
296        valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
297
298        seq_encoding, cls_encoding = model(inputs, token_types, valid_length)
299
300        onnx_inputs = [inputs, token_types, valid_length]
301        input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs)))
302        pred_onx, cls_onx = session.run(None, input_dict)
303
304        assert_almost_equal(seq_encoding, pred_onx)
305        assert_almost_equal(cls_encoding, cls_onx)
306
307    finally:
308        shutil.rmtree(tmp_path)
309
310
311@with_seed()
312@pytest.mark.parametrize('model_name', [('awd_lstm_lm_600', 600), ('awd_lstm_lm_1150', 1150)])
313@pytest.mark.parametrize('seq_length', [16, 128, 256])
314def test_awd_rnn_lstm_pretrained_inference_onnxruntime(tmp_path, model_name, seq_length):
315    try:
316        import gluonnlp as nlp
317        ctx = mx.cpu()
318        dataset= 'wikitext-2'
319        model, _ = nlp.model.get_model(
320            name=model_name[0],
321            ctx=ctx,
322            pretrained=True,
323            dataset_name=dataset,
324            dropout=0)
325        model.hybridize()
326
327        batch = 2
328        num_hidden = model_name[1]
329        num_layers = 2
330        inputs = mx.nd.random.randint(0, 33278, shape=(seq_length, batch),
331                                      ctx=ctx).astype('float32')
332        begin_state = model.begin_state(func=mx.nd.random.uniform, low=0, high=1,
333                                        batch_size=batch, dtype='float32', ctx=ctx)
334        out, out_state= model(inputs, begin_state)
335
336        prefix = "%s/awd_lstm" % tmp_path
337        model.export(prefix)
338        sym_file = "%s-symbol.json" % prefix
339        params_file = "%s-0000.params" % prefix
340        onnx_file = "%s.onnx" % prefix
341
342        input_shapes = [(seq_length, batch),
343                        np.shape(begin_state[0][0]), np.shape(begin_state[0][1]),
344                        np.shape(begin_state[1][0]), np.shape(begin_state[1][1]),
345                        np.shape(begin_state[2][0]), np.shape(begin_state[2][1])]
346        input_types = [np.float32, np.float32, np.float32, np.float32, np.float32, np.float32,
347                       np.float32]
348        converted_model_path = mx.onnx.export_model(sym_file, params_file, input_shapes,
349                                                    input_types, onnx_file, verbose=True)
350
351        sess_options = onnxruntime.SessionOptions()
352        sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
353        sess = onnxruntime.InferenceSession(onnx_file, sess_options)
354
355        in_tensors = [inputs, begin_state[0][0], begin_state[0][1],
356                      begin_state[1][0], begin_state[1][1],
357                      begin_state[2][0], begin_state[2][1]]
358        input_dict = dict((sess.get_inputs()[i].name, in_tensors[i].asnumpy()) for i in range(len(in_tensors)))
359        pred = sess.run(None, input_dict)
360
361        assert_almost_equal(out, pred[6])
362        assert_almost_equal(out_state[0][0], pred[0])
363        assert_almost_equal(out_state[0][1], pred[1])
364        assert_almost_equal(out_state[1][0], pred[2])
365        assert_almost_equal(out_state[1][1], pred[3])
366        assert_almost_equal(out_state[2][0], pred[4])
367        assert_almost_equal(out_state[2][1], pred[5])
368
369    finally:
370        shutil.rmtree(tmp_path)
371
372
373@with_seed()
374@pytest.mark.parametrize('model_name', ['ernie_12_768_12'])
375def test_ernie_inference_onnxruntime(tmp_path, model_name):
376    tmp_path = str(tmp_path)
377    try:
378        import gluonnlp as nlp
379        dataset = 'baidu_ernie_uncased'
380        ctx = mx.cpu(0)
381        model, vocab = nlp.model.get_model(
382            name=model_name,
383            ctx=ctx,
384            dataset_name=dataset,
385            pretrained=True,
386            use_pooler=True,
387            use_decoder=False,
388            num_layers = 3,
389            hparam_allow_override = True,
390            use_classifier=False)
391
392        model.hybridize(static_alloc=True)
393
394        batch = 5
395        seq_length = 16
396        # create synthetic test data
397        inputs = mx.nd.random.uniform(0, 17964, shape=(batch, seq_length), dtype='float32')
398        token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32')
399        valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
400
401        seq_encoding, cls_encoding = model(inputs, token_types, valid_length)
402
403        prefix = "%s/ernie" % tmp_path
404        model.export(prefix)
405        sym_file = "%s-symbol.json" % prefix
406        params_file = "%s-0000.params" % prefix
407        onnx_file = "%s.onnx" % prefix
408
409        input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)]
410        input_types = [np.float32, np.float32, np.float32]
411        converted_model_path = mx.onnx.export_model(sym_file, params_file, input_shapes,
412                                                    input_types, onnx_file)
413
414        # create onnxruntime session using the generated onnx file
415        ses_opt = onnxruntime.SessionOptions()
416        ses_opt.log_severity_level = 3
417        session = onnxruntime.InferenceSession(onnx_file, ses_opt)
418
419        seq_encoding, cls_encoding = model(inputs, token_types, valid_length)
420
421        onnx_inputs = [inputs, token_types, valid_length]
422        input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs)))
423        pred_onx, cls_onx = session.run(None, input_dict)
424
425        assert_almost_equal(seq_encoding, pred_onx)
426        assert_almost_equal(cls_encoding, cls_onx)
427
428    finally:
429        shutil.rmtree(tmp_path)
430
431
432@with_seed()
433@pytest.mark.parametrize('model_name', ['transformer_en_de_512'])
434def test_transformer_pretrained_inference_onnxruntime(tmp_path, model_name):
435    tmp_path = str(tmp_path)
436    try:
437        import gluonnlp as nlp
438        dataset = 'WMT2014'
439        ctx = mx.cpu(0)
440        model, _, _ = nlp.model.get_model(
441            name=model_name,
442            ctx=ctx,
443            pretrained=True,
444            dataset_name=dataset)
445
446        model.hybridize(static_alloc=False)
447
448        batch = 7
449        seq_length = 16
450        C_in = 512
451        C_out = 512
452        src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')
453        step_input = mx.nd.random.uniform(0, 36794, shape=(batch,), dtype='float32')
454        src_valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
455
456        encoder_outputs, encoder_additional_outputs = model.encode(src,
457                                                                   valid_length=src_valid_length)
458
459        decoder_states = model.decoder.init_state_from_encoder(encoder_outputs, src_valid_length)
460
461        step_output, states, additional_outputs = model.decode_step(step_input, decoder_states)
462
463        # skip export of 'decoder' as it's used for training only
464        for component in ['encoder', 'one_step_ahead_decoder', 'src_embed', 'tgt_embed',
465                         'tgt_proj']:
466
467            prefix = "%s/%s" %(tmp_path, component)
468            component = getattr(model, component)
469            component.export(prefix)
470            sym_file = "%s-symbol.json" % prefix
471            params_file = "%s-0000.params" % prefix
472            onnx_file = "%s.onnx" % prefix
473
474        def export_to_onnx(prefix, input_shapes, input_types, **kwargs):
475            sym_file = "%s-symbol.json" % prefix
476            params_file = "%s-0000.params" % prefix
477            onnx_file = "%s.onnx" % prefix
478            return mx.onnx.export_model(sym_file, params_file, input_shapes, input_types,
479                                        onnx_file, **kwargs)
480
481        def onnx_runtime_predict(onnx_file, onnx_inputs):
482            ses_opt = onnxruntime.SessionOptions()
483            ses_opt.log_severity_level = 3
484            session = onnxruntime.InferenceSession(onnx_file, ses_opt)
485            input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy())
486                            for i in range(len(onnx_inputs)))
487            return session.run(None, input_dict)
488
489        def verify_encoder():
490            inputs = mx.nd.random.uniform(-1, 1, shape=(batch, seq_length, C_in), dtype='float32')
491            valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
492            pred = model.encoder(inputs, valid_length=valid_length)
493
494            prefix = "%s/encoder" %tmp_path
495            input_shapes = [(batch, seq_length, C_in), (batch,)]
496            input_types = [np.float32, np.float32]
497            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
498            onnx_inputs = [inputs, valid_length]
499            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
500
501            assert_almost_equal(pred[0], pred_onx[0])
502
503        def verify_src_embed():
504            src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')
505            pred = model.src_embed(src)
506
507            prefix = "%s/src_embed" %tmp_path
508            input_shapes = [(batch, seq_length)]
509            input_types = [np.float32]
510            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
511            onnx_inputs = [src]
512            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
513
514            assert_almost_equal(pred, pred_onx[0])
515
516        def verify_tgt_embed():
517            tgt = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')
518            pred = model.tgt_embed(tgt)
519
520            prefix = "%s/tgt_embed" %tmp_path
521            input_shapes = [(batch, seq_length)]
522            input_types = [np.float32]
523            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
524            onnx_inputs = [tgt]
525            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
526
527            assert_almost_equal(pred, pred_onx[0])
528
529        def verify_tgt_proj():
530            decoder_out = mx.nd.random.uniform(0, 512, shape=(batch, seq_length, C_out),
531                                               dtype='float32')
532            pred = model.tgt_proj(decoder_out)
533
534            prefix = "%s/tgt_proj" %tmp_path
535            input_shapes = [(batch, seq_length, C_out)]
536            input_types = [np.float32]
537            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
538            onnx_inputs = [decoder_out]
539            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
540
541            assert_almost_equal(pred, pred_onx[0], rtol=1.e-04, atol=1.5e-03)
542
543        def verify_one_step_ahead_decoder():
544            prefix = "%s/one_step_ahead_decoder" %tmp_path
545
546            # the input data order
547            perm = [2, 0, 1]
548            input_shapes = [(batch, seq_length, C_in), (batch, seq_length, C_out),
549                            (batch, seq_length)]
550            input_shapes = [input_shapes[i] for i in perm]
551            dynamic_input_shapes = [(batch, 'seq_length', C_in), (batch, 'seq_length', C_out),
552                                    (batch, 'seq_length')]
553            dynamic_input_shapes = [dynamic_input_shapes[i] for i in perm]
554            input_types = [np.float32, np.float32, np.float32]
555            # do a dynamic export
556            onnx_file = export_to_onnx(prefix, input_shapes, input_types, dynamic=True,
557                                       dynamic_input_shapes=dynamic_input_shapes)
558
559            # step 0
560            step_input = mx.nd.random.uniform(-1, 1, shape=(batch, C_in), dtype='float32')
561            # mxnet
562            pred, step_states, _ = model.one_step_ahead_decoder(step_input, decoder_states)
563            # onnx
564            # note that we need to expand the sequence axis just like in here:
565            # https://github.com/dmlc/gluon-nlp/blob/v0.10.x/src/gluonnlp/model/transformer.py#L831
566            input_onx = mx.nd.expand_dims(step_input, axis=1)
567            onnx_inputs = [input_onx, decoder_states[0], decoder_states[1]]
568            onnx_inputs = [onnx_inputs[i] for i in perm]
569            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
570
571            assert_almost_equal(pred, pred_onx[0])
572
573            # step >= 1
574            for i in range(20):
575                step_input = mx.nd.random.uniform(-10*i, 10*i, shape=(batch, C_in), dtype='float32')
576                # mxnet
577                pred, step_states, _ = model.one_step_ahead_decoder(step_input, step_states)
578                # onnx
579                # note that we need to concat the step_input with the previous inpus
580                # just like in here:
581                # https://github.com/dmlc/gluon-nlp/blob/v0.10.x/src/gluonnlp/model/transformer.py#L828
582                input_onx = mx.nd.concat(input_onx, mx.nd.expand_dims(step_input, axis=1), dim=1)
583                onnx_inputs = [input_onx, decoder_states[0], decoder_states[1]]
584                onnx_inputs = [onnx_inputs[i] for i in perm]
585                pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
586
587                assert_almost_equal(pred, pred_onx[0])
588
589        verify_encoder()
590        verify_src_embed()
591        verify_tgt_embed()
592        verify_tgt_proj()
593        verify_one_step_ahead_decoder()
594
595    finally:
596        shutil.rmtree(tmp_path)
597
598
599@with_seed()
600@pytest.mark.parametrize('model_params', [('gpt2_117m', 24), ('gpt2_345m', 48)])
601def test_gpt_pretrained_inference_onnxruntime(tmp_path, model_params):
602    tmp_path = str(tmp_path)
603    try:
604        import gluonnlp as nlp
605        import urllib.request
606        from zipfile import ZipFile
607        import importlib.util
608        import sys
609
610        url = 'https://nlp.gluon.ai/_downloads/77d227fbc8f1613e6802acc7253cc090/text_generation.zip'
611        urllib.request.urlretrieve(url, tmp_path + 'text_generation.zip')
612
613        with ZipFile(tmp_path + 'text_generation.zip', 'r') as zipObj:
614            zipObj.extractall(tmp_path)
615
616        # load in the text_generation module, refer to:
617        # https://github.com/dmlc/gluon-nlp/tree/v0.10.x/scripts/text_generation
618        spec = importlib.util.spec_from_file_location(
619            'text_generation',
620            tmp_path + '/text_generation/__init__.py')
621        mod = importlib.util.module_from_spec(spec)
622        sys.modules[spec.name] = mod
623        spec.loader.exec_module(mod)
624
625        ctx = mx.cpu(0)
626        model_name= model_params[0]
627        dataset= 'openai_webtext'
628        # get_model() is overridden in here:
629        # https://github.com/dmlc/gluon-nlp/blob/v0.10.x/scripts/text_generation/model/__init__.py#L23
630        model, _ = mod.model.get_model(
631            name=model_name,
632            ctx=ctx,
633            pretrained=True,
634            dataset_name=dataset)
635
636        model.hybridize()
637
638        batch = 4
639        seq_length = 64
640        inputs = mx.nd.random.uniform(0, 50257, shape=(batch, seq_length), dtype='float32',
641                                      ctx=ctx)
642
643        pred = model(inputs)
644
645        prefix = "%s/%s" % (tmp_path, model_name)
646        model.export(prefix)
647        sym_file = "%s-symbol.json" % prefix
648        params_file = "%s-0000.params" % prefix
649        onnx_file = "%s.onnx" % prefix
650
651        input_shapes = [(batch, seq_length)]
652        input_types = [np.float32]
653        converted_model_path = mx.onnx.export_model(sym_file, params_file, input_shapes,
654                                                    input_types, onnx_file)
655
656        ses_opt = onnxruntime.SessionOptions()
657        ses_opt.log_severity_level = 3
658        session = onnxruntime.InferenceSession(onnx_file, ses_opt)
659        onnx_inputs = [inputs]
660        input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs)))
661        pred_onx = session.run(None, input_dict)
662
663        # check output
664        assert_almost_equal(pred[0], pred_onx[0])
665        # check states
666        num_states = model_params[1]
667        for i in range(num_states):
668            assert_almost_equal(pred[1][i], pred_onx[i+1])
669
670    finally:
671        shutil.rmtree(tmp_path)
672