1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Test that type checker correcly computes types
18   for expressions.
19"""
20from tvm import relay
21from tvm.relay import op, transform, analysis
22from tvm.relay.analysis import assert_alpha_equal
23
24
25def run_infer_type(expr, mod=None):
26    if not mod:
27        mod = relay.Module.from_expr(expr)
28        mod = transform.InferType()(mod)
29        entry = mod["main"]
30        return entry if isinstance(expr, relay.Function) else entry.body
31    else:
32        if isinstance(expr, relay.GlobalVar):
33            gv = expr.name_hint
34        else:
35            func = expr
36            if not isinstance(expr, relay.Function):
37                func = relay.Function(analysis.free_vars(expr), expr)
38            mod["main"] = func
39            gv = "main"
40        mod = transform.InferType()(mod)
41
42        if isinstance(expr, (relay.GlobalVar, relay.Function)):
43            return mod[gv]
44        return mod[gv].body
45
46
47def assert_has_type(expr, typ, mod=relay.module.Module({})):
48    checked_expr = run_infer_type(expr, mod)
49    checked_type = checked_expr.checked_type
50    if checked_type != typ:
51        raise RuntimeError("Type mismatch %s vs %s" % (
52            checked_type, typ))
53
54
55# initializes simple ADT for tests
56def initialize_box_adt(mod):
57    box = relay.GlobalTypeVar('box')
58    tv = relay.TypeVar('tv')
59    constructor = relay.Constructor('constructor', [tv], box)
60    data = relay.TypeData(box, [tv], [constructor])
61    mod[box] = data
62    return (box, constructor)
63
64
65def test_monomorphic_let():
66    "Program: let %x = 1; %x"
67    sb = relay.ScopeBuilder()
68    x = sb.let('x', relay.const(1.0, "float64"))
69    sb.ret(x)
70    xchecked = run_infer_type(sb.get())
71    assert xchecked.checked_type == relay.scalar_type("float64" )
72
73
74def test_single_op():
75    "Program: fn (%x : float32) { let %t1 = f(%x); %t1 }"
76    x = relay.var('x', shape=[])
77    func = relay.Function([x], op.log(x))
78    ttype = relay.TensorType([], dtype='float32')
79    assert_has_type(func, relay.FuncType([ttype], ttype))
80
81
82def test_add_broadcast_op():
83    """
84    Program:
85        fn (%x: Tensor[(10, 4), float32], %y: Tensor[(5, 10, 1), float32])
86            -> Tensor[(5, 10, 4), float32] {
87            %x + %y
88        }
89    """
90    x = relay.var('x', shape=(10, 4))
91    y = relay.var('y', shape=(5, 10, 1))
92    z = x + y
93    func = relay.Function([x, y], z)
94    t1 = relay.TensorType((10, 4), 'float32')
95    t2 = relay.TensorType((5, 10, 1), 'float32')
96    t3 = relay.TensorType((5, 10, 4), 'float32')
97    expected_ty = relay.FuncType([t1, t2], t3)
98    assert_has_type(func, expected_ty)
99
100
101def test_dual_op():
102    """Program:
103       fn (%x : Tensor[(10, 10), float32]) {
104         let %t1 = log(x);
105         let %t2 = add(%t1, %x);
106         %t1
107       }
108    """
109    tp = relay.TensorType((10, 10), "float32")
110    x = relay.var("x", tp)
111    sb = relay.ScopeBuilder()
112    t1 = sb.let("t1", relay.log(x))
113    t2 = sb.let("t2", relay.add(t1, x))
114    sb.ret(t2)
115    f = relay.Function([x], sb.get())
116    fchecked = run_infer_type(f)
117    assert fchecked.checked_type == relay.FuncType([tp], tp)
118
119
120def test_decl():
121    """Program:
122       def @f(%x : Tensor[(10, 10), float32]) {
123           log(%x)
124       }
125    """
126    tp = relay.TensorType((10, 10))
127    x = relay.var("x", tp)
128    f = relay.Function([x], relay.log(x))
129    fchecked = run_infer_type(f)
130    assert fchecked.checked_type == relay.FuncType([tp], tp)
131
132
133def test_recursion():
134    """
135    Program:
136       def @f(%n: int32, %data: float32) -> float32 {
137          if (%n == 0) {
138              %data
139          } else {
140              @f(%n - 1, log(%data))
141          }
142       }
143    """
144    sb = relay.ScopeBuilder()
145    f = relay.GlobalVar("f")
146    ti32 = relay.scalar_type("int32")
147    tf32 = relay.scalar_type("float32")
148    n = relay.var("n", ti32)
149    data = relay.var("data", tf32)
150
151    with sb.if_scope(relay.equal(n, relay.const(0, ti32))):
152        sb.ret(data)
153    with sb.else_scope():
154        sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data)))
155    mod = relay.Module()
156    mod[f] = relay.Function([n, data], sb.get())
157    assert "@f(%1, %2) /* ty=float32 */" in mod.astext()
158    assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32)
159
160
161def test_incomplete_call():
162    tt = relay.scalar_type('int32')
163    x = relay.var('x', tt)
164    f = relay.var('f')
165    func = relay.Function([x, f], relay.Call(f, [x]), tt)
166
167    ft = run_infer_type(func)
168    f_type = relay.FuncType([tt], tt)
169    assert ft.checked_type == relay.FuncType([tt, f_type], tt)
170
171
172def test_higher_order_argument():
173    a = relay.TypeVar('a')
174    x = relay.Var('x', a)
175    id_func = relay.Function([x], x, a, [a])
176
177    b = relay.TypeVar('b')
178    f = relay.Var('f', relay.FuncType([b], b))
179    y = relay.Var('y', b)
180    ho_func = relay.Function([f, y], f(y), b, [b])
181
182    # id func should be an acceptable argument to the higher-order
183    # function even though id_func takes a type parameter
184    ho_call = ho_func(id_func, relay.const(0, 'int32'))
185
186    hc = run_infer_type(ho_call)
187    expected = relay.scalar_type('int32')
188    assert hc.checked_type == expected
189
190
191def test_higher_order_return():
192    a = relay.TypeVar('a')
193    x = relay.Var('x', a)
194    id_func = relay.Function([x], x, a, [a])
195
196    b = relay.TypeVar('b')
197    nested_id = relay.Function([], id_func, relay.FuncType([b], b), [b])
198
199    ft = run_infer_type(nested_id)
200    assert ft.checked_type == relay.FuncType([], relay.FuncType([b], b), [b])
201
202
203def test_higher_order_nested():
204    a = relay.TypeVar('a')
205    x = relay.Var('x', a)
206    id_func = relay.Function([x], x, a, [a])
207
208    choice_t = relay.FuncType([], relay.scalar_type('bool'))
209    f = relay.Var('f', choice_t)
210
211    b = relay.TypeVar('b')
212    z = relay.Var('z')
213    top = relay.Function(
214        [f],
215        relay.If(f(), id_func, relay.Function([z], z)),
216        relay.FuncType([b], b),
217        [b])
218
219    expected = relay.FuncType([choice_t], relay.FuncType([b], b), [b])
220    ft = run_infer_type(top)
221    assert ft.checked_type == expected
222
223
224def test_tuple():
225    tp = relay.TensorType((10,))
226    x = relay.var("x", tp)
227    res = relay.Tuple([x, x])
228    assert (run_infer_type(res).checked_type == relay.TupleType([tp, tp]))
229
230
231def test_ref():
232    x = relay.var("x", "float32")
233    y = relay.var("y", "float32")
234    r = relay.RefCreate(x)
235    st = relay.scalar_type("float32")
236    assert run_infer_type(r).checked_type == relay.RefType(st)
237    g = relay.RefRead(r)
238    assert run_infer_type(g).checked_type == st
239    w = relay.RefWrite(r, y)
240    assert run_infer_type(w).checked_type == relay.TupleType([])
241
242
243def test_free_expr():
244    return
245    x = relay.var("x", "float32")
246    y = relay.add(x, x)
247    yy = run_infer_type(y)
248    assert yy.checked_type == relay.scalar_type("float32")
249    assert x.vid.same_as(yy.args[0].vid)
250
251
252def test_type_args():
253    x = relay.var("x", shape=(10, 10))
254    y = relay.var("y", shape=(1, 10))
255    z = relay.add(x, y)
256    ty_z = run_infer_type(z)
257    ty_args = ty_z.type_args
258    assert len(ty_args) == 2
259    assert ty_args[0].dtype == "float32"
260    assert ty_args[1].dtype == "float32"
261    sh1 = ty_args[0].shape
262    sh2 = ty_args[1].shape
263    assert sh1[0].value == 10
264    assert sh1[1].value == 10
265    assert sh2[0].value == 1
266    assert sh2[1].value == 10
267
268
269def test_global_var_recursion():
270    mod = relay.Module({})
271    gv = relay.GlobalVar("main")
272    x = relay.var('x', shape=[])
273    tt = relay.scalar_type('float32')
274
275    func = relay.Function([x], relay.Call(gv, [x]), tt)
276    mod[gv] = func
277
278    ft = run_infer_type(gv, mod)
279    assert ft.checked_type == relay.FuncType([tt], tt)
280
281
282def test_equal():
283    i = relay.var('i', shape=[], dtype='int32')
284    eq = op.equal(i, relay.const(0, dtype='int32'))
285    func = relay.Function([i], eq)
286    ft = run_infer_type(func)
287
288    assert ft.checked_type == relay.FuncType([relay.scalar_type('int32')], relay.scalar_type('bool'))
289
290
291def test_constructor_type():
292    mod = relay.Module()
293    box, constructor = initialize_box_adt(mod)
294
295    a = relay.TypeVar('a')
296    x = relay.Var('x', a)
297    ct = run_infer_type(relay.Function([x], constructor(x), box(a), [a]), mod)
298    expected = relay.FuncType([a], box(a), [a])
299    assert ct.checked_type == expected
300
301
302def test_constructor_call():
303    mod = relay.Module()
304    box, constructor = initialize_box_adt(mod)
305
306    box_unit = constructor(relay.Tuple([]))
307    box_constant = constructor(relay.const(0, 'float32'))
308
309    ut = run_infer_type(box_unit, mod)
310    ct = run_infer_type(box_constant, mod)
311    assert ut.checked_type == box(relay.TupleType([]))
312    assert ct.checked_type == box(relay.TensorType((), 'float32'))
313
314
315def test_adt_match():
316    mod = relay.Module()
317    box, constructor = initialize_box_adt(mod)
318
319    v = relay.Var('v', relay.TensorType((), 'float32'))
320    match = relay.Match(constructor(relay.const(0, 'float32')),
321                        [relay.Clause(
322                            relay.PatternConstructor(constructor,
323                                                     [relay.PatternVar(v)]),
324                            relay.Tuple([])),
325                         # redundant but shouldn't matter to typechecking
326                         relay.Clause(relay.PatternWildcard(),
327                                      relay.Tuple([]))])
328
329    mt = run_infer_type(match, mod)
330    assert mt.checked_type == relay.TupleType([])
331
332
333def test_adt_match_type_annotations():
334    mod = relay.Module()
335    box, constructor = initialize_box_adt(mod)
336
337    # the only type annotation is inside the match pattern var
338    # but that should be enough info
339    tt = relay.TensorType((2, 2), 'float32')
340    x = relay.Var('x')
341    mv = relay.Var('mv', tt)
342    match = relay.Match(constructor(x),
343                        [relay.Clause(
344                            relay.PatternConstructor(constructor,
345                                                     [relay.PatternVar(mv)]),
346                                                     relay.Tuple([]))])
347
348    func = relay.Function([x], match)
349    ft = run_infer_type(func, mod)
350    assert ft.checked_type == relay.FuncType([tt], relay.TupleType([]))
351
352
353def test_let_polymorphism():
354    id = relay.Var("id")
355    xt = relay.TypeVar("xt")
356    x = relay.Var("x", xt)
357    body = relay.Tuple([id(relay.const(1)), id(relay.Tuple([]))])
358    body = relay.Let(id, relay.Function([x], x, xt, [xt]), body)
359    body = run_infer_type(body)
360    int32 = relay.TensorType((), "int32")
361    assert_alpha_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])]))
362
363
364if __name__ == "__main__":
365    test_free_expr()
366    test_dual_op()
367    test_single_op()
368    test_recursion()
369    test_monomorphic_let()
370    test_decl()
371    test_recursion()
372    test_tuple()
373    test_incomplete_call()
374    test_type_args()
375    test_global_var_recursion()
376    test_equal()
377    test_ref()
378    test_constructor_type()
379    test_constructor_call()
380    test_adt_match()
381    test_let_polymorphism()
382