1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17import pytest
18import tvm
19from tvm import te
20from tvm import relay
21from tvm.relay import transform
22from tvm.relay.prelude import Prelude
23
24
25def test_remove_all_prelude_functions():
26    mod = tvm.IRModule()
27    p = Prelude(mod)
28    x = relay.var("x", shape=(1, 16))
29    mod["main"] = relay.Function([x], x)
30    mod = relay.transform.RemoveUnusedFunctions()(mod)
31    l = set([x[0].name_hint for x in mod.functions.items()])
32    assert l == set(["main"])
33
34
35def test_remove_all_prelude_functions_but_referenced_functions():
36    mod = tvm.IRModule()
37    p = Prelude(mod)
38    x = relay.var("x", shape=(1, 16))
39    id_func = relay.Function([x], x)
40    id_name = relay.GlobalVar("id_func")
41    mod[id_name] = id_func
42
43    mod["main"] = relay.Function([x], id_name(x))
44    mod = relay.transform.RemoveUnusedFunctions()(mod)
45    l = set([x[0].name_hint for x in mod.functions.items()])
46    assert l == set(["id_func", "main"])
47
48
49def test_keep_only_referenced_prelude_functions():
50    mod = tvm.IRModule()
51    p = Prelude(mod)
52    l = p.nil()
53    for i in [4, 3, 2, 1, 0]:
54        l = p.cons(relay.const(i), l)
55    body = p.hd(p.tl(p.tl(l)))
56    mod["main"] = relay.Function([], body)
57    mod = relay.transform.RemoveUnusedFunctions()(mod)
58    l = set([x[0].name_hint for x in mod.functions.items()])
59    assert l == set(["tl", "hd", "main"])
60
61
62def test_multiple_entry_functions():
63    mod = tvm.IRModule()
64    p = Prelude(mod)
65    l = p.nil()
66    for i in [4, 3, 2, 1, 0]:
67        l = p.cons(relay.const(i), l)
68    body = p.hd(p.tl(p.tl(l)))
69    mod["main1"] = relay.Function([], body)
70
71    x = relay.var("x", shape=(1, 16))
72    id_func = relay.Function([x], x)
73    id_name = relay.GlobalVar("id_func")
74    mod[id_name] = id_func
75    mod["main2"] = relay.Function([x], id_name(x))
76    mod = relay.transform.RemoveUnusedFunctions(["main1", "main2"])(mod)
77    l = set([x[0].name_hint for x in mod.functions.items()])
78    assert l == set(["tl", "hd", "main2", "id_func", "main1"])
79
80
81def test_globalvar_as_call_arg():
82    mod = tvm.IRModule()
83    p = Prelude(mod)
84    tensor_array = p.get_var("tensor_array", "int32")
85    tensor1 = p.get_var("tensor1", "int32")
86    write = p.get_var("tensor_array_write", "int32")
87    stack = p.get_var("tensor_array_stack", "int32")
88    v = relay.var("v")
89    init_tensor_array = tensor_array(relay.const(3))
90    tensor_array1 = write(init_tensor_array, relay.const(0), tensor1(v))
91    tensor_array2 = stack(tensor_array1)
92    mod["main"] = relay.Function([v], tensor_array2)
93    mod = relay.transform.RemoveUnusedFunctions()(mod)
94    l = set([x[0].name_hint for x in mod.functions.items()])
95    assert "tensor_array_int32" in l
96
97
98def test_call_globalvar_without_args():
99    def get_mod():
100        mod = tvm.IRModule({})
101        fn1 = relay.Function([], relay.const(1))
102        fn2 = relay.Function([], relay.const(2))
103        g1 = relay.GlobalVar("g1")
104        g2 = relay.GlobalVar("g2")
105        mod[g1] = fn1
106        mod[g2] = fn2
107        p = relay.var("p", "bool")
108        mod["main"] = relay.Function([p], relay.Call(relay.If(p, g1, g2), []))
109        return mod
110
111    mod = get_mod()
112    ref_mod = get_mod()
113    mod = relay.transform.RemoveUnusedFunctions()(mod)
114    assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True)
115
116
117if __name__ == "__main__":
118    pytest.main()
119