1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import mxnet as mx 19import numpy as np 20import onnxruntime 21 22from mxnet.test_utils import assert_almost_equal 23from common import with_seed 24 25import json 26import os 27import pytest 28import shutil 29 30 31@with_seed() 32@pytest.mark.parametrize('model_name', ['roberta_24_1024_16', 'roberta_12_768_12']) 33def test_roberta_inference_onnxruntime(tmp_path, model_name): 34 tmp_path = str(tmp_path) 35 try: 36 import gluonnlp as nlp 37 ctx = mx.cpu(0) 38 39 dataset= 'openwebtext_ccnews_stories_books_cased'#'book_corpus_wiki_en_uncased' 40 model, _ = nlp.model.get_model( 41 name=model_name, 42 ctx=ctx, 43 pretrained=True, 44 use_decoder=True, 45 dataset_name=dataset) 46 47 model.hybridize(static_alloc=False) 48 49 batch = 2 50 seq_length = 32 51 num_masked_positions = 1 52 inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32', ctx=ctx) 53 valid_length = mx.nd.array([seq_length] * batch, dtype='float32', ctx=ctx) 54 masked_positions = mx.nd.random.uniform(0, 32, shape=(batch, num_masked_positions), 55 dtype='float32', ctx=ctx).astype('int32') 56 57 sequence_outputs, attention_outputs= model(inputs, valid_length, masked_positions) 58 59 prefix = "%s/roberta" % tmp_path 60 model.export(prefix) 61 62 sym_file = "%s-symbol.json" % prefix 63 params_file = "%s-0000.params" % prefix 64 onnx_file = "%s.onnx" % prefix 65 input_shapes = [(batch, seq_length), (batch,), (batch, num_masked_positions)] 66 converted_model_path = mx.onnx.export_model(sym_file, params_file, input_shapes, 67 [np.float32, np.float32, np.int32], 68 onnx_file, verbose=True) 69 70 sess_options = onnxruntime.SessionOptions() 71 sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 72 sess = onnxruntime.InferenceSession(onnx_file, sess_options) 73 74 in_tensors = [inputs, valid_length, masked_positions] 75 input_dict = dict((sess.get_inputs()[i].name, in_tensors[i].asnumpy()) for i in range(len(in_tensors))) 76 pred = sess.run(None, input_dict) 77 78 assert_almost_equal(sequence_outputs, pred[0]) 79 assert_almost_equal(attention_outputs, pred[1]) 80 81 finally: 82 shutil.rmtree(tmp_path) 83 84 85@with_seed() 86@pytest.mark.integration 87@pytest.mark.parametrize('model', ['bert_12_768_12', 'bert_24_1024_16']) 88def test_bert_inference_onnxruntime(tmp_path, model): 89 tmp_path = str(tmp_path) 90 try: 91 import gluonnlp as nlp 92 dataset = 'book_corpus_wiki_en_uncased' 93 ctx = mx.cpu(0) 94 model, vocab = nlp.model.get_model( 95 name=model, 96 ctx=ctx, 97 dataset_name=dataset, 98 pretrained=True, 99 use_pooler=True, 100 use_decoder=False, 101 use_classifier=False) 102 103 model.hybridize(static_alloc=True) 104 105 batch = 5 106 seq_length = 16 107 # create synthetic test data 108 inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32') 109 token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32') 110 valid_length = mx.nd.array([seq_length] * batch, dtype='float32') 111 112 seq_encoding, cls_encoding = model(inputs, token_types, valid_length) 113 114 prefix = "%s/bert" % tmp_path 115 model.export(prefix) 116 sym_file = "%s-symbol.json" % prefix 117 params_file = "%s-0000.params" % prefix 118 onnx_file = "%s.onnx" % prefix 119 120 121 input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)] 122 input_types = [np.float32, np.float32, np.float32] 123 converted_model_path = mx.onnx.export_model(sym_file, params_file, input_shapes, input_types, onnx_file) 124 125 126 # create onnxruntime session using the generated onnx file 127 ses_opt = onnxruntime.SessionOptions() 128 ses_opt.log_severity_level = 3 129 session = onnxruntime.InferenceSession(onnx_file, ses_opt) 130 onnx_inputs = [inputs, token_types, valid_length] 131 input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs))) 132 pred_onx, cls_onx = session.run(None, input_dict) 133 134 assert_almost_equal(seq_encoding, pred_onx, rtol = 0.0001, atol = 0.0025) 135 assert_almost_equal(cls_encoding, cls_onx, rtol = 0.0001, atol = 0.0025) 136 137 finally: 138 shutil.rmtree(tmp_path) 139 140 141@with_seed() 142@pytest.mark.parametrize('model_name', ['distilbert_6_768_12']) 143def test_distilbert_inference_onnxruntime(tmp_path, model_name): 144 tmp_path = str(tmp_path) 145 try: 146 import gluonnlp as nlp 147 dataset = 'distilbert_book_corpus_wiki_en_uncased' 148 ctx = mx.cpu(0) 149 model, _ = nlp.model.get_model( 150 name=model_name, 151 ctx=ctx, 152 pretrained=True, 153 dataset_name=dataset) 154 155 model.hybridize(static_alloc=True) 156 157 batch = 2 158 seq_length = 32 159 num_masked_positions = 1 160 inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32', ctx=ctx) 161 valid_length = mx.nd.array([seq_length] * batch, dtype='float32', ctx=ctx) 162 163 sequence_outputs = model(inputs, valid_length) 164 165 prefix = "%s/distilbert" % tmp_path 166 model.export(prefix) 167 sym_file = "%s-symbol.json" % prefix 168 params_file = "%s-0000.params" % prefix 169 onnx_file = "%s.onnx" % prefix 170 171 input_shapes = [(batch, seq_length), (batch,)] 172 converted_model_path = mx.onnx.export_model(sym_file, params_file, input_shapes, 173 [np.float32, np.float32], 174 onnx_file, verbose=True) 175 sess_options = onnxruntime.SessionOptions() 176 sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 177 sess = onnxruntime.InferenceSession(onnx_file, sess_options) 178 179 in_tensors = [inputs, valid_length] 180 input_dict = dict((sess.get_inputs()[i].name, in_tensors[i].asnumpy()) for i in range(len(in_tensors))) 181 pred = sess.run(None, input_dict) 182 183 assert_almost_equal(sequence_outputs, pred[0]) 184 185 finally: 186 shutil.rmtree(tmp_path) 187 188 189@with_seed() 190@pytest.mark.parametrize('model_name', [('standard_lstm_lm_200', 200), ('standard_lstm_lm_650', 650), 191 ('standard_lstm_lm_1500', 1500)]) 192@pytest.mark.parametrize('seq_length', [64, 128]) 193def test_standard_rnn_lstm_pretrained_inference_onnxruntime(tmp_path, model_name, seq_length): 194 try: 195 import gluonnlp as nlp 196 ctx = mx.cpu() 197 dataset= 'wikitext-2' 198 model, _ = nlp.model.get_model( 199 name=model_name[0], 200 ctx=ctx, 201 pretrained=True, 202 dataset_name=dataset, 203 dropout=0) 204 model.hybridize() 205 206 batch = 2 207 num_hidden = model_name[1] 208 num_layers = 2 209 inputs = mx.nd.random.randint(0, 33278, shape=(seq_length, batch), 210 ctx=ctx).astype('float32') 211 begin_state = model.begin_state(func=mx.nd.random.uniform, low=0, high=1, 212 batch_size=batch, dtype='float32', ctx=ctx) 213 out, out_state= model(inputs, begin_state) 214 215 prefix = "%s/standard_rnn_lstm" % tmp_path 216 model.export(prefix) 217 sym_file = "%s-symbol.json" % prefix 218 params_file = "%s-0000.params" % prefix 219 onnx_file = "%s.onnx" % prefix 220 221 input_shapes = [(seq_length, batch), np.shape(begin_state[0]), np.shape(begin_state[1])] 222 converted_model_path = mx.onnx.export_model(sym_file, params_file, input_shapes, 223 [np.float32, np.float32, np.float32], 224 onnx_file, verbose=True) 225 sess_options = onnxruntime.SessionOptions() 226 sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 227 sess = onnxruntime.InferenceSession(onnx_file, sess_options) 228 229 in_tensors = [inputs, begin_state[0], begin_state[1]] 230 input_dict = dict((sess.get_inputs()[i].name, in_tensors[i].asnumpy()) for i in range(len(in_tensors))) 231 pred = sess.run(None, input_dict) 232 233 assert_almost_equal(out, pred[2]) 234 assert_almost_equal(out_state[0], pred[0]) 235 assert_almost_equal(out_state[1], pred[1]) 236 237 finally: 238 shutil.rmtree(tmp_path) 239 240 241@with_seed() 242@pytest.mark.integration 243@pytest.mark.parametrize('model', ['bert_12_768_12']) 244def test_dynamic_shape_bert_inference_onnxruntime(tmp_path, model): 245 tmp_path = str(tmp_path) 246 try: 247 import gluonnlp as nlp 248 dataset = 'book_corpus_wiki_en_uncased' 249 ctx = mx.cpu(0) 250 model, vocab = nlp.model.get_model( 251 name=model, 252 ctx=ctx, 253 dataset_name=dataset, 254 pretrained=True, 255 use_pooler=True, 256 use_decoder=False, 257 num_layers = 3, 258 hparam_allow_override = True, 259 use_classifier=False) 260 261 model.hybridize(static_alloc=True) 262 263 batch = 5 264 seq_length = 16 265 # create synthetic test data 266 inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32') 267 token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32') 268 valid_length = mx.nd.array([seq_length] * batch, dtype='float32') 269 270 seq_encoding, cls_encoding = model(inputs, token_types, valid_length) 271 272 prefix = "%s/bert" % tmp_path 273 model.export(prefix) 274 sym_file = "%s-symbol.json" % prefix 275 params_file = "%s-0000.params" % prefix 276 onnx_file = "%s.onnx" % prefix 277 278 dynamic_input_shapes = [(None, seq_length), (None, seq_length), (None,)] 279 input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)] 280 input_types = [np.float32, np.float32, np.float32] 281 converted_model_path = mx.onnx.export_model(sym_file, params_file, input_shapes, 282 input_types, onnx_file, 283 dynamic=True, 284 dynamic_input_shapes=dynamic_input_shapes) 285 286 # create onnxruntime session using the generated onnx file 287 ses_opt = onnxruntime.SessionOptions() 288 ses_opt.log_severity_level = 3 289 session = onnxruntime.InferenceSession(onnx_file, ses_opt) 290 291 # test on a different batch size 292 batch = 7 293 seq_length = 16 294 inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32') 295 token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32') 296 valid_length = mx.nd.array([seq_length] * batch, dtype='float32') 297 298 seq_encoding, cls_encoding = model(inputs, token_types, valid_length) 299 300 onnx_inputs = [inputs, token_types, valid_length] 301 input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs))) 302 pred_onx, cls_onx = session.run(None, input_dict) 303 304 assert_almost_equal(seq_encoding, pred_onx) 305 assert_almost_equal(cls_encoding, cls_onx) 306 307 finally: 308 shutil.rmtree(tmp_path) 309 310 311@with_seed() 312@pytest.mark.parametrize('model_name', [('awd_lstm_lm_600', 600), ('awd_lstm_lm_1150', 1150)]) 313@pytest.mark.parametrize('seq_length', [16, 128, 256]) 314def test_awd_rnn_lstm_pretrained_inference_onnxruntime(tmp_path, model_name, seq_length): 315 try: 316 import gluonnlp as nlp 317 ctx = mx.cpu() 318 dataset= 'wikitext-2' 319 model, _ = nlp.model.get_model( 320 name=model_name[0], 321 ctx=ctx, 322 pretrained=True, 323 dataset_name=dataset, 324 dropout=0) 325 model.hybridize() 326 327 batch = 2 328 num_hidden = model_name[1] 329 num_layers = 2 330 inputs = mx.nd.random.randint(0, 33278, shape=(seq_length, batch), 331 ctx=ctx).astype('float32') 332 begin_state = model.begin_state(func=mx.nd.random.uniform, low=0, high=1, 333 batch_size=batch, dtype='float32', ctx=ctx) 334 out, out_state= model(inputs, begin_state) 335 336 prefix = "%s/awd_lstm" % tmp_path 337 model.export(prefix) 338 sym_file = "%s-symbol.json" % prefix 339 params_file = "%s-0000.params" % prefix 340 onnx_file = "%s.onnx" % prefix 341 342 input_shapes = [(seq_length, batch), 343 np.shape(begin_state[0][0]), np.shape(begin_state[0][1]), 344 np.shape(begin_state[1][0]), np.shape(begin_state[1][1]), 345 np.shape(begin_state[2][0]), np.shape(begin_state[2][1])] 346 input_types = [np.float32, np.float32, np.float32, np.float32, np.float32, np.float32, 347 np.float32] 348 converted_model_path = mx.onnx.export_model(sym_file, params_file, input_shapes, 349 input_types, onnx_file, verbose=True) 350 351 sess_options = onnxruntime.SessionOptions() 352 sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 353 sess = onnxruntime.InferenceSession(onnx_file, sess_options) 354 355 in_tensors = [inputs, begin_state[0][0], begin_state[0][1], 356 begin_state[1][0], begin_state[1][1], 357 begin_state[2][0], begin_state[2][1]] 358 input_dict = dict((sess.get_inputs()[i].name, in_tensors[i].asnumpy()) for i in range(len(in_tensors))) 359 pred = sess.run(None, input_dict) 360 361 assert_almost_equal(out, pred[6]) 362 assert_almost_equal(out_state[0][0], pred[0]) 363 assert_almost_equal(out_state[0][1], pred[1]) 364 assert_almost_equal(out_state[1][0], pred[2]) 365 assert_almost_equal(out_state[1][1], pred[3]) 366 assert_almost_equal(out_state[2][0], pred[4]) 367 assert_almost_equal(out_state[2][1], pred[5]) 368 369 finally: 370 shutil.rmtree(tmp_path) 371 372 373@with_seed() 374@pytest.mark.parametrize('model_name', ['ernie_12_768_12']) 375def test_ernie_inference_onnxruntime(tmp_path, model_name): 376 tmp_path = str(tmp_path) 377 try: 378 import gluonnlp as nlp 379 dataset = 'baidu_ernie_uncased' 380 ctx = mx.cpu(0) 381 model, vocab = nlp.model.get_model( 382 name=model_name, 383 ctx=ctx, 384 dataset_name=dataset, 385 pretrained=True, 386 use_pooler=True, 387 use_decoder=False, 388 num_layers = 3, 389 hparam_allow_override = True, 390 use_classifier=False) 391 392 model.hybridize(static_alloc=True) 393 394 batch = 5 395 seq_length = 16 396 # create synthetic test data 397 inputs = mx.nd.random.uniform(0, 17964, shape=(batch, seq_length), dtype='float32') 398 token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32') 399 valid_length = mx.nd.array([seq_length] * batch, dtype='float32') 400 401 seq_encoding, cls_encoding = model(inputs, token_types, valid_length) 402 403 prefix = "%s/ernie" % tmp_path 404 model.export(prefix) 405 sym_file = "%s-symbol.json" % prefix 406 params_file = "%s-0000.params" % prefix 407 onnx_file = "%s.onnx" % prefix 408 409 input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)] 410 input_types = [np.float32, np.float32, np.float32] 411 converted_model_path = mx.onnx.export_model(sym_file, params_file, input_shapes, 412 input_types, onnx_file) 413 414 # create onnxruntime session using the generated onnx file 415 ses_opt = onnxruntime.SessionOptions() 416 ses_opt.log_severity_level = 3 417 session = onnxruntime.InferenceSession(onnx_file, ses_opt) 418 419 seq_encoding, cls_encoding = model(inputs, token_types, valid_length) 420 421 onnx_inputs = [inputs, token_types, valid_length] 422 input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs))) 423 pred_onx, cls_onx = session.run(None, input_dict) 424 425 assert_almost_equal(seq_encoding, pred_onx) 426 assert_almost_equal(cls_encoding, cls_onx) 427 428 finally: 429 shutil.rmtree(tmp_path) 430 431 432@with_seed() 433@pytest.mark.parametrize('model_name', ['transformer_en_de_512']) 434def test_transformer_pretrained_inference_onnxruntime(tmp_path, model_name): 435 tmp_path = str(tmp_path) 436 try: 437 import gluonnlp as nlp 438 dataset = 'WMT2014' 439 ctx = mx.cpu(0) 440 model, _, _ = nlp.model.get_model( 441 name=model_name, 442 ctx=ctx, 443 pretrained=True, 444 dataset_name=dataset) 445 446 model.hybridize(static_alloc=False) 447 448 batch = 7 449 seq_length = 16 450 C_in = 512 451 C_out = 512 452 src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32') 453 step_input = mx.nd.random.uniform(0, 36794, shape=(batch,), dtype='float32') 454 src_valid_length = mx.nd.array([seq_length] * batch, dtype='float32') 455 456 encoder_outputs, encoder_additional_outputs = model.encode(src, 457 valid_length=src_valid_length) 458 459 decoder_states = model.decoder.init_state_from_encoder(encoder_outputs, src_valid_length) 460 461 step_output, states, additional_outputs = model.decode_step(step_input, decoder_states) 462 463 # skip export of 'decoder' as it's used for training only 464 for component in ['encoder', 'one_step_ahead_decoder', 'src_embed', 'tgt_embed', 465 'tgt_proj']: 466 467 prefix = "%s/%s" %(tmp_path, component) 468 component = getattr(model, component) 469 component.export(prefix) 470 sym_file = "%s-symbol.json" % prefix 471 params_file = "%s-0000.params" % prefix 472 onnx_file = "%s.onnx" % prefix 473 474 def export_to_onnx(prefix, input_shapes, input_types, **kwargs): 475 sym_file = "%s-symbol.json" % prefix 476 params_file = "%s-0000.params" % prefix 477 onnx_file = "%s.onnx" % prefix 478 return mx.onnx.export_model(sym_file, params_file, input_shapes, input_types, 479 onnx_file, **kwargs) 480 481 def onnx_runtime_predict(onnx_file, onnx_inputs): 482 ses_opt = onnxruntime.SessionOptions() 483 ses_opt.log_severity_level = 3 484 session = onnxruntime.InferenceSession(onnx_file, ses_opt) 485 input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) 486 for i in range(len(onnx_inputs))) 487 return session.run(None, input_dict) 488 489 def verify_encoder(): 490 inputs = mx.nd.random.uniform(-1, 1, shape=(batch, seq_length, C_in), dtype='float32') 491 valid_length = mx.nd.array([seq_length] * batch, dtype='float32') 492 pred = model.encoder(inputs, valid_length=valid_length) 493 494 prefix = "%s/encoder" %tmp_path 495 input_shapes = [(batch, seq_length, C_in), (batch,)] 496 input_types = [np.float32, np.float32] 497 onnx_file = export_to_onnx(prefix, input_shapes, input_types) 498 onnx_inputs = [inputs, valid_length] 499 pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs) 500 501 assert_almost_equal(pred[0], pred_onx[0]) 502 503 def verify_src_embed(): 504 src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32') 505 pred = model.src_embed(src) 506 507 prefix = "%s/src_embed" %tmp_path 508 input_shapes = [(batch, seq_length)] 509 input_types = [np.float32] 510 onnx_file = export_to_onnx(prefix, input_shapes, input_types) 511 onnx_inputs = [src] 512 pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs) 513 514 assert_almost_equal(pred, pred_onx[0]) 515 516 def verify_tgt_embed(): 517 tgt = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32') 518 pred = model.tgt_embed(tgt) 519 520 prefix = "%s/tgt_embed" %tmp_path 521 input_shapes = [(batch, seq_length)] 522 input_types = [np.float32] 523 onnx_file = export_to_onnx(prefix, input_shapes, input_types) 524 onnx_inputs = [tgt] 525 pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs) 526 527 assert_almost_equal(pred, pred_onx[0]) 528 529 def verify_tgt_proj(): 530 decoder_out = mx.nd.random.uniform(0, 512, shape=(batch, seq_length, C_out), 531 dtype='float32') 532 pred = model.tgt_proj(decoder_out) 533 534 prefix = "%s/tgt_proj" %tmp_path 535 input_shapes = [(batch, seq_length, C_out)] 536 input_types = [np.float32] 537 onnx_file = export_to_onnx(prefix, input_shapes, input_types) 538 onnx_inputs = [decoder_out] 539 pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs) 540 541 assert_almost_equal(pred, pred_onx[0], rtol=1.e-04, atol=1.5e-03) 542 543 def verify_one_step_ahead_decoder(): 544 prefix = "%s/one_step_ahead_decoder" %tmp_path 545 546 # the input data order 547 perm = [2, 0, 1] 548 input_shapes = [(batch, seq_length, C_in), (batch, seq_length, C_out), 549 (batch, seq_length)] 550 input_shapes = [input_shapes[i] for i in perm] 551 dynamic_input_shapes = [(batch, 'seq_length', C_in), (batch, 'seq_length', C_out), 552 (batch, 'seq_length')] 553 dynamic_input_shapes = [dynamic_input_shapes[i] for i in perm] 554 input_types = [np.float32, np.float32, np.float32] 555 # do a dynamic export 556 onnx_file = export_to_onnx(prefix, input_shapes, input_types, dynamic=True, 557 dynamic_input_shapes=dynamic_input_shapes) 558 559 # step 0 560 step_input = mx.nd.random.uniform(-1, 1, shape=(batch, C_in), dtype='float32') 561 # mxnet 562 pred, step_states, _ = model.one_step_ahead_decoder(step_input, decoder_states) 563 # onnx 564 # note that we need to expand the sequence axis just like in here: 565 # https://github.com/dmlc/gluon-nlp/blob/v0.10.x/src/gluonnlp/model/transformer.py#L831 566 input_onx = mx.nd.expand_dims(step_input, axis=1) 567 onnx_inputs = [input_onx, decoder_states[0], decoder_states[1]] 568 onnx_inputs = [onnx_inputs[i] for i in perm] 569 pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs) 570 571 assert_almost_equal(pred, pred_onx[0]) 572 573 # step >= 1 574 for i in range(20): 575 step_input = mx.nd.random.uniform(-10*i, 10*i, shape=(batch, C_in), dtype='float32') 576 # mxnet 577 pred, step_states, _ = model.one_step_ahead_decoder(step_input, step_states) 578 # onnx 579 # note that we need to concat the step_input with the previous inpus 580 # just like in here: 581 # https://github.com/dmlc/gluon-nlp/blob/v0.10.x/src/gluonnlp/model/transformer.py#L828 582 input_onx = mx.nd.concat(input_onx, mx.nd.expand_dims(step_input, axis=1), dim=1) 583 onnx_inputs = [input_onx, decoder_states[0], decoder_states[1]] 584 onnx_inputs = [onnx_inputs[i] for i in perm] 585 pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs) 586 587 assert_almost_equal(pred, pred_onx[0]) 588 589 verify_encoder() 590 verify_src_embed() 591 verify_tgt_embed() 592 verify_tgt_proj() 593 verify_one_step_ahead_decoder() 594 595 finally: 596 shutil.rmtree(tmp_path) 597 598 599@with_seed() 600@pytest.mark.parametrize('model_params', [('gpt2_117m', 24), ('gpt2_345m', 48)]) 601def test_gpt_pretrained_inference_onnxruntime(tmp_path, model_params): 602 tmp_path = str(tmp_path) 603 try: 604 import gluonnlp as nlp 605 import urllib.request 606 from zipfile import ZipFile 607 import importlib.util 608 import sys 609 610 url = 'https://nlp.gluon.ai/_downloads/77d227fbc8f1613e6802acc7253cc090/text_generation.zip' 611 urllib.request.urlretrieve(url, tmp_path + 'text_generation.zip') 612 613 with ZipFile(tmp_path + 'text_generation.zip', 'r') as zipObj: 614 zipObj.extractall(tmp_path) 615 616 # load in the text_generation module, refer to: 617 # https://github.com/dmlc/gluon-nlp/tree/v0.10.x/scripts/text_generation 618 spec = importlib.util.spec_from_file_location( 619 'text_generation', 620 tmp_path + '/text_generation/__init__.py') 621 mod = importlib.util.module_from_spec(spec) 622 sys.modules[spec.name] = mod 623 spec.loader.exec_module(mod) 624 625 ctx = mx.cpu(0) 626 model_name= model_params[0] 627 dataset= 'openai_webtext' 628 # get_model() is overridden in here: 629 # https://github.com/dmlc/gluon-nlp/blob/v0.10.x/scripts/text_generation/model/__init__.py#L23 630 model, _ = mod.model.get_model( 631 name=model_name, 632 ctx=ctx, 633 pretrained=True, 634 dataset_name=dataset) 635 636 model.hybridize() 637 638 batch = 4 639 seq_length = 64 640 inputs = mx.nd.random.uniform(0, 50257, shape=(batch, seq_length), dtype='float32', 641 ctx=ctx) 642 643 pred = model(inputs) 644 645 prefix = "%s/%s" % (tmp_path, model_name) 646 model.export(prefix) 647 sym_file = "%s-symbol.json" % prefix 648 params_file = "%s-0000.params" % prefix 649 onnx_file = "%s.onnx" % prefix 650 651 input_shapes = [(batch, seq_length)] 652 input_types = [np.float32] 653 converted_model_path = mx.onnx.export_model(sym_file, params_file, input_shapes, 654 input_types, onnx_file) 655 656 ses_opt = onnxruntime.SessionOptions() 657 ses_opt.log_severity_level = 3 658 session = onnxruntime.InferenceSession(onnx_file, ses_opt) 659 onnx_inputs = [inputs] 660 input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs))) 661 pred_onx = session.run(None, input_dict) 662 663 # check output 664 assert_almost_equal(pred[0], pred_onx[0]) 665 # check states 666 num_states = model_params[1] 667 for i in range(num_states): 668 assert_almost_equal(pred[1][i], pred_onx[i+1]) 669 670 finally: 671 shutil.rmtree(tmp_path) 672