1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17import numpy as np
18
19from coremltools.models.neural_network import NeuralNetworkBuilder
20from coremltools.models import datatypes
21
22import tvm
23from tvm import te
24from tvm.contrib import graph_runtime
25from tvm import topi
26import tvm.topi.testing
27from tvm import relay
28from tvm.topi.testing import conv2d_nchw_python
29
30import coremltools as cm
31import model_zoo
32import tvm.testing
33
34
35def get_tvm_output(
36    func, x, params, target, ctx, out_shape=(1, 1000), input_name="image", dtype="float32"
37):
38    with tvm.transform.PassContext(opt_level=3):
39        lib = relay.build(func, target, params=params)
40    m = graph_runtime.GraphModule(lib["default"](ctx))
41    # set inputs
42    m.set_input(input_name, tvm.nd.array(x.astype(dtype)))
43    m.run()
44    # get outputs
45    out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
46    return out.asnumpy()
47
48
49def run_model_checkonly(model_file, model_name="", input_name="image"):
50    model = cm.models.MLModel(model_file)
51    x = model_zoo.get_cat_image()
52    shape_dict = {input_name: x.shape}
53    # Some Relay passes change operators on the fly. Ensuring that we generate
54    # new graph for each target.
55    for target, ctx in tvm.testing.enabled_targets():
56        mod, params = relay.frontend.from_coreml(model, shape_dict)
57        tvm_output = get_tvm_output(mod["main"], x, params, target, ctx)
58        print(target, ctx, model_name, "prediction id: ", np.argmax(tvm_output.flat))
59
60
61@tvm.testing.uses_gpu
62def test_mobilenet_checkonly():
63    model_file = model_zoo.get_mobilenet()
64    run_model_checkonly(model_file, "mobilenet")
65
66
67@tvm.testing.uses_gpu
68def test_resnet50_checkonly():
69    model_file = model_zoo.get_resnet50()
70    run_model_checkonly(model_file, "resnet50")
71
72
73def run_tvm_graph(
74    coreml_model, target, ctx, input_data, input_name, output_shape, output_dtype="float32"
75):
76    """ Generic function to compile on relay and execute on tvm """
77    if isinstance(input_data, list):
78        shape_dict = {}
79        dtype_dict = {}
80        for i, e in enumerate(input_name):
81            shape_dict[e] = input_data[i].shape
82            dtype_dict[e] = input_data[i].dtype
83    else:
84        shape_dict = {input_name: input_data.shape}
85        dtype_dict = {input_name: input_data.dtype}
86
87    mod, params = relay.frontend.from_coreml(coreml_model, shape_dict)
88    with tvm.transform.PassContext(opt_level=3):
89        lib = relay.build(mod, target, params=params)
90
91    from tvm.contrib import graph_runtime
92
93    m = graph_runtime.GraphModule(lib["default"](ctx))
94    # set inputs
95    if isinstance(input_data, list):
96        for i, e in enumerate(input_name):
97            m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))
98    else:
99        m.set_input(input_name, tvm.nd.array(input_data.astype(input_data.dtype)))
100
101    # execute
102    m.run()
103    # get outputs
104    if isinstance(output_shape, list) and isinstance(output_dtype, list):
105        tvm_output_list = []
106        for i, s in enumerate(output_shape):
107            tvm_output = m.get_output(i, tvm.nd.empty((s), output_dtype[i]))
108            tvm_output_list.append(tvm_output.asnumpy())
109        return tvm_output_list
110    else:
111        if not output_shape:
112            tvm_output = m.get_output(0)
113        else:
114            tvm_output = m.get_output(0, tvm.nd.empty((output_shape), output_dtype))
115        return tvm_output.asnumpy()
116
117
118def verify_AddLayerParams(input_dim, alpha=2):
119    dtype = "float32"
120
121    a_np1 = np.random.uniform(size=input_dim).astype(dtype)
122    a_np2 = np.random.uniform(size=input_dim).astype(dtype)
123
124    b_np = np.add(a_np1, a_np2) + alpha
125    inputs = [("input1", datatypes.Array(*input_dim)), ("input2", datatypes.Array(*input_dim))]
126    output = [("output", datatypes.Array(*b_np.shape))]
127    builder = NeuralNetworkBuilder(inputs, output)
128    builder.add_elementwise(
129        name="Add", alpha=alpha, input_names=["input1", "input2"], output_name="output", mode="ADD"
130    )
131    model = cm.models.MLModel(builder.spec)
132    for target, ctx in tvm.testing.enabled_targets():
133        out = run_tvm_graph(
134            model, target, ctx, [a_np1, a_np2], ["input1", "input2"], b_np.shape, dtype
135        )
136        tvm.testing.assert_allclose(out, b_np, rtol=1e-5)
137
138
139@tvm.testing.uses_gpu
140def test_forward_AddLayerParams():
141    verify_AddLayerParams((1, 2, 2), 0)
142    verify_AddLayerParams((1, 2, 2), 1)
143    verify_AddLayerParams((1, 3, 3), 2)
144
145
146def verify_MultiplyLayerParams(input_dim, alpha):
147    dtype = "float32"
148
149    a_np1 = np.random.uniform(size=input_dim).astype(dtype)
150    a_np2 = np.random.uniform(size=input_dim).astype(dtype)
151
152    b_np = np.multiply(a_np1, a_np2) * alpha
153    inputs = [("input1", datatypes.Array(*input_dim)), ("input2", datatypes.Array(*input_dim))]
154    output = [("output", datatypes.Array(*b_np.shape))]
155    builder = NeuralNetworkBuilder(inputs, output)
156    builder.add_elementwise(
157        name="Mul",
158        alpha=alpha,
159        input_names=["input1", "input2"],
160        output_name="output",
161        mode="MULTIPLY",
162    )
163    model = cm.models.MLModel(builder.spec)
164    for target, ctx in tvm.testing.enabled_targets():
165        out = run_tvm_graph(
166            model, target, ctx, [a_np1, a_np2], ["input1", "input2"], b_np.shape, dtype
167        )
168        tvm.testing.assert_allclose(out, b_np, rtol=1e-5)
169
170
171@tvm.testing.uses_gpu
172def test_forward_MultiplyLayerParams():
173    verify_MultiplyLayerParams((1, 2, 2), 0)
174    verify_MultiplyLayerParams((1, 2, 2), 1)
175    verify_MultiplyLayerParams((1, 3, 3), 2)
176
177
178def verify_ConcatLayerParams(input1_dim, input2_dim):
179    dtype = "float32"
180
181    a_np1 = np.random.uniform(size=input1_dim).astype(dtype)
182    a_np2 = np.random.uniform(size=input2_dim).astype(dtype)
183
184    b_np = np.concatenate((a_np1, a_np2), axis=1)
185    inputs = [("input1", datatypes.Array(*input1_dim)), ("input2", datatypes.Array(*input2_dim))]
186    output = [("output", datatypes.Array(*b_np.shape))]
187    builder = NeuralNetworkBuilder(inputs, output)
188    builder.add_elementwise(
189        name="Concate", input_names=["input1", "input2"], output_name="output", mode="CONCAT"
190    )
191    model = cm.models.MLModel(builder.spec)
192    for target, ctx in tvm.testing.enabled_targets():
193        out = run_tvm_graph(
194            model, target, ctx, [a_np1, a_np2], ["input1", "input2"], b_np.shape, dtype
195        )
196        tvm.testing.assert_allclose(out, b_np, rtol=1e-5)
197
198
199@tvm.testing.uses_gpu
200def test_forward_ConcatLayerParams():
201    verify_ConcatLayerParams((1, 1, 2, 2), (1, 2, 2, 2))
202    verify_ConcatLayerParams((1, 2, 4, 4), (1, 3, 4, 4))
203
204
205def verify_UpsampleLayerParams(input_dim, scale, mode):
206    dtype = "float32"
207
208    a_np = np.full(input_dim, 1, dtype=dtype)
209    if mode == "NN":
210        b_np = tvm.topi.testing.upsampling_python(a_np, (scale, scale))
211    else:
212        new_h = input_dim[2] * scale
213        new_w = input_dim[3] * scale
214        b_np = tvm.topi.testing.bilinear_resize_python(a_np, (new_h, new_w), "NCHW")
215
216    input = [("input", datatypes.Array(*input_dim))]
217    output = [("output", datatypes.Array(*b_np.shape))]
218    builder = NeuralNetworkBuilder(input, output)
219    builder.add_upsample(
220        name="Upsample",
221        scaling_factor_h=scale,
222        scaling_factor_w=scale,
223        mode=mode,
224        input_name="input",
225        output_name="output",
226    )
227
228    model = cm.models.MLModel(builder.spec)
229    for target, ctx in tvm.testing.enabled_targets():
230        out = run_tvm_graph(model, target, ctx, a_np, "input", b_np.shape, dtype)
231        tvm.testing.assert_allclose(out, b_np, rtol=1e-5)
232
233
234@tvm.testing.uses_gpu
235def test_forward_UpsampleLayerParams():
236    verify_UpsampleLayerParams((1, 16, 32, 32), 2, "NN")
237    verify_UpsampleLayerParams((1, 4, 6, 6), 3, "BILINEAR")
238
239
240def verify_l2_normalize(input_dim, eps):
241    dtype = "float32"
242
243    a_np = np.random.uniform(size=input_dim).astype(dtype)
244    b_np = tvm.topi.testing.l2_normalize_python(a_np, eps, 1)
245
246    input = [("input", datatypes.Array(*input_dim))]
247    output = [("output", datatypes.Array(*b_np.shape))]
248    builder = NeuralNetworkBuilder(input, output)
249    builder.add_l2_normalize(name="L2", epsilon=eps, input_name="input", output_name="output")
250
251    model = cm.models.MLModel(builder.spec)
252    for target, ctx in tvm.testing.enabled_targets():
253        out = run_tvm_graph(model, target, ctx, a_np, "input", b_np.shape, dtype)
254        tvm.testing.assert_allclose(out, b_np, rtol=1e-5)
255
256
257@tvm.testing.uses_gpu
258def test_forward_l2_normalize():
259    verify_l2_normalize((1, 3, 20, 20), 0.001)
260
261
262def verify_lrn(input_dim, size, bias, alpha, beta):
263    dtype = "float32"
264    axis = 1
265    a_np = np.random.uniform(size=input_dim).astype(dtype)
266    b_np = tvm.topi.testing.lrn_python(a_np, size, axis, bias, alpha, beta)
267
268    input = [("input", datatypes.Array(*input_dim))]
269    output = [("output", datatypes.Array(*b_np.shape))]
270    builder = NeuralNetworkBuilder(input, output)
271    builder.add_lrn(
272        name="LRN",
273        input_name="input",
274        output_name="output",
275        alpha=alpha,
276        beta=beta,
277        k=bias,
278        local_size=size,
279    )
280
281    model = cm.models.MLModel(builder.spec)
282    for target, ctx in tvm.testing.enabled_targets():
283        out = run_tvm_graph(model, target, ctx, a_np, "input", b_np.shape, dtype)
284        tvm.testing.assert_allclose(out, b_np, rtol=1e-5)
285
286
287@tvm.testing.uses_gpu
288def test_forward_lrn():
289    verify_lrn((1, 3, 10, 20), 3, 1.0, 1.0, 0.5)
290
291
292def verify_average(input_dim1, input_dim2, axis=0):
293    dtype = "float32"
294
295    a_np1 = np.random.uniform(size=input_dim1).astype(dtype)
296    a_np2 = np.random.uniform(size=input_dim2).astype(dtype)
297
298    b_np = np.mean((a_np1, a_np2), axis=axis)
299
300    inputs = [("input1", datatypes.Array(*input_dim1)), ("input2", datatypes.Array(*input_dim2))]
301    output = [("output", datatypes.Array(*b_np.shape))]
302    builder = NeuralNetworkBuilder(inputs, output)
303    builder.add_elementwise(
304        name="MEAN", input_names=["input1", "input2"], output_name="output", mode="AVE"
305    )
306    model = cm.models.MLModel(builder.spec)
307    for target, ctx in tvm.testing.enabled_targets():
308        out = run_tvm_graph(
309            model, target, ctx, [a_np1, a_np2], ["input1", "input2"], b_np.shape, dtype
310        )
311        tvm.testing.assert_allclose(out, b_np, rtol=1e-5)
312
313
314@tvm.testing.uses_gpu
315def test_forward_average():
316    verify_average((1, 3, 20, 20), (1, 3, 20, 20))
317    verify_average((3, 20, 20), (1, 3, 20, 20))
318    verify_average((20, 20), (1, 3, 20, 20))
319
320
321def verify_max(input_dim):
322    dtype = "float32"
323
324    a_np1 = np.random.uniform(size=input_dim).astype(dtype)
325    a_np2 = np.random.uniform(size=input_dim).astype(dtype)
326    a_np3 = np.random.uniform(size=input_dim).astype(dtype)
327
328    b_np = np.max((a_np1, a_np2, a_np3), axis=0)
329
330    inputs = [
331        ("input1", datatypes.Array(*input_dim)),
332        ("input2", datatypes.Array(*input_dim)),
333        ("input3", datatypes.Array(*input_dim)),
334    ]
335    output = [("output", datatypes.Array(*b_np.shape))]
336    builder = NeuralNetworkBuilder(inputs, output)
337    builder.add_elementwise(
338        name="Max", input_names=["input1", "input2", "input3"], output_name="output", mode="MAX"
339    )
340    model = cm.models.MLModel(builder.spec)
341    for target, ctx in tvm.testing.enabled_targets():
342        out = run_tvm_graph(
343            model,
344            target,
345            ctx,
346            [a_np1, a_np2, a_np3],
347            ["input1", "input2", "input3"],
348            b_np.shape,
349            dtype,
350        )
351        tvm.testing.assert_allclose(out, b_np, rtol=1e-5)
352
353
354@tvm.testing.uses_gpu
355def test_forward_max():
356    verify_max((1, 3, 20, 20))
357    verify_max((20, 20))
358
359
360def verify_min(input_dim):
361    dtype = "float32"
362
363    a_np1 = np.random.uniform(size=input_dim).astype(dtype)
364    a_np2 = np.random.uniform(size=input_dim).astype(dtype)
365    a_np3 = np.random.uniform(size=input_dim).astype(dtype)
366
367    b_np = np.min((a_np1, a_np2, a_np3), axis=0)
368
369    inputs = [
370        ("input1", datatypes.Array(*input_dim)),
371        ("input2", datatypes.Array(*input_dim)),
372        ("input3", datatypes.Array(*input_dim)),
373    ]
374    output = [("output", datatypes.Array(*b_np.shape))]
375    builder = NeuralNetworkBuilder(inputs, output)
376    builder.add_elementwise(
377        name="Min", input_names=["input1", "input2", "input3"], output_name="output", mode="MIN"
378    )
379    model = cm.models.MLModel(builder.spec)
380    for target, ctx in tvm.testing.enabled_targets():
381        out = run_tvm_graph(
382            model,
383            target,
384            ctx,
385            [a_np1, a_np2, a_np3],
386            ["input1", "input2", "input3"],
387            b_np.shape,
388            dtype,
389        )
390        tvm.testing.assert_allclose(out, b_np, rtol=1e-5)
391
392
393@tvm.testing.uses_gpu
394def test_forward_min():
395    verify_min((1, 3, 20, 20))
396    verify_min((20, 20))
397
398
399def verify_unary_sqrt(input_dim):
400    dtype = "float32"
401
402    a_np = np.random.uniform(size=input_dim).astype(dtype)
403    ref_val = np.sqrt(a_np)
404
405    inputs = [("input", datatypes.Array(*input_dim))]
406    output = [("output", datatypes.Array(*ref_val.shape))]
407    builder = NeuralNetworkBuilder(inputs, output)
408    builder.add_unary(name="sqrt", input_name="input", output_name="output", mode="sqrt")
409
410    model = cm.models.MLModel(builder.spec)
411    for target, ctx in tvm.testing.enabled_targets():
412        out = run_tvm_graph(model, target, ctx, [a_np], ["input"], ref_val.shape, dtype)
413        tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)
414
415
416def verify_unary_rsqrt(input_dim, epsilon=0):
417    dtype = "float32"
418
419    a_np = np.random.uniform(size=input_dim).astype(dtype)
420    ref_val = 1 / np.sqrt(a_np + epsilon)
421
422    inputs = [("input", datatypes.Array(*input_dim))]
423    output = [("output", datatypes.Array(*ref_val.shape))]
424    builder = NeuralNetworkBuilder(inputs, output)
425    builder.add_unary(
426        name="rsqrt", input_name="input", output_name="output", mode="rsqrt", epsilon=epsilon
427    )
428
429    model = cm.models.MLModel(builder.spec)
430    for target, ctx in tvm.testing.enabled_targets():
431        out = run_tvm_graph(model, target, ctx, [a_np], ["input"], ref_val.shape, dtype)
432        tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)
433
434
435def verify_unary_inverse(input_dim, epsilon=0):
436    dtype = "float32"
437
438    a_np = np.random.uniform(size=input_dim).astype(dtype)
439    ref_val = 1 / (a_np + epsilon)
440
441    inputs = [("input", datatypes.Array(*input_dim))]
442    output = [("output", datatypes.Array(*ref_val.shape))]
443    builder = NeuralNetworkBuilder(inputs, output)
444    builder.add_unary(
445        name="inverse", input_name="input", output_name="output", mode="inverse", epsilon=epsilon
446    )
447
448    model = cm.models.MLModel(builder.spec)
449    for target, ctx in tvm.testing.enabled_targets():
450        out = run_tvm_graph(model, target, ctx, [a_np], ["input"], ref_val.shape, dtype)
451        tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)
452
453
454def verify_unary_power(input_dim, alpha):
455    dtype = "float32"
456
457    a_np = np.random.uniform(size=input_dim).astype(dtype)
458    ref_val = np.power(a_np, alpha)
459
460    inputs = [("input", datatypes.Array(*input_dim))]
461    output = [("output", datatypes.Array(*ref_val.shape))]
462    builder = NeuralNetworkBuilder(inputs, output)
463    builder.add_unary(
464        name="power", input_name="input", output_name="output", mode="power", alpha=alpha
465    )
466
467    model = cm.models.MLModel(builder.spec)
468    for target, ctx in tvm.testing.enabled_targets():
469        out = run_tvm_graph(model, target, ctx, [a_np], ["input"], ref_val.shape, dtype)
470        tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)
471
472
473def verify_unary_exp(input_dim):
474    dtype = "float32"
475
476    a_np = np.random.uniform(size=input_dim).astype(dtype)
477    ref_val = np.exp(a_np)
478
479    inputs = [("input", datatypes.Array(*input_dim))]
480    output = [("output", datatypes.Array(*ref_val.shape))]
481    builder = NeuralNetworkBuilder(inputs, output)
482    builder.add_unary(name="exp", input_name="input", output_name="output", mode="exp")
483
484    model = cm.models.MLModel(builder.spec)
485    for target, ctx in tvm.testing.enabled_targets():
486        out = run_tvm_graph(model, target, ctx, [a_np], ["input"], ref_val.shape, dtype)
487        tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)
488
489
490def verify_unary_log(input_dim):
491    dtype = "float32"
492
493    a_np = np.random.uniform(size=input_dim).astype(dtype)
494    ref_val = np.log(a_np)
495
496    inputs = [("input", datatypes.Array(*input_dim))]
497    output = [("output", datatypes.Array(*ref_val.shape))]
498    builder = NeuralNetworkBuilder(inputs, output)
499    builder.add_unary(name="log", input_name="input", output_name="output", mode="log")
500
501    model = cm.models.MLModel(builder.spec)
502    for target, ctx in tvm.testing.enabled_targets():
503        out = run_tvm_graph(model, target, ctx, [a_np], ["input"], ref_val.shape, dtype)
504        tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)
505
506
507def verify_unary_abs(input_dim):
508    dtype = "float32"
509
510    a_np = np.random.uniform(-100.0, 100.0, size=input_dim).astype(dtype)
511    ref_val = np.abs(a_np)
512
513    inputs = [("input", datatypes.Array(*input_dim))]
514    output = [("output", datatypes.Array(*ref_val.shape))]
515    builder = NeuralNetworkBuilder(inputs, output)
516    builder.add_unary(name="abs", input_name="input", output_name="output", mode="abs")
517
518    model = cm.models.MLModel(builder.spec)
519    for target, ctx in tvm.testing.enabled_targets():
520        out = run_tvm_graph(model, target, ctx, [a_np], ["input"], ref_val.shape, dtype)
521        tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)
522
523
524def verify_unary_threshold(input_dim, alpha):
525    dtype = "float32"
526
527    a_np = np.random.uniform(-100.0, 100.0, size=input_dim).astype(dtype)
528    ref_val = np.maximum(a_np, alpha)
529
530    inputs = [("input", datatypes.Array(*input_dim))]
531    output = [("output", datatypes.Array(*ref_val.shape))]
532    builder = NeuralNetworkBuilder(inputs, output)
533    builder.add_unary(
534        name="threshold", input_name="input", output_name="output", mode="threshold", alpha=alpha
535    )
536
537    model = cm.models.MLModel(builder.spec)
538    for target, ctx in tvm.testing.enabled_targets():
539        out = run_tvm_graph(model, target, ctx, [a_np], ["input"], ref_val.shape, dtype)
540        tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)
541
542
543@tvm.testing.uses_gpu
544def test_forward_unary():
545    verify_unary_sqrt((1, 3, 20, 20))
546    verify_unary_rsqrt((1, 3, 20, 20))
547    verify_unary_rsqrt((1, 3, 20, 20), epsilon=1e-6)
548    verify_unary_inverse((1, 3, 20, 20))
549    verify_unary_inverse((1, 3, 20, 20), epsilon=1e-6)
550    verify_unary_power((1, 3, 20, 20), alpha=0.5)
551    verify_unary_power((1, 3, 20, 20), alpha=4)
552    verify_unary_exp((1, 3, 20, 20))
553    verify_unary_log((1, 3, 20, 20))
554    verify_unary_abs((1, 3, 20, 20))
555    verify_unary_threshold((1, 3, 20, 20), alpha=-6.0)
556    verify_unary_threshold((1, 3, 20, 20), alpha=5.0)
557
558
559@tvm.testing.uses_gpu
560def test_forward_reduce():
561    from enum import Enum
562
563    class ReduceAxis(Enum):
564        CHW = 0
565        HW = 1
566        C = 2
567        H = 3
568        W = 4
569
570    def _verify_reduce(input_dim, mode, axis, ref_func, dtype="float32"):
571        print(input_dim, mode, axis)
572        a_np = np.random.uniform(size=input_dim).astype(dtype)
573
574        # translate to axis from coreml format
575        if axis == ReduceAxis.CHW:
576            np_axis = (-3, -2, -1)
577        elif axis == ReduceAxis.HW:
578            np_axis = (-2, -1)
579        elif axis == ReduceAxis.C:
580            np_axis = -3
581        elif axis == ReduceAxis.H:
582            np_axis = -2
583        elif axis == ReduceAxis.W:
584            np_axis = -1
585
586        if ref_func == np.argmax:
587            ref_val = np.expand_dims(ref_func(a_np, np_axis), np_axis).astype(dtype)
588        else:
589            ref_val = ref_func(a_np, np_axis, keepdims=True)
590
591        inputs = [("input", datatypes.Array(*input_dim))]
592        output = [("output", datatypes.Array(*ref_val.shape))]
593        builder = NeuralNetworkBuilder(inputs, output)
594        builder.add_reduce(
595            name=mode, input_name="input", output_name="output", axis=axis.name, mode=mode
596        )
597
598        model = cm.models.MLModel(builder.spec)
599        for target, ctx in tvm.testing.enabled_targets():
600            out = run_tvm_graph(model, target, ctx, [a_np], ["input"], ref_val.shape, dtype)
601            tvm.testing.assert_allclose(out, ref_val, rtol=1e-5, atol=1e-5)
602
603    dshapes = [[10, 10], [1, 10, 10], [1, 3, 10, 10]]
604    for dshape in dshapes:
605        for axis in ReduceAxis:
606            if len(dshape) < 3 and axis in [ReduceAxis.CHW, ReduceAxis.C]:
607                # input must have rank at least 3
608                continue
609            _verify_reduce(dshape, "sum", axis, np.sum)
610            _verify_reduce(dshape, "avg", axis, np.mean)
611            _verify_reduce(dshape, "prod", axis, np.prod)
612            _verify_reduce(dshape, "min", axis, np.min)
613            _verify_reduce(dshape, "max", axis, np.max)
614            if axis in [ReduceAxis.C, ReduceAxis.H, ReduceAxis.W]:
615                # For mode ArgMax, axis must be [-1] or [-2] or [-3]
616                _verify_reduce(dshape, "argmax", axis, np.argmax, dtype="int32")
617
618
619def verify_reshape(input_dim, target_shape, mode):
620    dtype = "float32"
621
622    a_np = np.random.uniform(-100.0, 100.0, size=input_dim).astype(dtype)
623    ref_val = np.reshape(a_np, target_shape)
624
625    inputs = [("input", datatypes.Array(*input_dim))]
626    output = [("output", datatypes.Array(*ref_val.shape))]
627    builder = NeuralNetworkBuilder(inputs, output)
628    builder.add_reshape(
629        name="reshape",
630        input_name="input",
631        output_name="output",
632        target_shape=target_shape,
633        mode=mode,
634    )
635
636    model = cm.models.MLModel(builder.spec)
637    for target, ctx in tvm.testing.enabled_targets():
638        out = run_tvm_graph(model, target, ctx, [a_np], ["input"], ref_val.shape, dtype)
639        tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)
640
641
642def test_forward_reshape():
643    for mode in [0, 1]:
644        verify_reshape((20,), (1, 2, 2, 5), mode)
645        verify_reshape((1, 3, 20, 20), (1, 12, 10, 10), mode)
646
647
648def verify_split(input_dim, nOutputs):
649    dtype = "float32"
650
651    a_np = np.random.uniform(-100.0, 100.0, size=input_dim).astype(dtype)
652    ref_val = np.split(a_np, nOutputs, axis=-3)
653
654    inputs = [("input", datatypes.Array(*input_dim))]
655
656    output_names = []
657    outputs = []
658    output_shapes = []
659    for i, out in enumerate(ref_val):
660        output_name = "output" + str(i)
661        output_names = output_names + [output_name]
662        outputs = outputs + [(output_name, datatypes.Array(*out.shape))]
663        output_shapes = output_shapes + [out.shape]
664
665    builder = NeuralNetworkBuilder(inputs, outputs)
666    builder.add_split(name="split", input_name="input", output_names=output_names)
667
668    model = cm.models.MLModel(builder.spec)
669    for target, ctx in tvm.testing.enabled_targets():
670        out = run_tvm_graph(
671            model, target, ctx, [a_np], ["input"], output_shapes, [dtype] * len(output_shapes)
672        )
673        tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)
674
675
676def test_forward_split():
677    verify_split(
678        (
679            1,
680            4,
681            4,
682            4,
683        ),
684        2,
685    )
686    verify_split(
687        (
688            1,
689            3,
690            30,
691            20,
692        ),
693        3,
694    )
695
696
697def verify_image_scaler(input_dim, blue_bias=0.0, green_bias=0.0, red_bias=0.0, image_scale=1.0):
698    dtype = "float32"
699    a_np = np.random.uniform(size=input_dim).astype(dtype)
700    # make sure it is valid image format CHW.
701    assert len(a_np.shape) == 3 and a_np.shape[0] == 3
702    b_np = np.zeros(a_np.shape, dtype=dtype)
703    b_np[0, :, :] = image_scale * a_np[0, :, :] + blue_bias
704    b_np[1, :, :] = image_scale * a_np[1, :, :] + green_bias
705    b_np[2, :, :] = image_scale * a_np[2, :, :] + red_bias
706    b_np = np.add(a_np, b_np)
707    inputs = [("input1", datatypes.Array(*input_dim)), ("input2", datatypes.Array(*input_dim))]
708    output = [("output", datatypes.Array(*b_np.shape))]
709    builder = NeuralNetworkBuilder(inputs, output)
710    builder.set_pre_processing_parameters(
711        image_input_names=["input1"],
712        is_bgr=True,
713        blue_bias=blue_bias,
714        green_bias=green_bias,
715        red_bias=red_bias,
716        image_scale=image_scale,
717    )
718    # add one add layer to make CoreML model format valid
719    # add layer has been tested before.
720    builder.add_elementwise(
721        name="add", input_names=["input1", "input2"], output_name="output", alpha=0, mode="ADD"
722    )
723    model = cm.models.MLModel(builder.spec)
724    for target, ctx in tvm.testing.enabled_targets():
725        out = run_tvm_graph(
726            model, target, ctx, [a_np, a_np], ["input1", "input2"], b_np.shape, dtype
727        )
728        tvm.testing.assert_allclose(out, b_np, rtol=1e-5)
729
730
731@tvm.testing.uses_gpu
732def test_forward_image_scaler():
733    verify_image_scaler((3, 224, 224), image_scale=0.17)
734    verify_image_scaler(
735        (3, 224, 224),
736        blue_bias=-1.7669800519943237,
737        green_bias=-1.985260009765625,
738        red_bias=-2.102560043334961,
739        image_scale=0.379,
740    )
741
742
743def verify_convolution(input_dim, filter, padding):
744    dtype = "float32"
745    N, C, H, W = input_dim
746    OC, _, KH, KW = filter
747    a_np = np.random.uniform(size=input_dim).astype(dtype)
748    w_np = np.random.uniform(size=(OC, C, KH, KW)).astype(dtype)
749    w_np_cm = np.transpose(w_np, axes=(2, 3, 1, 0))
750    b_np = conv2d_nchw_python(a_np, w_np, [1, 1], padding)
751    inputs = [("input1", datatypes.Array(C, H, W))]
752    output = [("output", datatypes.Array(*b_np.shape))]
753    builder = NeuralNetworkBuilder(inputs, output)
754    builder.add_convolution(
755        name="conv",
756        kernel_channels=3,
757        output_channels=OC,
758        height=KH,
759        width=KW,
760        stride_height=1,
761        stride_width=1,
762        border_mode=padding.lower(),
763        groups=1,
764        W=w_np_cm,
765        b=None,
766        has_bias=False,
767        is_deconv=False,
768        input_name="input1",
769        output_name="output",
770    )
771    model = cm.models.MLModel(builder.spec)
772    for target, ctx in tvm.testing.enabled_targets():
773        out = run_tvm_graph(model, target, ctx, [a_np], ["input1"], output_shape=None)
774        tvm.testing.assert_allclose(out, b_np, rtol=1e-5)
775
776
777@tvm.testing.uses_gpu
778def test_forward_convolution():
779    verify_convolution((1, 3, 224, 224), filter=(32, 3, 3, 3), padding="VALID")
780    verify_convolution((1, 3, 224, 224), filter=(32, 3, 3, 3), padding="SAME")
781
782
783if __name__ == "__main__":
784    test_forward_AddLayerParams()
785    test_forward_ConcatLayerParams()
786    test_forward_MultiplyLayerParams()
787    test_forward_UpsampleLayerParams()
788    test_forward_l2_normalize()
789    test_forward_lrn()
790    test_forward_average()
791    test_forward_max()
792    test_forward_min()
793    test_forward_unary()
794    test_forward_reduce()
795    test_forward_reshape()
796    test_forward_split()
797    test_mobilenet_checkonly()
798    test_resnet50_checkonly()
799    test_forward_image_scaler()
800    test_forward_convolution()
801