1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17import tvm
18import tvm.testing
19from tvm import te
20
21
22def test_copy2d():
23    m = te.var("m")
24    l = te.var("l")
25    A = te.placeholder((m, l), name="A")
26    B = te.compute((m, l), lambda i, j: A[i, j], name="B")
27    s = te.create_schedule(B.op)
28    s[B].pragma(B.op.axis[0], "memcpy")
29    bounds = tvm.te.schedule.InferBound(s)
30    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
31    func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
32    mod = tvm.IRModule.from_expr(func)
33    mod = tvm.tir.transform.StorageFlatten(64)(mod)
34
35    def cb(src, dst, pad_before, pad_after, pad_value):
36        assert dst.strides[0] == l
37        assert dst.strides[1].value == 1
38        assert src.strides[0] == l
39        assert tuple(src.shape) == (m, l)
40        return tvm.tir.Evaluate(0)
41
42    stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
43
44
45def test_copy_pad():
46    m = te.var("m")
47    l = te.var("l")
48    A = te.placeholder((m, l), name="A")
49    B = te.compute(
50        (m + 2, l),
51        lambda i, j: tvm.tir.if_then_else(tvm.tir.all(i >= 1, i < m + 1), A[i - 1, j], 1.0),
52        name="B",
53    )
54    s = te.create_schedule(B.op)
55    s[B].pragma(B.op.axis[0], "memcpy")
56    bounds = tvm.te.schedule.InferBound(s)
57    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
58
59    func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
60    mod = tvm.IRModule.from_expr(func)
61    mod = tvm.tir.transform.StorageFlatten(64)(mod)
62
63    def cb(src, dst, pad_before, pad_after, pad_value):
64        tvm.testing.assert_prim_expr_equal(src.elem_offset, 0)
65        assert pad_before[0].value == 1
66        assert pad_before[1].value == 0
67        assert pad_after[0].value == 1
68        assert pad_after[1].value == 0
69        assert pad_value.value == 1.0
70        return tvm.tir.Evaluate(0)
71
72    stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
73
74
75def test_single_point_test():
76    A = te.placeholder((1,), name="A")
77    B = te.compute((1,), lambda i: A[i], name="B")
78    s = te.create_schedule(B.op)
79    s[B].pragma(B.op.axis[0], "memcpy")
80    bounds = tvm.te.schedule.InferBound(s)
81    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
82
83    func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
84    mod = tvm.IRModule.from_expr(func)
85    mod = tvm.tir.transform.StorageFlatten(64)(mod)
86
87    def cb(src, dst, pad_before, pad_after, pad_value):
88        tvm.testing.assert_prim_expr_equal(src.elem_offset, 0)
89        tvm.testing.assert_prim_expr_equal(dst.elem_offset, 0)
90        tvm.testing.assert_prim_expr_equal(src.strides[0], 1)
91        tvm.testing.assert_prim_expr_equal(dst.strides[0], 1)
92        return tvm.tir.Evaluate(0)
93
94    stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
95
96
97def test_copy_pad_split():
98    m = 4 * 3
99    A = te.placeholder((m,), name="A")
100    Apad = te.compute(
101        (m + 2,), lambda i: tvm.tir.if_then_else(tvm.tir.all(i >= 1, i <= m), A[i - 1], 0.0), "Apad"
102    )
103    B = te.compute((m,), lambda i: Apad[i] + Apad[i + 1] + Apad[i + 2])
104    s = te.create_schedule(B.op)
105    xo, xi = s[B].split(B.op.axis[0], factor=4)
106    s[Apad].compute_at(s[B], xo)
107    s[Apad].pragma(s[Apad].op.axis[0], "memcpy")
108    bounds = tvm.te.schedule.InferBound(s)
109    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
110
111    func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
112    mod = tvm.IRModule.from_expr(func)
113    mod = tvm.tir.transform.StorageFlatten(64)(mod._move())
114    mod = tvm.tir.transform.Simplify()(mod._move())
115
116    def cb(src, dst, pad_before, pad_after, pad_value):
117        assert dst.elem_offset.value == 0
118        tvm.testing.assert_prim_expr_equal(src.elem_offset, tvm.te.max(xo * 4, 1) - 1)
119
120        rpad_before = tvm.te.max(1 - xo * 4, 0)
121        rpad_after = tvm.te.max(xo * 4 - 7, 0)
122        tvm.testing.assert_prim_expr_equal(pad_before[0], rpad_before)
123        tvm.testing.assert_prim_expr_equal(pad_after[0], rpad_after)
124        tvm.testing.assert_prim_expr_equal(src.shape[0], 6 - rpad_before - rpad_after)
125        return tvm.tir.Evaluate(0)
126
127    stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
128
129
130if __name__ == "__main__":
131    test_copy2d()
132    test_copy_pad()
133    test_copy_pad_split()
134    test_single_point_test()
135