1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17import tvm 18import tvm.testing 19from tvm import te 20import numpy 21 22 23def collect_visit(stmt, f): 24 ret = [] 25 tvm.tir.stmt_functor.post_order_visit(stmt, lambda x: ret.append(f(x))) 26 return ret 27 28 29def test_basic(): 30 n = te.size_var("n") 31 A = te.placeholder((n,), name="A") 32 B = te.placeholder((n,), name="B") 33 34 T = te.compute((n,), lambda i: A[i] + B[i]) 35 s = te.create_schedule(T.op) 36 xo, xi = s[T].split(T.op.axis[0], factor=4) 37 38 bounds = tvm.te.schedule.InferBound(s) 39 stmt = tvm.te.schedule.ScheduleOps(s, bounds) 40 41 mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) 42 mod = tvm.tir.transform.LoopPartition()(mod) 43 stmt = tvm.tir.transform.Simplify()(mod)["main"].body 44 45 assert not any(collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))) 46 assert any(collect_visit(stmt.body.body[1], lambda x: isinstance(x, tvm.tir.IfThenElse))) 47 48 49def test_const_loop(): 50 n = 21 51 A = te.placeholder((n,), name="A") 52 B = te.placeholder((n,), name="B") 53 54 T = te.compute((n,), lambda i: A[i] + B[i]) 55 s = te.create_schedule(T.op) 56 xo, xi = s[T].split(T.op.axis[0], factor=4) 57 58 bounds = tvm.te.schedule.InferBound(s) 59 stmt = tvm.te.schedule.ScheduleOps(s, bounds) 60 61 mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) 62 with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): 63 mod = tvm.tir.transform.LoopPartition()(mod) 64 stmt = tvm.tir.transform.Simplify()(mod)["main"].body 65 66 assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) 67 68 69def test_multi_loop(): 70 ib = tvm.tir.ir_builder.create() 71 m = te.size_var("m") 72 n = te.size_var("n") 73 with ib.for_range(0, 4, "i") as i: 74 with ib.for_range(0, n, "j") as j: 75 with ib.for_range(0, m, "k") as k: 76 with ib.if_scope(ib.likely(i * m + j + k < n)): 77 ib.emit(tvm.tir.Evaluate(m)) 78 with ib.else_scope(): 79 ib.emit(tvm.tir.Evaluate(n)) 80 stmt = ib.get() 81 82 mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n, m], stmt)) 83 mod = tvm.tir.transform.LoopPartition()(mod) 84 stmt = tvm.tir.transform.Simplify()(mod)["main"].body 85 86 assert not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))) 87 88 89def test_multi_if(): 90 ib = tvm.tir.ir_builder.create() 91 m = te.size_var("m") 92 n = te.size_var("n") 93 with ib.for_range(0, 4, "i") as i: 94 with ib.for_range(0, n, "j") as j: 95 with ib.for_range(0, m, "k") as k: 96 with ib.if_scope(ib.likely(i * m + j + k < n)): 97 ib.emit(tvm.tir.Evaluate(m)) 98 with ib.else_scope(): 99 ib.emit(tvm.tir.Evaluate(n)) 100 with ib.if_scope(ib.likely(i * m + j - k < n)): 101 ib.emit(tvm.tir.Evaluate(m)) 102 with ib.else_scope(): 103 ib.emit(tvm.tir.Evaluate(n)) 104 stmt = ib.get() 105 106 mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) 107 mod = tvm.tir.transform.LoopPartition()(mod) 108 stmt = tvm.tir.transform.Simplify()(mod)["main"].body 109 110 assert not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))) 111 112 113def test_thread_axis(): 114 m = te.size_var("m") 115 l = te.size_var("l") 116 A = te.placeholder((m, l), name="A") 117 B = te.compute((m, l), lambda i, j: A[i, j] + 3, name="B") 118 s = te.create_schedule(B.op) 119 120 s[B].set_scope("shared") 121 num_thread = 16 122 xo, xi = s[B].split(B.op.axis[0], 32) 123 xi0, xi1 = s[B].split(xi, nparts=num_thread) 124 s[B].bind(xi0, te.thread_axis("threadIdx.x")) 125 126 bounds = tvm.te.schedule.InferBound(s) 127 stmt = tvm.te.schedule.ScheduleOps(s, bounds) 128 129 mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) 130 mod = tvm.tir.transform.LoopPartition()(mod) 131 stmt = tvm.tir.transform.Simplify()(mod)["main"].body 132 133 assert not any(collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))) 134 135 136def test_vectorize(): 137 n = te.size_var("n") 138 A = te.placeholder((n,), name="A") 139 B = te.placeholder((n,), name="B") 140 bias = te.size_var("bias", dtype="float32") 141 scale = te.size_var("scale", dtype="float32") 142 C = te.compute(A.shape, lambda *i: A(*i) + B(*i) * scale + bias, name="C") 143 # schedule 144 s = te.create_schedule(C.op) 145 # create iter var and assign them tags. 146 num_thread = 32 147 bx, x = s[C].split(C.op.axis[0], factor=num_thread * 4) 148 tx, x = s[C].split(x, nparts=num_thread) 149 _, x = s[C].split(x, factor=4) 150 s[C].bind(bx, te.thread_axis("blockIdx.x")) 151 s[C].bind(tx, te.thread_axis("threadIdx.x")) 152 s[C].vectorize(x) 153 stmt = tvm.lower(s, [A, B], name="main")["main"].body 154 body = stmt.body.body.body.body 155 assert x.var.name not in str(body.condition) 156 assert any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp))) 157 158 159def test_condition(): 160 ib = tvm.tir.ir_builder.create() 161 m = te.size_var("m") 162 n = te.size_var("n") 163 with ib.for_range(0, tvm.tir.truncdiv(n + 3, 4), "i") as i: 164 with ib.for_range(0, 4, "j") as j: 165 ib.emit(tvm.tir.Evaluate(tvm.tir.Select(ib.likely(i * 4 + j < n), m, n))) 166 stmt = ib.get() 167 168 mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt)) 169 mod = tvm.tir.transform.LoopPartition()(mod) 170 stmt = tvm.tir.transform.Simplify()(mod)["main"].body 171 172 assert not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select))) 173 174 175def test_condition_EQ(): 176 ib = tvm.tir.ir_builder.create() 177 m = te.size_var("m") 178 n = te.size_var("n") 179 with ib.for_range(0, 10, "i") as i: 180 ib.emit(tvm.tir.Evaluate(tvm.tir.Select(ib.likely(tvm.tir.EQ(i, 5)), m, n))) 181 stmt = ib.get() 182 183 mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt)) 184 with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): 185 mod = tvm.tir.transform.LoopPartition()(mod) 186 stmt = tvm.tir.transform.Simplify()(mod)["main"].body 187 188 assert not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select))) 189 190 191def test_thread_axis2(): 192 n = tvm.runtime.convert(4096) 193 m = te.size_var("m") 194 A = te.placeholder((n,), name="A") 195 B = te.placeholder((n,), name="B") 196 C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") 197 s = te.create_schedule(C.op) 198 num_thread = 32 199 bx, x = s[C].split(C.op.axis[0], factor=32) 200 tx, x = s[C].split(x, nparts=num_thread) 201 _, x = s[C].split(x, factor=m) 202 s[C].bind(bx, te.thread_axis("blockIdx.x")) 203 s[C].bind(tx, te.thread_axis("threadIdx.x")) 204 stmt = tvm.lower(s, [A, B], name="main")["main"].body 205 for_body = stmt.body.body.body.body[0] 206 assert "threadIdx" not in str(for_body.extent) 207 208 209def test_everything_during_deduction(): 210 m = te.size_var("m") 211 n = te.size_var("n") 212 ib = tvm.tir.ir_builder.create() 213 with ib.for_range(0, n, "i") as i: 214 with ib.for_range(0, 32, "j") as j: 215 with ib.if_scope(ib.likely(tvm.tir.truncdiv(i, j) < m)): 216 # this guard will produce everything during deduction 217 ib.emit(tvm.tir.Evaluate(m)) 218 stmt = ib.get() 219 220 mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt)) 221 mod = tvm.tir.transform.LoopPartition()(mod) 222 stmt = tvm.tir.transform.Simplify()(mod)["main"].body 223 224 assert isinstance(stmt.body.body, tvm.tir.IfThenElse) 225 226 227def test_single_likely(): 228 n = 60 229 A = te.placeholder((n,), name="A") 230 B = te.placeholder((n,), name="B") 231 232 T = te.compute((n,), lambda i: A[i] + B[i]) 233 s = te.create_schedule(T.op) 234 x = T.op.axis[0] 235 xo, xi = s[T].split(x, factor=16) 236 237 bounds = tvm.te.schedule.InferBound(s) 238 stmt = tvm.te.schedule.ScheduleOps(s, bounds) 239 240 mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) 241 242 with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): 243 mod = tvm.tir.transform.LoopPartition()(mod) 244 stmt = tvm.tir.transform.Simplify()(mod)["main"].body 245 246 assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) 247 248 249def test_multi_likely(): 250 n = 94 251 m = 62 252 A = te.placeholder((n, m), name="A") 253 B = te.placeholder((n, m), name="B") 254 255 T = te.compute((n, m), lambda i, j: A[i, j] + B[i, j]) 256 s = te.create_schedule(T.op) 257 bounds = tvm.te.schedule.InferBound(s) 258 stmt = tvm.te.schedule.ScheduleOps(s, bounds) 259 x, y = T.op.axis 260 xo, xi = s[T].split(x, factor=16) 261 yo, yi = s[T].split(y, factor=16) 262 s[T].reorder(xo, yo, xi, yi) 263 264 bounds = tvm.te.schedule.InferBound(s) 265 stmt = tvm.te.schedule.ScheduleOps(s, bounds) 266 267 mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) 268 269 with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): 270 mod = tvm.tir.transform.LoopPartition()(mod) 271 stmt = tvm.tir.transform.Simplify()(mod)["main"].body 272 273 assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) 274 275 276def test_oneD_pool(): 277 m = te.size_var("m") 278 ib = tvm.tir.ir_builder.create() 279 # data = te.placeholder((16,), name = 'data') 280 data = ib.pointer("float32", name="A") 281 out = ib.pointer("float32", name="A") 282 with ib.for_range(0, 16, "ow") as ow: 283 with ib.for_range(0, 3, "kw") as kw: 284 with ib.if_scope(ib.likely(ow > 0)): 285 with ib.if_scope(ib.likely(ow < 15)): 286 out[ow] = tvm.te.max(out[ow], data[ow + kw - 1]) 287 with ib.for_range(0, 16, "ow") as ow: 288 with ib.for_range(0, 3, "kw") as kw: 289 with ib.if_scope(ib.likely(ow < 1)): 290 with ib.if_scope(ib.likely(kw > 0)): 291 out[ow] = tvm.te.max(out[ow], data[ow + kw - 1]) 292 with ib.for_range(0, 16, "ow") as ow: 293 with ib.for_range(0, 3, "kw") as kw: 294 with ib.if_scope(ib.likely(ow > 14)): 295 with ib.if_scope(ib.likely(kw < 2)): 296 out[ow] = tvm.te.max(out[ow], data[ow + kw - 1]) 297 298 stmt = ib.get() 299 300 mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, data, out], stmt)) 301 302 with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): 303 mod = tvm.tir.transform.LoopPartition()(mod) 304 stmt = tvm.tir.transform.Simplify()(mod)["main"].body 305 306 assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) 307 308 309def test_cce_loop_1(): 310 ib = tvm.tir.ir_builder.create() 311 dtype = "float16" 312 n = 514 313 m = 514 314 _A = te.placeholder((n * m,), name="A") 315 Ab = tvm.tir.decl_buffer((n * m,), dtype, name="A") 316 A = ib.buffer_ptr(Ab) 317 _B = te.placeholder((n * m,), name="B") 318 Bb = tvm.tir.decl_buffer((n * m,), dtype, name="B") 319 B = ib.buffer_ptr(Bb) 320 # for i in 0 to n-1: 321 with ib.for_range(0, 11, name="i") as i: 322 with ib.for_range(0, 160, name="j") as j: 323 with ib.if_scope(ib.likely(((i * 160) + j) < 1600)): 324 A[(i + 1) * m + j + 1] = ( 325 B[(i) * m + j + 1] + B[(i + 1) * m + j + 1] + B[(i + 2) * m + j + 1] 326 ) 327 stmt = ib.get() 328 329 mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) 330 with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): 331 mod = tvm.tir.transform.LoopPartition()(mod) 332 stmt = tvm.tir.transform.Simplify()(mod)["main"].body 333 334 assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) 335 336 337def test_cce_loop_2(): 338 ib = tvm.tir.ir_builder.create() 339 len = 112 340 tile = 32 341 loop = (len + tile - 1) // tile 342 with ib.for_range(0, loop, "i") as i: 343 head = i * tile 344 with ib.if_scope(ib.likely(head + tile > len)): 345 tail = len 346 ib.emit(tvm.tir.call_extern("float32", "cce_intrisic", head, tail)) 347 with ib.else_scope(): 348 tail = head + tile 349 ib.emit(tvm.tir.call_extern("float32", "cce_intrisic", head, tail)) 350 351 stmt = ib.get() 352 353 mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) 354 with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): 355 mod = tvm.tir.transform.LoopPartition()(mod) 356 stmt = tvm.tir.transform.Simplify()(mod)["main"].body 357 358 assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) 359 360 361def test_cce_loop_3(): 362 ib = tvm.tir.ir_builder.create() 363 loop1 = 4 364 loop2 = 9998 365 tile = 39991 366 with ib.for_range(0, loop2, "i") as i: 367 with ib.for_range(0, loop1, "j") as j: 368 head1 = i 369 head2 = j 370 with ib.if_scope(ib.likely(head1 * loop1 + head2 < tile)): 371 ib.emit(tvm.tir.call_extern("float16", "cce_intrisic", head1)) 372 373 stmt = ib.get() 374 mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) 375 376 with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): 377 mod = tvm.tir.transform.LoopPartition()(mod) 378 stmt = tvm.tir.transform.Simplify()(mod)["main"].body 379 380 assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) 381 382 383def test_conv_tiling(): 384 HSTR = WSTR = 1 385 in_channel = 128 386 kernel_height = kernel_width = 3 387 out_channel = 64 388 batch_size = 1 389 in_height = in_width = 64 390 out_height = out_width = in_height - kernel_height + 1 391 data = te.placeholder((batch_size, in_channel, in_height, in_width), name="data") 392 kernel = te.placeholder((kernel_height, kernel_width, in_channel, out_channel), name="kernel") 393 ic = te.reduce_axis((0, in_channel), name="ic") 394 kh = te.reduce_axis((0, kernel_height), name="kh") 395 kw = te.reduce_axis((0, kernel_width), name="kw") 396 conv = te.compute( 397 (batch_size, out_channel, out_height, out_width), 398 lambda n, oc, oh, ow: te.sum( 399 data[n, ic, oh * HSTR + kh, ow * WSTR + kw] * kernel[kh, kw, ic, oc], axis=[ic, kh, kw] 400 ), 401 name="conv2d", 402 ) 403 s = te.create_schedule(conv.op) 404 405 n, oc, oh, ow = conv.op.axis 406 oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16) 407 bounds = tvm.te.schedule.InferBound(s) 408 stmt = tvm.te.schedule.ScheduleOps(s, bounds) 409 410 mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) 411 with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): 412 mod = tvm.tir.transform.LoopPartition()(mod) 413 stmt = tvm.tir.transform.Simplify()(mod)["main"].body 414 415 assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) 416 417 418def test_multilevel_splitting_with_indivisble_factors(): 419 from tvm import topi 420 421 A = te.placeholder((130,), dtype="float32") 422 B = topi.nn.relu(A) 423 s = te.create_schedule(B.op) 424 (y,) = s[B].op.axis 425 (yo, yi) = s[B].split(y, factor=8) 426 (yoo, yoi) = s[B].split(yo, factor=16) 427 s[B].reorder(yoo, yoi, yi) 428 s[B].unroll(yi) 429 430 ## But this does the right thing. 431 with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): 432 lowered_body = tvm.lower(s, [A, B], name="x")["x"].body 433 434 def visit_stmt(op): 435 return isinstance(op, tvm.tir.Max) 436 437 num_max = collect_visit(lowered_body, visit_stmt) 438 assert num_max.count(True) == 10 439 440 441def test_double_splitting_with_indivisible_factors(): 442 m = 48 443 dtype = "float32" 444 A = te.placeholder((m,), name="A", dtype=dtype) 445 C = te.compute((m,), lambda i: A[i], name="C") 446 D = te.compute((m,), lambda i: C[i], name="D") 447 448 s = te.create_schedule(D.op) 449 co, ci = s[C].split(C.op.axis[0], factor=10) 450 do, di = s[D].split(D.op.axis[0], 32) 451 s[C].compute_at(s[D], do) 452 453 target = "llvm" 454 with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): 455 f = tvm.lower(s, [A, C, D], name="fadd1", simple_mode=False) 456 func = tvm.build(f, target=target) 457 458 top_produce = f["fadd1"].body 459 assert not any(collect_visit(top_produce, lambda x: isinstance(x, tvm.tir.IfThenElse))) 460 461 # check functional correctness of generated code 462 ctx = tvm.context(target, 0) 463 a = tvm.nd.array( 464 numpy.ones( 465 m, 466 ).astype(dtype), 467 ctx, 468 ) 469 c = tvm.nd.array( 470 numpy.zeros( 471 m, 472 ).astype(dtype), 473 ctx, 474 ) 475 d = tvm.nd.array( 476 numpy.zeros( 477 m, 478 ).astype(dtype), 479 ctx, 480 ) 481 func(a, c, d) 482 tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy(), rtol=1e-5) 483 tvm.testing.assert_allclose(d.asnumpy(), a.asnumpy(), rtol=1e-5) 484 485 486def test_simple_rfactor(): 487 K = 16 * 4 + 4 488 k = te.reduce_axis((0, K), "k") 489 490 A = te.placeholder((1, K), name="A") 491 492 B = te.compute((1,), lambda b: te.sum(A[b, k], axis=k), name="B") 493 494 s = te.create_schedule(B.op) 495 ko, _ = s[B].split(s[B].op.reduce_axis[0], 16) 496 BF = s.rfactor(B, ko, 0) 497 498 s.normalize() 499 bounds = tvm.te.schedule.InferBound(s) 500 stmt1 = tvm.te.schedule.ScheduleOps(s, bounds) 501 502 mod1 = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt1)) 503 stmt1 = tvm.tir.transform.Simplify()(mod1)["main"].body 504 505 with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): 506 mod2 = tvm.tir.transform.LoopPartition()(mod1) 507 stmt2 = tvm.tir.transform.Simplify()(mod2)["main"].body 508 509 # make sure loop partition actually did something 510 assert not tvm.ir.structural_equal(stmt1.body, stmt2.body) 511 512 513if __name__ == "__main__": 514 test_basic() 515 test_const_loop() 516 test_multi_loop() 517 test_multi_if() 518 test_thread_axis() 519 test_vectorize() 520 test_condition() 521 test_condition_EQ() 522 test_thread_axis2() 523 test_everything_during_deduction() 524 test_single_likely() 525 test_multi_likely() 526 test_oneD_pool() 527 test_cce_loop_1() 528 test_cce_loop_2() 529 test_cce_loop_3() 530 test_conv_tiling() 531 test_double_splitting_with_indivisible_factors() 532 test_multilevel_splitting_with_indivisble_factors() 533 test_simple_rfactor() 534