1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Unit tests for the Bring Your Own Datatype framework. 18 19TODO(@gussmith23 @hypercubestart) link to documentation""" 20import tvm 21import tvm.topi.testing 22import numpy as np 23import pytest 24from numpy.random import MT19937, RandomState, SeedSequence 25from tvm import relay 26from tvm.relay.testing.layers import batch_norm_infer 27from tvm.target.datatype import ( 28 register, 29 register_min_func, 30 register_op, 31 create_lower_func, 32 lower_ite, 33 lower_call_pure_extern, 34 create_min_lower_func, 35) 36from tvm.tir.op import call_pure_extern 37 38# note: we can't use relay.testing models because params are randomly initialized, 39# which lead the output to have the same values 40# get mobilenet model from Gluon CV 41# because: https://discuss.tvm.apache.org/t/mobilenet-intermediate-values-are-0/7812 42def get_mobilenet(): 43 dshape = (1, 3, 224, 224) 44 from mxnet.gluon.model_zoo.vision import get_model 45 46 block = get_model("mobilenet0.25", pretrained=True) 47 shape_dict = {"data": dshape} 48 return relay.frontend.from_mxnet(block, shape_dict) 49 50 51# use real image instead of random data for end-to-end model training 52# or else output would all be around the same value 53def get_cat_image(dimensions): 54 from tvm.contrib.download import download_testdata 55 from PIL import Image 56 57 url = "https://gist.githubusercontent.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/fa7ef0e9c9a5daea686d6473a62aacd1a5885849/cat.png" 58 dst = "cat.png" 59 real_dst = download_testdata(url, dst, module="data") 60 img = Image.open(real_dst).resize(dimensions) 61 # CoreML's standard model image format is BGR 62 img_bgr = np.array(img)[:, :, ::-1] 63 img = np.transpose(img_bgr, (2, 0, 1))[np.newaxis, :] 64 return np.asarray(img, dtype="float32") 65 66 67# we use a random seed to generate input_data 68# to guarantee stable tests 69rs = RandomState(MT19937(SeedSequence(123456789))) 70 71 72def convert_ndarray(dst_dtype, array): 73 """Converts NDArray(s) into the specified datatype""" 74 x = relay.var("x", shape=array.shape, dtype=str(array.dtype)) 75 cast = relay.Function([x], x.astype(dst_dtype)) 76 with tvm.transform.PassContext(config={"tir.disable_vectorize": True}): 77 return relay.create_executor("graph").evaluate(cast)(array) 78 79 80def change_dtype(src, dst, module, params): 81 """Convert constants and functions in module from src type to dst type. 82 Returns changed module and converted params of type dst_type. 83 """ 84 module = relay.frontend.ChangeDatatype(src, dst)(module) 85 module = relay.transform.InferType()(module) 86 params = {k: convert_ndarray(dst, v) for k, v in params.items()} 87 return module, params 88 89 90def compare(module, input, src_dtype, dst_dtype, rtol, atol, params={}, target="llvm"): 91 module = relay.transform.SimplifyInference()(module) 92 ex = relay.create_executor("graph", mod=module) 93 94 correct = ex.evaluate()(*input, **params) 95 module, converted_params = change_dtype(src_dtype, dst_dtype, module, params) 96 ex = relay.create_executor("graph", mod=module, target=target) 97 # converts all inputs to dst_dtype 98 x_converted = [convert_ndarray(dst_dtype, arr) for arr in input] 99 100 # Vectorization is not implemented with custom datatypes 101 with tvm.transform.PassContext(config={"tir.disable_vectorize": True}): 102 maybe_correct = ex.evaluate()(*x_converted, **converted_params) 103 # currently this only works for comparing single output 104 maybe_correct_converted = convert_ndarray(src_dtype, maybe_correct) 105 np.testing.assert_allclose( 106 maybe_correct_converted.asnumpy(), correct.asnumpy(), rtol=rtol, atol=atol 107 ) 108 109 110def setup_myfloat(): 111 """Set up tests for myfloat (a custom datatype that under the hood is float) 112 113 Currently, this registers some custom datatypes using the Bring Your 114 Own Datatypes framework. 115 """ 116 117 # To use datatype operations in an external library, you should first load 118 # the library containing the datatype implementation: 119 # CDLL("libposit.so", RTLD_GLOBAL) 120 # In this case, the datatype library we are using is built right into TVM, 121 # so we do not need to explicitly load any library. 122 123 # You can pick a code for your datatype arbitrarily, as long as it is 124 # greater than 128 and has not already been chosen. 125 register("myfloat", 131) 126 127 register_op( 128 create_lower_func({(32, 32): "FloatToCustom32"}), "Cast", "llvm", "float", "myfloat" 129 ) 130 register_op( 131 create_lower_func({(32, 32): "Custom32ToFloat"}), "Cast", "llvm", "myfloat", "float" 132 ) 133 register_op(create_lower_func({32: "Custom32Add"}), "Add", "llvm", "myfloat") 134 register_op( 135 create_lower_func( 136 { 137 32: "Custom32Sub", 138 } 139 ), 140 "Sub", 141 "llvm", 142 "myfloat", 143 ) 144 register_op(create_lower_func({32: "Custom32Mul"}), "Mul", "llvm", "myfloat") 145 register_op( 146 create_lower_func( 147 { 148 32: "FloatToCustom32", 149 } 150 ), 151 "FloatImm", 152 "llvm", 153 "myfloat", 154 ) 155 register_op( 156 create_lower_func( 157 { 158 32: "Custom32Div", 159 } 160 ), 161 "Div", 162 "llvm", 163 "myfloat", 164 ) 165 register_op(create_lower_func({32: "Custom32Max"}), "Max", "llvm", "myfloat") 166 register_op( 167 create_lower_func({32: "Custom32Sqrt"}), 168 "Call", 169 "llvm", 170 "myfloat", 171 intrinsic_name="tir.sqrt", 172 ) 173 register_op( 174 create_lower_func({32: "Custom32Exp"}), "Call", "llvm", "myfloat", intrinsic_name="tir.exp" 175 ) 176 register_op( 177 create_lower_func({32: "Custom32Log"}), "Call", "llvm", "myfloat", intrinsic_name="tir.log" 178 ) 179 register_op( 180 create_lower_func({32: "Custom32Sigmoid"}), 181 "Call", 182 "llvm", 183 "myfloat", 184 intrinsic_name="tir.sigmoid", 185 ) 186 register_op( 187 create_lower_func({32: "Custom32Tanh"}), 188 "Call", 189 "llvm", 190 "myfloat", 191 intrinsic_name="tir.tanh", 192 ) 193 register_op(lower_ite, "Call", "llvm", "myfloat", intrinsic_name="tir.if_then_else") 194 register_op( 195 lower_call_pure_extern, "Call", "llvm", "myfloat", intrinsic_name="tir.call_pure_extern" 196 ) 197 198 register_min_func(create_min_lower_func({32: "MinCustom32"}, "myfloat"), "myfloat") 199 200 201def setup_posites2(): 202 """Set up tests for posites2 203 Currently, this registers some custom datatypes using the Bring Your 204 Own Datatypes framework. 205 """ 206 207 # To use datatype operations in an external library, you should first load 208 # the library containing the datatype implementation: 209 # CDLL("libposit.so", RTLD_GLOBAL) 210 # In this case, the datatype library we are using is built right into TVM, 211 # so we do not need to explicitly load any library. 212 213 # You can pick a code for your datatype arbitrarily, as long as it is 214 # greater than 128 and has not already been chosen. 215 216 register("posites2", 132) 217 218 register_op( 219 create_lower_func( 220 { 221 (32, 32): "FloatToPosit32es2", 222 (32, 16): "FloatToPosit16es2", 223 (32, 8): "FloatToPosit8es2", 224 } 225 ), 226 "Cast", 227 "llvm", 228 "float", 229 "posites2", 230 ) 231 register_op( 232 create_lower_func( 233 { 234 (32, 32): "Posit32es2ToFloat", 235 (16, 32): "Posit16es2ToFloat", 236 (8, 32): "Posit8es2ToFloat", 237 } 238 ), 239 "Cast", 240 "llvm", 241 "posites2", 242 "float", 243 ) 244 register_op( 245 create_lower_func({32: "Posit32es2Add", 16: "Posit16es2Add", 8: "Posit8es2Add"}), 246 "Add", 247 "llvm", 248 "posites2", 249 ) 250 register_op( 251 create_lower_func({32: "Posit32es2Sub", 16: "Posit16es2Sub", 8: "Posit8es2Sub"}), 252 "Sub", 253 "llvm", 254 "posites2", 255 ) 256 register_op( 257 create_lower_func( 258 {32: "FloatToPosit32es2", 16: "FloatToPosit16es2", 8: "FloatToPosit8es2"} 259 ), 260 "FloatImm", 261 "llvm", 262 "posites2", 263 ) 264 register_op( 265 create_lower_func({32: "Posit32es2Mul", 16: "Posit16es2Mul", 8: "Posit8es2Mul"}), 266 "Mul", 267 "llvm", 268 "posites2", 269 ) 270 register_op( 271 create_lower_func({32: "Posit32es2Div", 16: "Posit16es2Div", 8: "Posit8es2Div"}), 272 "Div", 273 "llvm", 274 "posites2", 275 ) 276 register_op( 277 create_lower_func({32: "Posit32es2Max", 16: "Posit16es2Max", 8: "Posit8es2Max"}), 278 "Max", 279 "llvm", 280 "posites2", 281 ) 282 register_op( 283 create_lower_func({32: "Posit32es2Sqrt", 16: "Posit16es2Sqrt", 8: "Posit8es2Sqrt"}), 284 "Call", 285 "llvm", 286 "posites2", 287 intrinsic_name="tir.sqrt", 288 ) 289 register_op(lower_ite, "Call", "llvm", "posites2", intrinsic_name="tir.if_then_else") 290 register_op( 291 lower_call_pure_extern, "Call", "llvm", "posites2", intrinsic_name="tir.call_pure_extern" 292 ) 293 register_op( 294 create_lower_func({32: "Posit32es2Exp", 16: "Posit16es2Exp", 8: "Posit8es2Exp"}), 295 "Call", 296 "llvm", 297 "posites2", 298 intrinsic_name="tir.exp", 299 ) 300 register_op( 301 create_lower_func({32: "Posit32es2Log", 16: "Posit16es2Log", 8: "Posit8es2Log"}), 302 "Call", 303 "llvm", 304 "posites2", 305 intrinsic_name="tir.log", 306 ) 307 register_op( 308 create_lower_func( 309 {32: "Posit32es2Sigmoid", 16: "Posit16es2Sigmoid", 8: "Posit8es2Sigmoid"} 310 ), 311 "Call", 312 "llvm", 313 "posites2", 314 intrinsic_name="tir.sigmoid", 315 ) 316 register_op( 317 create_lower_func({32: "Posit32es2Tanh", 16: "Posit16es2Tanh", 8: "Posit8es2Tanh"}), 318 "Call", 319 "llvm", 320 "posites2", 321 intrinsic_name="tir.tanh", 322 ) 323 324 register_min_func( 325 create_min_lower_func( 326 {32: "MinPosit32es2", 16: "MinPosit16es2", 8: "MinPosit8es2"}, "posites2" 327 ), 328 "posites2", 329 ) 330 331 332def run_ops(src_dtype, dst_dtype, rtol=1e-7, atol=1e-7): 333 """Run the same op, but with two different datatypes""" 334 # used for unary ops, first shape in binary ops 335 shape1 = (5, 10, 5) 336 # second shape for binary ops 337 shape2 = (5,) 338 339 def check_unary_op(op, src_dtype, dst_dtype, shape): 340 t1 = relay.TensorType(shape, src_dtype) 341 x = relay.var("x", t1) 342 z = op(x) 343 x_data = rs.rand(*shape).astype(t1.dtype) 344 345 module = tvm.IRModule.from_expr(relay.Function([x], z)) 346 347 compare(module, (x_data,), src_dtype, dst_dtype, rtol, atol) 348 349 # test unary ops 350 for op in [ 351 relay.nn.softmax, 352 tvm.relay.log, 353 tvm.relay.exp, 354 tvm.relay.sqrt, 355 tvm.relay.rsqrt, 356 tvm.relay.sigmoid, 357 tvm.relay.tanh, 358 relay.nn.relu, 359 relay.nn.batch_flatten, 360 ]: 361 check_unary_op(op, src_dtype, dst_dtype, shape1) 362 363 # test unary ops over 4d data 364 for op in [relay.nn.max_pool2d, relay.nn.avg_pool2d, relay.nn.global_avg_pool2d]: 365 shape_2d = (3, 32, 32, 32) 366 check_unary_op(op, src_dtype, dst_dtype, shape_2d) 367 368 def check_binary_op(opfunc, src_dtype, dst_dtype): 369 t1 = relay.TensorType(shape1, src_dtype) 370 t2 = relay.TensorType(shape2, src_dtype) 371 x = relay.var("x", t1) 372 y = relay.var("y", t2) 373 z = opfunc(x, y) 374 x_data = rs.rand(*shape1).astype(t1.dtype) 375 y_data = rs.rand(*shape2).astype(t2.dtype) 376 module = tvm.IRModule.from_expr(relay.Function([x, y], z)) 377 378 compare(module, (x_data, y_data), src_dtype, dst_dtype, rtol, atol) 379 380 for op in [ 381 relay.add, 382 relay.subtract, 383 relay.divide, 384 relay.multiply, 385 ]: 386 check_binary_op(op, src_dtype, dst_dtype) 387 388 # we would like to test tvm_if_then_else 389 # but Relay.IfNode is not lowered to this intrinsic, 390 # so to keep our tests consistent with relay, we decide to not unit test 391 # Note: tvm_if_then_else is tested as part of the mobile_net model 392 393 394def run_model(get_workload, input, src_dtype, dst_dtype, rtol=1e-4, atol=1e-4): 395 module, params = get_workload() 396 397 # we don't generate random data here 398 # because then the output data would all be around the same value 399 compare(module, input, src_dtype, dst_dtype, rtol, atol, params) 400 401 402def run_conv2d(src_dtype, dst_dtype, rtol=1e-7, atol=1e-4): 403 def run_test_conv2d( 404 src_dtype, 405 dst_dtype, 406 scale, 407 dshape, 408 kshape, 409 padding=(1, 1), 410 groups=1, 411 dilation=(1, 1), 412 **attrs, 413 ): 414 x = relay.var("x", shape=dshape, dtype=src_dtype) 415 w = relay.var("w", shape=kshape, dtype=src_dtype) 416 y = relay.nn.conv2d(x, w, padding=padding, dilation=dilation, groups=groups, **attrs) 417 module = tvm.IRModule.from_expr(relay.Function([x, w], y)) 418 data = rs.uniform(-scale, scale, size=dshape).astype(src_dtype) 419 kernel = rs.uniform(-scale, scale, size=kshape).astype(src_dtype) 420 421 compare(module, (data, kernel), src_dtype, dst_dtype, rtol, atol) 422 423 # depthwise conv2d 424 dshape = (1, 32, 18, 18) 425 kshape = (32, 1, 3, 3) 426 run_test_conv2d( 427 src_dtype, 428 dst_dtype, 429 1, 430 dshape, 431 kshape, 432 padding=(1, 1), 433 channels=32, 434 groups=32, 435 kernel_size=(3, 3), 436 ) 437 438 # CUDA is disabled for 'direct' schedule: 439 # https://github.com/dmlc/tvm/pull/3070#issuecomment-486597553 440 # group conv2d 441 dshape = (1, 32, 18, 18) 442 kshape = (32, 4, 3, 3) 443 run_test_conv2d( 444 src_dtype, 445 dst_dtype, 446 1, 447 dshape, 448 kshape, 449 padding=(1, 1), 450 channels=32, 451 groups=8, 452 kernel_size=(3, 3), 453 ) 454 # also group conv2d 455 dshape = (1, 32, 18, 18) 456 kshape = (64, 1, 3, 3) 457 run_test_conv2d( 458 src_dtype, 459 dst_dtype, 460 1, 461 dshape, 462 kshape, 463 padding=(1, 1), 464 channels=64, 465 groups=32, 466 kernel_size=(3, 3), 467 ) 468 469 # normal conv2d 470 dshape = (1, 3, 224, 224) 471 kshape = (10, 3, 3, 3) 472 run_test_conv2d( 473 src_dtype, dst_dtype, 1, dshape, kshape, padding=(1, 1), channels=10, kernel_size=(3, 3) 474 ) 475 476 # dilated conv2d 477 dshape = (1, 3, 18, 18) 478 kshape = (10, 3, 3, 3) 479 run_test_conv2d( 480 src_dtype, 481 dst_dtype, 482 1, 483 dshape, 484 kshape, 485 padding=(1, 1), 486 channels=10, 487 kernel_size=(3, 3), 488 dilation=(3, 3), 489 ) 490 491 492def run_batchnorm(src_dtype, dst_dtype, rtol=1e-6, atol=1e-6): 493 shape = (3, 32, 32) 494 t = relay.TensorType(shape, src_dtype) 495 x = relay.var("x", t) 496 bn = batch_norm_infer(data=x, epsilon=2e-5, scale=False, name="bn_x") 497 f = relay.Function(relay.analysis.free_vars(bn), bn) 498 499 x_data = rs.rand(*shape).astype(t.dtype) 500 module = tvm.IRModule.from_expr(f) 501 502 zero_data = np.zeros((32), "float32") 503 compare( 504 module, 505 (x_data, zero_data, zero_data, zero_data, zero_data), 506 src_dtype, 507 dst_dtype, 508 rtol, 509 atol, 510 ) 511 512 513def test_myfloat(): 514 setup_myfloat() 515 run_ops("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6) 516 run_conv2d("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6) 517 run_batchnorm("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6) 518 519 # mxnet python package not available 520 # run_model(get_mobilenet, (get_cat_image((224, 224)), ), 521 # 'float32', 522 # 'custom[myfloat]32') 523 524 525def _has_posit(): 526 return tvm.support.libinfo()["USE_BYODT_POSIT"] == "ON" 527 528 529@pytest.mark.skipif(not _has_posit(), reason="compiled with USE_BYODT_POSIT flag OFF") 530def test_posites2(): 531 setup_posites2() 532 run_ops("float32", "custom[posites2]8", rtol=1, atol=1) 533 run_ops("float32", "custom[posites2]16", rtol=0.01, atol=1) 534 run_ops("float32", "custom[posites2]32", rtol=1e-6, atol=1e-6) 535 536 run_conv2d("float32", "custom[posites2]8", rtol=1, atol=1) 537 run_conv2d("float32", "custom[posites2]16", rtol=0.01, atol=1) 538 run_conv2d("float32", "custom[posites2]32") 539 540 run_batchnorm("float32", "custom[posites2]8", rtol=1, atol=1) 541 run_batchnorm("float32", "custom[posites2]16", rtol=0.01, atol=1) 542 run_batchnorm("float32", "custom[posites2]32", rtol=1e-4, atol=1e-4) 543 # Expected posit8 might be faster, but it's not. 544 # run_model(get_mobilenet, (get_cat_image((224, 224)), ), 'float32', 'custom[posit8]8') 545 # run_model(get_mobilenet, (get_cat_image((224, 224)), ), 'float32', 'custom[posit32]32') 546 # run_model(get_inception, (get_cat_image((229, 229)), ), 'float32', 'custom[posit32]32') 547 # run_model(get_resnet, (get_cat_image((224, 224)), ), 'float32', 'custom[posit32]32') 548 549 # can't run cifar-10 sizes because dimensions 550 # don't match pretrained weights 551 552 # runs on the order of minutes... 553 # run_model(get_inception, (get_cat_image((229, 229)), ), 554 # 'float32', 555 # 'custom[posites2]32') 556 # run_model(get_resnet, (get_cat_image((224, 224)), ), 557 # 'float32', 558 # 'custom[posites2]32') 559 560 561if __name__ == "__main__": 562 pytest.main([__file__]) 563