1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17import tvm 18 19def test_copy2d(): 20 m = tvm.var('m') 21 l = tvm.var('l') 22 A = tvm.placeholder((m, l), name='A') 23 B = tvm.compute((m, l), lambda i, j: A[i, j], name='B') 24 s = tvm.create_schedule(B.op) 25 s[B].pragma(B.op.axis[0], "memcpy") 26 bounds = tvm.schedule.InferBound(s) 27 stmt = tvm.schedule.ScheduleOps(s, bounds) 28 Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') 29 Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') 30 stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) 31 def cb(src, dst, pad_before, pad_after, pad_value): 32 assert dst.strides[0] == l 33 assert dst.strides[1].value == 1 34 assert src.strides[0] == l 35 assert tuple(src.shape) == (m, l) 36 return tvm.make.Evaluate(0) 37 stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) 38 39def test_copy_pad(): 40 m = tvm.var('m') 41 l = tvm.var('l') 42 A = tvm.placeholder((m, l), name='A') 43 B = tvm.compute((m + 2, l), lambda i, j: 44 tvm.if_then_else(tvm.all(i >= 1, i < m + 1), 45 A[i - 1, j], 1.0), name='B') 46 s = tvm.create_schedule(B.op) 47 s[B].pragma(B.op.axis[0], "memcpy") 48 bounds = tvm.schedule.InferBound(s) 49 stmt = tvm.schedule.ScheduleOps(s, bounds) 50 Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') 51 Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') 52 stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) 53 def cb(src, dst, pad_before, pad_after, pad_value): 54 assert tvm.ir_pass.Simplify(src.elem_offset).value == 0 55 assert pad_before[0].value == 1 56 assert pad_before[1].value == 0 57 assert pad_after[0].value == 1 58 assert pad_after[1].value == 0 59 assert pad_value.value == 1.0 60 return tvm.make.Evaluate(0) 61 stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) 62 63def test_single_point_test(): 64 A = tvm.placeholder((1,), name='A') 65 B = tvm.compute((1,), lambda i: 66 A[i], name='B') 67 s = tvm.create_schedule(B.op) 68 s[B].pragma(B.op.axis[0], "memcpy") 69 bounds = tvm.schedule.InferBound(s) 70 stmt = tvm.schedule.ScheduleOps(s, bounds) 71 Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') 72 Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') 73 stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) 74 def cb(src, dst, pad_before, pad_after, pad_value): 75 assert tvm.ir_pass.Simplify(src.elem_offset).value == 0 76 assert tvm.ir_pass.Simplify(dst.elem_offset).value == 0 77 assert tvm.ir_pass.Simplify(src.strides[0]).value == 1 78 assert tvm.ir_pass.Simplify(dst.strides[0]).value == 1 79 return tvm.make.Evaluate(0) 80 stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) 81 82def assert_expr_equal(a, b): 83 assert tvm.ir_pass.Simplify(a - b).value == 0 84 85def test_copy_pad_split(): 86 m = 4 * 3 87 A = tvm.placeholder((m, ), name="A") 88 Apad = tvm.compute((m + 2,), lambda i: 89 tvm.if_then_else(tvm.all(i >= 1, i <= m), 90 A[i - 1], 0.0), "Apad") 91 B = tvm.compute((m,), lambda i: Apad[i] + Apad[i + 1] + Apad[i + 2]) 92 s = tvm.create_schedule(B.op) 93 xo, xi = s[B].split(B.op.axis[0], factor=4) 94 s[Apad].compute_at(s[B], xo) 95 s[Apad].pragma(s[Apad].op.axis[0], "memcpy") 96 bounds = tvm.schedule.InferBound(s) 97 stmt = tvm.schedule.ScheduleOps(s, bounds) 98 Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') 99 Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') 100 stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) 101 stmt = tvm.ir_pass.Simplify(stmt) 102 stmt = tvm.ir_pass.CanonicalSimplify(stmt) 103 def cb(src, dst, pad_before, pad_after, pad_value): 104 assert(dst.elem_offset.value == 0) 105 assert_expr_equal(src.elem_offset, tvm.max(xo * 4, 1) - 1) 106 107 rpad_before = tvm.max(1 - xo * 4, 0) 108 rpad_after = tvm.max(xo * 4 - 7, 0) 109 assert_expr_equal(pad_before[0], rpad_before) 110 assert_expr_equal(pad_after[0], rpad_after) 111 assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after) 112 return tvm.make.Evaluate(0) 113 stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) 114 115 116if __name__ == "__main__": 117 test_copy2d() 118 test_copy_pad() 119 test_copy_pad_split() 120 test_single_point_test() 121