1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""Ethos-N integration end-to-end network tests"""
19
20import pytest
21
22pytest.importorskip("tflite")
23pytest.importorskip("tensorflow")
24
25from tvm import relay
26from tvm.relay.op.contrib.ethosn import ethosn_available
27from tvm.contrib import download
28import tvm.relay.testing.tf as tf_testing
29import tflite.Model
30from . import infrastructure as tei
31
32
33def _get_tflite_model(tflite_model_path, inputs_dict, dtype):
34    with open(tflite_model_path, "rb") as f:
35        tflite_model_buffer = f.read()
36
37    try:
38        tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buffer, 0)
39    except AttributeError:
40        tflite_model = tflite.Model.GetRootAsModel(tflite_model_buffer, 0)
41    shape_dict = {}
42    dtype_dict = {}
43    for input in inputs_dict:
44        input_shape = inputs_dict[input]
45        shape_dict[input] = input_shape
46        dtype_dict[input] = dtype
47
48    return relay.frontend.from_tflite(
49        tflite_model,
50        shape_dict=shape_dict,
51        dtype_dict=dtype_dict,
52    )
53
54
55def _test_image_network(
56    model_url,
57    model_sub_path,
58    input_dict,
59    compile_hash,
60    output_count,
61    host_ops=0,
62    npu_partitions=1,
63    run=False,
64):
65    """Test an image network.
66
67    Parameters
68    ----------
69    model_url : str
70        The URL to the model.
71    model_sub_path : str
72        The name of the model file.
73    input_dict : dict
74        The input dict.
75    compile_hash : str, set
76        The compile hash(es) to check the compilation output against.
77    output_count : int
78        The expected number of outputs.
79    host_ops : int
80        The expected number of host operators.
81    npu_partitions : int
82        The expected number of Ethos-N partitions.
83    run : bool
84        Whether or not to try running the network. If hardware isn't
85        available, the run will still take place but with a mocked
86        inference function, so the results will be incorrect. This is
87        therefore just to test the runtime flow is working rather than
88        to check the correctness/accuracy.
89
90    """
91    if not ethosn_available():
92        return
93
94    def get_model():
95        if model_url[-3:] in ("tgz", "zip"):
96            model_path = tf_testing.get_workload_official(
97                model_url,
98                model_sub_path,
99            )
100        else:
101            model_path = download.download_testdata(
102                model_url,
103                model_sub_path,
104            )
105        return _get_tflite_model(model_path, input_dict, "uint8")
106
107    inputs = {}
108    for input_name in input_dict:
109        input_shape = input_dict[input_name]
110        inputs[input_name] = tei.get_real_image(input_shape[1], input_shape[2])
111
112    mod, params = get_model()
113    m = tei.build(mod, params, npu=True, expected_host_ops=host_ops, npu_partitions=npu_partitions)
114    tei.assert_lib_hash(m.get_lib(), compile_hash)
115    if run:
116        tei.run(m, inputs, output_count, npu=True)
117
118
119def test_mobilenet_v1():
120    # If this test is failing due to a hash mismatch, please notify @mbaret and
121    # @Leo-arm. The hash is there to catch any changes in the behaviour of the
122    # codegen, which could come about from either a change in Support Library
123    # version or a change in the Ethos-N codegen. To update this requires running
124    # on hardware that isn't available in CI.
125    _test_image_network(
126        model_url="https://storage.googleapis.com/download.tensorflow.org/"
127        "models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
128        model_sub_path="mobilenet_v1_1.0_224_quant.tflite",
129        input_dict={"input": (1, 224, 224, 3)},
130        compile_hash="81637c89339201a07dc96e3b5dbf836a",
131        output_count=1,
132        host_ops=3,
133        npu_partitions=1,
134        run=True,
135    )
136
137
138def test_inception_v3():
139    # If this test is failing due to a hash mismatch, please notify @mbaret and
140    # @Leo-arm. The hash is there to catch any changes in the behaviour of the
141    # codegen, which could come about from either a change in Support Library
142    # version or a change in the Ethos-N codegen. To update this requires running
143    # on hardware that isn't available in CI.
144    _test_image_network(
145        model_url="https://storage.googleapis.com/download.tensorflow.org/"
146        "models/tflite_11_05_08/inception_v3_quant.tgz",
147        model_sub_path="inception_v3_quant.tflite",
148        input_dict={"input": (1, 299, 299, 3)},
149        compile_hash="de0e175af610ebd45ccb03d170dc9664",
150        output_count=1,
151        host_ops=0,
152        npu_partitions=1,
153    )
154
155
156def test_inception_v4():
157    # If this test is failing due to a hash mismatch, please notify @mbaret and
158    # @Leo-arm. The hash is there to catch any changes in the behaviour of the
159    # codegen, which could come about from either a change in Support Library
160    # version or a change in the Ethos-N codegen. To update this requires running
161    # on hardware that isn't available in CI.
162    _test_image_network(
163        model_url="https://storage.googleapis.com/download.tensorflow.org/"
164        "models/inception_v4_299_quant_20181026.tgz",
165        model_sub_path="inception_v4_299_quant.tflite",
166        input_dict={"input": (1, 299, 299, 3)},
167        compile_hash="06bf6cb56344f3904bcb108e54edfe87",
168        output_count=1,
169        host_ops=3,
170        npu_partitions=1,
171    )
172
173
174def test_ssd_mobilenet_v1():
175    # If this test is failing due to a hash mismatch, please notify @mbaret and
176    # @Leo-arm. The hash is there to catch any changes in the behaviour of the
177    # codegen, which could come about from either a change in Support Library
178    # version or a change in the Ethos-N codegen. To update this requires running
179    # on hardware that isn't available in CI.
180    _test_image_network(
181        model_url="https://storage.googleapis.com/download.tensorflow.org/"
182        "models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip",
183        model_sub_path="detect.tflite",
184        input_dict={"normalized_input_image_tensor": (1, 300, 300, 3)},
185        compile_hash={"29aec6b184b09454b4323271aadf89b1", "6211d96103880b016baa85e638abddef"},
186        output_count=4,
187        host_ops=28,
188        npu_partitions=2,
189    )
190