1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17import pytest 18import tvm 19from tvm import te 20from tvm import relay 21from tvm.relay import transform 22from tvm.relay.prelude import Prelude 23 24 25def test_remove_all_prelude_functions(): 26 mod = tvm.IRModule() 27 p = Prelude(mod) 28 x = relay.var("x", shape=(1, 16)) 29 mod["main"] = relay.Function([x], x) 30 mod = relay.transform.RemoveUnusedFunctions()(mod) 31 l = set([x[0].name_hint for x in mod.functions.items()]) 32 assert l == set(["main"]) 33 34 35def test_remove_all_prelude_functions_but_referenced_functions(): 36 mod = tvm.IRModule() 37 p = Prelude(mod) 38 x = relay.var("x", shape=(1, 16)) 39 id_func = relay.Function([x], x) 40 id_name = relay.GlobalVar("id_func") 41 mod[id_name] = id_func 42 43 mod["main"] = relay.Function([x], id_name(x)) 44 mod = relay.transform.RemoveUnusedFunctions()(mod) 45 l = set([x[0].name_hint for x in mod.functions.items()]) 46 assert l == set(["id_func", "main"]) 47 48 49def test_keep_only_referenced_prelude_functions(): 50 mod = tvm.IRModule() 51 p = Prelude(mod) 52 l = p.nil() 53 for i in [4, 3, 2, 1, 0]: 54 l = p.cons(relay.const(i), l) 55 body = p.hd(p.tl(p.tl(l))) 56 mod["main"] = relay.Function([], body) 57 mod = relay.transform.RemoveUnusedFunctions()(mod) 58 l = set([x[0].name_hint for x in mod.functions.items()]) 59 assert l == set(["tl", "hd", "main"]) 60 61 62def test_multiple_entry_functions(): 63 mod = tvm.IRModule() 64 p = Prelude(mod) 65 l = p.nil() 66 for i in [4, 3, 2, 1, 0]: 67 l = p.cons(relay.const(i), l) 68 body = p.hd(p.tl(p.tl(l))) 69 mod["main1"] = relay.Function([], body) 70 71 x = relay.var("x", shape=(1, 16)) 72 id_func = relay.Function([x], x) 73 id_name = relay.GlobalVar("id_func") 74 mod[id_name] = id_func 75 mod["main2"] = relay.Function([x], id_name(x)) 76 mod = relay.transform.RemoveUnusedFunctions(["main1", "main2"])(mod) 77 l = set([x[0].name_hint for x in mod.functions.items()]) 78 assert l == set(["tl", "hd", "main2", "id_func", "main1"]) 79 80 81def test_globalvar_as_call_arg(): 82 mod = tvm.IRModule() 83 p = Prelude(mod) 84 tensor_array = p.get_var("tensor_array", "int32") 85 tensor1 = p.get_var("tensor1", "int32") 86 write = p.get_var("tensor_array_write", "int32") 87 stack = p.get_var("tensor_array_stack", "int32") 88 v = relay.var("v") 89 init_tensor_array = tensor_array(relay.const(3)) 90 tensor_array1 = write(init_tensor_array, relay.const(0), tensor1(v)) 91 tensor_array2 = stack(tensor_array1) 92 mod["main"] = relay.Function([v], tensor_array2) 93 mod = relay.transform.RemoveUnusedFunctions()(mod) 94 l = set([x[0].name_hint for x in mod.functions.items()]) 95 assert "tensor_array_int32" in l 96 97 98def test_call_globalvar_without_args(): 99 def get_mod(): 100 mod = tvm.IRModule({}) 101 fn1 = relay.Function([], relay.const(1)) 102 fn2 = relay.Function([], relay.const(2)) 103 g1 = relay.GlobalVar("g1") 104 g2 = relay.GlobalVar("g2") 105 mod[g1] = fn1 106 mod[g2] = fn2 107 p = relay.var("p", "bool") 108 mod["main"] = relay.Function([p], relay.Call(relay.If(p, g1, g2), [])) 109 return mod 110 111 mod = get_mod() 112 ref_mod = get_mod() 113 mod = relay.transform.RemoveUnusedFunctions()(mod) 114 assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True) 115 116 117if __name__ == "__main__": 118 pytest.main() 119