1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17import numpy as np
18import tvm
19from tvm import relay
20from tvm.relay.testing import to_python, run_as_python
21from tvm.relay.prelude import Prelude
22from tvm.relay.backend.interpreter import TensorValue, TupleValue, RefValue, ConstructorValue
23
24# helper: uses a dummy let binding to sequence a list
25# of expressions: expr1; expr2; expr3, etc.
26def seq(*exprs):
27    ret = exprs[0]
28    for expr in exprs[1:]:
29        ret = relay.Let(relay.var('_'), ret, expr)
30    return ret
31
32
33# creates a dummy ADT for testing
34def init_box_adt(mod):
35    box = relay.GlobalTypeVar('box')
36    a = relay.TypeVar('a')
37    box_ctor = relay.Constructor('box', [a], box)
38    mod[box] = relay.TypeData(box, [a], [box_ctor])
39    return (box, box_ctor)
40
41
42# assert that the candidate is a TensorValue with value val
43def assert_tensor_value(candidate, val):
44    assert isinstance(candidate, TensorValue)
45    assert np.array_equal(candidate.asnumpy(), np.array(val))
46
47
48# assert that the candidate is a TupleValue with the indicate number of fields
49def assert_tuple_value(candidate, fields):
50    assert isinstance(candidate, TupleValue)
51    assert len(candidate.fields) == fields
52
53
54# assert that the candidate is a ConstructorValue with the approrpaite constructor
55# and number of fields
56def assert_constructor_value(candidate, constructor, fields):
57    assert isinstance(candidate, ConstructorValue)
58    assert candidate.tag == constructor.tag
59    assert len(candidate.fields) == fields
60
61
62def test_create_empty_tuple():
63    empty = relay.Tuple([])
64    tup_val = run_as_python(empty)
65    assert_tuple_value(tup_val, 0)
66
67
68def test_create_scalar():
69    scalar = relay.const(1)
70    tensor_val = run_as_python(scalar)
71    assert_tensor_value(tensor_val, 1)
72
73
74def test_create_tensor():
75    tensor = relay.const([[1, 1], [2, 2]])
76    tensor_val = run_as_python(tensor)
77    assert_tensor_value(tensor_val, [[1, 1], [2, 2]])
78
79
80def test_create_nested_tuple():
81    relay_tup = relay.Tuple([
82        relay.const(1), relay.const(2),
83        relay.Tuple([
84            relay.const(3),
85            relay.const(4)
86        ])
87    ])
88    tup_val = run_as_python(relay_tup)
89    assert_tuple_value(tup_val, 3)
90    for i in range(2):
91        assert_tensor_value(tup_val.fields[i], i + 1)
92    assert_tuple_value(tup_val.fields[2], 2)
93    for i in range(2):
94        assert_tensor_value(tup_val.fields[2].fields[i], i + 3)
95
96
97def test_tuple_get_item():
98    relay_tup = relay.Tuple([
99        relay.const(1), relay.const(2),
100        relay.Tuple([
101            relay.const(3),
102            relay.const(4)
103        ])
104    ])
105    for i in range(2):
106        index = relay.TupleGetItem(relay_tup, i)
107        val = run_as_python(index)
108        assert_tensor_value(val, i + 1)
109    # try the inner value too
110    for i in range(2):
111        index = relay.TupleGetItem(relay.TupleGetItem(relay_tup, 2), i)
112        val = run_as_python(index)
113        assert_tensor_value(val, i + 3)
114
115
116def test_create_let():
117    v = relay.Var('v')
118    let = relay.Let(v, relay.Tuple([]), relay.Tuple([v, v]))
119    tup_val = run_as_python(let)
120    assert_tuple_value(tup_val, 2)
121    assert_tuple_value(tup_val.fields[0], 0)
122    assert_tuple_value(tup_val.fields[1], 0)
123
124
125def test_create_ref():
126    relay_ref = relay.RefCreate(relay.Tuple([]))
127    ref_val = run_as_python(relay_ref)
128    assert isinstance(ref_val, RefValue)
129    assert_tuple_value(ref_val.value, 0)
130
131
132def test_ref_read():
133    v = relay.Var('v')
134    assign = relay.Let(v, relay.RefCreate(relay.Tuple([])), relay.RefRead(v))
135    read_val = run_as_python(assign)
136    assert_tuple_value(read_val, 0)
137
138
139def test_ref_write():
140    # check that the result of a ref write is an empty tuple
141    v = relay.Var('v')
142    initial_write = relay.Let(v, relay.RefCreate(relay.Tuple([relay.const(1)])),
143                              relay.RefWrite(v, relay.Tuple([relay.const(2)])))
144    write_val = run_as_python(initial_write)
145    assert_tuple_value(write_val, 0)
146
147    # now ensure that the value, once written, can be read back
148    # (we read the value before and after mutation)
149    w = relay.Var('w')
150    read_after_write = relay.Let(
151        v, relay.RefCreate(relay.Tuple([relay.const(1)])),
152        relay.Let(
153            w, relay.RefCreate(relay.RefRead(v)),
154            seq(relay.RefWrite(v, relay.Tuple([relay.const(2)])),
155                relay.Tuple([relay.RefRead(w), relay.RefRead(v)]))))
156    read_val = run_as_python(read_after_write)
157    assert_tuple_value(read_val, 2)
158    assert_tuple_value(read_val.fields[0], 1)
159    assert_tuple_value(read_val.fields[1], 1)
160    assert_tensor_value(read_val.fields[0].fields[0], 1)
161    assert_tensor_value(read_val.fields[1].fields[0], 2)
162
163
164def test_if():
165    # we will have effects in the blocks to ensure only the intended one is executed
166    true_cond = relay.const(True)
167    false_cond = relay.const(False)
168
169    v  = relay.Var('v')
170    true_branch = seq(relay.RefWrite(v, relay.const(1)), relay.RefRead(v))
171    false_branch = seq(relay.RefWrite(v, relay.const(2)), relay.RefRead(v))
172
173    true_expr = relay.Let(v, relay.RefCreate(relay.const(0)),
174                          relay.If(true_cond, true_branch, false_branch))
175    false_expr = relay.Let(v, relay.RefCreate(relay.const(0)),
176                           relay.If(false_cond, true_branch, false_branch))
177
178    true_val = run_as_python(true_expr)
179    assert_tensor_value(true_val, 1)
180
181    false_val = run_as_python(false_expr)
182    assert_tensor_value(false_val, 2)
183
184
185def test_local_function():
186    v = relay.Var('v')
187    ident = relay.Function([v], v)
188    f = relay.Var('f')
189    call1 = relay.Let(f, ident, f(relay.Tuple([])))
190    call2 = relay.Let(f, ident, f(relay.const(2)))
191
192    call_val1 = run_as_python(call1)
193    assert_tuple_value(call_val1, 0)
194
195    call_val2 = run_as_python(call2)
196    assert_tensor_value(call_val2, 2)
197
198
199def test_global_function():
200    mod = relay.Module()
201    ident = relay.GlobalVar('ident')
202    a = relay.TypeVar('a')
203    v = relay.Var('v', a)
204    mod[ident] = relay.Function([v], v, a, [a])
205
206    call1 = ident(relay.const(1))
207    call2 = ident(relay.Tuple([relay.const(2), relay.const(2)]))
208
209    call_val1 = run_as_python(call1, mod)
210    assert_tensor_value(call_val1, 1)
211
212    call_val2 = run_as_python(call2, mod)
213    assert_tuple_value(call_val2, 2)
214    assert_tensor_value(call_val2.fields[0], 2)
215    assert_tensor_value(call_val2.fields[1], 2)
216
217
218def test_constructor():
219    mod = relay.Module()
220    box, box_ctor = init_box_adt(mod)
221
222    init_box_int = box_ctor(relay.const(1))
223    box_val_int = run_as_python(init_box_int, mod)
224
225    assert_constructor_value(box_val_int, box_ctor, 1)
226    assert_tensor_value(box_val_int.fields[0], 1)
227
228    init_box_tup = box_ctor(relay.Tuple([]))
229    box_val_tup = run_as_python(init_box_tup, mod)
230
231    assert_constructor_value(box_val_tup, box_ctor, 1)
232    assert_tuple_value(box_val_tup.fields[0], 0)
233
234
235def test_match_wildcard():
236    mod = relay.Module()
237    box, box_ctor = init_box_adt(mod)
238    v = relay.Var('v')
239    match = relay.Let(
240        v, box_ctor(relay.Tuple([])),
241        relay.Match(v, [
242            relay.Clause(relay.PatternWildcard(), relay.const(1))
243        ]))
244
245    match_val = run_as_python(match, mod)
246    assert_tensor_value(match_val, 1)
247
248
249def test_match_var():
250    mod = relay.Module()
251    box, box_ctor = init_box_adt(mod)
252    v = relay.Var('v')
253    w = relay.Var('w')
254    match = relay.Let(
255        v, box_ctor(relay.const(1)),
256        relay.Match(v, [
257            relay.Clause(relay.PatternVar(w), w)
258        ]))
259
260    match_val = run_as_python(match, mod)
261    assert_constructor_value(match_val, box_ctor, 1)
262    assert_tensor_value(match_val.fields[0], 1)
263
264
265def test_match_pattern():
266    mod = relay.Module()
267    box, box_ctor = init_box_adt(mod)
268    v = relay.Var('v')
269    w = relay.Var('w')
270    match = relay.Let(
271        v, box_ctor(relay.const(1)),
272        relay.Match(v, [
273            relay.Clause(relay.PatternConstructor(box_ctor, [relay.PatternVar(w)]), w)
274        ]))
275    match_val = run_as_python(match, mod)
276    assert_tensor_value(match_val, 1)
277
278
279def test_nested_match_pattern():
280    mod = relay.Module()
281    box, box_ctor = init_box_adt(mod)
282    v = relay.Var('v')
283    w = relay.Var('w')
284    match = relay.Let(
285        v, box_ctor(box_ctor(relay.const(2))),
286        relay.Match(v, [
287            relay.Clause(
288                relay.PatternConstructor(
289                    box_ctor, [
290                        relay.PatternConstructor(box_ctor, [relay.PatternVar(w)])
291                    ]),
292                w)]))
293    match_val = run_as_python(match, mod)
294    assert_tensor_value(match_val, 2)
295
296def test_match_order():
297    mod = relay.Module()
298    box, box_ctor = init_box_adt(mod)
299    v = relay.Var('v')
300    w = relay.Var('w')
301    # wildcard pattern goes first
302    match = relay.Let(
303        v, box_ctor(box_ctor(relay.const(2))),
304        relay.Match(v, [
305            relay.Clause(relay.PatternWildcard(), relay.const(1)),
306            relay.Clause(
307                relay.PatternConstructor(
308                    box_ctor, [
309                        relay.PatternConstructor(box_ctor, [relay.PatternVar(w)])
310                    ]),
311                w)]))
312    match_val = run_as_python(match, mod)
313    assert_tensor_value(match_val, 1)
314
315
316def test_local_recursion():
317    mod = relay.Module()
318    p = Prelude(mod)
319
320    v = relay.Var('v')
321    h = relay.Var('h')
322    t = relay.Var('t')
323    f = relay.Var('f')
324
325    # just returns the same list
326    let = relay.Let(f, relay.Function([v], relay.Match(v, [
327        relay.Clause(relay.PatternConstructor(p.cons,
328                                              [relay.PatternVar(h), relay.PatternVar(t)]),
329                     p.cons(h, f(t))),
330        relay.Clause(relay.PatternConstructor(p.nil, []), p.nil())
331    ])),
332                    f(p.cons(relay.const(1),
333                             p.cons(relay.const(2),
334                                    p.cons(relay.const(3), p.nil())))))
335
336    val = run_as_python(let, mod)
337    assert_constructor_value(val, p.cons, 2)
338    assert_tensor_value(val.fields[0], 1)
339    assert_constructor_value(val.fields[1], p.cons, 2)
340    assert_tensor_value(val.fields[1].fields[0], 2)
341    assert_constructor_value(val.fields[1].fields[1], p.cons, 2)
342    assert_tensor_value(val.fields[1].fields[1].fields[0], 3)
343    assert_constructor_value(val.fields[1].fields[1].fields[1], p.nil, 0)
344
345
346def test_global_recursion():
347    mod = relay.Module()
348    p = Prelude(mod)
349    copy = relay.GlobalVar('copy')
350    # same as above: it copies the given list
351    a = relay.TypeVar('a')
352    v = relay.Var('v', p.l(a))
353    h = relay.Var('h')
354    t = relay.Var('t')
355    copy_def = relay.Function([v], relay.Match(v, [
356        relay.Clause(relay.PatternConstructor(p.cons,
357                                              [relay.PatternVar(h), relay.PatternVar(t)]),
358                     p.cons(h, copy(t))),
359        relay.Clause(relay.PatternConstructor(p.nil, []), p.nil())
360    ]), p.l(a), [a])
361    mod[copy] = copy_def
362
363    call1 = copy_def(p.cons(relay.const(1), p.cons(relay.const(2), p.nil())))
364    val1 = run_as_python(call1, mod)
365    assert_constructor_value(val1, p.cons, 2)
366    assert_tensor_value(val1.fields[0], 1)
367    assert_constructor_value(val1.fields[1], p.cons, 2)
368    assert_tensor_value(val1.fields[1].fields[0], 2)
369    assert_constructor_value(val1.fields[1].fields[1], p.nil, 0)
370
371    call2 = copy_def(p.cons(relay.Tuple([]), p.nil()))
372    val2 = run_as_python(call2, mod)
373    assert_constructor_value(val2, p.cons, 2)
374    assert_tuple_value(val2.fields[0], 0)
375    assert_constructor_value(val2.fields[1], p.nil, 0)
376
377
378def test_higher_order_call():
379    # test with anon func
380    h = relay.Var('h')
381    f = relay.Var('f')
382    x = relay.Var('x')
383    ho_anon = relay.Let(h, relay.Function([f], f(relay.Tuple([]))),
384                        h(relay.Function([x], relay.const(1))))
385
386    anon_val = run_as_python(ho_anon)
387    assert_tensor_value(anon_val, 1)
388
389    # test with named func
390    g = relay.Var('g')
391    ho_named = relay.Let(h, relay.Function([f], f(relay.Tuple([]))),
392                         relay.Let(g, relay.Function([x], relay.const(2)),
393                           h(g)))
394    named_val = run_as_python(ho_named)
395    assert_tensor_value(named_val, 2)
396
397
398def test_match_effect_exactly_once():
399    mod = relay.Module()
400    p = Prelude(mod)
401
402    # the list should be of length 1!
403    # Unless we mistakenly execute the data clause more than once
404    r = relay.Var('r')
405    data = seq(relay.RefWrite(r, p.cons(relay.Tuple([]), relay.RefRead(r))), relay.RefRead(r))
406    match = relay.Let(
407        r, relay.RefCreate(p.nil()),
408        relay.Match(data, [
409            relay.Clause(relay.PatternConstructor(p.nil, []), relay.const(0)),
410            relay.Clause(
411                relay.PatternConstructor(
412                    p.cons,
413                    [relay.PatternWildcard(), relay.PatternConstructor(p.nil, [])]),
414                relay.const(1)),
415            relay.Clause(relay.PatternWildcard(), relay.const(2))
416        ]))
417
418    match_val = run_as_python(match, mod)
419    assert_tensor_value(match_val, 1)
420
421
422def test_arbitrary_let_nesting():
423    # something that is tricky to do in Python but comes naturally in Relay
424    mod = relay.Module()
425    p = Prelude(mod)
426    x = relay.Var('x')
427    r = relay.Var('r')
428    y = relay.Var('y')
429    z = relay.Var('z')
430    expr = relay.Tuple([
431        relay.Let(x, relay.Tuple([relay.const(1), relay.const(2)]),
432                  relay.TupleGetItem(x, 1)),
433        relay.Let(r, relay.RefCreate(relay.const(1)),
434                  seq(relay.RefWrite(r, relay.const(3)), relay.RefRead(r))),
435        relay.Let(y, p.id(relay.Let(z, relay.const(4), z)), y)
436    ])
437
438    tup_val = run_as_python(expr, mod)
439    assert_tuple_value(tup_val, 3)
440    assert_tensor_value(tup_val.fields[0], 2)
441    assert_tensor_value(tup_val.fields[1], 3)
442    assert_tensor_value(tup_val.fields[2], 4)
443
444
445def test_ref_execution_order():
446    # we want to have effects execute from left to right
447    x = relay.Var('x')
448    y = relay.Var('y')
449    f = relay.Var('f')
450    r = relay.Var('r')
451
452    expr = relay.Let(f, relay.Function([x, y], x),
453                     # r = 1
454                     relay.Let(r, relay.RefCreate(relay.const(1)),
455                               relay.Tuple([
456                                   # should be 1
457                                   relay.RefRead(r),
458                                   # set r to 2 and read back
459                                   seq(relay.RefWrite(r, relay.const(2)),
460                                       relay.RefRead(r)),
461                                   # set r to 3 and read back
462                                   seq(relay.RefWrite(r, relay.const(3)),
463                                       relay.RefRead(r)),
464                                   # set r to 4 and read as first arg to f
465                                   # set r to 5 and read as second arg to f
466                                   # f should evaluate to 4
467                                   f(
468                                       seq(relay.RefWrite(r, relay.const(4)),
469                                           relay.RefRead(r)),
470                                       seq(relay.RefWrite(r, relay.const(5)),
471                                           relay.RefRead(r))),
472                                   # read back 5
473                                   relay.RefRead(r)
474                  ])))
475
476    tup_val = run_as_python(expr)
477    assert_tuple_value(tup_val, 5)
478    assert_tensor_value(tup_val.fields[0], 1)
479    assert_tensor_value(tup_val.fields[1], 2)
480    assert_tensor_value(tup_val.fields[2], 3)
481    assert_tensor_value(tup_val.fields[3], 4)
482    assert_tensor_value(tup_val.fields[4], 5)
483
484
485def test_op_add():
486    add = relay.add(relay.const(1), relay.const(2))
487    add_val = run_as_python(add)
488    assert_tensor_value(add_val, 3)
489
490
491# test an op with a tuple input
492# adapted from test_stack in test_op_level3
493def test_op_stack():
494    def verify_stack(dshapes, axis):
495        x_data = [np.random.normal(size=shape).astype('int32') for shape in dshapes]
496        ref_res = np.stack(x_data, axis=axis)
497
498        args = []
499        for data in x_data:
500            args.append(relay.const(data))
501        call = relay.stack(relay.Tuple(args), axis)
502        call_val = run_as_python(call)
503        assert_tensor_value(call_val, ref_res)
504
505    verify_stack([(2,), (2,), (2,)], -1)
506    verify_stack([(2,), (2,), (2,)], 0)
507    verify_stack([(2, 2, 4), (2, 2, 4), (2, 2, 4)], 1)
508    verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1)
509
510
511# test an op with a tuple output
512# adapted from test_split_infer_type in test_op_level3
513# and test_split in nnvm's test_top_level1
514def test_split():
515    def verify_split(shape, indices_or_sections, axis=0):
516        x = np.random.normal(size=shape).astype('float32')
517        ref_res = np.split(x, indices_or_sections, axis=axis)
518        call = relay.split(relay.const(x), indices_or_sections, axis=axis)
519        call_val = run_as_python(call)
520        assert_tuple_value(call_val, len(ref_res))
521        for i in range(len(ref_res)):
522            assert_tensor_value(call_val.fields[i], ref_res[i])
523
524    verify_split((2, 3), 2)
525    verify_split((5, 3), [3])
526    verify_split((5, 9, 3), [3, 4], 1)
527    verify_split((5, 5, 2, 2), 5, 1)
528    verify_split((5, 5, 2, 2), 5, 0)
529
530
531# ensure we can generate code for batch_norm, since it requires simplify_inference
532# adapted from test_batchnorm in nnvm's test_top_level1
533def test_batch_norm():
534    def verify_batch_norm(shapes):
535        data = [np.absolute(np.random.normal(size=shape).astype('float32'))
536                for shape in shapes]
537        relay_args = [relay.const(arg) for arg in data]
538
539        eps = 1e-5
540        def reference(x, gamma, beta, moving_mean, moving_var):
541            return (x - moving_mean) / np.sqrt(moving_var + eps) * gamma + beta
542        ref_res = reference(*data)
543
544        call = relay.nn.batch_norm(*relay_args, epsilon=eps)[0]
545        call_val = run_as_python(call)
546
547        # there will be a change in accuracy so we need to check
548        # approximate equality
549        assert isinstance(call_val, TensorValue)
550        tvm.testing.assert_allclose(call_val.asnumpy(), ref_res, atol=eps, rtol=eps)
551
552    verify_batch_norm([(10, 20), (20,), (20,), (20,), (20,)])
553    verify_batch_norm([(20, 10), (10,), (10,), (10,), (10,)])
554    verify_batch_norm([(10, 50), (50,), (50,), (50,), (50,)])
555    verify_batch_norm([(30, 40), (40,), (40,), (40,), (40,)])
556