1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Test group effect"""
18import tvm
19from tvm import te
20
21
22def test_scan_group():
23    m = te.size_var("m")
24    n = te.size_var("n")
25    x = te.compute((m, n), lambda i, j: tvm.tir.const(1, "float32"), name="x")
26    s_state = te.placeholder((m, n))
27    s_init = te.compute((1, n), lambda _, i: x[0, i])
28
29    s_update1 = te.compute((m, n), lambda t, i: s_state[t - 1, i] + x[t, i])
30    s_update2 = te.compute((m, n), lambda t, i: s_update1[t, i] + 1)
31    s_update3 = te.compute((m, n), lambda t, i: s_update2[t, i] + 1)
32    res = tvm.te.scan(s_init, s_update3, s_state, inputs=x)
33
34    s = te.create_schedule(res.op)
35    assert s[s_update1].group is not None
36    assert s[s_update2].group == s[s_update1].group
37    # Assign within group, is valid
38    s[s_update1].compute_at(s[s_update2], s_update2.op.axis[1])
39    # create a new group, for [s_update2 and s_update1]
40    g2 = s.create_group(outputs=s_update2, inputs=[s_state, x])
41    assert g2.group is not None
42    assert g2.group == s[s_update3].group
43    assert s[s_update2].group == g2
44    assert s[s_update1].group == g2
45    g2.compute_at(s[s_update3], s_update3.op.axis[1])
46    assert g2.attach_stage == s[s_update3]
47    try:
48        # compute outside group error.
49        s[s_update2].compute_at(s[s_init], s_init.op.axis[0])
50        assert False
51    except tvm.error.TVMError:
52        pass
53
54
55def test_compute_group():
56    m = te.size_var("m")
57    n = te.size_var("n")
58    x = te.compute((m, n), lambda i, j: tvm.tir.const(1, "float32"), name="x")
59    x1 = te.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
60    x2 = te.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
61    s = te.create_schedule(x2.op)
62    g = s.create_group(outputs=x1, inputs=x, include_inputs=True)
63    assert s[x1].group == g
64    assert s[x].group == g
65    g.compute_at(s[x2], x2.op.axis[1])
66    assert g.attach_stage == s[x2]
67    assert g.num_child_stages == 2
68
69
70def test_nest_group():
71    m = te.size_var("m")
72    n = te.size_var("n")
73    x = te.compute((m, n), lambda i, j: tvm.tir.const(1, "float32"), name="x")
74    x1 = te.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
75    x2 = te.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
76    s = te.create_schedule(x2.op)
77    g1 = s.create_group(outputs=x1, inputs=x)
78    g2 = s.create_group(outputs=x1, inputs=x, include_inputs=True)
79    assert set(s.groups) == set([g1, g2])
80    assert s[x].group == g2
81    assert s[x1].group == g1
82    assert g1.group == g2
83    assert g2.num_child_stages == 2
84    assert g1.num_child_stages == 1
85
86
87if __name__ == "__main__":
88    test_nest_group()
89    test_compute_group()
90    test_scan_group()
91