1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17import tvm
18import tvm.testing
19from tvm import te
20import numpy
21
22
23def collect_visit(stmt, f):
24    ret = []
25    tvm.tir.stmt_functor.post_order_visit(stmt, lambda x: ret.append(f(x)))
26    return ret
27
28
29def test_basic():
30    n = te.size_var("n")
31    A = te.placeholder((n,), name="A")
32    B = te.placeholder((n,), name="B")
33
34    T = te.compute((n,), lambda i: A[i] + B[i])
35    s = te.create_schedule(T.op)
36    xo, xi = s[T].split(T.op.axis[0], factor=4)
37
38    bounds = tvm.te.schedule.InferBound(s)
39    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
40
41    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt))
42    mod = tvm.tir.transform.LoopPartition()(mod)
43    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
44
45    assert not any(collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse)))
46    assert any(collect_visit(stmt.body.body[1], lambda x: isinstance(x, tvm.tir.IfThenElse)))
47
48
49def test_const_loop():
50    n = 21
51    A = te.placeholder((n,), name="A")
52    B = te.placeholder((n,), name="B")
53
54    T = te.compute((n,), lambda i: A[i] + B[i])
55    s = te.create_schedule(T.op)
56    xo, xi = s[T].split(T.op.axis[0], factor=4)
57
58    bounds = tvm.te.schedule.InferBound(s)
59    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
60
61    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
62    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
63        mod = tvm.tir.transform.LoopPartition()(mod)
64        stmt = tvm.tir.transform.Simplify()(mod)["main"].body
65
66    assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse)))
67
68
69def test_multi_loop():
70    ib = tvm.tir.ir_builder.create()
71    m = te.size_var("m")
72    n = te.size_var("n")
73    with ib.for_range(0, 4, "i") as i:
74        with ib.for_range(0, n, "j") as j:
75            with ib.for_range(0, m, "k") as k:
76                with ib.if_scope(ib.likely(i * m + j + k < n)):
77                    ib.emit(tvm.tir.Evaluate(m))
78                with ib.else_scope():
79                    ib.emit(tvm.tir.Evaluate(n))
80    stmt = ib.get()
81
82    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n, m], stmt))
83    mod = tvm.tir.transform.LoopPartition()(mod)
84    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
85
86    assert not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse)))
87
88
89def test_multi_if():
90    ib = tvm.tir.ir_builder.create()
91    m = te.size_var("m")
92    n = te.size_var("n")
93    with ib.for_range(0, 4, "i") as i:
94        with ib.for_range(0, n, "j") as j:
95            with ib.for_range(0, m, "k") as k:
96                with ib.if_scope(ib.likely(i * m + j + k < n)):
97                    ib.emit(tvm.tir.Evaluate(m))
98                with ib.else_scope():
99                    ib.emit(tvm.tir.Evaluate(n))
100                with ib.if_scope(ib.likely(i * m + j - k < n)):
101                    ib.emit(tvm.tir.Evaluate(m))
102                with ib.else_scope():
103                    ib.emit(tvm.tir.Evaluate(n))
104    stmt = ib.get()
105
106    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
107    mod = tvm.tir.transform.LoopPartition()(mod)
108    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
109
110    assert not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse)))
111
112
113def test_thread_axis():
114    m = te.size_var("m")
115    l = te.size_var("l")
116    A = te.placeholder((m, l), name="A")
117    B = te.compute((m, l), lambda i, j: A[i, j] + 3, name="B")
118    s = te.create_schedule(B.op)
119
120    s[B].set_scope("shared")
121    num_thread = 16
122    xo, xi = s[B].split(B.op.axis[0], 32)
123    xi0, xi1 = s[B].split(xi, nparts=num_thread)
124    s[B].bind(xi0, te.thread_axis("threadIdx.x"))
125
126    bounds = tvm.te.schedule.InferBound(s)
127    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
128
129    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
130    mod = tvm.tir.transform.LoopPartition()(mod)
131    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
132
133    assert not any(collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse)))
134
135
136def test_vectorize():
137    n = te.size_var("n")
138    A = te.placeholder((n,), name="A")
139    B = te.placeholder((n,), name="B")
140    bias = te.size_var("bias", dtype="float32")
141    scale = te.size_var("scale", dtype="float32")
142    C = te.compute(A.shape, lambda *i: A(*i) + B(*i) * scale + bias, name="C")
143    # schedule
144    s = te.create_schedule(C.op)
145    # create iter var and assign them tags.
146    num_thread = 32
147    bx, x = s[C].split(C.op.axis[0], factor=num_thread * 4)
148    tx, x = s[C].split(x, nparts=num_thread)
149    _, x = s[C].split(x, factor=4)
150    s[C].bind(bx, te.thread_axis("blockIdx.x"))
151    s[C].bind(tx, te.thread_axis("threadIdx.x"))
152    s[C].vectorize(x)
153    stmt = tvm.lower(s, [A, B], name="main")["main"].body
154    body = stmt.body.body.body.body
155    assert x.var.name not in str(body.condition)
156    assert any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp)))
157
158
159def test_condition():
160    ib = tvm.tir.ir_builder.create()
161    m = te.size_var("m")
162    n = te.size_var("n")
163    with ib.for_range(0, tvm.tir.truncdiv(n + 3, 4), "i") as i:
164        with ib.for_range(0, 4, "j") as j:
165            ib.emit(tvm.tir.Evaluate(tvm.tir.Select(ib.likely(i * 4 + j < n), m, n)))
166    stmt = ib.get()
167
168    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt))
169    mod = tvm.tir.transform.LoopPartition()(mod)
170    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
171
172    assert not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select)))
173
174
175def test_condition_EQ():
176    ib = tvm.tir.ir_builder.create()
177    m = te.size_var("m")
178    n = te.size_var("n")
179    with ib.for_range(0, 10, "i") as i:
180        ib.emit(tvm.tir.Evaluate(tvm.tir.Select(ib.likely(tvm.tir.EQ(i, 5)), m, n)))
181    stmt = ib.get()
182
183    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt))
184    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
185        mod = tvm.tir.transform.LoopPartition()(mod)
186        stmt = tvm.tir.transform.Simplify()(mod)["main"].body
187
188    assert not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select)))
189
190
191def test_thread_axis2():
192    n = tvm.runtime.convert(4096)
193    m = te.size_var("m")
194    A = te.placeholder((n,), name="A")
195    B = te.placeholder((n,), name="B")
196    C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
197    s = te.create_schedule(C.op)
198    num_thread = 32
199    bx, x = s[C].split(C.op.axis[0], factor=32)
200    tx, x = s[C].split(x, nparts=num_thread)
201    _, x = s[C].split(x, factor=m)
202    s[C].bind(bx, te.thread_axis("blockIdx.x"))
203    s[C].bind(tx, te.thread_axis("threadIdx.x"))
204    stmt = tvm.lower(s, [A, B], name="main")["main"].body
205    for_body = stmt.body.body.body.body[0]
206    assert "threadIdx" not in str(for_body.extent)
207
208
209def test_everything_during_deduction():
210    m = te.size_var("m")
211    n = te.size_var("n")
212    ib = tvm.tir.ir_builder.create()
213    with ib.for_range(0, n, "i") as i:
214        with ib.for_range(0, 32, "j") as j:
215            with ib.if_scope(ib.likely(tvm.tir.truncdiv(i, j) < m)):
216                # this guard will produce everything during deduction
217                ib.emit(tvm.tir.Evaluate(m))
218    stmt = ib.get()
219
220    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt))
221    mod = tvm.tir.transform.LoopPartition()(mod)
222    stmt = tvm.tir.transform.Simplify()(mod)["main"].body
223
224    assert isinstance(stmt.body.body, tvm.tir.IfThenElse)
225
226
227def test_single_likely():
228    n = 60
229    A = te.placeholder((n,), name="A")
230    B = te.placeholder((n,), name="B")
231
232    T = te.compute((n,), lambda i: A[i] + B[i])
233    s = te.create_schedule(T.op)
234    x = T.op.axis[0]
235    xo, xi = s[T].split(x, factor=16)
236
237    bounds = tvm.te.schedule.InferBound(s)
238    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
239
240    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
241
242    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
243        mod = tvm.tir.transform.LoopPartition()(mod)
244        stmt = tvm.tir.transform.Simplify()(mod)["main"].body
245
246    assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse)))
247
248
249def test_multi_likely():
250    n = 94
251    m = 62
252    A = te.placeholder((n, m), name="A")
253    B = te.placeholder((n, m), name="B")
254
255    T = te.compute((n, m), lambda i, j: A[i, j] + B[i, j])
256    s = te.create_schedule(T.op)
257    bounds = tvm.te.schedule.InferBound(s)
258    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
259    x, y = T.op.axis
260    xo, xi = s[T].split(x, factor=16)
261    yo, yi = s[T].split(y, factor=16)
262    s[T].reorder(xo, yo, xi, yi)
263
264    bounds = tvm.te.schedule.InferBound(s)
265    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
266
267    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
268
269    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
270        mod = tvm.tir.transform.LoopPartition()(mod)
271        stmt = tvm.tir.transform.Simplify()(mod)["main"].body
272
273    assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse)))
274
275
276def test_oneD_pool():
277    m = te.size_var("m")
278    ib = tvm.tir.ir_builder.create()
279    # data = te.placeholder((16,), name = 'data')
280    data = ib.pointer("float32", name="A")
281    out = ib.pointer("float32", name="A")
282    with ib.for_range(0, 16, "ow") as ow:
283        with ib.for_range(0, 3, "kw") as kw:
284            with ib.if_scope(ib.likely(ow > 0)):
285                with ib.if_scope(ib.likely(ow < 15)):
286                    out[ow] = tvm.te.max(out[ow], data[ow + kw - 1])
287    with ib.for_range(0, 16, "ow") as ow:
288        with ib.for_range(0, 3, "kw") as kw:
289            with ib.if_scope(ib.likely(ow < 1)):
290                with ib.if_scope(ib.likely(kw > 0)):
291                    out[ow] = tvm.te.max(out[ow], data[ow + kw - 1])
292    with ib.for_range(0, 16, "ow") as ow:
293        with ib.for_range(0, 3, "kw") as kw:
294            with ib.if_scope(ib.likely(ow > 14)):
295                with ib.if_scope(ib.likely(kw < 2)):
296                    out[ow] = tvm.te.max(out[ow], data[ow + kw - 1])
297
298    stmt = ib.get()
299
300    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, data, out], stmt))
301
302    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
303        mod = tvm.tir.transform.LoopPartition()(mod)
304        stmt = tvm.tir.transform.Simplify()(mod)["main"].body
305
306    assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse)))
307
308
309def test_cce_loop_1():
310    ib = tvm.tir.ir_builder.create()
311    dtype = "float16"
312    n = 514
313    m = 514
314    _A = te.placeholder((n * m,), name="A")
315    Ab = tvm.tir.decl_buffer((n * m,), dtype, name="A")
316    A = ib.buffer_ptr(Ab)
317    _B = te.placeholder((n * m,), name="B")
318    Bb = tvm.tir.decl_buffer((n * m,), dtype, name="B")
319    B = ib.buffer_ptr(Bb)
320    # for i in 0 to n-1:
321    with ib.for_range(0, 11, name="i") as i:
322        with ib.for_range(0, 160, name="j") as j:
323            with ib.if_scope(ib.likely(((i * 160) + j) < 1600)):
324                A[(i + 1) * m + j + 1] = (
325                    B[(i) * m + j + 1] + B[(i + 1) * m + j + 1] + B[(i + 2) * m + j + 1]
326                )
327    stmt = ib.get()
328
329    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
330    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
331        mod = tvm.tir.transform.LoopPartition()(mod)
332        stmt = tvm.tir.transform.Simplify()(mod)["main"].body
333
334    assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse)))
335
336
337def test_cce_loop_2():
338    ib = tvm.tir.ir_builder.create()
339    len = 112
340    tile = 32
341    loop = (len + tile - 1) // tile
342    with ib.for_range(0, loop, "i") as i:
343        head = i * tile
344        with ib.if_scope(ib.likely(head + tile > len)):
345            tail = len
346            ib.emit(tvm.tir.call_extern("float32", "cce_intrisic", head, tail))
347        with ib.else_scope():
348            tail = head + tile
349            ib.emit(tvm.tir.call_extern("float32", "cce_intrisic", head, tail))
350
351    stmt = ib.get()
352
353    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
354    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
355        mod = tvm.tir.transform.LoopPartition()(mod)
356        stmt = tvm.tir.transform.Simplify()(mod)["main"].body
357
358    assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse)))
359
360
361def test_cce_loop_3():
362    ib = tvm.tir.ir_builder.create()
363    loop1 = 4
364    loop2 = 9998
365    tile = 39991
366    with ib.for_range(0, loop2, "i") as i:
367        with ib.for_range(0, loop1, "j") as j:
368            head1 = i
369            head2 = j
370            with ib.if_scope(ib.likely(head1 * loop1 + head2 < tile)):
371                ib.emit(tvm.tir.call_extern("float16", "cce_intrisic", head1))
372
373    stmt = ib.get()
374    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
375
376    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
377        mod = tvm.tir.transform.LoopPartition()(mod)
378        stmt = tvm.tir.transform.Simplify()(mod)["main"].body
379
380    assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse)))
381
382
383def test_conv_tiling():
384    HSTR = WSTR = 1
385    in_channel = 128
386    kernel_height = kernel_width = 3
387    out_channel = 64
388    batch_size = 1
389    in_height = in_width = 64
390    out_height = out_width = in_height - kernel_height + 1
391    data = te.placeholder((batch_size, in_channel, in_height, in_width), name="data")
392    kernel = te.placeholder((kernel_height, kernel_width, in_channel, out_channel), name="kernel")
393    ic = te.reduce_axis((0, in_channel), name="ic")
394    kh = te.reduce_axis((0, kernel_height), name="kh")
395    kw = te.reduce_axis((0, kernel_width), name="kw")
396    conv = te.compute(
397        (batch_size, out_channel, out_height, out_width),
398        lambda n, oc, oh, ow: te.sum(
399            data[n, ic, oh * HSTR + kh, ow * WSTR + kw] * kernel[kh, kw, ic, oc], axis=[ic, kh, kw]
400        ),
401        name="conv2d",
402    )
403    s = te.create_schedule(conv.op)
404
405    n, oc, oh, ow = conv.op.axis
406    oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
407    bounds = tvm.te.schedule.InferBound(s)
408    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
409
410    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
411    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
412        mod = tvm.tir.transform.LoopPartition()(mod)
413        stmt = tvm.tir.transform.Simplify()(mod)["main"].body
414
415    assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse)))
416
417
418def test_multilevel_splitting_with_indivisble_factors():
419    from tvm import topi
420
421    A = te.placeholder((130,), dtype="float32")
422    B = topi.nn.relu(A)
423    s = te.create_schedule(B.op)
424    (y,) = s[B].op.axis
425    (yo, yi) = s[B].split(y, factor=8)
426    (yoo, yoi) = s[B].split(yo, factor=16)
427    s[B].reorder(yoo, yoi, yi)
428    s[B].unroll(yi)
429
430    ## But this does the right thing.
431    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
432        lowered_body = tvm.lower(s, [A, B], name="x")["x"].body
433
434        def visit_stmt(op):
435            return isinstance(op, tvm.tir.Max)
436
437        num_max = collect_visit(lowered_body, visit_stmt)
438        assert num_max.count(True) == 10
439
440
441def test_double_splitting_with_indivisible_factors():
442    m = 48
443    dtype = "float32"
444    A = te.placeholder((m,), name="A", dtype=dtype)
445    C = te.compute((m,), lambda i: A[i], name="C")
446    D = te.compute((m,), lambda i: C[i], name="D")
447
448    s = te.create_schedule(D.op)
449    co, ci = s[C].split(C.op.axis[0], factor=10)
450    do, di = s[D].split(D.op.axis[0], 32)
451    s[C].compute_at(s[D], do)
452
453    target = "llvm"
454    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
455        f = tvm.lower(s, [A, C, D], name="fadd1", simple_mode=False)
456        func = tvm.build(f, target=target)
457
458    top_produce = f["fadd1"].body
459    assert not any(collect_visit(top_produce, lambda x: isinstance(x, tvm.tir.IfThenElse)))
460
461    # check functional correctness of generated code
462    ctx = tvm.context(target, 0)
463    a = tvm.nd.array(
464        numpy.ones(
465            m,
466        ).astype(dtype),
467        ctx,
468    )
469    c = tvm.nd.array(
470        numpy.zeros(
471            m,
472        ).astype(dtype),
473        ctx,
474    )
475    d = tvm.nd.array(
476        numpy.zeros(
477            m,
478        ).astype(dtype),
479        ctx,
480    )
481    func(a, c, d)
482    tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy(), rtol=1e-5)
483    tvm.testing.assert_allclose(d.asnumpy(), a.asnumpy(), rtol=1e-5)
484
485
486def test_simple_rfactor():
487    K = 16 * 4 + 4
488    k = te.reduce_axis((0, K), "k")
489
490    A = te.placeholder((1, K), name="A")
491
492    B = te.compute((1,), lambda b: te.sum(A[b, k], axis=k), name="B")
493
494    s = te.create_schedule(B.op)
495    ko, _ = s[B].split(s[B].op.reduce_axis[0], 16)
496    BF = s.rfactor(B, ko, 0)
497
498    s.normalize()
499    bounds = tvm.te.schedule.InferBound(s)
500    stmt1 = tvm.te.schedule.ScheduleOps(s, bounds)
501
502    mod1 = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt1))
503    stmt1 = tvm.tir.transform.Simplify()(mod1)["main"].body
504
505    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
506        mod2 = tvm.tir.transform.LoopPartition()(mod1)
507        stmt2 = tvm.tir.transform.Simplify()(mod2)["main"].body
508
509    # make sure loop partition actually did something
510    assert not tvm.ir.structural_equal(stmt1.body, stmt2.body)
511
512
513if __name__ == "__main__":
514    test_basic()
515    test_const_loop()
516    test_multi_loop()
517    test_multi_if()
518    test_thread_axis()
519    test_vectorize()
520    test_condition()
521    test_condition_EQ()
522    test_thread_axis2()
523    test_everything_during_deduction()
524    test_single_likely()
525    test_multi_likely()
526    test_oneD_pool()
527    test_cce_loop_1()
528    test_cce_loop_2()
529    test_cce_loop_3()
530    test_conv_tiling()
531    test_double_splitting_with_indivisible_factors()
532    test_multilevel_splitting_with_indivisble_factors()
533    test_simple_rfactor()
534