1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Unit tests for heterogeneous compilation and execution."""
18import json
19import numpy as np
20
21import tvm
22from tvm import relay
23from tvm.contrib import graph_runtime
24from tvm.relay.expr_functor import ExprMutator
25from tvm.relay import transform
26
27
28def run_opt_pass(expr, passes):
29    passes = passes if isinstance(passes, list) else [passes]
30    mod = relay.Module.from_expr(expr)
31    seq = transform.Sequential(passes)
32    with transform.PassContext(opt_level=3):
33        mod = seq(mod)
34    return mod["main"]
35
36
37def test_redundant_annotation():
38    ctx1 = tvm.context(1)
39    ctx2 = tvm.context(2)
40    x = relay.var("x", shape=(3,))
41    y = relay.var("y", shape=(3,))
42    z = relay.var("z", shape=(3,))
43
44    def annotated():
45        add = relay.add(x, y)
46        _add1 = relay.annotation.on_device(add, ctx2)
47        _add2 = relay.annotation.on_device(add, ctx2)
48        sub1 = relay.subtract(_add1, z)
49        sub2 = relay.subtract(_add2, z)
50
51        func = relay.Function([x, y, z], relay.Tuple([sub1, sub2]))
52        func = run_opt_pass(func,
53                            transform.RewriteAnnotatedOps(ctx1.device_type))
54        return func
55
56    def expected():
57        add = relay.add(x, y)
58        copy_add_sub1 = relay.device_copy(add, ctx2, ctx1)
59        sub1 = relay.subtract(copy_add_sub1, z)
60        copy_add_sub2 = relay.device_copy(add, ctx2, ctx1)
61        sub2 = relay.subtract(copy_add_sub2, z)
62        func = relay.Function([x, y, z], relay.Tuple([sub1, sub2]))
63        return func
64
65    annotated_func = annotated()
66    expected_func = run_opt_pass(expected(), transform.InferType())
67    assert relay.analysis.alpha_equal(annotated_func, expected_func)
68
69
70def test_annotate_expr():
71    ctx1 = tvm.context(1)
72    ctx2 = tvm.context(2)
73    x = relay.var("x", shape=(3,))
74    y = relay.var("y", shape=(3,))
75    z = relay.var("z", shape=(3,))
76
77    def annotated():
78        add = relay.add(x, y)
79        _add = relay.annotation.on_device(add, ctx1)
80        sub = relay.subtract(_add, z)
81        _sub = relay.annotation.on_device(sub, ctx2)
82        expr = run_opt_pass(_sub,
83                            transform.RewriteAnnotatedOps(ctx1.device_type))
84        return expr
85
86    def expected():
87        add = relay.add(x, y)
88        copy_add_sub = relay.device_copy(add, ctx1, ctx2)
89        sub = relay.subtract(copy_add_sub, z)
90        return sub
91
92    annotated_expr = annotated()
93    expected_expr = run_opt_pass(expected(), transform.InferType())
94    assert relay.analysis.graph_equal(annotated_expr, expected_expr)
95
96
97def test_annotate_all():
98    ctx1 = tvm.context(1)
99    ctx2 = tvm.context(2)
100    x = relay.var("x", shape=(3,))
101    y = relay.var("y", shape=(3,))
102    z = relay.var("z", shape=(3,))
103
104    def annotated():
105        add = relay.add(x, y)
106        _add = relay.annotation.on_device(add, ctx2)
107        sub = relay.subtract(_add, z)
108        _sub = relay.annotation.on_device(sub, ctx2)
109
110        func = relay.Function([x, y, z], _sub)
111        func = run_opt_pass(func,
112                            transform.RewriteAnnotatedOps(ctx1.device_type))
113        return func
114
115    def expected():
116        add = relay.add(x, y)
117        sub = relay.subtract(add, z)
118        func = relay.Function([x, y, z], sub)
119        return func
120
121    annotated_func = annotated()
122    expected_func = run_opt_pass(expected(), transform.InferType())
123    assert relay.analysis.graph_equal(annotated_func, expected_func)
124
125
126def test_annotate_none():
127    ctx1 = tvm.context(1)
128    ctx2 = tvm.context(2)
129    x = relay.var("x", shape=(3,))
130    y = relay.var("y", shape=(3,))
131    z = relay.var("z", shape=(3,))
132
133    def annotated():
134        add = relay.add(x, y)
135        sub = relay.subtract(add, z)
136        func = relay.Function([x, y, z], sub)
137        func = run_opt_pass(func,
138                            transform.RewriteAnnotatedOps(ctx1.device_type))
139        return func
140
141    def expected():
142        add = relay.add(x, y)
143        sub = relay.subtract(add, z)
144        func = relay.Function([x, y, z], sub)
145        return func
146
147    annotated_func = annotated()
148    expected_func = run_opt_pass(expected(), transform.InferType())
149    assert relay.analysis.graph_equal(annotated_func, expected_func)
150
151
152def check_annotated_graph(annotated_func, expected_func):
153    annotated_func = run_opt_pass(annotated_func, transform.InferType())
154    expected_func = run_opt_pass(expected_func, transform.InferType())
155    assert relay.analysis.alpha_equal(annotated_func, expected_func)
156
157
158def test_conv_network():
159    R""" The network is as following:
160             data1     data2
161               |         |
162             conv2d    conv2d
163                \       /
164                   add
165                    |
166                  conv2d
167    """
168    batch_size = 1
169    dshape = (batch_size, 64, 56, 56)
170    weight = relay.var("weight", shape=(64, 64, 3, 3))
171    data1 = relay.var("data1", shape=dshape)
172    data2 = relay.var("data2", shape=dshape)
173    dev1 = tvm.context(1)
174    dev2 = tvm.context(2)
175
176    def original():
177        conv2d_1 = relay.nn.conv2d(
178            data1,
179            weight,
180            channels=64,
181            kernel_size=(3, 3),
182            padding=(1, 1))
183        conv2d_2 = relay.nn.conv2d(
184            data2,
185            weight,
186            channels=64,
187            kernel_size=(3, 3),
188            padding=(1, 1))
189        add = relay.add(conv2d_1, conv2d_2)
190        conv2d_3 = relay.nn.conv2d(
191            add,
192            weight,
193            channels=64,
194            kernel_size=(3, 3),
195            padding=(1, 1))
196
197        func = relay.Function([data1, data2, weight], conv2d_3)
198        func = run_opt_pass(
199            func, transform.RewriteAnnotatedOps(tvm.context(3).device_type))
200        return func
201
202
203    def annotated():
204        conv2d_1 = relay.nn.conv2d(
205            data1,
206            weight,
207            channels=64,
208            kernel_size=(3, 3),
209            padding=(1, 1))
210        _conv2d_1 = relay.annotation.on_device(conv2d_1, dev2)
211        conv2d_2 = relay.nn.conv2d(
212            data2,
213            weight,
214            channels=64,
215            kernel_size=(3, 3),
216            padding=(1, 1))
217        _conv2d_2 = relay.annotation.on_device(conv2d_2, dev2)
218        add = relay.add(_conv2d_1, _conv2d_2)
219        _add = relay.annotation.on_device(add, dev1)
220        conv2d_3 = relay.nn.conv2d(
221            _add,
222            weight,
223            channels=64,
224            kernel_size=(3, 3),
225            padding=(1, 1))
226        _conv2d_3 = relay.annotation.on_device(conv2d_3, dev2)
227
228        func = relay.Function([data1, data2, weight], _conv2d_3)
229        func = run_opt_pass(
230            func, transform.RewriteAnnotatedOps(tvm.context(3).device_type))
231        return func
232
233    class ScheduleConv2d(ExprMutator):
234        def __init__(self, device):
235            self.device = device
236            super().__init__()
237
238        def visit_call(self, expr):
239            visit = super().visit_call(expr)
240            if expr.op == tvm.relay.op.get("nn.conv2d"):
241                return relay.annotation.on_device(visit, self.device)
242            else:
243                return visit
244
245    def annotate_with_visitor(func):
246        sched = ScheduleConv2d(dev2)
247        func = sched.visit(func)
248        func = run_opt_pass(
249            func, transform.RewriteAnnotatedOps(dev1.device_type))
250        return func
251
252    def expected():
253        conv2d_1 = relay.nn.conv2d(
254            data1,
255            weight,
256            channels=64,
257            kernel_size=(3, 3),
258            padding=(1, 1))
259        device_copy1 = relay.device_copy(conv2d_1, dev2, dev1)
260        conv2d_2 = relay.nn.conv2d(
261            data2,
262            weight,
263            channels=64,
264            kernel_size=(3, 3),
265            padding=(1, 1))
266        device_copy2 = relay.device_copy(conv2d_2, dev2, dev1)
267        add = relay.add(device_copy1, device_copy2)
268        device_copy3 = relay.device_copy(add, dev1, dev2)
269        conv2d_3 = relay.nn.conv2d(
270            device_copy3,
271            weight,
272            channels=64,
273            kernel_size=(3, 3),
274            padding=(1, 1))
275
276        func = relay.Function([data1, data2, weight], conv2d_3)
277        return func
278
279    def check_storage_and_device_types():
280        func = annotated()
281        func = run_opt_pass(func, [transform.RewriteAnnotatedOps(3),
282                                   transform.FuseOps(2)])
283        smap = relay.backend._backend.GraphPlanMemory(func)
284        storage_ids = []
285        device_types = []
286        for _, storage_dev_type in smap.items():
287            assert len(storage_dev_type) == 2
288            for sid in storage_dev_type[0]:
289                storage_ids.append(sid.value)
290            for did in storage_dev_type[1]:
291                device_types.append(did.value)
292        assert len(storage_ids) == 10
293        assert len(set(storage_ids)) == 8
294        assert len(set(device_types)) == 2
295        assert set(device_types) == {1, 2}
296
297    def test_manual_annotation():
298        annotated_func = annotated()
299        expected_func = expected()
300        check_annotated_graph(annotated_func, expected_func)
301        check_storage_and_device_types()
302
303    def test_visitor_annotation():
304        annotated_func = annotate_with_visitor(original())
305        expected_func = expected()
306        check_annotated_graph(annotated_func, expected_func)
307
308    test_manual_annotation()
309    test_visitor_annotation()
310
311
312def run_fusible_network(dev, tgt):
313    R""" The network is as following:
314               x     y
315                \   /
316                 add
317                /   \
318             sqrt   log
319                \   /
320              subtract
321                  |
322                 exp
323    """
324    x = relay.var("x", shape=(1, 10))
325    y = relay.var("y", shape=(10, 10))
326    x_data = np.random.rand(1, 10).astype('float32')
327    y_data = np.random.rand(10, 10).astype('float32')
328    tmp_add = x_data + y_data
329    tmp_sqrt = np.sqrt(tmp_add)
330    tmp_log = np.log(tmp_add)
331    tmp_sub = np.subtract(tmp_sqrt, tmp_log)
332    ref_res = np.exp(tmp_sub)
333
334    def get_func():
335        add = relay.add(x, y)
336        sqrt = relay.sqrt(add)
337        log = relay.log(add)
338        subtract = relay.subtract(sqrt, log)
339        exp = relay.exp(subtract)
340
341        func = relay.Function([x, y], exp)
342        return func
343
344    def test_runtime(target, device, func, fallback_device=None,
345                     expected_index=None):
346        params = {"x": x_data, "y": y_data}
347        config = {"opt_level": 1}
348        if fallback_device:
349            config["fallback_device"] = fallback_device
350        with relay.build_config(**config):
351            graph, lib, params = relay.build(
352                func,
353                target,
354                params=params)
355            contexts = [tvm.cpu(0), tvm.context(device)]
356            graph_json = json.loads(graph)
357            if "device_index" in graph_json["attrs"]:
358                device_index = graph_json["attrs"]["device_index"][1]
359                assert device_index == expected_index
360            mod = graph_runtime.create(graph, lib, contexts)
361            mod.set_input(**params)
362            mod.run()
363            res = mod.get_output(0).asnumpy()
364            tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5)
365
366    def test_fuse_log_add(device, tgt):
367        """ Only log and add are fused."""
368        fallback_device = tvm.context("cpu")
369        target = {"cpu": "llvm", device: tgt}
370        cpu_ctx = fallback_device
371        dev_ctx = tvm.context(device)
372
373        def annotated():
374            add = relay.add(x, y)
375            sqrt = relay.sqrt(add)
376            _sqrt = relay.annotation.on_device(sqrt, dev_ctx)
377            log = relay.log(add)
378            subtract = relay.subtract(_sqrt, log)
379            exp = relay.exp(subtract)
380            _exp = relay.annotation.on_device(exp, dev_ctx)
381
382            func = relay.Function([x, y], _exp)
383            func = run_opt_pass(
384                func, transform.RewriteAnnotatedOps(cpu_ctx.device_type))
385            return func
386
387        def expected():
388            add = relay.add(x, y)
389            copy_add_sqrt = relay.device_copy(add, cpu_ctx, dev_ctx)
390            sqrt = relay.sqrt(copy_add_sqrt)
391            log = relay.log(add)
392            copy_sqrt_subtract = relay.device_copy(sqrt, dev_ctx, cpu_ctx)
393            subtract = relay.subtract(copy_sqrt_subtract, log)
394            copy_sub_exp = relay.device_copy(subtract, cpu_ctx, dev_ctx)
395            exp = relay.exp(copy_sub_exp)
396
397            func = relay.Function([x, y], exp)
398            return func
399
400        annotated_func = annotated()
401        expected_func = expected()
402        ctx = tvm.context(device, 0)
403        dev_idx = ctx.device_type
404        expected_index = [1, 1, 1, dev_idx, dev_idx, 1, 1, dev_idx, dev_idx]
405        check_annotated_graph(annotated_func, expected_func)
406        test_runtime(target, device, annotated_func, fallback_device,
407                     expected_index)
408
409    def test_fuse_all(device, tgt):
410        """Fuse all operators."""
411        fallback_device = tvm.context("cpu")
412        target = {"cpu": "llvm", device: tgt}
413        cpu_ctx = fallback_device
414        dev_ctx = tvm.context(device)
415
416        def annotated():
417            add = relay.add(x, y)
418            _add = relay.annotation.on_device(add, dev_ctx)
419            sqrt = relay.sqrt(_add)
420            _sqrt = relay.annotation.on_device(sqrt, dev_ctx)
421            log = relay.log(_add)
422            _log = relay.annotation.on_device(log, dev_ctx)
423            subtract = relay.subtract(_sqrt, _log)
424            _subtract = relay.annotation.on_device(subtract, dev_ctx)
425            exp = relay.exp(_subtract)
426            _exp = relay.annotation.on_device(exp, dev_ctx)
427
428            func = relay.Function([x, y], _exp)
429            func = run_opt_pass(
430                func, transform.RewriteAnnotatedOps(cpu_ctx.device_type))
431            return func
432
433        annotated_func = annotated()
434        expected_func = get_func()
435        check_annotated_graph(annotated_func, expected_func)
436        test_runtime(target, device, annotated_func, fallback_device)
437
438    def test_fallback_exp(device, tgt):
439        fallback_device = tvm.context("cpu")
440        target = {"cpu": "llvm", device: tgt}
441        cpu_ctx = fallback_device
442        dev_ctx = tvm.context(device)
443
444        def annotated():
445            add = relay.add(x, y)
446            sqrt = relay.sqrt(add)
447            log = relay.log(add)
448            subtract = relay.subtract(sqrt, log)
449            exp = relay.exp(subtract)
450            _exp = relay.annotation.on_device(exp, cpu_ctx)
451
452            func = relay.Function([x, y], _exp)
453            func = run_opt_pass(
454                func, transform.RewriteAnnotatedOps(dev_ctx.device_type))
455            return func
456
457        def expected():
458            add = relay.add(x, y)
459            sqrt = relay.sqrt(add)
460            log = relay.log(add)
461            subtract = relay.subtract(sqrt, log)
462            copy_sub_exp = relay.device_copy(subtract, dev_ctx, cpu_ctx)
463            exp = relay.exp(copy_sub_exp)
464
465            func = relay.Function([x, y], exp)
466            return func
467
468        annotated_func = annotated()
469        expected_func = expected()
470        ctx = tvm.context(device, 0)
471        dev_idx = ctx.device_type
472        expected_index = [dev_idx, dev_idx, dev_idx, 1, 1]
473        check_annotated_graph(annotated_func, expected_func)
474        test_runtime(target, device, annotated_func, fallback_device,
475                     expected_index)
476
477    def test_fallback_all_operators(device, tgt):
478        target = {device: tgt, "cpu": "llvm"}
479        annotated_func = get_func()
480        expected_func = get_func()
481        check_annotated_graph(annotated_func, expected_func)
482        test_runtime(target, device, annotated_func)
483
484
485    test_fuse_log_add(dev, tgt)
486    test_fuse_all(dev, tgt)
487    test_fallback_exp(dev, tgt)
488    test_fallback_all_operators(dev, tgt)
489
490def run_unpropagatable_graph(dev, tgt):
491    R""" The network is as following:
492            a     b  c     d
493             \   /    \   /
494              add      mul
495                \      /
496                subtract
497    """
498
499    a = relay.var("a", shape=(10, 10))
500    b = relay.var("b", shape=(10, 10))
501    c = relay.var("c", shape=(10, 10))
502    d = relay.var("d", shape=(10, 10))
503    a_data = np.random.rand(10, 10).astype('float32')
504    b_data = np.random.rand(10, 10).astype('float32')
505    c_data = np.random.rand(10, 10).astype('float32')
506    d_data = np.random.rand(10, 10).astype('float32')
507    tmp_add = a_data + b_data
508    tmp_mul = np.multiply(c_data, d_data)
509    ref_res = np.subtract(tmp_add, tmp_mul)
510
511    fallback_device = tvm.context("cpu")
512    target = {"cpu": "llvm", dev: tgt}
513    cpu_ctx = fallback_device
514    dev_ctx = tvm.context(dev)
515
516    def annotated():
517        add = relay.add(a, b)
518        _add = relay.annotation.on_device(add, dev_ctx)
519        mul = relay.multiply(c, d)
520        _mul = relay.annotation.on_device(mul, cpu_ctx)
521        sub = relay.subtract(_add, _mul)
522        _sub = relay.annotation.on_device(sub, dev_ctx)
523        func = relay.Function([a, b, c, d], _sub)
524        func = run_opt_pass(
525            func, transform.RewriteAnnotatedOps(dev_ctx.device_type))
526        return func
527
528    def expected():
529        add = relay.add(a, b)
530        mul = relay.multiply(c, d)
531        copy_mul_sub = relay.device_copy(mul, cpu_ctx, dev_ctx)
532        sub = relay.subtract(add, copy_mul_sub)
533        func = relay.Function([a, b, c, d], sub)
534        return func
535
536    annotated_func = annotated()
537    expected_func = expected()
538    expected_index = [2, 2, 2, 1, 1, 1, 2, 2]
539    check_annotated_graph(annotated_func, expected_func)
540    params = {"a": a_data, "b": b_data, "c": c_data, "d": d_data}
541    config = {"opt_level": 0}
542    config["fallback_device"] = fallback_device
543    with relay.build_config(**config):
544        graph, lib, params = relay.build(annotated_func, target, params=params)
545        contexts = [tvm.cpu(0), tvm.context(dev)]
546        graph_json = json.loads(graph)
547        if "device_index" in graph_json["attrs"]:
548            device_index = graph_json["attrs"]["device_index"][1]
549            assert device_index == expected_index
550        mod = graph_runtime.create(graph, lib, contexts)
551        mod.set_input(**params)
552        mod.run()
553        res = mod.get_output(0).asnumpy()
554        tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5)
555
556
557def test_check_run():
558    for dev, tgt in [("opencl", "opencl"), ("cuda", "cuda"),
559                 ("opencl", str(tvm.target.intel_graphics()))]:
560        if not tvm.module.enabled(dev):
561            print("Skip test because %s is not enabled." % dev)
562            continue
563        run_fusible_network(dev, tgt)
564        run_unpropagatable_graph(dev, tgt)
565
566
567def test_tuple_get_item():
568    dev = "cuda"
569    if not tvm.module.enabled(dev):
570        print("Skip test because %s is not enabled." % dev)
571        return
572
573    cpu_ctx = tvm.cpu(0)
574    gpu_ctx = tvm.context(dev)
575
576    def expected():
577        x = relay.var("x", relay.ty.TensorType((3, 3, 4), "float32"))
578        split = relay.op.split(x, 3)
579        elem0 = relay.device_copy(split[0], gpu_ctx, cpu_ctx)
580        elem1 = relay.device_copy(split[1], gpu_ctx, cpu_ctx)
581        sub = elem0 - elem1
582        func = relay.Function(relay.analysis.free_vars(sub), sub)
583        return func
584
585    def annotated():
586        x = relay.var("x", relay.ty.TensorType((3, 3, 4), "float32"))
587        split = relay.op.split(x, 3)
588        split = split.astuple()
589        split = relay.annotation.on_device(split, gpu_ctx)
590        split = relay.TupleWrapper(split, 3)
591        sub = split[0] - split[1]
592        func = relay.Function(relay.analysis.free_vars(sub), sub)
593        func = run_opt_pass(
594            func, transform.RewriteAnnotatedOps(cpu_ctx.device_type))
595        return func
596
597    annotated_func = annotated()
598    expected_func = run_opt_pass(expected(), transform.InferType())
599    assert relay.analysis.graph_equal(annotated_func, expected_func)
600
601
602if __name__ == "__main__":
603    test_redundant_annotation()
604    test_annotate_expr()
605    test_annotate_all()
606    test_annotate_none()
607    test_conv_network()
608    test_check_run()
609    test_tuple_get_item()
610