1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17import tvm
18from tvm import te
19
20
21def test_bound_tile_mod():
22    def compute(M_tiles, N_tiles, factor, dtype):
23        # Algo
24        M = M_tiles * factor
25        N = N_tiles * factor
26
27        A = tvm.te.placeholder((N, M), name="A", dtype=dtype)
28        C = tvm.te.compute((N, M), lambda n, m: A[n, m], name="C")
29        s = tvm.te.create_schedule(C.op)
30
31        return s, A, C
32
33    def schedule(s, factor, padding, A, C):
34        C_local = s.cache_write(C, "local")
35
36        n, m = C.op.axis
37        bn, bm, ni, mi = s[C].tile(n, m, factor, factor)
38        nio, nii = s[C].split(ni, 2)
39        n = s[C].fuse(nii, mi)
40        C_shared = s.cache_write(C, "shared")
41        bn, bm, ni, mi = C_shared.op.axis
42        s[C_shared].storage_align(ni, factor * 2, padding)
43
44        n, m = s[C].op.axis
45        bn, bm, ni, mi = s[C].tile(n, m, factor, factor)
46        s[C].set_scope("global")
47        niio, niii = s[C].split(ni, 32)
48        s[C_shared].compute_at(s[C], niio)
49
50        return s
51
52    s, A, C = compute(2, 2, 128, "float16")
53    s = schedule(s, 128, 8, A, C)
54    bounds = tvm.te.schedule.InferBound(s)
55    check = bounds[s.stages[2].op.axis[2]].extent == 16
56    if not check:
57        print(tvm.lower(s, [A, C], simple_mode=True))
58    assert check
59
60
61if __name__ == "__main__":
62    test_bound_tile_mod()
63