1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Test group effect""" 18import tvm 19from tvm import te 20 21 22def test_scan_group(): 23 m = te.size_var("m") 24 n = te.size_var("n") 25 x = te.compute((m, n), lambda i, j: tvm.tir.const(1, "float32"), name="x") 26 s_state = te.placeholder((m, n)) 27 s_init = te.compute((1, n), lambda _, i: x[0, i]) 28 29 s_update1 = te.compute((m, n), lambda t, i: s_state[t - 1, i] + x[t, i]) 30 s_update2 = te.compute((m, n), lambda t, i: s_update1[t, i] + 1) 31 s_update3 = te.compute((m, n), lambda t, i: s_update2[t, i] + 1) 32 res = tvm.te.scan(s_init, s_update3, s_state, inputs=x) 33 34 s = te.create_schedule(res.op) 35 assert s[s_update1].group is not None 36 assert s[s_update2].group == s[s_update1].group 37 # Assign within group, is valid 38 s[s_update1].compute_at(s[s_update2], s_update2.op.axis[1]) 39 # create a new group, for [s_update2 and s_update1] 40 g2 = s.create_group(outputs=s_update2, inputs=[s_state, x]) 41 assert g2.group is not None 42 assert g2.group == s[s_update3].group 43 assert s[s_update2].group == g2 44 assert s[s_update1].group == g2 45 g2.compute_at(s[s_update3], s_update3.op.axis[1]) 46 assert g2.attach_stage == s[s_update3] 47 try: 48 # compute outside group error. 49 s[s_update2].compute_at(s[s_init], s_init.op.axis[0]) 50 assert False 51 except tvm.error.TVMError: 52 pass 53 54 55def test_compute_group(): 56 m = te.size_var("m") 57 n = te.size_var("n") 58 x = te.compute((m, n), lambda i, j: tvm.tir.const(1, "float32"), name="x") 59 x1 = te.compute(x.shape, lambda *i: x(*i) + 1, name="x1") 60 x2 = te.compute(x.shape, lambda *i: x1(*i) + 2, name="x2") 61 s = te.create_schedule(x2.op) 62 g = s.create_group(outputs=x1, inputs=x, include_inputs=True) 63 assert s[x1].group == g 64 assert s[x].group == g 65 g.compute_at(s[x2], x2.op.axis[1]) 66 assert g.attach_stage == s[x2] 67 assert g.num_child_stages == 2 68 69 70def test_nest_group(): 71 m = te.size_var("m") 72 n = te.size_var("n") 73 x = te.compute((m, n), lambda i, j: tvm.tir.const(1, "float32"), name="x") 74 x1 = te.compute(x.shape, lambda *i: x(*i) + 1, name="x1") 75 x2 = te.compute(x.shape, lambda *i: x1(*i) + 2, name="x2") 76 s = te.create_schedule(x2.op) 77 g1 = s.create_group(outputs=x1, inputs=x) 78 g2 = s.create_group(outputs=x1, inputs=x, include_inputs=True) 79 assert set(s.groups) == set([g1, g2]) 80 assert s[x].group == g2 81 assert s[x1].group == g1 82 assert g1.group == g2 83 assert g2.num_child_stages == 2 84 assert g1.num_child_stages == 1 85 86 87if __name__ == "__main__": 88 test_nest_group() 89 test_compute_group() 90 test_scan_group() 91