1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17import tvm 18import tvm.testing 19from tvm import te 20 21 22def test_copy2d(): 23 m = te.var("m") 24 l = te.var("l") 25 A = te.placeholder((m, l), name="A") 26 B = te.compute((m, l), lambda i, j: A[i, j], name="B") 27 s = te.create_schedule(B.op) 28 s[B].pragma(B.op.axis[0], "memcpy") 29 bounds = tvm.te.schedule.InferBound(s) 30 stmt = tvm.te.schedule.ScheduleOps(s, bounds) 31 func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) 32 mod = tvm.IRModule.from_expr(func) 33 mod = tvm.tir.transform.StorageFlatten(64)(mod) 34 35 def cb(src, dst, pad_before, pad_after, pad_value): 36 assert dst.strides[0] == l 37 assert dst.strides[1].value == 1 38 assert src.strides[0] == l 39 assert tuple(src.shape) == (m, l) 40 return tvm.tir.Evaluate(0) 41 42 stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body 43 44 45def test_copy_pad(): 46 m = te.var("m") 47 l = te.var("l") 48 A = te.placeholder((m, l), name="A") 49 B = te.compute( 50 (m + 2, l), 51 lambda i, j: tvm.tir.if_then_else(tvm.tir.all(i >= 1, i < m + 1), A[i - 1, j], 1.0), 52 name="B", 53 ) 54 s = te.create_schedule(B.op) 55 s[B].pragma(B.op.axis[0], "memcpy") 56 bounds = tvm.te.schedule.InferBound(s) 57 stmt = tvm.te.schedule.ScheduleOps(s, bounds) 58 59 func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) 60 mod = tvm.IRModule.from_expr(func) 61 mod = tvm.tir.transform.StorageFlatten(64)(mod) 62 63 def cb(src, dst, pad_before, pad_after, pad_value): 64 tvm.testing.assert_prim_expr_equal(src.elem_offset, 0) 65 assert pad_before[0].value == 1 66 assert pad_before[1].value == 0 67 assert pad_after[0].value == 1 68 assert pad_after[1].value == 0 69 assert pad_value.value == 1.0 70 return tvm.tir.Evaluate(0) 71 72 stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body 73 74 75def test_single_point_test(): 76 A = te.placeholder((1,), name="A") 77 B = te.compute((1,), lambda i: A[i], name="B") 78 s = te.create_schedule(B.op) 79 s[B].pragma(B.op.axis[0], "memcpy") 80 bounds = tvm.te.schedule.InferBound(s) 81 stmt = tvm.te.schedule.ScheduleOps(s, bounds) 82 83 func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) 84 mod = tvm.IRModule.from_expr(func) 85 mod = tvm.tir.transform.StorageFlatten(64)(mod) 86 87 def cb(src, dst, pad_before, pad_after, pad_value): 88 tvm.testing.assert_prim_expr_equal(src.elem_offset, 0) 89 tvm.testing.assert_prim_expr_equal(dst.elem_offset, 0) 90 tvm.testing.assert_prim_expr_equal(src.strides[0], 1) 91 tvm.testing.assert_prim_expr_equal(dst.strides[0], 1) 92 return tvm.tir.Evaluate(0) 93 94 stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body 95 96 97def test_copy_pad_split(): 98 m = 4 * 3 99 A = te.placeholder((m,), name="A") 100 Apad = te.compute( 101 (m + 2,), lambda i: tvm.tir.if_then_else(tvm.tir.all(i >= 1, i <= m), A[i - 1], 0.0), "Apad" 102 ) 103 B = te.compute((m,), lambda i: Apad[i] + Apad[i + 1] + Apad[i + 2]) 104 s = te.create_schedule(B.op) 105 xo, xi = s[B].split(B.op.axis[0], factor=4) 106 s[Apad].compute_at(s[B], xo) 107 s[Apad].pragma(s[Apad].op.axis[0], "memcpy") 108 bounds = tvm.te.schedule.InferBound(s) 109 stmt = tvm.te.schedule.ScheduleOps(s, bounds) 110 111 func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) 112 mod = tvm.IRModule.from_expr(func) 113 mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) 114 mod = tvm.tir.transform.Simplify()(mod._move()) 115 116 def cb(src, dst, pad_before, pad_after, pad_value): 117 assert dst.elem_offset.value == 0 118 tvm.testing.assert_prim_expr_equal(src.elem_offset, tvm.te.max(xo * 4, 1) - 1) 119 120 rpad_before = tvm.te.max(1 - xo * 4, 0) 121 rpad_after = tvm.te.max(xo * 4 - 7, 0) 122 tvm.testing.assert_prim_expr_equal(pad_before[0], rpad_before) 123 tvm.testing.assert_prim_expr_equal(pad_after[0], rpad_after) 124 tvm.testing.assert_prim_expr_equal(src.shape[0], 6 - rpad_before - rpad_after) 125 return tvm.tir.Evaluate(0) 126 127 stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body 128 129 130if __name__ == "__main__": 131 test_copy2d() 132 test_copy_pad() 133 test_copy_pad_split() 134 test_single_point_test() 135