1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17import tvm
18
19def test_copy2d():
20    m = tvm.var('m')
21    l = tvm.var('l')
22    A = tvm.placeholder((m, l), name='A')
23    B = tvm.compute((m, l), lambda i, j: A[i, j], name='B')
24    s = tvm.create_schedule(B.op)
25    s[B].pragma(B.op.axis[0], "memcpy")
26    bounds = tvm.schedule.InferBound(s)
27    stmt = tvm.schedule.ScheduleOps(s, bounds)
28    Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
29    Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
30    stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
31    def cb(src, dst, pad_before, pad_after, pad_value):
32        assert dst.strides[0] == l
33        assert dst.strides[1].value == 1
34        assert src.strides[0] == l
35        assert tuple(src.shape) == (m, l)
36        return tvm.make.Evaluate(0)
37    stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
38
39def test_copy_pad():
40    m = tvm.var('m')
41    l = tvm.var('l')
42    A = tvm.placeholder((m, l), name='A')
43    B = tvm.compute((m + 2, l), lambda i, j:
44                    tvm.if_then_else(tvm.all(i >= 1, i < m + 1),
45                                     A[i - 1, j], 1.0), name='B')
46    s = tvm.create_schedule(B.op)
47    s[B].pragma(B.op.axis[0], "memcpy")
48    bounds = tvm.schedule.InferBound(s)
49    stmt = tvm.schedule.ScheduleOps(s, bounds)
50    Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
51    Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
52    stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
53    def cb(src, dst, pad_before, pad_after, pad_value):
54        assert tvm.ir_pass.Simplify(src.elem_offset).value == 0
55        assert pad_before[0].value == 1
56        assert pad_before[1].value == 0
57        assert pad_after[0].value == 1
58        assert pad_after[1].value == 0
59        assert pad_value.value == 1.0
60        return tvm.make.Evaluate(0)
61    stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
62
63def test_single_point_test():
64    A = tvm.placeholder((1,), name='A')
65    B = tvm.compute((1,), lambda i:
66                    A[i], name='B')
67    s = tvm.create_schedule(B.op)
68    s[B].pragma(B.op.axis[0], "memcpy")
69    bounds = tvm.schedule.InferBound(s)
70    stmt = tvm.schedule.ScheduleOps(s, bounds)
71    Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
72    Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
73    stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
74    def cb(src, dst, pad_before, pad_after, pad_value):
75        assert tvm.ir_pass.Simplify(src.elem_offset).value == 0
76        assert tvm.ir_pass.Simplify(dst.elem_offset).value == 0
77        assert tvm.ir_pass.Simplify(src.strides[0]).value == 1
78        assert tvm.ir_pass.Simplify(dst.strides[0]).value == 1
79        return tvm.make.Evaluate(0)
80    stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
81
82def assert_expr_equal(a, b):
83    assert tvm.ir_pass.Simplify(a - b).value == 0
84
85def test_copy_pad_split():
86    m = 4 * 3
87    A = tvm.placeholder((m, ), name="A")
88    Apad = tvm.compute((m + 2,), lambda i:
89                       tvm.if_then_else(tvm.all(i >= 1, i <= m),
90                                        A[i - 1], 0.0), "Apad")
91    B = tvm.compute((m,), lambda i: Apad[i] + Apad[i + 1] + Apad[i + 2])
92    s = tvm.create_schedule(B.op)
93    xo, xi = s[B].split(B.op.axis[0], factor=4)
94    s[Apad].compute_at(s[B], xo)
95    s[Apad].pragma(s[Apad].op.axis[0], "memcpy")
96    bounds = tvm.schedule.InferBound(s)
97    stmt = tvm.schedule.ScheduleOps(s, bounds)
98    Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
99    Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
100    stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
101    stmt = tvm.ir_pass.Simplify(stmt)
102    stmt = tvm.ir_pass.CanonicalSimplify(stmt)
103    def cb(src, dst, pad_before, pad_after, pad_value):
104        assert(dst.elem_offset.value == 0)
105        assert_expr_equal(src.elem_offset, tvm.max(xo * 4, 1) - 1)
106
107        rpad_before = tvm.max(1 - xo * 4, 0)
108        rpad_after = tvm.max(xo * 4 - 7, 0)
109        assert_expr_equal(pad_before[0], rpad_before)
110        assert_expr_equal(pad_after[0], rpad_after)
111        assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after)
112        return tvm.make.Evaluate(0)
113    stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
114
115
116if __name__ == "__main__":
117    test_copy2d()
118    test_copy_pad()
119    test_copy_pad_split()
120    test_single_point_test()
121