1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17import tvm, inspect, sys, traceback, numpy, pytest, types, os 18from tvm.contrib import util 19from tvm.hybrid import script 20from tvm.hybrid.runtime import HYBRID_GLOBALS 21 22@pytest.mark.skip 23def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None): 24 def tvm_val_2_py_val(val): 25 val = tvm.ir_pass.Substitute(val, var_dict) 26 val = tvm.ir_pass.Simplify(val) 27 assert isinstance(val, (tvm.expr.IntImm, tvm.expr.UIntImm)) 28 return val.value 29 30 ctx = tvm.context(target, 0) 31 op = None 32 33 if sch is None: 34 outs = func(*tuple(tvm.convert(i) if isinstance(i, list) else i for i in args)) 35 op = outs[0].op if isinstance(outs, list) else outs.op 36 sch = tvm.create_schedule(op) 37 else: 38 assert outs is not None 39 assert isinstance(outs, list) 40 op = outs[0].op 41 42 emu_args = [] 43 nd_args = [] 44 for i in args: 45 if isinstance(i, tvm.tensor.Tensor): 46 shape = [tvm_val_2_py_val(j) for j in i.shape] 47 emu_args.append(numpy.random.randn(*shape).astype(i.dtype)) 48 nd_args.append(tvm.nd.array(emu_args[-1], ctx)) 49 elif isinstance(i, tvm.expr.Var): 50 emu_args.append(tvm_val_2_py_val(i)) 51 nd_args.append(emu_args[-1]) 52 else: 53 assert isinstance(i, list) 54 emu_args.append(numpy.array(i)) 55 56 compile_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.expr.Var))] + \ 57 (outs if isinstance(outs, list) else [outs]) 58 module = tvm.build(sch, 59 compile_args, 60 target=target) 61 assert module 62 63 out_tensors = [] 64 for i in range(op.num_outputs): 65 output = op.output(i) 66 shape = [tvm_val_2_py_val(j) for j in output.shape] 67 nd_args.append(tvm.nd.array(numpy.zeros(shape).astype(output.dtype), ctx)) 68 out_tensors.append(nd_args[-1]) 69 70 ref_data = func(*emu_args) 71 if isinstance(ref_data, numpy.ndarray): 72 ref_data = [ref_data] 73 74 module(*nd_args) 75 76 for nd, np in zip(out_tensors, ref_data): 77 tvm.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-5, atol=1e-5) 78 79 module_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.expr.Var))] 80 module_outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs 81 h_module = tvm.hybrid.build(sch, module_args, module_outs) 82 83 return h_module, module_args, module_outs 84 85@script 86def outer_product(n, m, a, b): 87 """This is a simple outer product. 88 Actually this function is not required to be documented. 89 I write this docstring to test skipping docstring functionality. 90 """ 91 c = output_tensor((n, m), a.dtype) 92 for i in range(n): 93 for j in range(m): 94 assert i < n and j < m, "index out of range!" 95 c[i, j] = a[i] * b[j] 96 return c 97 98#Test global function 99#Test bridge between frontend and backend 100def test_outer_product(): 101 n = tvm.var('n') 102 m = tvm.var('m') 103 a = tvm.placeholder((n, ), name='a') 104 b = tvm.placeholder((m, ), name='b') 105 106 try: 107 c = outer_product(n, m, a, b) 108 ir = c.op.body 109 except IOError as err: 110 assert sys.version_info[0] == 2 and str(err) == 'could not get source code' 111 return 112 113 #Check for i in (0, n) 114 assert isinstance(ir, tvm.stmt.For) 115 assert ir.loop_var.name == 'i' 116 assert ir.min.value == 0 117 assert ir.extent.name == 'n' 118 ibody = ir.body 119 assert isinstance(ibody, tvm.stmt.For) 120 #Check for j in (0, m) 121 assert ibody.loop_var.name == 'j' 122 assert ibody.min.value == 0 123 assert ibody.extent.name == 'm' 124 #Check loop body 125 jblock = ibody.body 126 assert isinstance(jblock, tvm.stmt.Block) 127 jbody = jblock.first 128 assert isinstance(jbody, tvm.stmt.AssertStmt) 129 assert isinstance(jbody.message, tvm.expr.StringImm) 130 assert jbody.message.value == "index out of range!" 131 jbody = jblock.rest 132 assert isinstance(jbody, tvm.stmt.Provide) 133 assert jbody.func.name == 'c' 134 assert len(jbody.args) == 2 135 assert jbody.args[0].name == 'i' 136 assert jbody.args[1].name == 'j' 137 assert isinstance(jbody.value, tvm.expr.Mul) 138 mul = jbody.value 139 assert isinstance(mul.a, tvm.expr.Call) 140 assert mul.a.name == 'a' 141 assert mul.b.name == 'b' 142 143 func, ins, outs = run_and_check(outer_product, [n, m, a, b], {n: 99, m: 101}) 144 temp = util.tempdir() 145 path = temp.relpath('%s.py' % func.name) 146 func.save(path) 147 func_ = tvm.hybrid.HybridModule() 148 func_.load(path) 149 run_and_check(func_, ins, {n: 99, m: 101}, outs=outs) 150 151 for key, _ in HYBRID_GLOBALS.items(): 152 assert key not in globals().keys() 153 assert key not in outer_product.__globals__.keys() 154 155#Test local function 156#Test allocation of local variable 157def test_fanout(): 158 @script 159 def fanout(n, a): 160 three = 3.0 161 b = output_tensor((a.shape[0] - 3, ), a.dtype) 162 for i in range(a.shape[0] - 3): 163 sigma = 0.0 164 for j in range(3): 165 sigma += a[i + j] 166 sigma = sigma / three 167 b[i] = sigma 168 return b 169 170 n = tvm.var('n') 171 a = tvm.placeholder((n, ), 'float32', name='a') 172 try: 173 b = fanout(n, a) 174 ir = b.op.body 175 except IOError as err: 176 assert sys.version_info[0] == 2 and str(err) == 'could not get source code' 177 return 178 179 #Check for i in (0, n-3) 180 assert isinstance(ir, tvm.stmt.For) 181 assert ir.loop_var.name == 'i' 182 assert ir.min.value == 0 183 assert tvm.ir_pass.Equal(ir.extent, n - 3) 184 #Check loopbody 185 ibody = ir.body 186 assert isinstance(ibody, tvm.stmt.AttrStmt) 187 abody = ibody.body 188 assert isinstance(abody, tvm.stmt.Realize) 189 assert abody.bounds[0].min.value == 0 190 assert abody.bounds[0].extent.value == 1 191 assert abody.func.name == 'sigma' 192 #Check i loop body 193 rbody = abody.body 194 assert isinstance(rbody.first, tvm.stmt.Provide) 195 assert rbody.first.func.name == 'sigma' 196 assert len(rbody.first.args) == 1 197 assert rbody.first.args[0].value == 0 198 #Check fanout loop 199 jloop = rbody.rest.first 200 assert jloop.loop_var.name == 'j' 201 assert jloop.min.value == 0 202 assert jloop.extent.value == 3 203 jbody = jloop.body 204 assert isinstance(jbody, tvm.stmt.Provide) 205 assert len(jbody.args) == 1 206 assert jbody.args[0].value == 0 207 assert jbody.func.name == 'sigma' 208 assert isinstance(jbody.value, tvm.expr.Add) 209 value = jbody.value 210 assert isinstance(value.a, tvm.expr.Call) 211 assert value.a.name == 'sigma' 212 assert len(value.a.args) == 1 213 assert value.a.args[0].value == 0 214 assert value.b.name == 'a' 215 assert len(value.b.args) == 1 216 assert tvm.ir_pass.Equal(value.b.args[0], ir.loop_var + jloop.loop_var) 217 divide= rbody.rest.rest.first 218 assert isinstance(divide, tvm.stmt.Provide) 219 assert len(divide.args) == 1 220 assert divide.args[0].value == 0 221 value = divide.value 222 assert isinstance(value, tvm.expr.Mul) 223 assert value.a.name == 'sigma' 224 assert len(value.a.args) == 1 225 assert value.a.args[0].value == 0 226 assert abs(value.b.value - (1 / 3.0)) < 1e-5 227 write = rbody.rest.rest.rest 228 assert isinstance(write, tvm.stmt.Provide) 229 assert write.func.name == 'b' 230 assert write.value.name == 'sigma' 231 assert len(write.value.args) == 1 232 assert write.value.args[0].value == 0 233 234 func, ins, outs = run_and_check(fanout, [n, a], {n: 10}) 235 run_and_check(func, ins, {n: 10}, outs=outs) 236 237 238def test_looptype(): 239 @script 240 def looptype(a, b, c): 241 d = output_tensor((16, ), 'int32') 242 e = output_tensor((16, ), 'int32') 243 f = output_tensor((16, ), 'int32') 244 for i in parallel(16): 245 d[i] = a[i] 246 for j in vectorize(16): 247 e[j] = b[j] 248 for k in unroll(16): 249 f[k] = c[k] 250 return d, e, f 251 252 a = tvm.placeholder((16, ), name='a', dtype='int32') 253 b = tvm.placeholder((16, ), name='b', dtype='int32') 254 c = tvm.placeholder((16, ), name='c', dtype='int32') 255 try: 256 d, e, f = looptype(a, b, c) 257 ir = d.op.body 258 except: 259 return 260 iloop = ir.first 261 jloop = ir.rest.first 262 kloop = ir.rest.rest 263 assert iloop.for_type == tvm.stmt.For.Parallel 264 assert jloop.for_type == tvm.stmt.For.Vectorized 265 assert kloop.for_type == tvm.stmt.For.Unrolled 266 267 func, ins, outs = run_and_check(looptype, [a, b, c]) 268 run_and_check(func, ins, outs=outs) 269 270 271def test_if(): 272 @script 273 def if_then_else(a): 274 b = output_tensor((10, ), 'int32') 275 c = output_tensor((10, ), 'int32') 276 for i in range(10): 277 if i % 2 == 0: 278 c[i] = a[i] 279 else: 280 c[i] = b[i] 281 for i in unroll(10): 282 b[i] = -1 if i % 2 == 0 else 1 283 return b, c 284 285 a = tvm.placeholder((10, ), dtype='int32', name='a') 286 287 func, ins, outs = run_and_check(if_then_else, [a]) 288 run_and_check(func, ins, outs=outs) 289 290 @script 291 def if_triple_condition(a): 292 b = output_tensor((10, ), 'int32') 293 for i in range(10): 294 if 0 <= i < 5: 295 b[i] = a[i] 296 else: 297 b[i] = a[i] + 1 298 return b 299 300 func, ins, outs = run_and_check(if_triple_condition, [a]) 301 run_and_check(func, ins, outs=outs) 302 303 @script 304 def if_and(a): 305 b = output_tensor((10, ), 'int32') 306 for i in range(10): 307 if i >= 0 and i < 5: 308 b[i] = a[i] 309 else: 310 b[i] = a[i] + 1 311 return b 312 313 func, ins, outs = run_and_check(if_and, [a]) 314 run_and_check(func, ins, outs=outs) 315 316 317def test_bind(): 318 if not tvm.gpu(0).exist: 319 print('[Warning] No GPU found! Skip bind test!') 320 return 321 322 @script 323 def vec_add(a, b): 324 c = output_tensor((1000, ), 'float32') 325 for tx in bind('threadIdx.x', 1000): 326 c[tx] = a[tx] + b[tx] 327 return c 328 329 a = tvm.placeholder((1000, ), dtype='float32', name='a') 330 b = tvm.placeholder((1000, ), dtype='float32', name='b') 331 func, ins, outs = run_and_check(vec_add, [a, b], target='cuda') 332 run_and_check(func, ins, outs=outs, target='cuda') 333 334 @script 335 def raw(a, b): 336 c = output_tensor((1000, ), 'float32') 337 for i in range(1000): 338 c[i] = a[i] + b[i] 339 return c 340 341 c = raw(a, b) 342 sch = tvm.create_schedule(c.op) 343 x = tvm.thread_axis('threadIdx.x') 344 sch[c].bind(c.op.axis[0], x) 345 func, ins, outs = run_and_check(raw, [a, b], sch=sch, outs=[c], target='cuda') 346 run_and_check(func, ins, outs=outs, target='cuda') 347 348 349 @tvm.hybrid.script 350 def foo(a): 351 c = output_tensor((a.shape[0],), a.dtype) 352 total = allocate((1,), a.dtype, 'local') 353 len_i = a.shape[0] 354 len_j = a.shape[1] 355 for i in bind('threadIdx.x', len_i): 356 total[0] = 0. 357 for k in const_range(len_j): 358 total[0] += a[i, k] 359 c[i] = total[0] 360 361 return c 362 363 a = tvm.placeholder((8, 4), 'float32') 364 c = foo(a) 365 s = tvm.create_schedule(c.op) 366 ir = tvm.lower(s, [a, c], simple_mode=True) 367 assert not isinstance(ir, tvm.stmt.AttrStmt) 368 func, ins, outs = run_and_check(foo, [a], target='cuda') 369 run_and_check(func, ins, outs=outs, target='cuda') 370 371 @tvm.hybrid.script 372 def max_threads(a): 373 b = output_tensor(a.shape, a.dtype) 374 n = a.shape[0] 375 m = max_num_threads(True) 376 for i in bind('threadIdx.x', m): 377 for j in bind('blockIdx.x', ceil_div(n, m)): 378 if i * m + j < n: 379 b[i * m + j] = a[i * m + j] + a[i * m + j] 380 return b 381 382 a = tvm.placeholder((10000, ), 'float32') 383 with tvm.target.create('cuda'): 384 func, ins, outs = run_and_check(max_threads, [a], target='cuda') 385 run_and_check(func, ins, outs=outs, target='cuda') 386 387 388def test_math_intrin(): 389 @script 390 def intrin_real(a): 391 b = output_tensor((8, ), 'float32') 392 b[0] = sqrt(a[0]) 393 b[1] = log(a[1]) 394 b[2] = exp(a[2]) 395 b[3] = sigmoid(a[3]) 396 b[4] = power(a[4], a[5]) 397 b[5] = tanh(a[5]) 398 b[6] = min(a[4], a[5]) 399 b[7] = max(a[5], a[6]) 400 return b 401 402 a8 = tvm.placeholder((8, ), dtype='float32', name='a') 403 b8 = intrin_real(a8) 404 sch = tvm.create_schedule(b8.op) 405 func = tvm.build(sch, [a8, b8]) 406 assert func 407 a = numpy.arange(2, 10).astype('float32') 408 tvm_a = tvm.ndarray.array(a) 409 tvm_b = tvm.ndarray.array(numpy.zeros((8, ), dtype='float32')) 410 b = intrin_real(a) 411 func(tvm_a, tvm_b) 412 tvm.testing.assert_allclose(b, tvm_b.asnumpy(), rtol=1e-5) 413 414 @script 415 def intrin_int(a): 416 b = output_tensor((1, ), 'int32') 417 b[0] = popcount(a[0]) 418 return b 419 420 a1 = tvm.placeholder((1, ), dtype='int32') 421 b1 = intrin_int(a1) 422 sch = tvm.create_schedule(b1.op) 423 func = tvm.build(sch, [a1, b1]) 424 assert func 425 a = numpy.array([114514]).astype('int32') 426 tvm_a = tvm.ndarray.array(a) 427 tvm_b = tvm.ndarray.array(numpy.array([0]).astype('int32')) 428 b = intrin_int(a) 429 func(tvm_a, tvm_b) 430 assert tvm_b.asnumpy()[0] == b[0] 431 432# test non caconical loops 433def test_non_zero(): 434 @tvm.hybrid.script 435 def blur(a): 436 b = output_tensor((30, 30), 'float32') 437 for i in range(2, 32): 438 for j in range(2, 32): 439 s = 0.0 440 for di in range(3): 441 for dj in range(3): 442 s += a[i-di, j-dj] 443 b[i-2, j-2] = s / 9.0 444 return b 445 446 a = tvm.placeholder((32, 32), 'float32', 'a') 447 func, ins, outs = run_and_check(blur, [a]) 448 run_and_check(func, ins, outs=outs) 449 450 @tvm.hybrid.script 451 def triangle(a, b): 452 c = output_tensor((10, 10), dtype='float32') 453 for i in range(10): 454 for j in range(i, 10): 455 c[i, j] = a[i] * b[j] 456 return c 457 458 a = tvm.placeholder((10, ), dtype='float32', name='a') 459 b = tvm.placeholder((10, ), dtype='float32', name='b') 460 461 func, ins, outs = run_and_check(triangle, [a, b]) 462 run_and_check(func, ins, outs=outs) 463 464def test_allocate(): 465 @tvm.hybrid.script 466 def blur2d(a): 467 b = output_tensor((30, 30), 'float32') 468 for i in range(30): 469 ha = allocate((3, 30), 'float32') 470 for j in range(3): 471 for k in range(30): 472 ha[j, k] = a[i+j, k] + a[i+j, k+1] + a[i+j, k+2] 473 for j in range(30): 474 b[i, j] = (ha[0, j] + ha[1, j] + ha[2, j]) / 9.0 475 return b 476 477 a = tvm.placeholder((32, 32), 'float32', 'a') 478 b = blur2d(a) 479 sch = tvm.create_schedule(b.op) 480 func, ins, outs = run_and_check(blur2d, [a]) 481 run_and_check(func, ins, outs=outs) 482 483 if tvm.gpu().exist: 484 @tvm.hybrid.script 485 def share_vec_add(a, b): 486 c = output_tensor((256, ), 'float32') 487 shared = allocate((256, ), 'float32', 'shared') 488 for i in bind("threadIdx.x", 256): 489 shared[i] = a[i] 490 local = allocate((256, ), 'float32', 'local') 491 for i in bind("threadIdx.x", 256): 492 local[i] = b[i] 493 for i in bind("threadIdx.x", 256): 494 c[i] = shared[i] + local[i] 495 return c 496 497 a = tvm.placeholder((256, ), dtype='float32', name='a') 498 b = tvm.placeholder((256, ), dtype='float32', name='b') 499 c = share_vec_add(a, b) 500 func, ins, outs = run_and_check(share_vec_add, [a, b], target='cuda') 501 run_and_check(func, ins, outs=outs, target='cuda') 502 else: 503 print('[Warning] No GPU found! Skip shared mem test!') 504 505def test_upstream(): 506 @tvm.hybrid.script 507 def upstream(a): 508 b = output_tensor((20, ), 'float32') 509 for i in range(20): 510 b[i] = a[i] * i 511 return b 512 513 a = tvm.placeholder((20, ), 'float32') 514 b = tvm.placeholder((20, ), 'float32') 515 c = tvm.compute((20, ), lambda x: a[x] + b[x]) 516 d = upstream(c) 517 sch = tvm.create_schedule([c.op, d.op]) 518 ir = tvm.lower(sch, [a, b, d], simple_mode=True) 519 func = tvm.build(sch, [a, b, d]) 520 assert(func) 521 522 a = numpy.random.randn(20).astype('float32') 523 b = numpy.random.randn(20).astype('float32') 524 ref = numpy.zeros((20, ), 'float32') 525 for i in range(20): 526 ref[i] = (a[i] + b[i]) * i 527 528 tvm_a = tvm.nd.array(a) 529 tvm_b = tvm.nd.array(b) 530 tvm_d = tvm.nd.array(numpy.zeros((20, )).astype('float32')) 531 532 func(tvm_a, tvm_b, tvm_d) 533 tvm.testing.assert_allclose(tvm_d.asnumpy(), ref, 1e-5, 1e-5) 534 535def test_downstream(): 536 @tvm.hybrid.script 537 def downstream(a): 538 b = output_tensor((20, ), 'float32') 539 for i in range(20): 540 b[i] = a[i] * i 541 return b 542 543 544 a = tvm.placeholder((20, ), 'float32') 545 b = downstream(a) 546 c = tvm.compute((20, ), lambda x: b[x] + 1.0) 547 548 sch = tvm.create_schedule(c.op) 549 module = tvm.build(sch, [a, c]) 550 assert module 551 552 a = numpy.random.randn(20).astype('float32') 553 ref = numpy.zeros((20, )).astype('float32') 554 for i in range(20): 555 ref[i] = (a[i] * i) + 1.0 556 557 tvm_a = tvm.nd.array(a) 558 tvm_c = tvm.nd.array(numpy.zeros((20, )).astype('float32')) 559 module(tvm_a, tvm_c) 560 tvm.testing.assert_allclose(tvm_c.asnumpy(), ref, 1e-5, 1e-5) 561 562def test_const_param(): 563 @tvm.hybrid.script 564 def add_something(a, b): 565 c = output_tensor((11, ), 'int32') 566 for i in range(11): 567 c[i] = a[i] + b 568 return c 569 570 a = tvm.placeholder((11, ), dtype='int32', name='a') 571 b = tvm.const(11, 'int32') 572 c = add_something(a, b) 573 sch = tvm.create_schedule(c.op) 574 module = tvm.build(sch, [a, c], 'llvm') 575 assert(module) 576 577 np_a = numpy.arange(11).astype('int32') 578 np_b = 11 579 np_c = numpy.zeros((11, )).astype('int32') 580 581 nd_a = tvm.ndarray.array(np_a) 582 nd_c = tvm.ndarray.array(numpy.zeros((11, )).astype('int32')) 583 module(nd_a, nd_c) 584 ref = add_something(np_a, 11) 585 586 tvm.testing.assert_allclose(nd_c.asnumpy(), ref, 1e-5, 1e-5) 587 588def test_value_index(): 589 @tvm.hybrid.script 590 def kernel_a(a): 591 b = output_tensor((16, ), 'int32') 592 c = output_tensor((4, 4), 'int32') 593 for i in range(16): 594 b[i] = a[i] + 2 595 c[i // 4, i % 4] = a[i] + 1 596 return b, c 597 598 @tvm.hybrid.script 599 def kernel_b(b, a): 600 c = output_tensor((4, 4), 'int32') 601 for i in range(4): 602 for j in range(4): 603 c[i, j] = a[i * 4 + j] * b[i, j] 604 return c 605 606 a = tvm.placeholder((16, ), 'int32') 607 b, c = kernel_a(a) 608 d = kernel_b(c, b) 609 sch = tvm.create_schedule(d.op) 610 module = tvm.build(sch, [a, d]) 611 assert module 612 613 np_a = numpy.arange(16).astype('int32') 614 np_b, np_c = kernel_a(np_a) 615 ref = kernel_b(np_c, np_b) 616 617 res = tvm.ndarray.array(numpy.zeros((4, 4)).astype('int32')) 618 module(tvm.ndarray.array(np_a), res) 619 tvm.testing.assert_allclose(res.asnumpy(), ref) 620 621def test_func_call(): 622 @tvm.hybrid.script 623 def foo(a, b): 624 for i in range(len(a)): 625 a[i] = i + 1.0 626 for i in range(len(a)): 627 b[i] = i + 1.0 628 c = outer_product(10, 10, a, b) 629 d = output_tensor(c.shape, c.dtype) 630 for i in range(10): 631 for j in range(10): 632 d[i, j] = c[i, j] + i * j 633 return d 634 635 a = tvm.placeholder((10, ), name='a') 636 b = tvm.placeholder((10, ), name='b') 637 func, ins, outs = run_and_check(foo, [a, b]) 638 run_and_check(func, ins, outs=outs) 639 640def test_bool(): 641 @tvm.hybrid.script 642 def foo(a): 643 b = output_tensor(a.shape, a.dtype) 644 b[0] = 1.2 645 for i in range(1, a.shape[0] - 1): 646 if a[i] * a[i - 1] < a[i] or a[i] * a[i - 1] < a[i - 1] or i * a[i] == a[i]: 647 b[i] = a[i] 648 else: 649 b[i] = 0.0 650 return b 651 a = tvm.placeholder((10, ), name='a') 652 func, ins, outs = run_and_check(foo, [a]) 653 run_and_check(func, ins, outs=outs) 654 655def test_const_range(): 656 @tvm.hybrid.script 657 def foo(a, b): 658 c = output_tensor(a.shape, a.dtype) 659 d = output_tensor(a.shape, 'int32') 660 661 for i in const_range(2): 662 for j in const_range(5): 663 c[i, j] = float32(int32(a[i, j]) + b[i, j]) 664 665 for i in const_range(len(b)): 666 for j in const_range(len(b[0])): 667 d[i, j] = int32(a[i, j] + b[i, j]) 668 669 return c, d 670 671 a = tvm.placeholder((2, 5), name='a', dtype='float32') 672 b = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1]] 673 func, ins, outs = run_and_check(foo, [a, b]) 674 run_and_check(func, ins, outs=outs) 675 676 @tvm.hybrid.script 677 def goo(a, b): 678 c = output_tensor(a.shape, a.dtype) 679 len_b = len(b) 680 for i in const_range(len_b * 2): 681 if i < len_b: 682 c[i] = a[i] + b[i] 683 else: 684 c[i - len_b] = a[i - len_b] + b[i - len_b] 685 return c 686 a = tvm.placeholder((5, ), name='a', dtype='int32') 687 b = [1, 2, 3, 4, 5] 688 c = goo(a, tvm.convert(b)) 689 sch = tvm.create_schedule(c.op) 690 func, ins, outs = run_and_check(goo, [a, b]) 691 run_and_check(func, ins, outs=outs) 692 693 @tvm.hybrid.script 694 def hoo(a, b): 695 c = output_tensor(a.shape, a.dtype) 696 len_b = len(b) 697 for i in range(a.shape[0]): 698 for j in const_range(len(b)): 699 d = a[i] * b[j] 700 d += a[i] + b[j] 701 c[i] = d 702 return c 703 a = tvm.placeholder((5, ), name='a', dtype='int32') 704 b = [1, 2, 3, 4, 5] 705 func, ins, outs = run_and_check(hoo, [a, b]) 706 run_and_check(func, ins, outs=outs) 707 708def test_schedule(): 709 @script 710 def outer_product(a, b): 711 c = output_tensor((64, 64), a.dtype) 712 for i in range(64): 713 for j in range(64): 714 c[i, j] = a[i] * b[j] 715 return c 716 a = tvm.placeholder((64,), name='a', dtype='float32') 717 b = tvm.placeholder((64,), name='b', dtype='float32') 718 c = outer_product(a, b) 719 720 # Test perfect loop split 721 # Test loop reorder 722 # Test loop annotation 723 sch = tvm.create_schedule(c.op) 724 i, j = c.op.axis 725 io, ii = sch[c].split(i, 4) 726 sch[c].parallel(ii) 727 jo, ji = sch[c].split(j, 4) 728 joo, joi = sch[c].split(jo, 4) 729 sch[c].vectorize(ji) 730 sch[c].reorder(ii, io, joo, joi, ji) 731 ir = tvm.lower(sch, [a, b, c], simple_mode=True) 732 assert isinstance(ir, tvm.stmt.ProducerConsumer) 733 ir = ir.body 734 assert isinstance(ir, tvm.stmt.AttrStmt) 735 ir = ir.body 736 assert isinstance(ir, tvm.stmt.For) 737 assert ir.loop_var.name == 'i.inner' 738 ir = ir.body 739 assert isinstance(ir, tvm.stmt.For) 740 assert ir.loop_var.name == 'i.outer' 741 ir = ir.body 742 assert isinstance(ir, tvm.stmt.For) 743 assert ir.loop_var.name == 'j.outer.outer' 744 ir = ir.body 745 assert isinstance(ir, tvm.stmt.For) 746 assert ir.loop_var.name == 'j.outer.inner' 747 ir = ir.body 748 func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c]) 749 run_and_check(func, ins, outs=outs) 750 751 # Test fuse 752 sch = tvm.create_schedule(c.op) 753 sch[c].fuse(c.op.axis[0], c.op.axis[1]) 754 ir = tvm.lower(sch, [a, b, c], simple_mode=True) 755 assert isinstance(ir, tvm.stmt.ProducerConsumer) 756 ir = ir.body 757 assert isinstance(ir, tvm.stmt.AttrStmt) 758 ir = ir.body 759 assert isinstance(ir, tvm.stmt.For) 760 assert ir.loop_var.name == 'i.j.fused' 761 func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c]) 762 run_and_check(func, ins, outs=outs) 763 764 # Test imperfect loop split 765 sch = tvm.create_schedule(c.op) 766 sch[c].split(c.op.axis[0], 3) 767 ir = tvm.lower(sch, [a, b, c], simple_mode=True) 768 func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c]) 769 run_and_check(func, ins, outs=outs) 770 771 # Test loop binds 772 773def test_capture(): 774 n = 8 775 776 constant_tuple = (10, n) 777 constant_list = [[1, 2], [3, n]] 778 const_value = 1 779 780 @tvm.hybrid.script 781 def add_something(a): 782 c = output_tensor((constant_tuple[1],), 'int32') 783 for i in range(constant_tuple[1]): 784 c[i] = a[i] + constant_list[1][const_value] 785 return c 786 787 a = tvm.placeholder((n, ), dtype='int32', name='a') 788 789 func, ins, outs = run_and_check(add_something, [a]) 790 run_and_check(func, ins, outs=outs) 791 792if __name__ == "__main__": 793 test_outer_product() 794 test_fanout() 795 test_looptype() 796 test_if() 797 test_bind() 798 test_math_intrin() 799 test_non_zero() 800 test_allocate() 801 test_upstream() 802 test_downstream() 803 test_const_param() 804 test_value_index() 805 test_func_call() 806 test_bool() 807 test_const_range() 808 test_schedule() 809 test_capture() 810 # TODO: 811 # test_inplace() 812