1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=import-self, invalid-name, unused-argument
18"""
19TFLite testcases
20================
21This article is a test script to test TFLite operator with Relay.
22"""
23from __future__ import print_function
24from functools import partial
25import pytest
26import numpy as np
27import tvm
28from tvm import te
29from tvm import relay
30
31try:
32    import tensorflow.compat.v1 as tf
33
34    # tensorflow.python.framework.ops module itself is not part of
35    # TensorFlow's public API: the precise contents of that module
36    # may vary from one version to the next
37    import tensorflow.compat.v1 as ops
38except ImportError:
39    import tensorflow as tf
40    import tensorflow as ops
41from tensorflow.python.framework import constant_op
42
43from tensorflow.python.ops import math_ops
44from tensorflow.python.ops import nn_ops
45from tensorflow.python.ops import array_ops
46from tensorflow.python.ops import gen_array_ops
47from tensorflow.python.ops import nn_impl
48from tensorflow.python.ops import variables
49
50try:
51    from tensorflow import lite as interpreter_wrapper
52except ImportError:
53    from tensorflow.contrib import lite as interpreter_wrapper
54
55from tvm.contrib.download import download_testdata
56import tvm.relay.testing.tf as tf_testing
57from packaging import version as package_version
58
59from PIL import Image
60import os
61
62#######################################################################
63# Generic run functions for TVM & TFLite
64# --------------------------------------
65def convert_to_list(x):
66    if not isinstance(x, list):
67        x = [x]
68    return x
69
70
71#######################################################################
72# Get a real image for e2e testing
73# --------------------------------
74def get_real_image(im_height, im_width):
75    repo_base = "https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/"
76    img_name = "elephant-299.jpg"
77    image_url = os.path.join(repo_base, img_name)
78    img_path = download_testdata(image_url, img_name, module="data")
79    image = Image.open(img_path).resize((im_height, im_width))
80    x = np.array(image).astype("uint8")
81    data = np.reshape(x, (1, im_height, im_width, 3))
82    return data
83
84
85def pre_processed_image(height, width):
86    repo_base = "https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/"
87    img_name = "elephant-299.jpg"
88    image_url = os.path.join(repo_base, img_name)
89    img_path = download_testdata(image_url, img_name, module="data")
90    image = tf.io.read_file(img_path)
91    image = tf.image.decode_jpeg(image, channels=3)
92    with tf.name_scope("eval_image"):
93        if image.dtype != tf.float32:
94            image = tf.image.convert_image_dtype(image, dtype=tf.float32)
95        image = tf.image.central_crop(image, central_fraction=0.875)
96    # Resize the image to the specified height and width.
97    image = tf.image.resize(image, [height, width], align_corners=False)
98    image = tf.expand_dims(image, axis=0)
99    return image
100
101
102def get_real_image_object_detection(im_height, im_width):
103    repo_base = "https://github.com/dmlc/web-data/raw/master/gluoncv/detection/"
104    img_name = "street_small.jpg"
105    image_url = os.path.join(repo_base, img_name)
106    img_path = download_testdata(image_url, img_name, module="data")
107    image = Image.open(img_path).resize((im_height, im_width))
108    x = np.array(image).astype("uint8")
109    data = np.reshape(x, (1, im_height, im_width, 3))
110    return data
111
112
113def vmobj_to_list(o):
114    if isinstance(o, tvm.nd.NDArray):
115        return [o.asnumpy().tolist()]
116    elif isinstance(o, tvm.runtime.container.ADT):
117        result = []
118        for f in o:
119            result.extend(vmobj_to_list(f))
120        return result
121    elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
122        if o.constructor.name_hint == "Cons":
123            tl = vmobj_to_list(o.fields[1])
124            hd = vmobj_to_list(o.fields[0])
125            hd.extend(tl)
126            return hd
127        elif o.constructor.name_hint == "Nil":
128            return []
129        elif "tensor_nil" in o.constructor.name_hint:
130            return [0]
131        elif "tensor" in o.constructor.name_hint:
132            return [o.fields[0].asnumpy()]
133        else:
134            raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint)
135    else:
136        raise RuntimeError("Unknown object type: %s" % type(o))
137
138
139def _quantize_keras_model(keras_model, representative_data_gen):
140    """Utility function to quantize a Keras model using TFLite converter."""
141    converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model)
142    converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
143    converter.representative_dataset = representative_data_gen
144    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
145    converter.inference_input_type = tf.uint8
146    converter.inference_output_type = tf.uint8
147    return converter.convert()
148
149
150def run_tvm_graph(
151    tflite_model_buf,
152    input_data,
153    input_node,
154    num_output=1,
155    target="llvm",
156    out_names=None,
157    mode="graph_runtime",
158):
159    """ Generic function to compile on relay and execute on tvm """
160    # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
161    try:
162        import tflite.Model
163
164        tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
165    except AttributeError:
166        import tflite
167
168        tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
169    except ImportError:
170        raise ImportError("The tflite package must be installed")
171
172    input_data = convert_to_list(input_data)
173    input_node = convert_to_list(input_node)
174
175    shape_dict = {}
176    dtype_dict = {}
177    for i, e in enumerate(input_node):
178        shape_dict[e] = input_data[i].shape
179        dtype_dict[e] = input_data[i].dtype.name
180
181    mod, params = relay.frontend.from_tflite(
182        tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict
183    )
184
185    if mode in ["debug", "vm"]:
186        ex = relay.create_executor(mode, mod=mod, ctx=tvm.cpu(), target="llvm")
187        inputs = []
188        for param in mod["main"].params:
189            found = False
190            for i, n in enumerate(input_node):
191                if n == param.name_hint:
192                    found = True
193                    inputs.append(tvm.nd.array(input_data[i]))
194                    break
195            # Interpreter doesn't bind constants, so still need to find in params
196            if not found:
197                inputs.append(tvm.nd.array(params[param.name_hint]))
198        result = ex.evaluate()(*inputs)
199        return vmobj_to_list(result)
200    else:
201        with tvm.transform.PassContext(opt_level=3):
202            lib = relay.build(mod, target, params=params)
203
204        ctx = tvm.context(target, 0)
205        from tvm.contrib import graph_runtime
206
207        m = graph_runtime.GraphModule(lib["default"](ctx))
208        # set inputs
209        for i, e in enumerate(input_node):
210            m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))
211        # execute
212        m.run()
213        # get outputs
214        assert out_names is None or num_output == len(
215            out_names
216        ), "out_names: {} num_output: {}".format(out_names, num_output)
217        tvm_output_list = []
218        for i in range(0, num_output):
219            tvm_output = m.get_output(i)
220            tvm_output_list.append(tvm_output.asnumpy())
221        return tvm_output_list
222
223
224def run_tflite_graph(tflite_model_buf, input_data):
225    """ Generic function to execute TFLite """
226    input_data = convert_to_list(input_data)
227
228    interpreter = interpreter_wrapper.Interpreter(model_content=tflite_model_buf)
229    input_details = interpreter.get_input_details()
230    output_details = interpreter.get_output_details()
231
232    for i in range(len(input_details)):
233        interpreter.resize_tensor_input(input_details[i]["index"], input_data[i].shape)
234    interpreter.allocate_tensors()
235
236    # set input
237    assert len(input_data) == len(input_details)
238    for i in range(len(input_details)):
239        interpreter.set_tensor(input_details[i]["index"], input_data[i])
240
241    # Run
242    interpreter.invoke()
243
244    # get output
245    tflite_output = list()
246    for i in range(len(output_details)):
247        tflite_output.append(interpreter.get_tensor(output_details[i]["index"]))
248
249    return tflite_output
250
251
252def compare_tflite_with_tvm(
253    in_data,
254    in_name,
255    input_tensors,
256    output_tensors,
257    init_global_variables=False,
258    out_names=None,
259    quantized=False,
260    input_range=None,
261    mode="graph_runtime",
262    experimental_new_converter=False,
263):
264    """Generic function to generate and compare TFLite and TVM output"""
265    in_data = convert_to_list(in_data)
266    in_name = convert_to_list(in_name)
267    out_names = convert_to_list(out_names)
268    in_node = [0] * len(in_name)
269    for i in range(len(in_name)):
270        in_node[i] = in_name[i].split(":")[0] if ":" in in_name[i] else in_name[i]
271
272    with tf.Session() as sess:
273        if init_global_variables:
274            sess.run(variables.global_variables_initializer())
275        # convert to tflite model
276        converter = tf.lite.TFLiteConverter.from_session(sess, input_tensors, output_tensors)
277        converter.experimental_new_converter = experimental_new_converter
278        if quantized:
279            converter.inference_type = tf.lite.constants.QUANTIZED_UINT8
280            input_arrays = converter.get_input_arrays()
281            input_stats = {}
282            # calculate the mean and quantization scale for every input tensor,
283            # with respect to its fp32 input range, defined in fake_quant.
284            # s = 255/(fmax-fmin);  m = -fmin*s (the zero point)
285            for i in input_arrays:
286                try:
287                    quant_scale = 255 / (input_range[i][1] - input_range[i][0])
288                except ZeroDivisionError:
289                    raise ZeroDivisionError(
290                        "Min and max of the input range for tensor " + i + " can't be equal"
291                    )
292                mean = -input_range[i][0] * quant_scale
293                input_stats[i] = (mean, quant_scale)
294            converter.quantized_input_stats = input_stats
295
296        tflite_model_buffer = converter.convert()
297        tflite_output = run_tflite_graph(tflite_model_buffer, in_data)
298
299        for device in ["llvm"]:
300            ctx = tvm.context(device, 0)
301            if not tvm.testing.device_enabled(device):
302                print("Skip because %s is not enabled" % device)
303                continue
304
305            tvm_output = run_tvm_graph(
306                tflite_model_buffer,
307                in_data,
308                in_node,
309                target=device,
310                num_output=len(out_names),
311                out_names=out_names,
312                mode=mode,
313            )
314
315            # WARNING: the results could well be random values clipped to 0 or 255 because of badly tuned output
316            # range for the specific operator. While adding test ensure that we aren't getting only clipped values
317            # in output tensors that still pass the assertion. For reference see _test_elemwise_qnn_out_range()
318            if quantized:
319                for i in range(len(tflite_output)):
320                    # allow absolute tolerance of 1 in the quantized results
321                    tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1, rtol=1e-5)
322            else:
323                for i in range(len(tflite_output)):
324                    tvm.testing.assert_allclose(
325                        tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5
326                    )
327
328
329def with_fused_activation_function(input_tensor, fn_name):
330    if fn_name is None or fn_name == "NONE":
331        return input_tensor
332    if fn_name == "RELU":
333        return nn_ops.relu(input_tensor)
334    if fn_name == "RELU6":
335        return nn_ops.relu6(input_tensor)
336    if fn_name == "RELU_N1_TO_1":
337        return math_ops.maximum(-1, math_ops.minimum(input_tensor, 1))
338    if fn_name == "TANH":
339        return math_ops.tanh(input_tensor)
340    raise AssertionError("Unknown fused_activation_function {}".format(fn_name))
341
342
343def _test_split(in_shape, axis, num_splits, dtype):
344    """internal split tester taking as parameters in_shape, number of tensors to split into
345    and dtype (data type)"""
346
347    np_data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)
348    with tf.Graph().as_default():
349        in_data = array_ops.placeholder(shape=in_shape, dtype=dtype, name="in_data")
350        out = array_ops.split(in_data, num_splits, axis=axis)
351        num_splits = len(num_splits) if isinstance(num_splits, list) else num_splits
352        out_names = ["out_" + str(n) + ":0" for n in range(num_splits)]
353        compare_tflite_with_tvm([np_data], ["in_data"], [in_data], out, out_names=out_names)
354
355
356def test_forward_split():
357    """test split layer"""
358    # rank 1
359    _test_split((3,), 0, 1, "float32")
360    _test_split((3,), 0, 3, "float32")
361    _test_split((6,), 0, 3, "float32")
362    # rank 2
363    _test_split((6, 2), 0, 3, "float32")
364    _test_split((2, 6), 1, 6, "float32")
365    # rank 3
366    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
367        _test_split((6, 2, 4), 0, 2, "int32")
368
369    _test_split((2, 6, 4), 1, 3, "float32")
370    _test_split((2, 4, 6), 2, 1, "float32")
371    # rank 4
372    _test_split((6, 1, 3, 5), 0, 3, "float32")
373    _test_split((1, 6, 3, 5), 1, 3, "float32")
374    _test_split((1, 3, 6, 5), 2, 3, "float32")
375    _test_split((1, 3, 5, 6), 3, 3, "float32")
376    # split along negative axis
377    _test_split((6, 1, 3, 5), -4, 3, "float32")
378    _test_split((1, 6, 3, 5), -3, 3, "float32")
379    _test_split((1, 3, 6, 5), -2, 3, "float32")
380    _test_split((1, 3, 5, 6), -1, 3, "float32")
381    # size_splits split
382    _test_split((6,), 0, [1, 2, 3], "float32")
383    _test_split((3, 6, 4), -2, [1, 4, 1], "float32")
384
385
386#######################################################################
387# slice
388# -----
389
390
391def _test_slice(data, begin, size):
392    """ One iteration of SLICE """
393    with tf.Graph().as_default():
394        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
395        out = array_ops.slice(in_data, begin, size)
396        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
397
398
399def test_forward_slice():
400    """ SLICE """
401    _test_slice(np.arange(4, dtype=np.float32).reshape((4,)), begin=[0], size=[2])
402    _test_slice(np.arange(18, dtype=np.int32).reshape((3, 2, 3)), begin=[1, 0, 0], size=[1, 1, 3])
403    # tflite 1.13 outputs nonsense values if size[i] == -1
404    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
405        _test_slice(np.arange(8, dtype=np.int32).reshape((2, 4)), begin=[0, 1], size=[-1, -1])
406        _test_slice(np.arange(5, dtype=np.int32).reshape((5,)), begin=[4], size=[-1])
407
408
409#######################################################################
410# Topk
411# ----
412def _test_topk(in_shape, k=1):
413    """ One iteration of TOPK """
414    data = np.random.uniform(size=in_shape).astype("float32")
415    with tf.Graph().as_default():
416        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
417        out = nn_ops.top_k(in_data, k, name="TopK")
418        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out[0]])
419
420
421def test_forward_topk():
422    """ TOPK """
423    _test_topk((3,), 1)
424    _test_topk((3,), 3)
425    _test_topk((3, 5, 7), 3)
426    _test_topk((3, 5, 7), 3)
427
428
429#######################################################################
430# Gather
431# ------
432
433
434def _test_gather(dshape, indices, axis, dtype, quantized=False, oob=False, wrap_idx=False):
435    """ One iteration of Gather """
436    indices = np.asarray(indices).astype("int32")
437    data = np.random.uniform(1, 10, size=dshape)
438    data = data.astype(np.uint8) if quantized else data.astype(dtype)
439    with tf.Graph().as_default():
440        if wrap_idx:
441            in_name = "in_indices"
442            indices_expr = array_ops.placeholder(
443                shape=indices.shape, dtype=indices.dtype, name=in_name
444            )
445            in_tensor_name = [in_name + ":0"]
446            in_indices = [indices_expr]
447        else:
448            indices_expr = indices
449            indices = []
450            in_tensor_name = []
451            in_indices = []
452
453        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="in_data")
454        if axis:
455            out = array_ops.gather(in_data, indices_expr, axis=axis)
456        else:
457            out = array_ops.gather(in_data, indices_expr)  # tflite conversion fails for None axis
458        input_range = {"in_data": (-100, 100)} if quantized else None
459        try:
460            compare_tflite_with_tvm(
461                [data] + indices,
462                ["in_data:0"] + in_tensor_name,
463                [in_data] + in_indices,
464                [out],
465                quantized=quantized,
466                input_range=input_range,
467            )
468        except ValueError as e:
469            if not oob:
470                raise e
471        except Exception as e:
472            raise e
473
474
475def test_forward_gather():
476    """ GATHER """
477    for quantized in [False, True]:
478        for wrap_idx in [False, True]:
479            _test_gather((4,), [1], 0, "float32", quantized, wrap_idx)
480            _test_gather((4,), [1], None, "int32", quantized, wrap_idx)
481            _test_gather((1, 4), [0], 0, "int32", quantized, wrap_idx)
482            _test_gather((4,), [[[1, 0], [0, 1]]], 0, "float32", quantized, wrap_idx)
483            _test_gather((2, 2), [[[1, 0], [0, 1]]], 1, "int32", quantized, wrap_idx)
484            _test_gather((2, 2), [[[1, 0], [0, 1]]], None, "float32", quantized, wrap_idx)
485            _test_gather((3, 3, 3), [[[1, 0]]], 0, "int32", quantized, wrap_idx)
486            _test_gather((3, 3, 3), [[[1, 0]]], 2, "int32", quantized, wrap_idx)
487            _test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, "float32", quantized, wrap_idx)
488            _test_gather((3, 3, 3), [[[2, 1]]], -1, "int32", quantized, wrap_idx)
489        # Out of boundary error cannot be tested with wrapped index
490        _test_gather((4,), [16], 0, "float32", quantized, oob=True)
491        _test_gather((1, 3, 3), [12], 0, "int32", quantized, oob=True)
492        _test_gather((1, 3, 3), [20], 1, "float32", quantized, oob=True)
493        _test_gather((1, 3, 3), [20, 20], 2, "float32", quantized, oob=True)
494
495
496#######################################################################
497# Gather_ND
498# ---------
499
500
501def _test_gather_nd(data, indices):
502    """ One iteration of GATHER_ND """
503    with tf.Graph().as_default():
504        in_data = tf.placeholder(shape=data.shape, dtype=data.dtype, name="data")
505        indices_data = tf.placeholder(shape=indices.shape, dtype=indices.dtype, name="indices")
506        out = tf.gather_nd(in_data, indices_data)
507
508        compare_tflite_with_tvm(
509            [data, indices], ["data:0", "indices:0"], [in_data, indices_data], [out]
510        )
511
512
513def test_forward_gather_nd():
514    """ GATHER_ND """
515    _test_gather_nd(
516        np.array([[[1.2, 2.0], [3.1, 4.1]], [[5.1, 6.1], [7.1, 8.1]]]).astype("float32"),
517        np.asarray([[0, 1], [1, 0]]).astype("int32"),
518    )
519    _test_gather_nd(
520        np.reshape(np.arange(30), [5, 6]).astype("int32"), np.asarray([[1, 2]]).astype("int32")
521    )
522    _test_gather_nd(
523        np.reshape(np.arange(12), [2, 3, 2]).astype("int32"),
524        np.asarray([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]).astype("int32"),
525    )
526    _test_gather_nd(
527        np.reshape(np.arange(4), [4]).astype("float32"), np.asarray([1]).astype("int32")
528    )
529    _test_gather_nd(
530        np.reshape(np.arange(4), [1, 4]).astype("float32"), np.asarray([0]).astype("int32")
531    )
532    _test_gather_nd(
533        np.reshape(np.arange(4), [1, 4]).astype("float32"), np.asarray([0, 3]).astype("int32")
534    )
535
536
537#######################################################################
538# StridedSlice
539# ------------
540
541
542def _test_stridedslice(
543    ip_shape,
544    begin,
545    end,
546    stride,
547    dtype,
548    begin_mask=0,
549    end_mask=0,
550    new_axis_mask=0,
551    shrink_axis_mask=0,
552    ellipsis_mask=0,
553    quantized=False,
554):
555    """ One iteration of a Stridedslice """
556    data = np.random.uniform(size=ip_shape).astype(dtype)
557    data = data.astype(np.uint8) if quantized else data.astype(dtype)
558    with tf.Graph().as_default():
559        in_data = tf.placeholder(dtype, ip_shape, name="in_data")
560        out = array_ops.strided_slice(
561            in_data,
562            begin,
563            end,
564            stride,
565            begin_mask=begin_mask,
566            end_mask=end_mask,
567            new_axis_mask=new_axis_mask,
568            shrink_axis_mask=shrink_axis_mask,
569            ellipsis_mask=ellipsis_mask,
570        )
571        input_range = {"in_data": (-100, 100)} if quantized else None
572        compare_tflite_with_tvm(
573            [data], ["in_data:0"], [in_data], [out], quantized=quantized, input_range=input_range
574        )
575
576
577def test_forward_stridedslice():
578    """test StridedSlice"""
579    for quantized in [False, True]:
580        _test_stridedslice((2), [1], [1], [1], "float32", shrink_axis_mask=1, quantized=quantized)
581        _test_stridedslice(
582            (3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], "float32", quantized=quantized
583        )
584        _test_stridedslice(
585            (3, 4), [1, 0], [4, 4], [1, 1], "float32", shrink_axis_mask=0, quantized=quantized
586        )
587        _test_stridedslice(
588            (4, 4), [1, 0], [4, 4], [1, 1], "float32", shrink_axis_mask=2, quantized=quantized
589        )
590
591
592#######################################################################
593# transpose
594# ---------
595
596
597def _test_forward_transpose(ishape, axes=()):
598    data = np.random.uniform(size=ishape).astype(np.float32)
599
600    with tf.Graph().as_default():
601        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
602
603        if not axes:
604            out = array_ops.transpose(in_data)
605        else:
606            out = array_ops.transpose(in_data, axes)
607
608        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
609
610
611def test_forward_transpose():
612    _test_forward_transpose((2, 2))
613    _test_forward_transpose((2, 3, 4))
614    _test_forward_transpose((7, 8, 8, 10))
615    _test_forward_transpose((2, 3, 4), (1, 2, 0))
616    _test_forward_transpose((2, 3, 4), (0, 1, 2))
617    _test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2))
618    _test_forward_transpose((2, 3, 4, 5), ())
619
620
621#######################################################################
622# Cast
623# ----
624
625
626def _test_cast(data, cast_dtype):
627    """ One iteration of CAST """
628    with tf.Graph().as_default():
629        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
630        out = math_ops.cast(in_data, cast_dtype)
631        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
632
633
634def test_forward_cast():
635    """ CAST """
636    _test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.int32)
637    _test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.uint8)
638    _test_cast(np.arange(6.0, dtype=np.int32).reshape((1, 6)), cast_dtype=tf.int64)
639
640
641#######################################################################
642# Batch Mat Mul
643# ----
644def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False):
645    with tf.Graph().as_default():
646        A = array_ops.placeholder(shape=A_shape, dtype=dtype, name="A")
647        B = array_ops.placeholder(shape=B_shape, dtype=dtype, name="B")
648        result = math_ops.matmul(A, B, adjoint_a=adjoint_a, adjoint_b=adjoint_b, name="batchmatmul")
649
650        A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype)
651        B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
652        compare_tflite_with_tvm([A_np, B_np], [A.name, B.name], [A, B], [result])
653
654
655def test_forward_batch_matmul():
656    """ BATCH_MAT_MUL """
657    _test_batch_matmul((3, 5, 4), (3, 4, 5), "float32")
658    _test_batch_matmul((3, 5, 4), (3, 4, 5), "float32", True, True)
659    _test_batch_matmul((3, 5, 4), (3, 5, 4), "float32", True, False)
660    _test_batch_matmul((3, 5, 4), (3, 5, 4), "float32", False, True)
661    _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), "float32")
662
663
664#######################################################################
665# Tile
666# ----
667
668
669def _test_forward_tile(in_shape, reps, dtype):
670    data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)
671
672    with tf.Graph().as_default():
673        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
674
675        out = array_ops.tile(in_data, reps)
676
677        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
678
679
680def test_forward_tile():
681    _test_forward_tile((2,), (3,), "int32")
682    _test_forward_tile((2, 2), (2, 3), "float32")
683
684
685######################################################################
686# BatchToSpaceND
687# --------------
688
689
690def _test_batch_to_space_nd(input_shape, block_shape, crops, dtype="int32"):
691    data = np.random.uniform(0, 5, size=input_shape).astype(dtype)
692
693    with tf.Graph().as_default():
694        in_data = array_ops.placeholder(shape=input_shape, dtype=dtype)
695
696        out = array_ops.batch_to_space_nd(in_data, block_shape, crops)
697
698        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
699
700
701def test_forward_batch_to_space_nd():
702    # test cases: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d
703    _test_batch_to_space_nd(input_shape=[4, 1, 1, 1], block_shape=[2, 2], crops=[[0, 0], [0, 0]])
704
705    _test_batch_to_space_nd(input_shape=[4, 1, 1, 3], block_shape=[2, 2], crops=[[0, 0], [0, 0]])
706
707    _test_batch_to_space_nd(input_shape=[4, 2, 2, 1], block_shape=[2, 2], crops=[[0, 0], [0, 0]])
708
709
710######################################################################
711# SpaceToBatchND
712# --------------
713
714
715def _test_space_to_batch_nd(input_shape, block_shape, paddings, dtype="int32"):
716    data = np.random.uniform(0, 5, size=input_shape).astype(dtype)
717
718    with tf.Graph().as_default():
719        in_data = array_ops.placeholder(shape=input_shape, dtype=dtype)
720
721        out = array_ops.space_to_batch_nd(in_data, block_shape, paddings)
722
723        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
724
725
726def test_forward_space_to_batch_nd():
727    # test cases: https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd
728    _test_space_to_batch_nd(input_shape=[1, 2, 2, 1], block_shape=[2, 2], paddings=[[0, 0], [0, 0]])
729
730    _test_space_to_batch_nd(input_shape=[1, 2, 2, 3], block_shape=[2, 2], paddings=[[0, 0], [0, 0]])
731
732    _test_space_to_batch_nd(input_shape=[1, 4, 4, 1], block_shape=[2, 2], paddings=[[0, 0], [0, 0]])
733
734    _test_space_to_batch_nd(input_shape=[2, 2, 4, 1], block_shape=[2, 2], paddings=[[0, 0], [2, 0]])
735
736
737#######################################################################
738# Pooling
739# -------
740def _test_pooling_iteration(input_shape, **kwargs):
741    """ One iteration of pool operation with given shapes and attributes """
742
743    x = -np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1
744
745    with tf.Graph().as_default():
746        in_data = array_ops.placeholder(shape=input_shape, dtype="float32")
747        out = nn_ops.pool(in_data, **kwargs)
748
749        compare_tflite_with_tvm(x, "Placeholder:0", [in_data], [out])
750
751
752def _test_pooling(input_shape, **kwargs):
753    _test_pooling_iteration(input_shape, **kwargs)
754
755
756def test_forward_pooling():
757    """ Pooling """
758
759    for pool_type in ["AVG", "MAX"]:
760        _test_pooling(
761            input_shape=[2, 9, 10, 2],
762            window_shape=[1, 1],
763            padding="SAME",
764            pooling_type=pool_type,
765            dilation_rate=[1, 1],
766            strides=[1, 1],
767        )
768
769        _test_pooling(
770            input_shape=[2, 10, 9, 2],
771            window_shape=[1, 1],
772            padding="SAME",
773            pooling_type=pool_type,
774            dilation_rate=[1, 1],
775            strides=[1, 1],
776        )
777
778        _test_pooling(
779            input_shape=[2, 9, 10, 2],
780            window_shape=[2, 1],
781            padding="SAME",
782            pooling_type=pool_type,
783            dilation_rate=[1, 1],
784            strides=[1, 1],
785        )
786
787        _test_pooling(
788            input_shape=[2, 10, 9, 2],
789            window_shape=[2, 3],
790            padding="SAME",
791            pooling_type=pool_type,
792            dilation_rate=[1, 1],
793            strides=[2, 1],
794        )
795
796
797def _test_l2_pool2d(input_shape, ksize, strides, padding, data_format, fused_func_name=None):
798    x = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1
799
800    with tf.Graph().as_default():
801        in_data = tf.placeholder(dtype=tf.float32, name="input", shape=input_shape)
802        out = tf.sqrt(
803            tf.nn.avg_pool(
804                tf.square(in_data),
805                ksize=ksize,
806                strides=strides,
807                padding=padding,
808                data_format=data_format,
809            )
810        )
811        out = with_fused_activation_function(out, fused_func_name)
812
813        compare_tflite_with_tvm(x, "input", [in_data], [out])
814
815
816def test_forward_l2_pool2d():
817    _test_l2_pool2d([1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], "SAME", "NHWC", "RELU6")
818    _test_l2_pool2d([2, 9, 10, 2], [1, 1, 1, 1], [1, 1, 1, 1], "SAME", "NHWC", "RELU6")
819    _test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 1, 1], "SAME", "NHWC")
820    _test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 2, 1], "SAME", "NHWC")
821    _test_l2_pool2d([1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], "VALID", "NHWC", "RELU")
822    _test_l2_pool2d([2, 9, 10, 2], [1, 1, 1, 1], [1, 1, 1, 1], "VALID", "NHWC")
823    _test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 1, 1], "VALID", "NHWC")
824    _test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 2, 1], "VALID", "NHWC", "RELU6")
825
826
827#######################################################################
828# Convolution
829# -----------
830
831
832def _test_tflite2_quantized_convolution(
833    input_shape, kernel_shape, dilations, strides, padding, data_format
834):
835    """ One iteration of TFLite2 quantized convolution with given shapes and attributes """
836    data_format = "channels_last" if "NHWC" else "channels_first"
837    data = np.random.uniform(0, 1, input_shape).astype("float32")
838    kernel = np.random.uniform(0, 1, kernel_shape).astype("float32")
839
840    data_in = tf.keras.layers.Input(shape=data.shape[1:])
841    conv = tf.keras.layers.Conv2D(
842        filters=kernel_shape[3],
843        kernel_size=(kernel_shape[0], kernel_shape[1]),
844        strides=strides,
845        padding=padding,
846        data_format=data_format,
847        activation="relu",
848        use_bias=False,
849    )(data_in)
850    keras_model = tf.keras.models.Model(data_in, conv)
851    keras_model.layers[1].set_weights([kernel])
852
853    # To create quantized values with dynamic range of activations, needs representative dataset
854    def representative_data_gen():
855        for i in range(1):
856            yield [data]
857
858    tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen)
859
860    tflite_output = run_tflite_graph(tflite_model_quant, data)
861    tvm_output = run_tvm_graph(tflite_model_quant, data, data_in.name.replace(":0", ""))
862    tvm.testing.assert_allclose(
863        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-2, atol=1e-2
864    )
865
866
867def _test_tflite2_quantized_depthwise_convolution(
868    input_shape, kernel_shape, dilations, strides, padding, data_format, depth_multiplier
869):
870    """One iteration of TFLite2 quantized depthwise convolution with given shapes and attributes"""
871
872    data_format = "channels_last" if "NHWC" else "channels_first"
873    data = np.random.uniform(0, 1, input_shape).astype("float32")
874    kernel = np.random.uniform(0, 1, kernel_shape).astype("float32")
875
876    data_in = tf.keras.layers.Input(shape=data.shape[1:])
877    conv = tf.keras.layers.DepthwiseConv2D(
878        kernel_size=(kernel_shape[0], kernel_shape[1]),
879        strides=strides,
880        padding=padding,
881        data_format=data_format,
882        activation="relu",
883        use_bias=False,
884        depth_multiplier=depth_multiplier,
885    )(data_in)
886    keras_model = tf.keras.models.Model(data_in, conv)
887    keras_model.layers[1].set_weights([kernel])
888
889    # To create quantized values with dynamic range of activations, needs representative dataset
890    def representative_data_gen():
891        for i in range(1):
892            yield [data]
893
894    tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen)
895
896    tflite_output = run_tflite_graph(tflite_model_quant, data)
897    tvm_output = run_tvm_graph(tflite_model_quant, data, data_in.name.replace(":0", ""))
898    tvm.testing.assert_allclose(
899        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-2, atol=1e-2
900    )
901
902
903def _test_convolution(
904    tensor_in_sizes,
905    filter_in_sizes,
906    dilations,
907    strides,
908    padding,
909    data_format,
910    is_depthwise=False,
911    quantized=False,
912):
913    """ One iteration of convolution with given shapes and attributes """
914
915    total_size_1 = 1
916    total_size_2 = 1
917    for s in tensor_in_sizes:
918        total_size_1 *= s
919    for s in filter_in_sizes:
920        total_size_2 *= s
921    # Initializes the input tensor with array containing incrementing
922    # numbers from 1.
923    if quantized:
924        data_array = np.random.uniform(0, 255, tensor_in_sizes).astype("uint8")
925        filter_array = np.random.uniform(0, 255, filter_in_sizes).astype("uint8")
926    else:
927        data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
928        filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
929
930    with tf.Graph().as_default():
931        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype="float32", name="in_data")
932        in_filter = constant_op.constant(
933            filter_array, shape=filter_in_sizes, dtype="float32", name="in_filter"
934        )
935        strides = [1] + strides + [1]
936        dilations = [1] + dilations + [1]
937
938        if is_depthwise:
939            out = nn_ops.depthwise_conv2d_native(
940                in_data, in_filter, strides=strides, padding=padding, data_format=data_format
941            )
942        else:
943            out = nn_ops.conv2d(
944                in_data, in_filter, strides=strides, padding=padding, data_format=data_format
945            )
946
947        if quantized:
948            if is_depthwise:
949                # Quantized the inputs and feed them to the convolution
950                inq_data = tf.quantization.fake_quant_with_min_max_args(
951                    in_data, min=-100, max=100, name="inq_data"
952                )
953                inq_filter = tf.quantization.fake_quant_with_min_max_args(
954                    in_filter, min=-100, max=100, name="inq_filter"
955                )
956                out = nn_ops.depthwise_conv2d_native(
957                    inq_data, inq_filter, strides=strides, padding=padding, data_format=data_format
958                )
959                out = tf.quantization.fake_quant_with_min_max_args(
960                    out, min=-200, max=200, name="out"
961                )
962
963                # Set the input quantization range
964                input_range = {"in_data": (-100, 100)} if quantized else None
965
966                # Compare
967                compare_tflite_with_tvm(
968                    data_array,
969                    "in_data",
970                    [in_data],
971                    [out],
972                    quantized=quantized,
973                    input_range=input_range,
974                )
975            else:
976                # Quantized the inputs and feed them to the convolution
977                inq_data = tf.quantization.fake_quant_with_min_max_args(
978                    in_data, min=-100, max=100, name="inq_data"
979                )
980                inq_filter = tf.quantization.fake_quant_with_min_max_args(
981                    in_filter, min=-100, max=100, name="inq_filter"
982                )
983                out = nn_ops.conv2d(
984                    inq_data, inq_filter, strides=strides, padding=padding, data_format=data_format
985                )
986                out = tf.quantization.fake_quant_with_min_max_args(
987                    out, min=-200, max=200, name="out"
988                )
989
990                # Set the input quantization range
991                input_range = {"in_data": (-100, 100)} if quantized else None
992
993                # Compare
994                compare_tflite_with_tvm(
995                    data_array,
996                    "in_data",
997                    [in_data],
998                    [out],
999                    quantized=quantized,
1000                    input_range=input_range,
1001                )
1002        else:
1003            data_array = np.reshape(data_array, tensor_in_sizes).astype("float32")
1004            compare_tflite_with_tvm(data_array, "in_data", [in_data], [out])
1005
1006
1007def test_forward_convolution():
1008    for quantized in [False, True]:
1009        _test_convolution(
1010            [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], "SAME", "NHWC", quantized=quantized
1011        )
1012        _test_convolution(
1013            [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], "VALID", "NHWC", quantized=quantized
1014        )
1015        _test_convolution(
1016            [4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], "SAME", "NHWC", quantized=quantized
1017        )
1018        _test_convolution(
1019            [4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], "VALID", "NHWC", quantized=quantized
1020        )
1021
1022        # depthwise convolution
1023        _test_convolution(
1024            [4, 8, 8, 176],
1025            [1, 1, 176, 1],
1026            [1, 1],
1027            [1, 1],
1028            "SAME",
1029            "NHWC",
1030            True,
1031            quantized=quantized,
1032        )
1033        _test_convolution(
1034            [4, 17, 17, 19],
1035            [3, 3, 19, 1],
1036            [1, 1],
1037            [2, 2],
1038            "VALID",
1039            "NHWC",
1040            True,
1041            quantized=quantized,
1042        )
1043        _test_convolution(
1044            [4, 17, 17, 124],
1045            [1, 1, 124, 1],
1046            [1, 1],
1047            [1, 1],
1048            "SAME",
1049            "NHWC",
1050            True,
1051            quantized=quantized,
1052        )
1053        _test_convolution(
1054            [4, 17, 17, 12],
1055            [3, 3, 12, 1],
1056            [1, 1],
1057            [2, 2],
1058            "VALID",
1059            "NHWC",
1060            True,
1061            quantized=quantized,
1062        )
1063        _test_convolution(
1064            [4, 17, 17, 12],
1065            [3, 3, 12, 2],
1066            [1, 1],
1067            [2, 2],
1068            "VALID",
1069            "NHWC",
1070            True,
1071            quantized=quantized,
1072        )
1073        # depthwise convolution with single input channel
1074        _test_convolution(
1075            [1, 76, 64, 1], [9, 5, 1, 96], [1, 1], [1, 1], "SAME", "NHWC", True, quantized=quantized
1076        )
1077
1078    # TFLite2 quantized convolution testing
1079    if package_version.parse(tf.VERSION) >= package_version.parse("2.1.0"):
1080        _test_tflite2_quantized_convolution(
1081            [1, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], "SAME", "NHWC"
1082        )
1083        _test_tflite2_quantized_convolution(
1084            [1, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], "VALID", "NHWC"
1085        )
1086        _test_tflite2_quantized_convolution(
1087            [1, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], "VALID", "NHWC"
1088        )
1089        _test_tflite2_quantized_convolution(
1090            [1, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], "SAME", "NHWC"
1091        )
1092
1093        # Disable as tests are flaky - https://github.com/apache/incubator-tvm/issues/6064
1094        # depthwise convolution
1095        # _test_tflite2_quantized_depthwise_convolution([1, 8, 8, 128], [1, 1, 128, 1], [1, 1], [1, 1],
1096        #                                               'SAME', 'NHWC', 1)
1097        # _test_tflite2_quantized_depthwise_convolution([1, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2],
1098        #                                               'VALID', 'NHWC', 1)
1099        # _test_tflite2_quantized_depthwise_convolution([1, 24, 24, 3], [7, 7, 3, 8], [1, 1], [2, 2],
1100        #                                               'SAME', 'NHWC', 8)
1101
1102
1103#######################################################################
1104# Transpose Convolution
1105# ---------------------
1106
1107
1108def _test_transpose_conv(tensor_in_sizes, filter_in_sizes, output_shape, strides, padding):
1109    """ One iteration of transpose convolution with given shapes and attributes """
1110
1111    total_size_1 = 1
1112    total_size_2 = 1
1113    for s in tensor_in_sizes:
1114        total_size_1 *= s
1115    for s in filter_in_sizes:
1116        total_size_2 *= s
1117    # Initializes the input tensor with array containing incrementing
1118    # numbers from 1.
1119    data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
1120    filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
1121
1122    with tf.Graph().as_default():
1123        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype="float32")
1124        in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype="float32")
1125        strides = [1] + strides + [1]
1126        # in_filter layout is HWOI
1127        out = nn_ops.conv2d_transpose(
1128            in_data, in_filter, output_shape=output_shape, strides=strides, padding=padding
1129        )
1130        data_array = np.reshape(data_array, tensor_in_sizes).astype("float32")
1131        compare_tflite_with_tvm(data_array, "Placeholder:0", [in_data], [out])
1132
1133
1134def test_forward_transpose_conv():
1135    # kernel 3x3, padding VALID
1136    _test_transpose_conv([4, 32, 32, 16], [3, 3, 5, 16], [4, 34, 34, 5], [1, 1], "VALID")
1137    _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 65, 5], [2, 2], "VALID")
1138    _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 34, 5], [2, 1], "VALID")
1139
1140    # kernel 3x3, padding SAME
1141    _test_transpose_conv([4, 32, 32, 16], [3, 3, 5, 16], [4, 32, 32, 5], [1, 1], "SAME")
1142    _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 64, 64, 5], [2, 2], "SAME")
1143    _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 64, 32, 5], [2, 1], "SAME")
1144
1145    # kernel 2x2, padding VALID
1146    _test_transpose_conv([4, 32, 32, 16], [2, 2, 5, 16], [4, 33, 33, 5], [1, 1], "VALID")
1147    _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 64, 5], [2, 2], "VALID")
1148    _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 33, 5], [2, 1], "VALID")
1149
1150    # kernel 2x2, padding SAME
1151    _test_transpose_conv([4, 32, 32, 16], [2, 2, 5, 16], [4, 32, 32, 5], [1, 1], "SAME")
1152    _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 64, 5], [2, 2], "SAME")
1153    _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 32, 5], [2, 1], "SAME")
1154
1155    # kernel 1x1, padding VALID
1156    _test_transpose_conv([4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], "VALID")
1157    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], "VALID")
1158    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], "VALID")
1159
1160    # kernel 1x1, padding SAME
1161    _test_transpose_conv([4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], "SAME")
1162    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], "SAME")
1163    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], "SAME")
1164
1165
1166#######################################################################
1167# Reshape
1168# -------
1169
1170
1171def _test_reshape(data, out_shape, wrap_shape):
1172    """ One iteration of reshape operation with given data and out shape """
1173    with tf.Graph().as_default():
1174        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
1175
1176        out_shape = out_shape if not wrap_shape else np.array(out_shape, dtype=np.int32)
1177
1178        in_shape = (
1179            out_shape
1180            if not wrap_shape
1181            else array_ops.placeholder(
1182                shape=out_shape.shape, dtype=out_shape.dtype, name="Newshape"
1183            )
1184        )
1185
1186        out = array_ops.reshape(in_data, in_shape)
1187
1188        compare_tflite_with_tvm(
1189            [data, out_shape] if wrap_shape else [data],
1190            ["Placeholder:0", "Newshape:0"] if wrap_shape else ["Placeholder:0"],
1191            [in_data, in_shape] if wrap_shape else [in_data],
1192            [out],
1193            mode="vm",
1194        )
1195
1196
1197def test_forward_reshape():
1198    for wrap in [True, False]:
1199        _test_reshape(np.arange(6.0, dtype=np.float32), [2, 3], wrap)
1200        _test_reshape(np.arange(6), [-1, 2], wrap)
1201        _test_reshape(np.arange(6), [3, -1], wrap)
1202        _test_reshape(np.arange(6), [-1], wrap)
1203
1204
1205#######################################################################
1206# Resize
1207# ------
1208
1209
1210def _test_resize(tf_resize_op, data, align_corners):
1211    """ One iteration of Resize """
1212
1213    assert len(data) == 2
1214
1215    # Test with tensor and constant
1216    with tf.Graph().as_default():
1217        images_tensor = array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name="in")
1218        size = ops.convert_to_tensor(data[1], dtype=data[1].dtype)
1219        out_tensor = tf_resize_op(images=images_tensor, size=size, align_corners=align_corners)
1220        compare_tflite_with_tvm([data[0]], ["in:0"], [images_tensor], [out_tensor])
1221
1222
1223def test_all_resize():
1224    """ Resize """
1225    data = [np.random.rand(1, 16, 16, 3).astype("float32"), np.array([8, 8], dtype=np.int32)]
1226    ### RESIZE_BILINEAR
1227    _test_resize(tf.image.resize_bilinear, data, align_corners=False)
1228    _test_resize(tf.image.resize_bilinear, data, align_corners=True)
1229    ### RESIZE_NEAREST_NEIGHBOR (was added in v1.13)
1230    # According to topi resize.h
1231    # Align corners not supported for nearest neighbour
1232    from tflite.BuiltinOperator import BuiltinOperator
1233
1234    if "RESIZE_NEAREST_NEIGHBOR" in dir(BuiltinOperator()):
1235        _test_resize(tf.image.resize_nearest_neighbor, data, align_corners=False)
1236
1237
1238#######################################################################
1239# Range
1240# -----
1241def _test_range(start, limit, delta):
1242    # tflite 1.13 convert method does not accept empty shapes
1243    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
1244        tf.reset_default_graph()
1245        with tf.Graph().as_default():
1246            start_scalar, limit_scalar, delta_scalar = (
1247                tf.placeholder(dtype=start.dtype, shape=(), name="start"),
1248                tf.placeholder(dtype=limit.dtype, shape=(), name="limit"),
1249                tf.placeholder(dtype=delta.dtype, shape=(), name="delta"),
1250            )
1251
1252            out = tf.range(start_scalar, limit_scalar, delta_scalar, name="range")
1253
1254            compare_tflite_with_tvm(
1255                [start, limit, delta],
1256                ["start", "limit", "delta"],
1257                [start_scalar, limit_scalar, delta_scalar],
1258                [out],
1259                mode="vm",
1260                quantized=False,
1261            )
1262
1263
1264def _test_range_default():
1265    # tflite 1.13 convert method does not accept empty shapes
1266    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
1267        tf.reset_default_graph()
1268        with tf.Graph().as_default():
1269            inputs = [
1270                tf.placeholder(dtype=tf.int32, shape=(), name="p1"),
1271                tf.placeholder(dtype=tf.int32, shape=(), name="p2"),
1272            ]
1273            outputs = [
1274                tf.range(start=inputs[0], limit=inputs[1]),  # use default delta
1275                tf.range(
1276                    start=inputs[1]
1277                ),  # use start as limit with 0 as the first item in the range
1278            ]
1279
1280            compare_tflite_with_tvm(
1281                [np.int32(1), np.int32(18)], ["p1", "p2"], inputs, outputs, mode="vm"
1282            )
1283
1284
1285def test_forward_range():
1286    _test_range(np.int32(1), np.int32(18), np.int32(3))
1287    _test_range(np.int32(1), np.int32(18), np.float32(3.1))  # increment is of type float
1288    _test_range(np.float32(1.0), np.int32(18), np.int32(3.1))  # start is of type float
1289    _test_range_default()
1290
1291
1292#######################################################################
1293# Shape
1294# -----
1295def test_forward_shape():
1296    # tflite 1.13 convert method does not accept empty shapes
1297    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
1298        tf.reset_default_graph()
1299        with tf.Graph().as_default():
1300            data = np.array([1, 18, 3], dtype=np.int32)
1301            start = tf.placeholder(dtype=tf.int32, shape=[], name="start")
1302            limit = tf.placeholder(dtype=tf.int32, shape=[], name="limit")
1303            delta = tf.placeholder(dtype=tf.int32, shape=[], name="delta")
1304            r = tf.range(start, limit, delta, tf.int32, name="range")
1305            out = tf.shape(r, out_type=tf.dtypes.int32)
1306            compare_tflite_with_tvm(
1307                [x for x in np.nditer(data)],
1308                ["start", "limit", "delta"],
1309                [start, limit, delta],
1310                [out],
1311                mode="vm",
1312            )
1313
1314
1315#######################################################################
1316# Concatenation
1317# -------------
1318
1319
1320def _test_concatenation(data, axis):
1321    """ One iteration of concatenation """
1322
1323    assert len(data) >= 1
1324
1325    with tf.Graph().as_default():
1326        in_data = [
1327            array_ops.placeholder(shape=tensor.shape, dtype=tensor.dtype, name="in_{}".format(idx))
1328            for idx, tensor in enumerate(data)
1329        ]
1330        out = array_ops.concat(in_data, axis=axis)
1331        name = ["in_{}:0".format(idx) for idx in range(len(data))]
1332
1333        compare_tflite_with_tvm(data, name, in_data, [out])
1334
1335
1336def test_forward_concatenation():
1337
1338    _test_concatenation([np.arange(6).reshape((1, 2, 1, 3)), np.arange(6).reshape((1, 2, 1, 3))], 1)
1339
1340    _test_concatenation([np.arange(6).reshape((3, 2)), np.arange(6).reshape((3, 2))], 1)
1341
1342    _test_concatenation(
1343        [
1344            np.arange(6).reshape((2, 1, 1, 3)),
1345            np.arange(6).reshape((2, 1, 1, 3)),
1346            np.arange(6).reshape((2, 1, 1, 3)),
1347        ],
1348        1,
1349    )
1350
1351
1352#######################################################################
1353# Unary elemwise
1354# --------------
1355
1356
1357def _test_unary_elemwise(math_op, data):
1358    """ One iteration of unary elemwise """
1359
1360    with tf.Graph().as_default():
1361        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="in")
1362        out = math_op(in_data)
1363        compare_tflite_with_tvm(data, ["in:0"], [in_data], [out])
1364
1365
1366#######################################################################
1367# Abs
1368# ---
1369
1370
1371def _test_abs(data):
1372    """ One iteration of abs """
1373    return _test_unary_elemwise(math_ops.abs, data)
1374
1375
1376#######################################################################
1377# Ceil
1378# ----
1379
1380
1381def _test_ceil(data):
1382    """ One iteration of ceil """
1383    return _test_unary_elemwise(math_ops.ceil, data)
1384
1385
1386#######################################################################
1387# Floor
1388# -----
1389
1390
1391def _test_floor(data):
1392    """ One iteration of floor """
1393    return _test_unary_elemwise(math_ops.floor, data)
1394
1395
1396#######################################################################
1397# Round
1398# -----
1399
1400
1401def _test_round(data):
1402    """ One iteration of round """
1403    return _test_unary_elemwise(math_ops.round, data)
1404
1405
1406#######################################################################
1407# Exp
1408# ---
1409
1410
1411def _test_exp(data):
1412    """ One iteration of exp """
1413    return _test_unary_elemwise(math_ops.exp, data)
1414
1415
1416#######################################################################
1417# Log
1418# ---
1419
1420
1421def _test_log(data):
1422    """ One iteration of log """
1423    return _test_unary_elemwise(math_ops.log, data)
1424
1425
1426#######################################################################
1427# Sin
1428# ---
1429
1430
1431def _test_sin(data):
1432    """ One iteration of sin """
1433    return _test_unary_elemwise(math_ops.sin, data)
1434
1435
1436#######################################################################
1437# Cos
1438# ---
1439
1440
1441def _test_cos(data):
1442    """ One iteration of cos """
1443    return _test_unary_elemwise(math_ops.cos, data)
1444
1445
1446#######################################################################
1447# Tan
1448# ---
1449
1450
1451def _test_tan(data):
1452    """ One iteration of tan """
1453    return _test_unary_elemwise(math_ops.tan, data)
1454
1455
1456#######################################################################
1457# Sqrt
1458# ----
1459
1460
1461def _test_sqrt(data):
1462    """ One iteration of sqrt """
1463    return _test_unary_elemwise(math_ops.sqrt, data)
1464
1465
1466#######################################################################
1467# Rsqrt
1468# -----
1469
1470
1471def _test_rsqrt(data):
1472    """ One iteration of rsqrt """
1473    return _test_unary_elemwise(math_ops.rsqrt, data)
1474
1475
1476#######################################################################
1477# Neg
1478# ---
1479
1480
1481def _test_neg(data):
1482    """ One iteration of neg """
1483    return _test_unary_elemwise(math_ops.neg, data)
1484
1485
1486#######################################################################
1487# Square
1488# ------
1489
1490
1491def _test_square(data):
1492    """ One iteration of square """
1493    return _test_unary_elemwise(math_ops.square, data)
1494
1495
1496#######################################################################
1497# Elu
1498# ---
1499
1500
1501def _test_elu(data):
1502    """ One iteration of elu """
1503    return _test_unary_elemwise(nn_ops.elu, data)
1504
1505
1506def _test_forward_unary_elemwise(test_op):
1507    # functions that need positive input
1508    if test_op.__name__ in {"_test_log", "_test_sqrt", "_test_rsqrt"}:
1509        test_op(np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)))
1510    else:
1511        test_op(np.random.uniform(-10, 10, (3, 2)).astype(np.float32))
1512
1513
1514def test_all_unary_elemwise():
1515    _test_forward_unary_elemwise(_test_abs)
1516    _test_forward_unary_elemwise(_test_floor)
1517    _test_forward_unary_elemwise(_test_exp)
1518    _test_forward_unary_elemwise(_test_log)
1519    _test_forward_unary_elemwise(_test_sin)
1520    _test_forward_unary_elemwise(_test_sqrt)
1521    _test_forward_unary_elemwise(_test_rsqrt)
1522    _test_forward_unary_elemwise(_test_neg)
1523    _test_forward_unary_elemwise(_test_square)
1524    # ceil and cos come with TFLite 1.14.0.post1 fbs schema
1525    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
1526        _test_forward_unary_elemwise(_test_ceil)
1527        _test_forward_unary_elemwise(_test_cos)
1528        _test_forward_unary_elemwise(_test_round)
1529        # This fails with TF and Tflite 1.15.2, this could not have been tested
1530        # in CI or anywhere else. The failure mode is that we see a backtrace
1531        # from the converter that we need to provide a custom Tan operator
1532        # implementation.
1533        # _test_forward_unary_elemwise(_test_tan)
1534        _test_forward_unary_elemwise(_test_elu)
1535
1536
1537#######################################################################
1538# Element-wise
1539# ------------
1540
1541
1542def _test_elemwise(
1543    math_op,
1544    data,
1545    fused_activation_function=None,
1546    quantized=False,
1547    qnn_op=None,
1548    same_qnn_params=False,
1549):
1550    """ One iteration of elemwise """
1551
1552    assert len(data) == 2
1553
1554    def __test_elemwise(in_data):
1555        assert 2 == len(in_data)
1556        if quantized:
1557            # set the fp32 output range with respect to the operation
1558            out_min, out_max = _test_elemwise_qnn_out_range(qnn_op)
1559            inq0_min, inq0_max = (-100, 100)
1560            inq1_min, inq1_max = (-50, 50)
1561
1562            # if requested use same quantization parameters provided by _test_elemwise_qnn_out_range
1563            if same_qnn_params:
1564                inq0_min, inq0_max = (out_min, out_max)
1565                inq1_min, inq1_max = (out_min, out_max)
1566
1567            # fake_quant will keep the tensors in float32 until the conversion in the session
1568            inq_data = [
1569                tf.quantization.fake_quant_with_min_max_args(
1570                    in_data[0], min=out_min, max=out_max, name="inq_0"
1571                )
1572                if None != in_data[0]
1573                else tf.quantization.fake_quant_with_min_max_args(
1574                    data[0], min=out_min, max=out_max, name="const_tensor0"
1575                ),
1576                tf.quantization.fake_quant_with_min_max_args(
1577                    in_data[1], min=out_min, max=out_max, name="inq_1"
1578                )
1579                if None != in_data[1]
1580                else tf.quantization.fake_quant_with_min_max_args(
1581                    data[1], min=out_min, max=out_max, name="const_tensor1"
1582                ),
1583            ]
1584
1585            input_range = {
1586                x[1][0]: x[1][1]
1587                for x in zip(
1588                    in_data, (("inq_0", (inq0_min, inq0_max)), ("inq_1", (inq1_min, inq1_max)))
1589                )
1590                if None != x[0]
1591            }
1592
1593            out = math_op(inq_data[0], inq_data[1])
1594            out = with_fused_activation_function(out, fused_activation_function)
1595            out = tf.quantization.fake_quant_with_min_max_args(
1596                out, min=out_min, max=out_max, name="out"
1597            )
1598
1599            # Note same_qnn_params uses experimental_new_converter as toco failed
1600            compare_tflite_with_tvm(
1601                [x[1] for x in zip(in_data, data) if None != x[0]],
1602                [x + ":0" for x in input_range.keys()],
1603                [x[1] for x in zip(in_data, inq_data) if None != x[0]],
1604                [out],
1605                quantized=True,
1606                input_range=input_range,
1607                experimental_new_converter=same_qnn_params,
1608            )
1609        else:
1610            out = math_op(
1611                in_data[0]
1612                if None != in_data[0]
1613                else ops.convert_to_tensor(data[0], dtype=data[0].dtype),
1614                in_data[1]
1615                if None != in_data[1]
1616                else ops.convert_to_tensor(data[1], dtype=data[1].dtype),
1617            )
1618            out = with_fused_activation_function(out, fused_activation_function)
1619            compare_tflite_with_tvm(
1620                [x[1] for x in zip(in_data, data) if None != x[0]],
1621                [x[1] for x in zip(in_data, ("in_0:0", "in_1:0")) if None != x[0]],
1622                [x for x in in_data if None != x],
1623                [out],
1624            )
1625
1626    # Test with two tensors
1627    with tf.Graph().as_default():
1628        __test_elemwise(
1629            in_data=[
1630                array_ops.placeholder(shape=data[0].shape, dtype="float32", name="in_0"),
1631                array_ops.placeholder(shape=data[1].shape, dtype="float32", name="in_1"),
1632            ]
1633        )
1634    # Test with tensor and constant
1635    with tf.Graph().as_default():
1636        __test_elemwise(
1637            in_data=[array_ops.placeholder(shape=data[0].shape, dtype="float32", name="in_0"), None]
1638        )
1639    # Test with constant and tensor
1640    with tf.Graph().as_default():
1641        __test_elemwise(
1642            in_data=[None, array_ops.placeholder(shape=data[1].shape, dtype="float32", name="in_1")]
1643        )
1644
1645
1646#######################################################################
1647# Add
1648# ---
1649
1650
1651def _test_add(data, fused_activation_function=None, quantized=False, qnn_op=None):
1652    """ One iteration of add """
1653    return _test_elemwise(math_ops.add, data, fused_activation_function, quantized, qnn_op)
1654
1655
1656#######################################################################
1657# Subtract
1658# --------
1659
1660
1661def _test_sub(data, fused_activation_function=None, quantized=False, qnn_op=None):
1662    """ One iteration of subtract """
1663    return _test_elemwise(math_ops.subtract, data, fused_activation_function, quantized, qnn_op)
1664
1665
1666#######################################################################
1667# Mul
1668# ---
1669
1670
1671def _test_mul(data, fused_activation_function=None, quantized=False, qnn_op=None):
1672    """ One iteration of mul """
1673    return _test_elemwise(math_ops.multiply, data, fused_activation_function, quantized, qnn_op)
1674
1675
1676#######################################################################
1677# Divide
1678# ------
1679
1680
1681def _test_div(data, fused_activation_function=None):
1682    """ One iteration of divide """
1683    return _test_elemwise(math_ops.divide, data, fused_activation_function)
1684
1685
1686#######################################################################
1687# Power
1688# -----
1689
1690
1691def _test_pow(data):
1692    """ One iteration of power """
1693    return _test_elemwise(math_ops.pow, data)
1694
1695
1696#######################################################################
1697# Maximum
1698# -------
1699
1700
1701def _test_maximum(data, fused_activation_function=None, quantized=False, qnn_op=None):
1702    """ One iteration of maximum """
1703    return _test_elemwise(
1704        math_ops.maximum, data, fused_activation_function, quantized, qnn_op, same_qnn_params=True
1705    )
1706
1707
1708#######################################################################
1709# Minimum
1710# -------
1711
1712
1713def _test_minimum(data, fused_activation_function=None, quantized=False, qnn_op=None):
1714    """ One iteration of minimum """
1715    return _test_elemwise(
1716        math_ops.minimum, data, fused_activation_function, quantized, qnn_op, same_qnn_params=True
1717    )
1718
1719
1720#######################################################################
1721# Greater
1722# -------
1723
1724
1725def _test_greater(data):
1726    """ One iteration of greater """
1727    return _test_elemwise(math_ops.greater, data)
1728
1729
1730#######################################################################
1731# Greater_equal
1732# -------------
1733
1734
1735def _test_greater_equal(data):
1736    """ One iteration of greater_equal """
1737    return _test_elemwise(math_ops.greater_equal, data)
1738
1739
1740#######################################################################
1741# Less
1742# ----
1743
1744
1745def _test_less(data):
1746    """ One iteration of less """
1747    return _test_elemwise(math_ops.less, data)
1748
1749
1750#######################################################################
1751# Less_equal
1752# ----------
1753
1754
1755def _test_less_equal(data):
1756    """ One iteration of less_equal """
1757    return _test_elemwise(math_ops.less_equal, data)
1758
1759
1760#######################################################################
1761# Equal
1762# -----
1763
1764
1765def _test_equal(data):
1766    """ One iteration of equal """
1767    return _test_elemwise(math_ops.equal, data)
1768
1769
1770#######################################################################
1771# Not_equal
1772# ---------
1773
1774
1775def _test_not_equal(data):
1776    """ One iteration of not_equal"""
1777    return _test_elemwise(math_ops.not_equal, data)
1778
1779
1780#######################################################################
1781# Squared_difference
1782# ------------------
1783
1784
1785def _test_squared_difference(data):
1786    """ One iteration of squared difference """
1787    return _test_elemwise(math_ops.squared_difference, data)
1788
1789
1790#######################################################################
1791# Floor_divide
1792# ------------
1793
1794
1795def _test_floor_divide(data):
1796    """ One iteration of floor_div"""
1797    return _test_elemwise(math_ops.floordiv, data)
1798
1799
1800#######################################################################
1801# Floor_mod
1802# ---------
1803
1804
1805def _test_floor_mod(data):
1806    """ One iteration of floor_mod"""
1807    return _test_elemwise(math_ops.floormod, data)
1808
1809
1810def _test_forward_elemwise(testop):
1811    """ Elewise"""
1812    testop(
1813        [
1814            np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
1815            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3)),
1816        ]
1817    )
1818    testop(
1819        [
1820            np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)),
1821            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)),
1822        ]
1823    )
1824    testop(
1825        [
1826            np.arange(3.0, dtype=np.float32).reshape((1, 3)),
1827            np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3)),
1828        ]
1829    )
1830
1831
1832def _test_forward_elemwise_quantized(testop):
1833    testop(
1834        [
1835            np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
1836            np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
1837        ],
1838        quantized=True,
1839        qnn_op=testop,
1840    )
1841
1842
1843def _test_elemwise_qnn_out_range(qnn_op):
1844    # set the fake_quant output range with respect to the input tensors float32 range
1845    qnn_out_range = {
1846        _test_add: (-150, 150),
1847        _test_sub: (-150, 150),
1848        _test_mul: (-5e3, 5e3),
1849        _test_maximum: (-112, 111),
1850        _test_minimum: (-128, 127),
1851    }
1852
1853    return qnn_out_range[qnn_op]
1854
1855
1856def test_all_elemwise():
1857    _test_forward_elemwise(_test_add)
1858    _test_forward_elemwise_quantized(_test_add)
1859    _test_forward_elemwise(partial(_test_add, fused_activation_function="RELU"))
1860    # this is broken with tf upgrade 1.15.2 and hits a segfault that needs
1861    # further investigation.
1862    # _test_forward_elemwise(partial(_test_add, fused_activation_function="RELU6"))
1863    _test_forward_elemwise(_test_sub)
1864    _test_forward_elemwise_quantized(_test_sub)
1865    _test_forward_elemwise(partial(_test_sub, fused_activation_function="RELU"))
1866    _test_forward_elemwise(partial(_test_sub, fused_activation_function="RELU6"))
1867    _test_forward_elemwise(_test_mul)
1868    _test_forward_elemwise_quantized(_test_mul)
1869    _test_forward_elemwise(partial(_test_mul, fused_activation_function="RELU"))
1870    _test_forward_elemwise(partial(_test_mul, fused_activation_function="RELU6"))
1871    _test_forward_elemwise(_test_div)
1872    _test_forward_elemwise(partial(_test_div, fused_activation_function="RELU"))
1873    _test_forward_elemwise(partial(_test_div, fused_activation_function="RELU6"))
1874    _test_forward_elemwise(_test_pow)
1875    _test_forward_elemwise(_test_maximum)
1876    _test_forward_elemwise_quantized(_test_maximum)
1877    _test_forward_elemwise(_test_minimum)
1878    _test_forward_elemwise_quantized(_test_minimum)
1879    _test_forward_elemwise(_test_greater)
1880    _test_forward_elemwise(_test_squared_difference)
1881    _test_forward_elemwise(_test_greater_equal)
1882    _test_forward_elemwise(_test_less)
1883    _test_forward_elemwise(_test_less_equal)
1884    _test_forward_elemwise(_test_equal)
1885    _test_forward_elemwise(_test_not_equal)
1886    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
1887        _test_forward_elemwise(_test_floor_divide)
1888        _test_forward_elemwise(_test_floor_mod)
1889
1890
1891#######################################################################
1892# AddN
1893# ----
1894
1895
1896def _test_forward_add_n(inputs):
1897    tf.reset_default_graph()
1898    with tf.Graph().as_default():
1899        temp = []
1900        for each in inputs:
1901            temp.append(tf.placeholder(shape=each.shape, dtype=each.dtype))
1902        output = tf.add_n(temp)
1903        compare_tflite_with_tvm(
1904            [each for each in inputs],
1905            [each.name for each in temp],
1906            [each for each in temp],
1907            [output],
1908        )
1909
1910
1911def test_forward_add_n():
1912    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
1913        x = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32)
1914        y = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32)
1915        z = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32)
1916        m, n, o = x.astype(np.float32), y.astype(np.float32), z.astype(np.float32)
1917        in0 = x
1918        in1 = [x, y]
1919        in2 = (x, y, z)
1920        in3 = m
1921        in4 = [m, n]
1922        in5 = (m, n, o)
1923        _test_forward_add_n(in0)
1924        _test_forward_add_n(in1)
1925        _test_forward_add_n(in2)
1926        _test_forward_add_n(in3)
1927        _test_forward_add_n(in4)
1928        _test_forward_add_n(in5)
1929
1930
1931#######################################################################
1932# Logical operators
1933# -----------------
1934
1935
1936def _test_logical_binary(logical_bin_op, data):
1937
1938    with tf.Graph().as_default():
1939        in_data = [
1940            array_ops.placeholder(shape=data[0].shape, dtype="bool", name="in_0"),
1941            array_ops.placeholder(shape=data[1].shape, dtype="bool", name="in_1"),
1942        ]
1943        if logical_bin_op == math_ops.logical_not:
1944            out = math_ops.logical_or(in_data[0], in_data[1], name="out1")
1945            out = logical_bin_op(out, name="out")
1946        else:
1947            out = logical_bin_op(in_data[0], in_data[1], name="out")
1948
1949        compare_tflite_with_tvm(data, ["in_0:0", "in_1:0"], in_data, [out])
1950
1951
1952def _test_forward_logical_and(data):
1953    """ One iteration of logical and """
1954    return _test_logical_binary(math_ops.logical_and, data)
1955
1956
1957def _test_forward_logical_or(data):
1958    """ One iteration of logical or """
1959    return _test_logical_binary(math_ops.logical_or, data)
1960
1961
1962def _test_forward_logical_not(data):
1963    """ One iteration of logical not """
1964    return _test_logical_binary(math_ops.logical_not, data)
1965
1966
1967def test_all_logical():
1968    data = [
1969        np.random.choice(a=[False, True], size=(2, 3, 4)).astype("bool"),
1970        np.random.choice(a=[False, True], size=(2, 3, 4)).astype("bool"),
1971    ]
1972    # boolean dtype is not supported by older versions than TFLite 1.15.0
1973    if package_version.parse(tf.VERSION) >= package_version.parse("1.15.0"):
1974        _test_forward_logical_and(data)
1975        _test_forward_logical_or(data)
1976        _test_forward_logical_not(data)
1977
1978
1979#######################################################################
1980# Zeros like
1981# ----------
1982
1983
1984def _test_zeros_like(data):
1985    """ One iteration of ZEROS LIKE """
1986    with tf.Graph().as_default():
1987        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
1988        out = gen_array_ops.zeros_like(in_data)
1989        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
1990
1991
1992def test_forward_zeros_like():
1993    """ ZEROS LIKE """
1994    _test_zeros_like(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
1995
1996
1997#######################################################################
1998# Fill
1999# ----
2000
2001
2002def _test_fill(dims, value_data, value_dtype):
2003    """ Use the fill op to create a tensor of value_data with constant dims."""
2004
2005    value_data = np.array(value_data, dtype=value_dtype)
2006    # TF 1.13 TFLite convert method does not accept empty shapes
2007    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
2008        with tf.Graph().as_default():
2009            value = array_ops.placeholder(dtype=value_dtype, name="value", shape=[])
2010            out = tf.fill(dims, value)
2011            compare_tflite_with_tvm([value_data], ["value"], [value], [out])
2012
2013    with tf.Graph().as_default():
2014        input1 = array_ops.placeholder(dtype=value_dtype, name="input1", shape=dims)
2015        # Fill op gets converted to static tensor during conversion
2016        out = tf.fill(dims, value_data)
2017        out1 = tf.add(out, input1)
2018        input1_data = np.random.uniform(0, 5, size=dims).astype(value_dtype)
2019        compare_tflite_with_tvm([input1_data], ["input1"], [input1], [out1])
2020
2021
2022def test_forward_fill():
2023    """ Test FILL op """
2024
2025    _test_fill((1, 2, 2, 4), 5, "int32")
2026    _test_fill((1, 2, 2, 4), 5, "float32")
2027    _test_fill((5,), 5, "int32")
2028
2029
2030#######################################################################
2031# Reduce
2032# ------
2033
2034
2035def _test_reduce(math_op, data, keep_dims=None):
2036    """ One iteration of reduce """
2037
2038    assert len(data) == 2
2039
2040    # Test with tensor and constant
2041    with tf.Graph().as_default():
2042        in_data = array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name="in")
2043        out = math_op(in_data, data[1], keep_dims)
2044        compare_tflite_with_tvm([data[0]], ["in:0"], [in_data], [out])
2045
2046
2047def _test_reduce_quantize(math_op, data, keep_dims=None):
2048    """ One iteration of reduce """
2049
2050    assert len(data) == 2
2051
2052    # Test with tensor and constant
2053    with tf.Graph().as_default():
2054        in_data = [array_ops.placeholder(shape=data[0].shape, dtype="float32", name="in")]
2055        inq_data = [
2056            tf.quantization.fake_quant_with_min_max_args(
2057                in_data[0], min=-100, max=100, name="inq_0"
2058            )
2059        ]
2060        input_range = {"inq_0": (-100, 100)}
2061        out = math_op(inq_data, data[1], keep_dims)
2062        out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out")
2063        compare_tflite_with_tvm(
2064            [data[0]], ["inq_0:0"], [inq_data[0]], [out], quantized=True, input_range=input_range
2065        )
2066
2067
2068#######################################################################
2069# Reduce_min
2070# ----------
2071
2072
2073def _test_reduce_min(data, keep_dims=None):
2074    """ One iteration of reduce_min """
2075    return _test_reduce(math_ops.reduce_min, data, keep_dims)
2076
2077
2078#######################################################################
2079# Reduce_max
2080# ----------
2081
2082
2083def _test_reduce_max(data, keep_dims=None):
2084    """ One iteration of reduce_max """
2085    return _test_reduce(math_ops.reduce_max, data, keep_dims)
2086
2087
2088#######################################################################
2089# Reduce_mean
2090# -----------
2091
2092
2093def _test_reduce_mean(data, keep_dims=None, quantized=False):
2094    """ One iteration of reduce_mean """
2095    if quantized:
2096        return _test_reduce_quantize(math_ops.reduce_mean, data, keep_dims)
2097    else:
2098        return _test_reduce(math_ops.reduce_mean, data, keep_dims)
2099
2100
2101#######################################################################
2102# Reduce_prod
2103# -----------
2104
2105
2106def _test_reduce_prod(data, keep_dims=None):
2107    """ One iteration of reduce_prod """
2108    return _test_reduce(math_ops.reduce_prod, data, keep_dims)
2109
2110
2111#######################################################################
2112# Reduce_sum
2113# -----------
2114
2115
2116def _test_reduce_sum(data, keep_dims=None):
2117    """ One iteration of reduce_sum """
2118    return _test_reduce(math_ops.reduce_sum, data, keep_dims)
2119
2120
2121#######################################################################
2122# Reduce_any
2123# ----------
2124
2125
2126def _test_reduce_any(data, keep_dims=None):
2127    """ One iteration of reduce_any """
2128    return _test_reduce(math_ops.reduce_any, data, keep_dims)
2129
2130
2131def _test_forward_reduce(testop, dtype="float32"):
2132    """ Reduce """
2133    if dtype == "bool":
2134        data0 = [np.random.choice(a=[False, True], size=(16, 16, 16, 16)).astype(dtype), None]
2135        data1 = [
2136            np.random.choice(a=[False, True], size=(16, 16, 16, 16)).astype(dtype),
2137            np.array([1, 2], dtype=np.int32),
2138        ]
2139    else:
2140        data0 = [np.random.rand(16, 16, 16, 16).astype(dtype), None]
2141        data1 = [np.random.rand(16, 16, 16, 16).astype(dtype), np.array([1, 2], dtype=np.int32)]
2142    testop(data0)
2143    testop(data0, keep_dims=False)
2144    testop(data0, keep_dims=True)
2145    testop(data1)
2146    testop(data1, keep_dims=False)
2147    testop(data1, keep_dims=True)
2148
2149
2150def _test_forward_reduce_quantized(testop):
2151    data0 = [
2152        np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
2153        np.array([1, 2], dtype=np.int32),
2154    ]
2155    testop(data0, quantized=True)
2156    testop(data0, keep_dims=False, quantized=True)
2157    testop(data0, keep_dims=True, quantized=True)
2158
2159
2160def test_all_reduce():
2161    _test_forward_reduce(_test_reduce_min)
2162    _test_forward_reduce(_test_reduce_max)
2163    _test_forward_reduce(_test_reduce_mean)
2164    _test_forward_reduce_quantized(_test_reduce_mean)
2165    _test_forward_reduce(_test_reduce_prod)
2166    _test_forward_reduce(_test_reduce_sum)
2167    if package_version.parse(tf.VERSION) >= package_version.parse("1.15.0"):
2168        _test_forward_reduce(_test_reduce_any, dtype="bool")
2169
2170
2171#######################################################################
2172# Arg_min_max
2173# -----------
2174
2175
2176def _test_arg_min_max(math_op, data, axis, quantized=False):
2177    """ One iteration of arg_min_max"""
2178
2179    with tf.Graph().as_default():
2180        t_name = "in"
2181        in_data = array_ops.placeholder(shape=data.shape, dtype=np.float32, name=t_name)
2182        input_range = None
2183        qmin, qmax = -100, 102
2184        if quantized:
2185            inq_data = tf.quantization.fake_quant_with_min_max_args(
2186                in_data, min=qmin, max=qmax, name="q" + t_name
2187            )
2188            input_range = {inq_data.name.split(":")[0]: (qmin, qmax)}
2189            out = math_op(input=inq_data, axis=axis)
2190            compare_tflite_with_tvm(
2191                [data], [inq_data.name], [inq_data], [out], quantized=True, input_range=input_range
2192            )
2193        else:
2194            out = math_op(input=in_data, axis=axis)
2195            compare_tflite_with_tvm([data], [in_data.name], [in_data], [out])
2196
2197
2198def test_forward_arg_min_max():
2199    # test quantized
2200    for data in [np.array(np.random.uniform(-100, 100, (3, 4)), dtype=np.uint8)]:
2201        # There is no quantized version of ArgMin
2202        for axis in [None, 0, 1, -1]:
2203            _test_arg_min_max(math_ops.argmax, data, axis, True)
2204
2205    for data in [np.array(np.random.uniform(-100, 100, (3, 4)), dtype=np.float32)]:
2206        for axis in [None, 0, 1, -1]:
2207            _test_arg_min_max(math_ops.argmax, data, axis)
2208            _test_arg_min_max(math_ops.argmin, data, axis)
2209
2210
2211#######################################################################
2212# Select, Where
2213# -------------
2214
2215
2216def test_forward_select():
2217    with tf.Graph().as_default():
2218        with tf.Session() as sess:
2219            input1 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name="input1")
2220            input2 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name="input2")
2221            mask = input1 > input2
2222            out = tf.where(mask, input1 + 1, input2 * 2)
2223            in_data1 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("int32")
2224            in_data2 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("int32")
2225
2226            compare_tflite_with_tvm(
2227                [in_data1, in_data2], ["input1:0", "input2:0"], [input1, input2], [out]
2228            )
2229
2230
2231# Squeeze
2232# -------
2233
2234
2235def _test_squeeze(data, squeeze_dims=None):
2236    """ One iteration of squeeze """
2237
2238    if squeeze_dims is None:
2239        squeeze_dims = []
2240
2241    with tf.Graph().as_default():
2242        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
2243
2244        if squeeze_dims:
2245            out = array_ops.squeeze(in_data, squeeze_dims)
2246        else:
2247            out = array_ops.squeeze(in_data)
2248
2249        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
2250
2251
2252def test_forward_squeeze():
2253    """ Squeeze """
2254    _test_squeeze(np.arange(6).reshape((1, 2, 1, 3)), [0, 2])
2255    _test_squeeze(np.arange(6).reshape((2, 1, 3, 1)), [1, 3])
2256
2257
2258#######################################################################
2259# Quantize/DeQuantize
2260# -------------------
2261
2262
2263def _test_quantize_dequantize(data):
2264    """ One iteration of quantize and dequantize """
2265
2266    # Keras model to force TFLite converter to insert 2 TFLite quantize ops.
2267    # First TFLite quantize op converts float32 tensor to int8 tensor - Qnn quantize.
2268    # Second TFLite quantize op converts int8 tensor to int8 tensor - Qnn requantize.
2269    data_in = tf.keras.layers.Input(shape=data.shape[1:])
2270    relu = tf.keras.layers.ReLU()(data_in)
2271    add = tf.keras.layers.Add()([data_in, relu])
2272    concat = tf.keras.layers.Concatenate(axis=0)([relu, add])
2273    keras_model = tf.keras.models.Model(inputs=data_in, outputs=concat)
2274    input_name = data_in.name.split(":")[0]
2275
2276    # To create quantized values with dynamic range of activations, needs representative dataset
2277    def representative_data_gen():
2278        for i in range(1):
2279            yield [data]
2280
2281    tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen)
2282
2283    tflite_output = run_tflite_graph(tflite_model_quant, data)
2284    tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
2285    tvm.testing.assert_allclose(
2286        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2
2287    )
2288
2289
2290def _test_quantize_dequantize_const(data):
2291    """ One iteration of quantize and dequantize """
2292
2293    # Keras model to force TFLite converter to insert 2 TFLite quantize ops.
2294    # First TFLite quantize op converts float32 tensor to int8 tensor - Qnn quantize.
2295    # Second TFLite quantize op converts int8 tensor to int8 tensor - Qnn requantize.
2296    data_in = tf.keras.layers.Input(shape=data.shape[1:])
2297    relu = tf.keras.layers.ReLU()(data_in)
2298    add = tf.keras.layers.Add()([data, relu])
2299    concat = tf.keras.layers.Concatenate(axis=0)([relu, add])
2300    keras_model = tf.keras.models.Model(inputs=data_in, outputs=concat)
2301    input_name = data_in.name.split(":")[0]
2302
2303    # To create quantized values with dynamic range of activations, needs representative dataset
2304    def representative_data_gen():
2305        for i in range(1):
2306            yield [data]
2307
2308    tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen)
2309
2310    tflite_output = run_tflite_graph(tflite_model_quant, data)
2311    tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
2312    tvm.testing.assert_allclose(
2313        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2
2314    )
2315
2316
2317def test_forward_quantize_dequantize():
2318    """ Quantize Dequantize """
2319    data = np.random.uniform(0, 1, (1, 4, 4, 3)).astype("float32")
2320    if package_version.parse(tf.VERSION) >= package_version.parse("2.1.0"):
2321        _test_quantize_dequantize(data)
2322        _test_quantize_dequantize_const(data)
2323
2324
2325#######################################################################
2326# Pad
2327# ---
2328
2329
2330def _test_pad(data, mode="CONSTANT", quantized=False):
2331    """ One iteration of PAD """
2332
2333    assert len(data) == 2
2334
2335    # Test with tensor and constant
2336    with tf.Graph().as_default():
2337        in_data = [array_ops.placeholder(shape=data[0].shape, dtype="float32", name="in")]
2338
2339        if quantized:
2340            # fake_quant will keep the tensors in float32 until the conversion in the session
2341            input_range = {"inq_0": (-100, 100)}
2342            inq_data = [
2343                tf.quantization.fake_quant_with_min_max_args(
2344                    in_data[0], min=-100, max=100, name="inq_0"
2345                )
2346            ]
2347            out = array_ops.pad(
2348                inq_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode
2349            )
2350            compare_tflite_with_tvm(
2351                [data[0]], ["inq_0:0"], inq_data, [out], quantized=True, input_range=input_range
2352            )
2353        else:
2354            out = array_ops.pad(
2355                in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode
2356            )
2357            compare_tflite_with_tvm([data[0]], ["in:0"], in_data, [out])
2358
2359
2360def test_forward_pad():
2361    """ Pad """
2362    _test_pad(
2363        [
2364            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3)),
2365            np.array([[1, 1], [2, 2], [1, 1], [2, 2]], dtype=np.int32),
2366        ]
2367    )
2368    _test_pad(
2369        [
2370            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)),
2371            np.array([[2, 2], [1, 1], [1, 1]], dtype=np.int32),
2372        ]
2373    )
2374    _test_pad(
2375        [
2376            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
2377            np.array([[1, 1], [2, 2]], dtype=np.int32),
2378        ]
2379    )
2380    _test_pad(
2381        [
2382            np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3)),
2383            np.array([[1, 1], [2, 2]], dtype=np.int32),
2384        ]
2385    )
2386    _test_pad(
2387        [
2388            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
2389            np.array([[1, 1], [2, 2]], dtype=np.int32),
2390        ],
2391        mode="REFLECT",
2392    )
2393    _test_pad(
2394        [
2395            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
2396            np.array([[1, 1], [2, 2]], dtype=np.int32),
2397        ],
2398        mode="SYMMETRIC",
2399    )
2400    _test_pad(
2401        [
2402            np.arange(0, 256, dtype=np.uint8).reshape((1, 256)),
2403            np.array([[1, 1], [2, 2]], dtype=np.int32),
2404        ],
2405        quantized=True,
2406    )
2407
2408
2409#######################################################################
2410# PADV2
2411# -----
2412
2413
2414def _test_padv2(data, mode="CONSTANT", quantized=False):
2415    """ One iteration of PADV2 """
2416
2417    assert len(data) == 2 or len(data) == 3
2418
2419    with_constant_values = len(data) == 3
2420
2421    # Test with tensor and constant
2422    with tf.Graph().as_default():
2423        in_data = [array_ops.placeholder(shape=data[0].shape, dtype="float32", name="in")]
2424
2425        if quantized:
2426            # fake_quant will keep the tensors in float32 until the conversion in the session
2427            input_range = {"inq_0": (-100, 100)}
2428            inq_data = [
2429                tf.quantization.fake_quant_with_min_max_args(
2430                    in_data[0], min=-100, max=100, name="inq_0"
2431                )
2432            ]
2433            if with_constant_values:
2434                in_constant_values = constant_op.constant(
2435                    data[2], shape=data[2].shape, dtype="float32", name="in_constant_values"
2436                )
2437                inq_constant_values = tf.quantization.fake_quant_with_min_max_args(
2438                    in_constant_values, min=-100, max=100, name="inq_constant_values"
2439                )
2440                out = array_ops.pad_v2(
2441                    inq_data[0],
2442                    ops.convert_to_tensor(data[1], dtype=data[1].dtype),
2443                    constant_values=inq_constant_values,
2444                    mode=mode,
2445                )
2446                out = tf.quantization.fake_quant_with_min_max_args(
2447                    out, min=-100, max=100, name="out"
2448                )
2449            else:
2450                out = array_ops.pad_v2(
2451                    inq_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode
2452                )
2453            compare_tflite_with_tvm(
2454                [data[0]], ["inq_0:0"], inq_data, [out], quantized=True, input_range=input_range
2455            )
2456        else:
2457            if with_constant_values:
2458                out = array_ops.pad_v2(
2459                    in_data[0],
2460                    ops.convert_to_tensor(data[1], dtype=data[1].dtype),
2461                    constant_values=ops.convert_to_tensor(data[2], dtype=data[2].dtype),
2462                    mode=mode,
2463                )
2464            else:
2465                out = array_ops.pad_v2(
2466                    in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode
2467                )
2468            compare_tflite_with_tvm([data[0]], ["in:0"], in_data, [out])
2469
2470
2471def test_forward_padv2():
2472    """ PADV2 """
2473    # Tests without Constant_values
2474    _test_padv2(
2475        [
2476            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3)),
2477            np.array([[1, 1], [2, 2], [1, 1], [2, 2]], dtype=np.int32),
2478        ]
2479    )
2480    _test_padv2(
2481        [
2482            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)),
2483            np.array([[2, 2], [1, 1], [1, 1]], dtype=np.int32),
2484        ]
2485    )
2486    _test_padv2(
2487        [
2488            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
2489            np.array([[1, 1], [2, 2]], dtype=np.int32),
2490        ]
2491    )
2492    _test_padv2(
2493        [
2494            np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3)),
2495            np.array([[1, 1], [2, 2]], dtype=np.int32),
2496        ]
2497    )
2498    _test_padv2(
2499        [
2500            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
2501            np.array([[1, 1], [2, 2]], dtype=np.int32),
2502        ],
2503        mode="REFLECT",
2504    )
2505    _test_padv2(
2506        [
2507            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
2508            np.array([[1, 1], [2, 2]], dtype=np.int32),
2509        ],
2510        mode="SYMMETRIC",
2511    )
2512    _test_padv2(
2513        [
2514            np.arange(0, 256, dtype=np.uint8).reshape((1, 256)),
2515            np.array([[1, 1], [2, 2]], dtype=np.int32),
2516        ],
2517        quantized=True,
2518    )
2519
2520    # Tests with Constant_values
2521    _test_padv2(
2522        [
2523            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3)),
2524            np.array([[1, 1], [2, 2], [1, 1], [2, 2]], dtype=np.int32),
2525            np.array([2], dtype=np.float32),
2526        ]
2527    )
2528    _test_padv2(
2529        [
2530            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)),
2531            np.array([[2, 2], [1, 1], [1, 1]], dtype=np.int32),
2532            np.array([1], dtype=np.float32),
2533        ]
2534    )
2535    _test_padv2(
2536        [
2537            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
2538            np.array([[1, 1], [2, 2]], dtype=np.int32),
2539            np.array([-1], dtype=np.float32),
2540        ]
2541    )
2542    _test_padv2(
2543        [
2544            np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3)),
2545            np.array([[1, 1], [2, 2]], dtype=np.int32),
2546            np.array([2], dtype=np.float32),
2547        ]
2548    )
2549    _test_padv2(
2550        [
2551            np.arange(0, 256, dtype=np.uint8).reshape((1, 256)),
2552            np.array([[1, 1], [2, 2]], dtype=np.int32),
2553            np.array([2], dtype=np.uint8),
2554        ],
2555        quantized=True,
2556    )
2557
2558    # Constant Values input can be scalar
2559    _test_padv2(
2560        [
2561            np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3)),
2562            np.array([[1, 1], [2, 2], [1, 1], [2, 2]], dtype=np.int32),
2563            np.float32(2),
2564        ]
2565    )
2566    _test_padv2(
2567        [
2568            np.arange(0, 256, dtype=np.uint8).reshape((1, 256)),
2569            np.array([[1, 1], [2, 2]], dtype=np.int32),
2570            np.uint8(10),
2571        ],
2572        quantized=True,
2573    )
2574
2575
2576#######################################################################
2577# EXPAND_DIMS
2578# -----------
2579
2580
2581def _test_expand_dims(input_shape, input_type, axis, quantized=False):
2582    """ One iteration of EXPAND_DIMS """
2583    with tf.Graph().as_default():
2584        axis = ops.convert_to_tensor(axis, dtype=axis.dtype)
2585
2586        if quantized:
2587            # ignoring input_type as quantized requires uint8
2588            input = np.random.uniform(0, 256, input_shape).astype("uint8")
2589            in_input = tf.placeholder(dtype="float32", shape=input.shape, name="input")
2590
2591            input_range = {"q_input": (-100, 100)}
2592            inq_input = tf.quantization.fake_quant_with_min_max_args(
2593                in_input, min=-100, max=100, name="q_input"
2594            )
2595
2596            out = array_ops.expand_dims(inq_input, axis=axis)
2597            out = tf.quantization.fake_quant_with_min_max_args(out, min=-100, max=100, name="out")
2598
2599            compare_tflite_with_tvm(
2600                [input], ["q_input"], [inq_input], [out], quantized=True, input_range=input_range
2601            )
2602        else:
2603            input = np.random.uniform(-100, 100, input_shape).astype(input_type)
2604            in_input = tf.placeholder(dtype=input.dtype, shape=input.shape, name="input")
2605
2606            out = array_ops.expand_dims(in_input, axis=axis)
2607
2608            compare_tflite_with_tvm([input], ["input"], [in_input], [out])
2609
2610
2611def test_forward_expand_dims():
2612    """ EXPAND_DIMS """
2613    for quantized in [False, True]:
2614        _test_expand_dims((6, 2, 7, 5), "float32", np.int32(0), quantized=quantized)
2615        _test_expand_dims((1, 2, 3), "int32", np.int32(-2), quantized=quantized)
2616        _test_expand_dims((2, 4, 5), "float32", np.array([1], dtype=np.int32), quantized=quantized)
2617
2618
2619#######################################################################
2620# ONE_HOT
2621# -------
2622
2623
2624def _test_one_hot(indices, depth, on_value, off_value, axis=None):
2625    """ One iteration of One_Hot """
2626    with tf.Graph().as_default():
2627        in_indices = tf.placeholder(dtype=indices.dtype, shape=indices.shape, name="indices")
2628        in_depth = ops.convert_to_tensor(depth, dtype=depth.dtype)
2629        in_on_value = tf.placeholder(dtype=on_value.dtype, shape=on_value.shape, name="on_value")
2630        in_off_value = tf.placeholder(
2631            dtype=off_value.dtype, shape=off_value.shape, name="off_value"
2632        )
2633        if axis is not None:
2634            out = array_ops.one_hot(in_indices, in_depth, in_on_value, in_off_value, axis=axis)
2635        else:
2636            out = array_ops.one_hot(in_indices, in_depth, in_on_value, in_off_value)
2637        compare_tflite_with_tvm(
2638            [indices, on_value, off_value],
2639            ["indices", "on_value", "off_value"],
2640            [in_indices, in_on_value, in_off_value],
2641            [out],
2642        )
2643
2644
2645def test_forward_one_hot():
2646    """ One_Hot """
2647    _test_one_hot(np.int32(2), np.int32(8), np.int32(1), np.int32(0))
2648    _test_one_hot(np.int32(4), np.int32(8), np.float32(1), np.float32(0))
2649    _test_one_hot(np.array([1, 2, 3], dtype=np.int32), np.int32(8), np.int32(3), np.int32(-1))
2650    _test_one_hot(
2651        np.array([1, 2, 3], dtype=np.int32), np.int32(8), np.int32(3), np.int32(-1), axis=0
2652    )
2653
2654
2655#######################################################################
2656# Pack
2657# ----
2658
2659
2660def _test_pack(data, axis):
2661    """ One iteration of pack """
2662
2663    assert len(data) >= 1
2664
2665    with tf.Graph().as_default():
2666        in_data = [
2667            array_ops.placeholder(shape=tensor.shape, dtype=tensor.dtype, name="in_{}".format(idx))
2668            for idx, tensor in enumerate(data)
2669        ]
2670        out = array_ops.pack(in_data, axis=axis)
2671        name = ["in_{}:0".format(idx) for idx in range(len(data))]
2672
2673        compare_tflite_with_tvm(data, name, in_data, [out])
2674
2675
2676def test_forward_pack():
2677    """ Pack """
2678    _test_pack([np.arange(6).reshape((1, 2, 1, 3)), np.arange(6).reshape((1, 2, 1, 3))], 1)
2679
2680    _test_pack([np.arange(6).reshape((3, 2)), np.arange(6).reshape((3, 2))], 1)
2681
2682    _test_pack(
2683        [
2684            np.arange(6).reshape((2, 1, 1, 3)),
2685            np.arange(6).reshape((2, 1, 1, 3)),
2686            np.arange(6).reshape((2, 1, 1, 3)),
2687        ],
2688        1,
2689    )
2690
2691
2692#######################################################################
2693# Unpack
2694# ------
2695
2696
2697def _test_unpack(data, axis, num_unpacks):
2698    """ One iteration of UNPACK """
2699    with tf.Graph().as_default():
2700        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
2701        out = gen_array_ops.unpack(in_data, num=num_unpacks, axis=axis, name="unpack")
2702        out_names = ["out_" + str(n) + ":0" for n in range(num_unpacks)]
2703        compare_tflite_with_tvm([data], "Placeholder:0", [in_data], out, out_names=out_names)
2704
2705
2706def test_forward_unpack():
2707    """ UNPACK """
2708    _test_unpack(np.array(np.random.uniform(0, 5, (3, 1)), dtype=np.int32), axis=1, num_unpacks=1)
2709    _test_unpack(np.array(np.random.uniform(0, 5, (3, 4)), dtype=np.float32), axis=0, num_unpacks=3)
2710    # tflite 1.13 doesn't accept negative axis
2711    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
2712        _test_unpack(
2713            np.array(np.random.uniform(0, 5, (3, 6)), dtype=np.int32), axis=-2, num_unpacks=3
2714        )
2715        _test_unpack(
2716            np.array(np.random.uniform(0, 5, (2, 3, 4)), dtype=np.int32), axis=-3, num_unpacks=2
2717        )
2718
2719
2720#######################################################################
2721# Local response normalization
2722# ----------------------------
2723
2724
2725def _test_local_response_normalization(data, depth_radius, bias, alpha, beta):
2726    """ One iteration of LOCAL_RESPONSE_NORMALIZATION """
2727    with tf.Graph().as_default():
2728        in_data = array_ops.placeholder(shape=data.shape, dtype="float32", name="in_0")
2729        out = nn_ops.local_response_normalization(
2730            in_data, depth_radius=depth_radius, bias=bias, alpha=alpha, beta=beta
2731        )
2732        compare_tflite_with_tvm(data, "in_0:0", [in_data], [out])
2733
2734
2735def test_forward_local_response_normalization():
2736    """ LOCAL_RESPONSE_NORMALIZATION """
2737    data = np.random.uniform(size=(1, 6, 4, 3)).astype("float32")
2738    # LOCAL_RESPONSE_NORMALIZATION come with TFLite >= 1.14.0 fbs schema
2739    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
2740        _test_local_response_normalization(data, depth_radius=5, bias=1, alpha=1, beta=0.5)
2741
2742
2743#######################################################################
2744# L2 normalization
2745# ----------------
2746
2747
2748def _test_l2_normalization(data, axis, fused_activation_function=None):
2749    """ One iteration of L2_NORMALIZATION """
2750    with tf.Graph().as_default():
2751        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
2752        out = nn_impl.l2_normalize(in_data, axis)
2753        out = with_fused_activation_function(out, fused_activation_function)
2754        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
2755
2756
2757def test_forward_l2_normalization():
2758    """ L2_NORMALIZATION """
2759    data = np.random.uniform(size=(3, 6, 4)).astype("float32")
2760    _test_l2_normalization(data, axis=2)
2761    _test_l2_normalization(data, axis=2, fused_activation_function="RELU")
2762
2763
2764#######################################################################
2765# Logistic
2766# --------
2767
2768
2769def _test_logistic(data, quantized=False):
2770    """ One iteration of LOGISTIC """
2771    with tf.Graph().as_default():
2772        in_data = array_ops.placeholder(shape=data.shape, dtype="float32", name="in_0")
2773
2774        if quantized:
2775            inq_data = tf.quantization.fake_quant_with_min_max_args(
2776                in_data, min=-5, max=5, name="inq_0"
2777            )
2778            input_range = {"inq_0": (-5, 5)}
2779            out = math_ops.sigmoid(inq_data)
2780            out = tf.quantization.fake_quant_with_min_max_args(out, min=0, max=1, name="out")
2781            compare_tflite_with_tvm(
2782                data, "inq_0:0", [inq_data], [out], quantized=True, input_range=input_range
2783            )
2784        else:
2785            out = math_ops.sigmoid(in_data)
2786            compare_tflite_with_tvm(data, "in_0:0", [in_data], [out])
2787
2788
2789def test_forward_logistic():
2790    """ LOGISTIC """
2791    _test_logistic(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
2792    _test_logistic(np.random.uniform(0, 255, (3, 6)).astype(np.uint8), quantized=True)
2793
2794
2795#######################################################################
2796# Softmax
2797# -------
2798
2799
2800def _test_softmax(data):
2801    """ One iteration of softmax """
2802    with tf.Graph().as_default():
2803        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
2804        out = nn_ops.softmax(in_data)
2805        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
2806
2807
2808def test_forward_softmax():
2809    """ Softmax """
2810    _test_softmax(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
2811
2812
2813######################################################################
2814# Log_softmax
2815# -----------
2816
2817
2818def _test_log_softmax(data, quantized=False):
2819    """ One iteration of log_softmax """
2820    with tf.Graph().as_default():
2821        in_data = array_ops.placeholder(shape=data.shape, dtype="float32", name="in_0")
2822
2823        if quantized:
2824            inq_data = tf.quantization.fake_quant_with_min_max_args(
2825                in_data, min=-10, max=10, name="inq_0"
2826            )
2827            input_range = {"inq_0": (-10, 10)}
2828            # tflite log_softmax supports only the case when axis is not specified
2829            out = nn_ops.log_softmax(inq_data)
2830            out = tf.quantization.fake_quant_with_min_max_args(out, min=-20, max=0, name="out")
2831            compare_tflite_with_tvm(
2832                data, "inq_0:0", [inq_data], [out], quantized=True, input_range=input_range
2833            )
2834        else:
2835            out = nn_ops.log_softmax(in_data)
2836            compare_tflite_with_tvm(data, "in_0:0", [in_data], [out])
2837
2838
2839def test_forward_log_softmax():
2840    """ Log_softmax """
2841    _test_log_softmax(np.random.uniform(-10, 10, size=(3, 6)).astype(np.float32))
2842    _test_log_softmax(np.random.uniform(0, 255, (3, 6)).astype(np.uint8), quantized=True)
2843
2844
2845#######################################################################
2846# Tanh
2847# ----
2848
2849
2850def _test_tanh(data):
2851    """ One iteration of TANH """
2852    with tf.Graph().as_default():
2853        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
2854        out = math_ops.tanh(in_data)
2855        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
2856
2857
2858def test_forward_tanh():
2859    """ TANH """
2860    _test_tanh(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
2861
2862
2863#######################################################################
2864# ReLu
2865# ----
2866
2867
2868def _test_relu(data, quantized=False):
2869    """ One iteration of ReLU """
2870
2871    if quantized:
2872        if package_version.parse(tf.VERSION) < package_version.parse("2.1.0"):
2873            pytest.skip("Testcase requires tflite version >= 2.1.0")
2874        data_in = tf.keras.layers.Input(shape=data.shape[1:])
2875        relu = tf.keras.layers.ReLU()(data_in)
2876        keras_model = tf.keras.models.Model(inputs=data_in, outputs=relu)
2877        input_name = data_in.name.split(":")[0]
2878
2879        # To create quantized values with dynamic range of activations, needs representative dataset
2880        def representative_data_gen():
2881            for i in range(1):
2882                yield [data]
2883
2884        tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen)
2885
2886        tflite_output = run_tflite_graph(tflite_model_quant, data)
2887        tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
2888        tvm.testing.assert_allclose(
2889            np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5
2890        )
2891    else:
2892        with tf.Graph().as_default():
2893            in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
2894            out = nn_ops.relu(in_data)
2895            compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
2896
2897
2898def test_forward_relu():
2899    """ ReLU """
2900    _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
2901    _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)), quantized=True)
2902
2903
2904#######################################################################
2905# ReLU6
2906# -----
2907
2908
2909def _test_relu6(data, quantized=False):
2910    """ One iteration of ReLU6 """
2911    with tf.Graph().as_default():
2912        in_data = array_ops.placeholder(shape=data.shape, dtype="float32", name="in_0")
2913
2914        if quantized:
2915            inq_data = tf.quantization.fake_quant_with_min_max_args(
2916                in_data, min=-10, max=10, name="inq_0"
2917            )
2918            input_range = {"inq_0": (-10, 10)}
2919            out = nn_ops.relu6(inq_data)
2920            out = tf.quantization.fake_quant_with_min_max_args(out, min=0, max=6, name="out")
2921            compare_tflite_with_tvm(
2922                data, "inq_0:0", [inq_data], [out], quantized=True, input_range=input_range
2923            )
2924        else:
2925            out = nn_ops.relu6(in_data)
2926            compare_tflite_with_tvm(data, "in_0:0", [in_data], [out])
2927
2928
2929def test_forward_relu6():
2930    """ ReLU6 """
2931    _test_relu6(np.random.uniform(-10, 10, size=(3, 6)).astype(np.float32))
2932    _test_relu6(np.random.uniform(0, 255, (3, 6)).astype(np.uint8), quantized=True)
2933
2934
2935#######################################################################
2936# Leaky_ReLU
2937# ----------
2938
2939
2940def _test_leaky_relu(data, alpha, quantized=False):
2941    """ One iteration of Leaky_ReLU """
2942    with tf.Graph().as_default():
2943        in_data = array_ops.placeholder(shape=data.shape, dtype="float32", name="in_0")
2944
2945        if quantized:
2946            inq_data = tf.quantization.fake_quant_with_min_max_args(
2947                in_data, min=-3, max=2, name="inq_0"
2948            )
2949            input_range = {"inq_0": (-3, 2)}
2950            out = nn_ops.leaky_relu(inq_data, alpha)
2951            out = tf.quantization.fake_quant_with_min_max_args(out, min=-3, max=2, name="out")
2952            compare_tflite_with_tvm(
2953                data, "inq_0:0", [inq_data], [out], quantized=True, input_range=input_range
2954            )
2955        else:
2956            out = nn_ops.leaky_relu(in_data, alpha)
2957            compare_tflite_with_tvm(data, "in_0:0", [in_data], [out])
2958
2959
2960def test_forward_leaky_relu():
2961    """ Leaky_ReLU """
2962    _test_leaky_relu(np.random.uniform(-5, 5, (1, 6)).astype(np.float32), alpha=0.2)
2963    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
2964        _test_leaky_relu(
2965            np.random.uniform(0, 255, (2, 3)).astype(np.uint8), alpha=0.3, quantized=True
2966        )
2967
2968
2969#######################################################################
2970# ReLU_n1_to_1
2971# ------------
2972
2973
2974def _test_relu_n1_to_1(data, quantized=False):
2975    """ One iteration of ReLU_n1_to_1 """
2976    with tf.Graph().as_default():
2977        in_data = array_ops.placeholder(shape=data.shape, dtype="float32", name="in_0")
2978
2979        if quantized:
2980            inq_data = tf.quantization.fake_quant_with_min_max_args(
2981                in_data, min=-3, max=3, name="inq_0"
2982            )
2983            input_range = {"inq_0": (-3, 3)}
2984            # There is no such tf operation. The specific pattern will be replaced into RELU_N1_TO_1 by tflite
2985            out = math_ops.maximum(-1.0, math_ops.minimum(inq_data, 1.0))
2986            out = tf.quantization.fake_quant_with_min_max_args(out, min=-1, max=1, name="out")
2987            compare_tflite_with_tvm(
2988                data, "inq_0:0", [inq_data], [out], quantized=True, input_range=input_range
2989            )
2990        else:
2991            out = math_ops.maximum(-1.0, math_ops.minimum(in_data, 1.0))
2992            compare_tflite_with_tvm(data, "in_0:0", [in_data], [out])
2993
2994
2995def test_forward_relu_n1_to_1():
2996    """ ReLU_n1_to_1 """
2997    _test_relu_n1_to_1(np.random.uniform(-3, 3, (1, 6)).astype(np.float32))
2998    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
2999        _test_relu_n1_to_1(np.random.uniform(0, 255, (3, 6)).astype(np.uint8), quantized=True)
3000
3001
3002#######################################################################
3003# PReLU
3004# -----
3005
3006
3007def _test_prelu(data, alpha):
3008    """ One iteration of PReLU """
3009    with tf.Graph().as_default():
3010        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
3011        # This specific pattern will be replaced into PRelu by tflite
3012        out = nn_ops.relu(in_data) + (-alpha * nn_ops.relu(-in_data))
3013        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
3014
3015
3016def test_forward_prelu():
3017    """ PReLU """
3018    _test_prelu(
3019        np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32"),
3020        np.full((3,), 0.2, dtype="float32"),
3021    )
3022    _test_prelu(
3023        np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32"),
3024        np.full((1, 1, 3), 0.2, dtype="float32"),
3025    )
3026
3027
3028#######################################################################
3029# DepthToSpace
3030# ------------
3031
3032
3033def _test_depthtospace(data, block_size):
3034    """ One iteration of depth_to_space operation with given data and block size """
3035
3036    with tf.Graph().as_default():
3037        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
3038        out = array_ops.depth_to_space(in_data, block_size)
3039        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
3040
3041
3042def test_forward_depthtospace():
3043    # DEPTH_TO_SPACE comes with TFLite >= 1.15.0 fbs schema
3044    if package_version.parse(tf.VERSION) >= package_version.parse("1.15.0"):
3045        _test_depthtospace(np.random.normal(size=[1, 32, 32, 4]).astype("float32"), 2)
3046        _test_depthtospace(np.random.normal(size=[1, 16, 8, 32]).astype("float32"), 4)
3047
3048
3049#######################################################################
3050# SpaceToDepth
3051# ------------
3052
3053
3054def _test_spacetodepth(data, block_size):
3055    """ One iteration of space_to_depth operation with given data and block size """
3056
3057    with tf.Graph().as_default():
3058        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
3059        out = array_ops.space_to_depth(in_data, block_size)
3060        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
3061
3062
3063def test_forward_spacetodepth():
3064    _test_spacetodepth(np.random.normal(size=[1, 32, 32, 4]).astype("float32"), 2)
3065    _test_spacetodepth(np.random.normal(size=[1, 16, 8, 32]).astype("float32"), 4)
3066
3067
3068#######################################################################
3069# ReverseSequence
3070# ---------------
3071
3072
3073def _test_reverse_sequence(shape, dtype, seq_lengths, batch_axis, seq_axis):
3074    """ One iteration of reverse_sequence operation with given data and attributes """
3075
3076    data = np.random.uniform(0, 100, size=shape).astype(dtype)
3077    with tf.Graph().as_default():
3078        in_data = array_ops.placeholder(dtype=dtype, name="input", shape=shape)
3079        out = tf.reverse_sequence(
3080            in_data, seq_lengths=seq_lengths, batch_axis=batch_axis, seq_axis=seq_axis
3081        )
3082
3083        compare_tflite_with_tvm(data, "input", [in_data], [out])
3084
3085
3086def test_forward_reverse_sequence():
3087    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
3088        _test_reverse_sequence([4, 3], "float32", [3, 2, 1], 1, 0)
3089        _test_reverse_sequence([4, 3], "float32", [3, 2, 1, 3], 0, 1)
3090        _test_reverse_sequence([2, 3, 3, 3], "float32", [2, 3, 2], 2, 1)
3091        _test_reverse_sequence([2, 4, 6, 4, 5], "float32", [5, 3], 0, 2)
3092        _test_reverse_sequence([2, 4, 6, 4, 5], "float32", [5, 3, 1, 4], 3, 2)
3093
3094
3095#######################################################################
3096# Sparse To Dense
3097# ---------------
3098def _test_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape):
3099    # tflite 1.13 convert method does not accept empty shapes
3100    if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
3101        with tf.Graph().as_default():
3102            indices = tf.placeholder(
3103                shape=sparse_indices.shape, dtype=str(sparse_indices.dtype), name="indices"
3104            )
3105            values = tf.placeholder(
3106                shape=sparse_values.shape, dtype=str(sparse_values.dtype), name="values"
3107            )
3108            oshape = tf.constant(
3109                output_shape, shape=output_shape.shape, dtype=str(output_shape.dtype)
3110            )
3111
3112            if default_value == None:
3113                output = tf.sparse_to_dense(indices, oshape, values)
3114                compare_tflite_with_tvm(
3115                    [sparse_indices, sparse_values],
3116                    ["indices", "values"],
3117                    [indices, values],
3118                    [output],
3119                )
3120            else:
3121                dv = tf.placeholder(shape=(), dtype=str(default_value.dtype), name="default_value")
3122                output = tf.sparse_to_dense(indices, oshape, values, dv)
3123                compare_tflite_with_tvm(
3124                    [sparse_indices, sparse_values, default_value],
3125                    ["indices", "values", "default_value"],
3126                    [indices, values, dv],
3127                    [output],
3128                )
3129
3130
3131def test_forward_sparse_to_dense():
3132    """
3133    Works in tvm/topi/tensorflow. But tflite converter breaks this test case
3134    _test_sparse_to_dense(
3135        np.int32(1),
3136        np.int32(3),
3137        np.int32(0),
3138        np.array([5]).astype("int32")
3139    )
3140    """
3141    # vector
3142    _test_sparse_to_dense(
3143        np.array([0, 1, 4]).astype("int32"),
3144        np.array([3, 3, 3]).astype("int32"),
3145        np.int32(0),
3146        np.array([5]).astype("int32"),
3147    )
3148    # vector nXd
3149    _test_sparse_to_dense(
3150        np.array([[0, 0], [1, 2]]).astype("int32"),
3151        np.array([1, 2]).astype("int32"),
3152        np.int32(0),
3153        np.array([3, 4]).astype("int32"),
3154    )
3155    _test_sparse_to_dense(
3156        np.array([[0, 0, 0], [1, 2, 3]]).astype("int32"),
3157        np.array([1, 2]).astype("int32"),
3158        np.int32(4),
3159        np.array([2, 3, 4]).astype("int32"),
3160    )
3161    # floats
3162    _test_sparse_to_dense(
3163        np.array([0, 1, 4]).astype("int32"),
3164        np.array([3.1, 3.1, 3.1]).astype("float32"),
3165        np.float32(3.5),
3166        np.array([5]).astype("int32"),
3167    )
3168    # default value not specified
3169    _test_sparse_to_dense(
3170        np.array([0, 1, 4]).astype("int32"),
3171        np.array([3.1, 3.1, 3.1]).astype("float32"),
3172        None,
3173        np.array([5]).astype("int32"),
3174    )
3175
3176
3177#######################################################################
3178# Fully Connected
3179# ---------------
3180
3181
3182def _test_fully_connected(tensor_in_sizes, const_input, filter_in_sizes, bias_in_size=None):
3183    """ One iteration of fully connected """
3184
3185    total_size_1 = np.prod(tensor_in_sizes)
3186    total_size_2 = np.prod(filter_in_sizes)
3187
3188    assert (
3189        int(total_size_1 / tensor_in_sizes[0]) == filter_in_sizes[0]
3190    ), "input size and filter size are mismatched"
3191
3192    # Initializes the input tensor with array containing incrementing
3193    # numbers from 1.
3194    data_array = np.arange(1, total_size_1 + 1, dtype=np.float32)
3195    filter_array = np.arange(1, total_size_2 + 1, dtype=np.float32)
3196
3197    with tf.Graph().as_default():
3198        in_name = "input"
3199        in_data = (
3200            constant_op.constant(data_array, shape=tensor_in_sizes, dtype=np.float32, name=in_name)
3201            if const_input
3202            else array_ops.placeholder(shape=tensor_in_sizes, dtype=np.float32, name=in_name)
3203        )
3204
3205        in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype=np.float32)
3206
3207        # reshape N H W C into N H*W*C
3208        in_data_reshape = array_ops.reshape(in_data, [tensor_in_sizes[0], -1])
3209
3210        out = math_ops.mat_mul(in_data_reshape, in_filter)
3211
3212        # if we have bias
3213        if bias_in_size:
3214            assert bias_in_size[0] == filter_in_sizes[1], "bias and filter size are mismatched"
3215            bias_array = np.arange(1, bias_in_size[0] + 1, dtype=np.float32)
3216            in_bias = constant_op.constant(bias_array, shape=bias_in_size, dtype=np.float32)
3217            out = nn_ops.bias_add(out, in_bias)
3218
3219        data_array = np.reshape(data_array, tensor_in_sizes).astype(np.float32)
3220        compare_tflite_with_tvm(data_array, [] if const_input else in_data.name, [in_data], [out])
3221
3222
3223def test_forward_fully_connected():
3224    """ Fully Connected """
3225    for const_input in [False, True]:
3226        _test_fully_connected([1, 1, 1, 150], const_input, [150, 100])
3227        _test_fully_connected([1, 1, 1, 150], const_input, [150, 100], [100])
3228        _test_fully_connected([5, 1, 1, 150], const_input, [150, 100])
3229        _test_fully_connected([5, 1, 1, 150], const_input, [150, 100], [100])
3230
3231
3232#######################################################################
3233# REVERSE_V2
3234# ----------
3235
3236
3237def _test_reverse_v2(input_shape, axis, dtype):
3238    """ One iteration of REVERSE_V2 """
3239    with tf.Graph().as_default():
3240        input = np.random.randint(0, 100, size=input_shape).astype(dtype)
3241        in_input = tf.placeholder(dtype=input.dtype, shape=input.shape, name="input")
3242        in_axis = ops.convert_to_tensor(axis, dtype=axis.dtype)
3243
3244        out = array_ops.reverse(in_input, in_axis)
3245
3246        compare_tflite_with_tvm([input], ["input"], [in_input], [out])
3247
3248
3249def test_forward_reverse_v2():
3250    """ REVERSE_V2 """
3251    for dtype in ["float32", "int32"]:
3252        _test_reverse_v2((5), np.array([0], dtype="int32"), dtype)
3253        _test_reverse_v2((5, 6, 4, 2), np.array([2], dtype="int32"), dtype)
3254
3255
3256#######################################################################
3257# MATRIX_SET_DIAG
3258# ---------------
3259
3260
3261def _test_matrix_set_diag(input_shape, input_type, quantized=False):
3262    """ One iteration of MATRIX_SET_DIAG """
3263    with tf.Graph().as_default():
3264        diagonal_shape = list(input_shape[:-2])
3265        diagonal_shape.append(min(input_shape[-2], input_shape[-1]))
3266
3267        if quantized:
3268            # ignoring input_type as quantized requires uint8
3269            input = np.random.uniform(0, 256, input_shape).astype("uint8")
3270            in_input = tf.placeholder(dtype="float32", shape=input.shape, name="input")
3271            inq_input = tf.quantization.fake_quant_with_min_max_args(
3272                in_input, min=-100, max=100, name="q_input"
3273            )
3274
3275            diagonal = np.random.uniform(0, 256, diagonal_shape).astype("uint8")
3276            in_diagonal = tf.placeholder(dtype="float32", shape=diagonal.shape, name="diagonal")
3277            inq_diagonal = tf.quantization.fake_quant_with_min_max_args(
3278                in_diagonal, min=-100, max=100, name="q_diagonal"
3279            )
3280
3281            input_range = {"q_input": (-100, 100), "q_diagonal": (-100, 100)}
3282
3283            out = array_ops.matrix_set_diag(inq_input, inq_diagonal)
3284            out = tf.quantization.fake_quant_with_min_max_args(out, min=-100, max=100, name="out")
3285
3286            compare_tflite_with_tvm(
3287                [input, diagonal],
3288                ["q_input", "q_diagonal"],
3289                [inq_input, inq_diagonal],
3290                [out],
3291                quantized=True,
3292                input_range=input_range,
3293            )
3294        else:
3295            input = np.random.uniform(0, 100, input_shape).astype(input_type)
3296            diagonal = np.random.uniform(0, 100, diagonal_shape).astype(input_type)
3297
3298            in_input = tf.placeholder(dtype=input.dtype, shape=input.shape, name="input")
3299            in_diagonal = tf.placeholder(
3300                dtype=diagonal.dtype, shape=diagonal.shape, name="diagonal"
3301            )
3302
3303            out = array_ops.matrix_set_diag(in_input, in_diagonal)
3304
3305            compare_tflite_with_tvm(
3306                [input, diagonal], ["input", "diagonal"], [in_input, in_diagonal], [out]
3307            )
3308
3309
3310def test_forward_matrix_set_diag():
3311    """ MATRIX_SET_DIAG """
3312    for dtype in [np.float32, np.int32]:
3313        _test_matrix_set_diag((4, 4), dtype)
3314        _test_matrix_set_diag((5, 4, 3, 4), dtype)
3315        _test_matrix_set_diag((4, 4, 2), dtype)
3316
3317    _test_matrix_set_diag((4, 4), np.uint8, quantized=True)
3318    _test_matrix_set_diag((5, 4, 3, 4), np.uint8, quantized=True)
3319    _test_matrix_set_diag((4, 4, 2), np.uint8, quantized=True)
3320
3321
3322#######################################################################
3323# MATRIX_DIAG
3324# -----------
3325
3326
3327def _test_matrix_diag(diagonal_shape, dtype):
3328    """ One iteration of MATRIX_DIAG """
3329    with tf.Graph().as_default():
3330        diagonal = np.random.uniform(0, 100, diagonal_shape).astype(dtype)
3331        in_diagonal = tf.placeholder(dtype=diagonal.dtype, shape=diagonal.shape, name="diagonal")
3332
3333        out = array_ops.matrix_diag(in_diagonal)
3334
3335        compare_tflite_with_tvm(
3336            [diagonal], ["diagonal"], [in_diagonal], [out], experimental_new_converter=True
3337        )
3338
3339
3340def test_forward_matrix_diag():
3341    """ MATRIX_DIAG """
3342    for dtype in [np.float32, np.int32]:
3343        _test_matrix_diag((4), dtype)
3344        _test_matrix_diag((5, 4, 3), dtype)
3345        _test_matrix_diag((2, 3), dtype)
3346
3347
3348#######################################################################
3349# Custom Operators
3350# ----------------
3351
3352
3353def test_detection_postprocess():
3354    tf_model_file = tf_testing.get_workload_official(
3355        "http://download.tensorflow.org/models/object_detection/"
3356        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
3357        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/tflite_graph.pb",
3358    )
3359    converter = tf.lite.TFLiteConverter.from_frozen_graph(
3360        tf_model_file,
3361        input_arrays=["raw_outputs/box_encodings", "raw_outputs/class_predictions"],
3362        output_arrays=[
3363            "TFLite_Detection_PostProcess",
3364            "TFLite_Detection_PostProcess:1",
3365            "TFLite_Detection_PostProcess:2",
3366            "TFLite_Detection_PostProcess:3",
3367        ],
3368        input_shapes={
3369            "raw_outputs/box_encodings": (1, 1917, 4),
3370            "raw_outputs/class_predictions": (1, 1917, 91),
3371        },
3372    )
3373    converter.allow_custom_ops = True
3374    converter.inference_type = tf.lite.constants.FLOAT
3375    tflite_model = converter.convert()
3376    np.random.seed(0)
3377    box_encodings = np.random.uniform(size=(1, 1917, 4)).astype("float32")
3378    class_predictions = np.random.uniform(size=(1, 1917, 91)).astype("float32")
3379    tflite_output = run_tflite_graph(tflite_model, [box_encodings, class_predictions])
3380    tvm_output = run_tvm_graph(
3381        tflite_model,
3382        [box_encodings, class_predictions],
3383        ["raw_outputs/box_encodings", "raw_outputs/class_predictions"],
3384        num_output=4,
3385    )
3386
3387    # Check all output shapes are equal
3388    assert all(
3389        [
3390            tvm_tensor.shape == tflite_tensor.shape
3391            for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)
3392        ]
3393    )
3394
3395    # Check valid count is the same
3396    assert tvm_output[3] == tflite_output[3]
3397    valid_count = tvm_output[3][0]
3398
3399    # For boxes that do not have any detections, TFLite puts random values. Therefore, we compare
3400    # tflite and tvm tensors for only valid boxes.
3401    for i in range(0, valid_count):
3402        # Check bounding box co-ords
3403        tvm.testing.assert_allclose(
3404            np.squeeze(tvm_output[0][0][i]),
3405            np.squeeze(tflite_output[0][0][i]),
3406            rtol=1e-5,
3407            atol=1e-5,
3408        )
3409
3410        # Check the class
3411        # Stricter check to ensure class remains same
3412        np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i]))
3413
3414        # Check the score
3415        tvm.testing.assert_allclose(
3416            np.squeeze(tvm_output[2][0][i]),
3417            np.squeeze(tflite_output[2][0][i]),
3418            rtol=1e-5,
3419            atol=1e-5,
3420        )
3421
3422
3423#######################################################################
3424# Mobilenet
3425# ---------
3426
3427
3428def test_forward_mobilenet_v1():
3429    """Test the Mobilenet V1 TF Lite model."""
3430    # MobilenetV1
3431    tflite_model_file = tf_testing.get_workload_official(
3432        "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz",
3433        "mobilenet_v1_1.0_224.tflite",
3434    )
3435    with open(tflite_model_file, "rb") as f:
3436        tflite_model_buf = f.read()
3437    data = np.random.uniform(size=(1, 224, 224, 3)).astype("float32")
3438    tflite_output = run_tflite_graph(tflite_model_buf, data)
3439    tvm_output = run_tvm_graph(tflite_model_buf, data, "input")
3440    tvm.testing.assert_allclose(
3441        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5
3442    )
3443
3444
3445def test_forward_mobilenet_v2():
3446    """Test the Mobilenet V2 TF Lite model."""
3447    # MobilenetV2
3448    tflite_model_file = tf_testing.get_workload_official(
3449        "http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz",
3450        "mobilenet_v2_1.0_224.tflite",
3451    )
3452    with open(tflite_model_file, "rb") as f:
3453        tflite_model_buf = f.read()
3454    data = np.random.uniform(size=(1, 224, 224, 3)).astype("float32")
3455    tflite_output = run_tflite_graph(tflite_model_buf, data)
3456    tvm_output = run_tvm_graph(tflite_model_buf, data, "input")
3457    tvm.testing.assert_allclose(
3458        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5
3459    )
3460
3461
3462#######################################################################
3463# Mobilenet V3
3464# ------------
3465
3466
3467def test_forward_mobilenet_v3():
3468    """Test the Mobilenet V3 TF Lite model."""
3469    # In MobilenetV3, some ops are not supported before tf 1.15 fbs schema
3470    if package_version.parse(tf.VERSION) < package_version.parse("1.15.0"):
3471        return
3472    tflite_model_file = tf_testing.get_workload_official(
3473        "https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_1.0_float.tgz",
3474        "v3-large_224_1.0_float/v3-large_224_1.0_float.tflite",
3475    )
3476    with open(tflite_model_file, "rb") as f:
3477        tflite_model_buf = f.read()
3478    data = np.random.uniform(size=(1, 224, 224, 3)).astype("float32")
3479    tflite_output = run_tflite_graph(tflite_model_buf, data)
3480    tvm_output = run_tvm_graph(tflite_model_buf, data, "input")
3481    tvm.testing.assert_allclose(
3482        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5
3483    )
3484
3485
3486#######################################################################
3487# Inception
3488# ---------
3489
3490
3491def test_forward_inception_v3_net():
3492    """Test the Inception V3 TF Lite model."""
3493    # InceptionV3
3494    tflite_model_file = tf_testing.get_workload_official(
3495        "https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz",
3496        "inception_v3.tflite",
3497    )
3498    with open(tflite_model_file, "rb") as f:
3499        tflite_model_buf = f.read()
3500    data = np.random.uniform(size=(1, 299, 299, 3)).astype("float32")
3501    tflite_output = run_tflite_graph(tflite_model_buf, data)
3502    tvm_output = run_tvm_graph(tflite_model_buf, data, "input")
3503    tvm.testing.assert_allclose(
3504        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5
3505    )
3506
3507
3508def test_forward_inception_v4_net():
3509    """Test the Inception V4 TF Lite model."""
3510    # InceptionV4
3511    tflite_model_file = tf_testing.get_workload_official(
3512        "https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz",
3513        "inception_v4.tflite",
3514    )
3515    with open(tflite_model_file, "rb") as f:
3516        tflite_model_buf = f.read()
3517    data = np.random.uniform(size=(1, 299, 299, 3)).astype("float32")
3518    tflite_output = run_tflite_graph(tflite_model_buf, data)
3519    tvm_output = run_tvm_graph(tflite_model_buf, data, "input")
3520    tvm.testing.assert_allclose(
3521        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5
3522    )
3523
3524
3525def test_forward_inception_v4_net_batched():
3526    """Test the Inception V4 TF Lite model."""
3527    # InceptionV4
3528    tflite_model_file = tf_testing.get_workload_official(
3529        "https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz",
3530        "inception_v4.tflite",
3531    )
3532    with open(tflite_model_file, "rb") as f:
3533        tflite_model_buf = f.read()
3534    data = np.random.uniform(size=(4, 299, 299, 3)).astype("float32")
3535    tflite_output = run_tflite_graph(tflite_model_buf, data)
3536    tvm_output = run_tvm_graph(tflite_model_buf, data, "input")
3537    tvm.testing.assert_allclose(
3538        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5
3539    )
3540
3541
3542def test_forward_qnn_inception_v1_net():
3543    """Test the Quantized TFLite Inception model."""
3544    # InceptionV1
3545    tflite_model_file = tf_testing.get_workload_official(
3546        "https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_224_quant_20181026.tgz",
3547        "inception_v1_224_quant.tflite",
3548    )
3549    with open(tflite_model_file, "rb") as f:
3550        tflite_model_buf = f.read()
3551
3552    # Test image. Checking the labels because the requantize implementation is different between
3553    # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via
3554    # labels. Also, giving a real image, instead of random inputs.
3555    data = get_real_image(224, 224)
3556
3557    tflite_output = run_tflite_graph(tflite_model_buf, data)
3558    tflite_predictions = np.squeeze(tflite_output)
3559    tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
3560    tvm_output = run_tvm_graph(tflite_model_buf, data, "input")
3561    tvm_predictions = np.squeeze(tvm_output)
3562    tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
3563    tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
3564
3565
3566def test_forward_qnn_mobilenet_v1_net():
3567    """Test the Quantized TFLite Mobilenet V1 model."""
3568    # MobilenetV1
3569    tflite_model_file = tf_testing.get_workload_official(
3570        "https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
3571        "mobilenet_v1_1.0_224_quant.tflite",
3572    )
3573    with open(tflite_model_file, "rb") as f:
3574        tflite_model_buf = f.read()
3575
3576    # Test image. Checking the labels because the requantize implementation is different between
3577    # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via
3578    # labels. Also, giving a real image, instead of random inputs.
3579    data = get_real_image(224, 224)
3580
3581    tflite_output = run_tflite_graph(tflite_model_buf, data)
3582    tflite_predictions = np.squeeze(tflite_output)
3583    tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
3584    tvm_output = run_tvm_graph(tflite_model_buf, data, "input")
3585    tvm_predictions = np.squeeze(tvm_output)
3586    tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
3587    tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
3588
3589
3590def test_forward_qnn_mobilenet_v2_net():
3591    """Test the Quantized TFLite Mobilenet V2 model."""
3592    # MobilenetV2
3593    tflite_model_file = tf_testing.get_workload_official(
3594        "https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224_quant.tgz",
3595        "mobilenet_v2_1.0_224_quant.tflite",
3596    )
3597    with open(tflite_model_file, "rb") as f:
3598        tflite_model_buf = f.read()
3599
3600    # Test image. Checking the labels because the requantize implementation is different between
3601    # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via
3602    # labels. Also, giving a real image, instead of random inputs.
3603    data = get_real_image(224, 224)
3604
3605    tflite_output = run_tflite_graph(tflite_model_buf, data)
3606    tflite_predictions = np.squeeze(tflite_output)
3607    tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
3608    tvm_output = run_tvm_graph(tflite_model_buf, data, "input")
3609    tvm_predictions = np.squeeze(tvm_output)
3610    tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
3611    tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
3612
3613
3614#######################################################################
3615# Mobilenet V3 Quantized
3616# ----------------------
3617
3618
3619def test_forward_qnn_mobilenet_v3_net():
3620    """Test the Quantized TFLite Mobilenet V3 model."""
3621    # In MobilenetV3, some ops are not supported before tf 1.15 fbs schema
3622    if package_version.parse(tf.VERSION) < package_version.parse("1.15.0"):
3623        pytest.skip("Unsupported in tflite < 1.15.0")
3624    else:
3625        pytest.skip("This segfaults with tensorflow 1.15.2 and above")
3626
3627    tflite_model_file = tf_testing.get_workload_official(
3628        "https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_1.0_uint8.tgz",
3629        "v3-large_224_1.0_uint8/v3-large_224_1.0_uint8.tflite",
3630    )
3631    with open(tflite_model_file, "rb") as f:
3632        tflite_model_buf = f.read()
3633
3634    # Test image. Checking the labels because the requantize implementation is different between
3635    # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via
3636    # labels. Also, giving a real image, instead of random inputs.
3637    data = get_real_image(224, 224)
3638
3639    tflite_output = run_tflite_graph(tflite_model_buf, data)
3640    tflite_predictions = np.squeeze(tflite_output)
3641    tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
3642    tvm_output = run_tvm_graph(tflite_model_buf, data, "input")
3643    tvm_predictions = np.squeeze(tvm_output)
3644    tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
3645    tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
3646
3647
3648def test_forward_tflite2_qnn_resnet50():
3649    """Test the Quantized TFLite version 2.1.0 Resnet50 model."""
3650    if package_version.parse(tf.VERSION) >= package_version.parse("2.1.0"):
3651        tflite_model_file = download_testdata(
3652            "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/Quantized/resnet_50_quantized.tflite",
3653            "resnet_50_quantized.tflite",
3654        )
3655        with open(tflite_model_file, "rb") as f:
3656            tflite_model_buf = f.read()
3657
3658        data = pre_processed_image(224, 224)
3659
3660        tflite_output = run_tflite_graph(tflite_model_buf, data)
3661        tflite_predictions = np.squeeze(tflite_output)
3662        tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
3663        tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), "input_1")
3664        tvm_predictions = np.squeeze(tvm_output)
3665        tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
3666        tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
3667
3668
3669def test_forward_tflite2_qnn_inception_v1():
3670    """Test the Quantized TFLite version 2.1.0 Inception V1 model."""
3671    if package_version.parse(tf.VERSION) >= package_version.parse("2.1.0"):
3672        tflite_model_file = download_testdata(
3673            "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/Quantized/inception_v1_quantized.tflite",
3674            "inception_v1_quantized.tflite",
3675        )
3676        with open(tflite_model_file, "rb") as f:
3677            tflite_model_buf = f.read()
3678
3679        data = pre_processed_image(224, 224)
3680
3681        tflite_output = run_tflite_graph(tflite_model_buf, data)
3682        tflite_predictions = np.squeeze(tflite_output)
3683        tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
3684        tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), "input_1")
3685        tvm_predictions = np.squeeze(tvm_output)
3686        tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
3687        tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
3688
3689
3690def test_forward_tflite2_qnn_mobilenet_v2():
3691    """Test the Quantized TFLite version 2.1.0 Mobilenet V2 model."""
3692    if package_version.parse(tf.VERSION) >= package_version.parse("2.1.0"):
3693        tflite_model_file = download_testdata(
3694            "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/Quantized/mobilenet_v2_quantized.tflite",
3695            "mobilenet_v2_quantized.tflite",
3696        )
3697        with open(tflite_model_file, "rb") as f:
3698            tflite_model_buf = f.read()
3699
3700        data = pre_processed_image(224, 224)
3701
3702        tflite_output = run_tflite_graph(tflite_model_buf, data)
3703        tflite_predictions = np.squeeze(tflite_output)
3704        tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
3705        tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), "input_1")
3706        tvm_predictions = np.squeeze(tvm_output)
3707        tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
3708        tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
3709
3710
3711#######################################################################
3712# Quantized SSD Mobilenet
3713# -----------------------
3714
3715
3716def test_forward_qnn_coco_ssd_mobilenet_v1():
3717    """Test the quantized Coco SSD Mobilenet V1 TF Lite model."""
3718    pytest.skip(
3719        "LLVM bug - getExtendedVectorNumElements - "
3720        + "https://discuss.tvm.ai/t/segfault-in-llvm/3567. The workaround is to use a "
3721        + "specific target, for example, llvm -mpcu=core-avx2"
3722    )
3723
3724    tflite_model_file = tf_testing.get_workload_official(
3725        "https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip",
3726        "detect.tflite",
3727    )
3728
3729    with open(tflite_model_file, "rb") as f:
3730        tflite_model_buf = f.read()
3731
3732    data = get_real_image_object_detection(300, 300)
3733    tflite_output = run_tflite_graph(tflite_model_buf, data)
3734    tvm_output = run_tvm_graph(
3735        tflite_model_buf, data, "normalized_input_image_tensor", num_output=4
3736    )
3737
3738    # Check all output shapes are equal
3739    assert all(
3740        [
3741            tvm_tensor.shape == tflite_tensor.shape
3742            for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)
3743        ]
3744    )
3745
3746    # Check valid count is the same
3747    assert tvm_output[3] == tflite_output[3]
3748    valid_count = tvm_output[3][0]
3749
3750    # For boxes that do not have any detections, TFLite puts random values. Therefore, we compare
3751    # tflite and tvm tensors for only valid boxes.
3752    for i in range(0, valid_count):
3753        # We compare the bounding boxes whose prediction score is above 60%. This is typical in end
3754        # to end application where a low prediction score is discarded. This is also needed because
3755        # multiple low score bounding boxes can have same score and TFlite and TVM can have
3756        # different orderings for same score bounding boxes. Another reason for minor differences in
3757        # low score bounding boxes is the difference between TVM and TFLite for requantize operator.
3758        if tvm_output[2][0][i] > 0.6:
3759            # Check bounding box co-ords. The tolerances have to be adjusted, from 1e-5 to 1e-2,
3760            # because of differences between for requantiize operator in TFLite and TVM.
3761            tvm.testing.assert_allclose(
3762                np.squeeze(tvm_output[0][0][i]),
3763                np.squeeze(tflite_output[0][0][i]),
3764                rtol=1e-2,
3765                atol=1e-2,
3766            )
3767
3768            # Check the class
3769            # Stricter check to ensure class remains same
3770            np.testing.assert_equal(
3771                np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i])
3772            )
3773
3774            # Check the score
3775            tvm.testing.assert_allclose(
3776                np.squeeze(tvm_output[2][0][i]),
3777                np.squeeze(tflite_output[2][0][i]),
3778                rtol=1e-5,
3779                atol=1e-5,
3780            )
3781
3782
3783#######################################################################
3784# SSD Mobilenet
3785# -------------
3786
3787
3788def test_forward_coco_ssd_mobilenet_v1():
3789    """Test the FP32 Coco SSD Mobilenet V1 TF Lite model."""
3790    tflite_model_file = tf_testing.get_workload_official(
3791        "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tgz",
3792        "ssd_mobilenet_v1_coco_2018_01_28.tflite",
3793    )
3794
3795    with open(tflite_model_file, "rb") as f:
3796        tflite_model_buf = f.read()
3797
3798    np.random.seed(0)
3799    data = np.random.uniform(size=(1, 300, 300, 3)).astype("float32")
3800    tflite_output = run_tflite_graph(tflite_model_buf, data)
3801    tvm_output = run_tvm_graph(
3802        tflite_model_buf, data, "normalized_input_image_tensor", num_output=4
3803    )
3804
3805    # Check all output shapes are equal
3806    assert all(
3807        [
3808            tvm_tensor.shape == tflite_tensor.shape
3809            for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)
3810        ]
3811    )
3812
3813    # Check valid count is the same
3814    assert tvm_output[3] == tflite_output[3]
3815    valid_count = tvm_output[3][0]
3816
3817    # For boxes that do not have any detections, TFLite puts random values. Therefore, we compare
3818    # tflite and tvm tensors for only valid boxes.
3819    for i in range(0, valid_count):
3820        # Check bounding box co-ords
3821        tvm.testing.assert_allclose(
3822            np.squeeze(tvm_output[0][0][i]),
3823            np.squeeze(tflite_output[0][0][i]),
3824            rtol=1e-5,
3825            atol=1e-5,
3826        )
3827        # Check the class
3828        np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i]))
3829
3830        # Check the score
3831        tvm.testing.assert_allclose(
3832            np.squeeze(tvm_output[2][0][i]),
3833            np.squeeze(tflite_output[2][0][i]),
3834            rtol=1e-5,
3835            atol=1e-5,
3836        )
3837
3838
3839#######################################################################
3840# MediaPipe
3841# -------------
3842def test_forward_mediapipe_hand_landmark():
3843    """Test MediaPipe 2D hand landmark TF Lite model."""
3844    # MediaPipe 2D hand landmark TF
3845    tflite_model_file = download_testdata(
3846        "https://github.com/google/mediapipe/raw/v0.7.4/mediapipe/models/hand_landmark.tflite",
3847        "hand_landmark.tflite",
3848    )
3849    with open(tflite_model_file, "rb") as f:
3850        tflite_model_buf = f.read()
3851    data = np.random.uniform(size=(1, 256, 256, 3)).astype("float32")
3852    tflite_output = run_tflite_graph(tflite_model_buf, data)
3853    tvm_output = run_tvm_graph(tflite_model_buf, data, "input_1", num_output=2)
3854    for i in range(2):
3855        tvm.testing.assert_allclose(
3856            np.squeeze(tvm_output[i]), np.squeeze(tflite_output[i]), rtol=1e-5, atol=1e-5
3857        )
3858
3859
3860#######################################################################
3861# Main
3862# ----
3863if __name__ == "__main__":
3864    # BatchToSpaceND
3865    test_forward_batch_to_space_nd()
3866
3867    # SpaceToBatchND
3868    test_forward_space_to_batch_nd()
3869
3870    # Split
3871    test_forward_split()
3872
3873    # Transpose
3874    test_forward_transpose()
3875
3876    # Cast
3877    test_forward_cast()
3878
3879    # BatchMatMul
3880    test_forward_batch_matmul()
3881
3882    # Tile
3883    test_forward_tile()
3884
3885    # Query
3886    test_forward_shape()
3887
3888    # Transforms
3889    test_forward_concatenation()
3890    test_forward_pad()
3891    test_forward_pack()
3892    test_forward_unpack()
3893    test_forward_reshape()
3894    test_all_resize()
3895    test_forward_range()
3896    test_forward_squeeze()
3897    test_forward_slice()
3898    test_forward_topk()
3899    test_forward_gather()
3900    test_forward_gather_nd()
3901    test_forward_stridedslice()
3902    test_forward_depthtospace()
3903    test_forward_spacetodepth()
3904    test_forward_reverse_sequence()
3905    test_forward_sparse_to_dense()
3906    test_forward_select()
3907    test_forward_quantize_dequantize()
3908    test_forward_arg_min_max()
3909    test_forward_expand_dims()
3910    test_forward_reverse_v2()
3911    test_forward_matrix_set_diag()
3912    test_forward_matrix_diag()
3913
3914    # NN
3915    test_forward_convolution()
3916    test_forward_transpose_conv()
3917    test_forward_logistic()
3918    test_forward_pooling()
3919    test_forward_l2_pool2d()
3920    test_forward_softmax()
3921    test_forward_tanh()
3922    test_forward_relu()
3923    test_forward_relu6()
3924    test_forward_leaky_relu()
3925    test_forward_relu_n1_to_1()
3926    test_forward_log_softmax()
3927    test_forward_prelu()
3928    test_forward_fully_connected()
3929    test_forward_l2_normalization()
3930    test_forward_local_response_normalization()
3931
3932    # Elemwise
3933    test_all_elemwise()
3934    test_forward_add_n()
3935
3936    # Unary elemwise
3937    test_all_unary_elemwise()
3938    # Zeros Like
3939    test_forward_zeros_like()
3940
3941    # Fill
3942    test_forward_fill()
3943
3944    # Reduce
3945    test_all_reduce()
3946
3947    # Logical
3948    test_all_logical()
3949
3950    # Detection_PostProcess
3951    test_detection_postprocess()
3952
3953    # End to End
3954    test_forward_mobilenet_v1()
3955    test_forward_mobilenet_v2()
3956    test_forward_mobilenet_v3()
3957    test_forward_inception_v3_net()
3958    test_forward_inception_v4_net()
3959    test_forward_inception_v4_net_batched()
3960    test_forward_coco_ssd_mobilenet_v1()
3961    test_forward_mediapipe_hand_landmark()
3962
3963    # End to End quantized
3964    test_forward_qnn_inception_v1_net()
3965    test_forward_qnn_mobilenet_v1_net()
3966    test_forward_qnn_mobilenet_v2_net()
3967    # This also fails with a segmentation fault in my run
3968    # with Tflite 1.15.2
3969    test_forward_qnn_mobilenet_v3_net()
3970    test_forward_qnn_coco_ssd_mobilenet_v1()
3971
3972    # TFLite 2.1.0 quantized tests
3973    test_forward_tflite2_qnn_resnet50()
3974    test_forward_tflite2_qnn_inception_v1()
3975    test_forward_tflite2_qnn_mobilenet_v2()
3976