1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Test that type checker correcly computes types 18 for expressions. 19""" 20from tvm import relay 21from tvm.relay import op, transform, analysis 22from tvm.relay.analysis import assert_alpha_equal 23 24 25def run_infer_type(expr, mod=None): 26 if not mod: 27 mod = relay.Module.from_expr(expr) 28 mod = transform.InferType()(mod) 29 entry = mod["main"] 30 return entry if isinstance(expr, relay.Function) else entry.body 31 else: 32 if isinstance(expr, relay.GlobalVar): 33 gv = expr.name_hint 34 else: 35 func = expr 36 if not isinstance(expr, relay.Function): 37 func = relay.Function(analysis.free_vars(expr), expr) 38 mod["main"] = func 39 gv = "main" 40 mod = transform.InferType()(mod) 41 42 if isinstance(expr, (relay.GlobalVar, relay.Function)): 43 return mod[gv] 44 return mod[gv].body 45 46 47def assert_has_type(expr, typ, mod=relay.module.Module({})): 48 checked_expr = run_infer_type(expr, mod) 49 checked_type = checked_expr.checked_type 50 if checked_type != typ: 51 raise RuntimeError("Type mismatch %s vs %s" % ( 52 checked_type, typ)) 53 54 55# initializes simple ADT for tests 56def initialize_box_adt(mod): 57 box = relay.GlobalTypeVar('box') 58 tv = relay.TypeVar('tv') 59 constructor = relay.Constructor('constructor', [tv], box) 60 data = relay.TypeData(box, [tv], [constructor]) 61 mod[box] = data 62 return (box, constructor) 63 64 65def test_monomorphic_let(): 66 "Program: let %x = 1; %x" 67 sb = relay.ScopeBuilder() 68 x = sb.let('x', relay.const(1.0, "float64")) 69 sb.ret(x) 70 xchecked = run_infer_type(sb.get()) 71 assert xchecked.checked_type == relay.scalar_type("float64" ) 72 73 74def test_single_op(): 75 "Program: fn (%x : float32) { let %t1 = f(%x); %t1 }" 76 x = relay.var('x', shape=[]) 77 func = relay.Function([x], op.log(x)) 78 ttype = relay.TensorType([], dtype='float32') 79 assert_has_type(func, relay.FuncType([ttype], ttype)) 80 81 82def test_add_broadcast_op(): 83 """ 84 Program: 85 fn (%x: Tensor[(10, 4), float32], %y: Tensor[(5, 10, 1), float32]) 86 -> Tensor[(5, 10, 4), float32] { 87 %x + %y 88 } 89 """ 90 x = relay.var('x', shape=(10, 4)) 91 y = relay.var('y', shape=(5, 10, 1)) 92 z = x + y 93 func = relay.Function([x, y], z) 94 t1 = relay.TensorType((10, 4), 'float32') 95 t2 = relay.TensorType((5, 10, 1), 'float32') 96 t3 = relay.TensorType((5, 10, 4), 'float32') 97 expected_ty = relay.FuncType([t1, t2], t3) 98 assert_has_type(func, expected_ty) 99 100 101def test_dual_op(): 102 """Program: 103 fn (%x : Tensor[(10, 10), float32]) { 104 let %t1 = log(x); 105 let %t2 = add(%t1, %x); 106 %t1 107 } 108 """ 109 tp = relay.TensorType((10, 10), "float32") 110 x = relay.var("x", tp) 111 sb = relay.ScopeBuilder() 112 t1 = sb.let("t1", relay.log(x)) 113 t2 = sb.let("t2", relay.add(t1, x)) 114 sb.ret(t2) 115 f = relay.Function([x], sb.get()) 116 fchecked = run_infer_type(f) 117 assert fchecked.checked_type == relay.FuncType([tp], tp) 118 119 120def test_decl(): 121 """Program: 122 def @f(%x : Tensor[(10, 10), float32]) { 123 log(%x) 124 } 125 """ 126 tp = relay.TensorType((10, 10)) 127 x = relay.var("x", tp) 128 f = relay.Function([x], relay.log(x)) 129 fchecked = run_infer_type(f) 130 assert fchecked.checked_type == relay.FuncType([tp], tp) 131 132 133def test_recursion(): 134 """ 135 Program: 136 def @f(%n: int32, %data: float32) -> float32 { 137 if (%n == 0) { 138 %data 139 } else { 140 @f(%n - 1, log(%data)) 141 } 142 } 143 """ 144 sb = relay.ScopeBuilder() 145 f = relay.GlobalVar("f") 146 ti32 = relay.scalar_type("int32") 147 tf32 = relay.scalar_type("float32") 148 n = relay.var("n", ti32) 149 data = relay.var("data", tf32) 150 151 with sb.if_scope(relay.equal(n, relay.const(0, ti32))): 152 sb.ret(data) 153 with sb.else_scope(): 154 sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data))) 155 mod = relay.Module() 156 mod[f] = relay.Function([n, data], sb.get()) 157 assert "@f(%1, %2) /* ty=float32 */" in mod.astext() 158 assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32) 159 160 161def test_incomplete_call(): 162 tt = relay.scalar_type('int32') 163 x = relay.var('x', tt) 164 f = relay.var('f') 165 func = relay.Function([x, f], relay.Call(f, [x]), tt) 166 167 ft = run_infer_type(func) 168 f_type = relay.FuncType([tt], tt) 169 assert ft.checked_type == relay.FuncType([tt, f_type], tt) 170 171 172def test_higher_order_argument(): 173 a = relay.TypeVar('a') 174 x = relay.Var('x', a) 175 id_func = relay.Function([x], x, a, [a]) 176 177 b = relay.TypeVar('b') 178 f = relay.Var('f', relay.FuncType([b], b)) 179 y = relay.Var('y', b) 180 ho_func = relay.Function([f, y], f(y), b, [b]) 181 182 # id func should be an acceptable argument to the higher-order 183 # function even though id_func takes a type parameter 184 ho_call = ho_func(id_func, relay.const(0, 'int32')) 185 186 hc = run_infer_type(ho_call) 187 expected = relay.scalar_type('int32') 188 assert hc.checked_type == expected 189 190 191def test_higher_order_return(): 192 a = relay.TypeVar('a') 193 x = relay.Var('x', a) 194 id_func = relay.Function([x], x, a, [a]) 195 196 b = relay.TypeVar('b') 197 nested_id = relay.Function([], id_func, relay.FuncType([b], b), [b]) 198 199 ft = run_infer_type(nested_id) 200 assert ft.checked_type == relay.FuncType([], relay.FuncType([b], b), [b]) 201 202 203def test_higher_order_nested(): 204 a = relay.TypeVar('a') 205 x = relay.Var('x', a) 206 id_func = relay.Function([x], x, a, [a]) 207 208 choice_t = relay.FuncType([], relay.scalar_type('bool')) 209 f = relay.Var('f', choice_t) 210 211 b = relay.TypeVar('b') 212 z = relay.Var('z') 213 top = relay.Function( 214 [f], 215 relay.If(f(), id_func, relay.Function([z], z)), 216 relay.FuncType([b], b), 217 [b]) 218 219 expected = relay.FuncType([choice_t], relay.FuncType([b], b), [b]) 220 ft = run_infer_type(top) 221 assert ft.checked_type == expected 222 223 224def test_tuple(): 225 tp = relay.TensorType((10,)) 226 x = relay.var("x", tp) 227 res = relay.Tuple([x, x]) 228 assert (run_infer_type(res).checked_type == relay.TupleType([tp, tp])) 229 230 231def test_ref(): 232 x = relay.var("x", "float32") 233 y = relay.var("y", "float32") 234 r = relay.RefCreate(x) 235 st = relay.scalar_type("float32") 236 assert run_infer_type(r).checked_type == relay.RefType(st) 237 g = relay.RefRead(r) 238 assert run_infer_type(g).checked_type == st 239 w = relay.RefWrite(r, y) 240 assert run_infer_type(w).checked_type == relay.TupleType([]) 241 242 243def test_free_expr(): 244 return 245 x = relay.var("x", "float32") 246 y = relay.add(x, x) 247 yy = run_infer_type(y) 248 assert yy.checked_type == relay.scalar_type("float32") 249 assert x.vid.same_as(yy.args[0].vid) 250 251 252def test_type_args(): 253 x = relay.var("x", shape=(10, 10)) 254 y = relay.var("y", shape=(1, 10)) 255 z = relay.add(x, y) 256 ty_z = run_infer_type(z) 257 ty_args = ty_z.type_args 258 assert len(ty_args) == 2 259 assert ty_args[0].dtype == "float32" 260 assert ty_args[1].dtype == "float32" 261 sh1 = ty_args[0].shape 262 sh2 = ty_args[1].shape 263 assert sh1[0].value == 10 264 assert sh1[1].value == 10 265 assert sh2[0].value == 1 266 assert sh2[1].value == 10 267 268 269def test_global_var_recursion(): 270 mod = relay.Module({}) 271 gv = relay.GlobalVar("main") 272 x = relay.var('x', shape=[]) 273 tt = relay.scalar_type('float32') 274 275 func = relay.Function([x], relay.Call(gv, [x]), tt) 276 mod[gv] = func 277 278 ft = run_infer_type(gv, mod) 279 assert ft.checked_type == relay.FuncType([tt], tt) 280 281 282def test_equal(): 283 i = relay.var('i', shape=[], dtype='int32') 284 eq = op.equal(i, relay.const(0, dtype='int32')) 285 func = relay.Function([i], eq) 286 ft = run_infer_type(func) 287 288 assert ft.checked_type == relay.FuncType([relay.scalar_type('int32')], relay.scalar_type('bool')) 289 290 291def test_constructor_type(): 292 mod = relay.Module() 293 box, constructor = initialize_box_adt(mod) 294 295 a = relay.TypeVar('a') 296 x = relay.Var('x', a) 297 ct = run_infer_type(relay.Function([x], constructor(x), box(a), [a]), mod) 298 expected = relay.FuncType([a], box(a), [a]) 299 assert ct.checked_type == expected 300 301 302def test_constructor_call(): 303 mod = relay.Module() 304 box, constructor = initialize_box_adt(mod) 305 306 box_unit = constructor(relay.Tuple([])) 307 box_constant = constructor(relay.const(0, 'float32')) 308 309 ut = run_infer_type(box_unit, mod) 310 ct = run_infer_type(box_constant, mod) 311 assert ut.checked_type == box(relay.TupleType([])) 312 assert ct.checked_type == box(relay.TensorType((), 'float32')) 313 314 315def test_adt_match(): 316 mod = relay.Module() 317 box, constructor = initialize_box_adt(mod) 318 319 v = relay.Var('v', relay.TensorType((), 'float32')) 320 match = relay.Match(constructor(relay.const(0, 'float32')), 321 [relay.Clause( 322 relay.PatternConstructor(constructor, 323 [relay.PatternVar(v)]), 324 relay.Tuple([])), 325 # redundant but shouldn't matter to typechecking 326 relay.Clause(relay.PatternWildcard(), 327 relay.Tuple([]))]) 328 329 mt = run_infer_type(match, mod) 330 assert mt.checked_type == relay.TupleType([]) 331 332 333def test_adt_match_type_annotations(): 334 mod = relay.Module() 335 box, constructor = initialize_box_adt(mod) 336 337 # the only type annotation is inside the match pattern var 338 # but that should be enough info 339 tt = relay.TensorType((2, 2), 'float32') 340 x = relay.Var('x') 341 mv = relay.Var('mv', tt) 342 match = relay.Match(constructor(x), 343 [relay.Clause( 344 relay.PatternConstructor(constructor, 345 [relay.PatternVar(mv)]), 346 relay.Tuple([]))]) 347 348 func = relay.Function([x], match) 349 ft = run_infer_type(func, mod) 350 assert ft.checked_type == relay.FuncType([tt], relay.TupleType([])) 351 352 353def test_let_polymorphism(): 354 id = relay.Var("id") 355 xt = relay.TypeVar("xt") 356 x = relay.Var("x", xt) 357 body = relay.Tuple([id(relay.const(1)), id(relay.Tuple([]))]) 358 body = relay.Let(id, relay.Function([x], x, xt, [xt]), body) 359 body = run_infer_type(body) 360 int32 = relay.TensorType((), "int32") 361 assert_alpha_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])])) 362 363 364if __name__ == "__main__": 365 test_free_expr() 366 test_dual_op() 367 test_single_op() 368 test_recursion() 369 test_monomorphic_let() 370 test_decl() 371 test_recursion() 372 test_tuple() 373 test_incomplete_call() 374 test_type_args() 375 test_global_var_recursion() 376 test_equal() 377 test_ref() 378 test_constructor_type() 379 test_constructor_call() 380 test_adt_match() 381 test_let_polymorphism() 382