1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18"""Ethos-N integration end-to-end network tests""" 19 20import pytest 21 22pytest.importorskip("tflite") 23pytest.importorskip("tensorflow") 24 25from tvm import relay 26from tvm.relay.op.contrib.ethosn import ethosn_available 27from tvm.contrib import download 28import tvm.relay.testing.tf as tf_testing 29import tflite.Model 30from . import infrastructure as tei 31 32 33def _get_tflite_model(tflite_model_path, inputs_dict, dtype): 34 with open(tflite_model_path, "rb") as f: 35 tflite_model_buffer = f.read() 36 37 try: 38 tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buffer, 0) 39 except AttributeError: 40 tflite_model = tflite.Model.GetRootAsModel(tflite_model_buffer, 0) 41 shape_dict = {} 42 dtype_dict = {} 43 for input in inputs_dict: 44 input_shape = inputs_dict[input] 45 shape_dict[input] = input_shape 46 dtype_dict[input] = dtype 47 48 return relay.frontend.from_tflite( 49 tflite_model, 50 shape_dict=shape_dict, 51 dtype_dict=dtype_dict, 52 ) 53 54 55def _test_image_network( 56 model_url, 57 model_sub_path, 58 input_dict, 59 compile_hash, 60 output_count, 61 host_ops=0, 62 npu_partitions=1, 63 run=False, 64): 65 """Test an image network. 66 67 Parameters 68 ---------- 69 model_url : str 70 The URL to the model. 71 model_sub_path : str 72 The name of the model file. 73 input_dict : dict 74 The input dict. 75 compile_hash : str, set 76 The compile hash(es) to check the compilation output against. 77 output_count : int 78 The expected number of outputs. 79 host_ops : int 80 The expected number of host operators. 81 npu_partitions : int 82 The expected number of Ethos-N partitions. 83 run : bool 84 Whether or not to try running the network. If hardware isn't 85 available, the run will still take place but with a mocked 86 inference function, so the results will be incorrect. This is 87 therefore just to test the runtime flow is working rather than 88 to check the correctness/accuracy. 89 90 """ 91 if not ethosn_available(): 92 return 93 94 def get_model(): 95 if model_url[-3:] in ("tgz", "zip"): 96 model_path = tf_testing.get_workload_official( 97 model_url, 98 model_sub_path, 99 ) 100 else: 101 model_path = download.download_testdata( 102 model_url, 103 model_sub_path, 104 ) 105 return _get_tflite_model(model_path, input_dict, "uint8") 106 107 inputs = {} 108 for input_name in input_dict: 109 input_shape = input_dict[input_name] 110 inputs[input_name] = tei.get_real_image(input_shape[1], input_shape[2]) 111 112 mod, params = get_model() 113 m = tei.build(mod, params, npu=True, expected_host_ops=host_ops, npu_partitions=npu_partitions) 114 tei.assert_lib_hash(m.get_lib(), compile_hash) 115 if run: 116 tei.run(m, inputs, output_count, npu=True) 117 118 119def test_mobilenet_v1(): 120 # If this test is failing due to a hash mismatch, please notify @mbaret and 121 # @Leo-arm. The hash is there to catch any changes in the behaviour of the 122 # codegen, which could come about from either a change in Support Library 123 # version or a change in the Ethos-N codegen. To update this requires running 124 # on hardware that isn't available in CI. 125 _test_image_network( 126 model_url="https://storage.googleapis.com/download.tensorflow.org/" 127 "models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz", 128 model_sub_path="mobilenet_v1_1.0_224_quant.tflite", 129 input_dict={"input": (1, 224, 224, 3)}, 130 compile_hash="81637c89339201a07dc96e3b5dbf836a", 131 output_count=1, 132 host_ops=3, 133 npu_partitions=1, 134 run=True, 135 ) 136 137 138def test_inception_v3(): 139 # If this test is failing due to a hash mismatch, please notify @mbaret and 140 # @Leo-arm. The hash is there to catch any changes in the behaviour of the 141 # codegen, which could come about from either a change in Support Library 142 # version or a change in the Ethos-N codegen. To update this requires running 143 # on hardware that isn't available in CI. 144 _test_image_network( 145 model_url="https://storage.googleapis.com/download.tensorflow.org/" 146 "models/tflite_11_05_08/inception_v3_quant.tgz", 147 model_sub_path="inception_v3_quant.tflite", 148 input_dict={"input": (1, 299, 299, 3)}, 149 compile_hash="de0e175af610ebd45ccb03d170dc9664", 150 output_count=1, 151 host_ops=0, 152 npu_partitions=1, 153 ) 154 155 156def test_inception_v4(): 157 # If this test is failing due to a hash mismatch, please notify @mbaret and 158 # @Leo-arm. The hash is there to catch any changes in the behaviour of the 159 # codegen, which could come about from either a change in Support Library 160 # version or a change in the Ethos-N codegen. To update this requires running 161 # on hardware that isn't available in CI. 162 _test_image_network( 163 model_url="https://storage.googleapis.com/download.tensorflow.org/" 164 "models/inception_v4_299_quant_20181026.tgz", 165 model_sub_path="inception_v4_299_quant.tflite", 166 input_dict={"input": (1, 299, 299, 3)}, 167 compile_hash="06bf6cb56344f3904bcb108e54edfe87", 168 output_count=1, 169 host_ops=3, 170 npu_partitions=1, 171 ) 172 173 174def test_ssd_mobilenet_v1(): 175 # If this test is failing due to a hash mismatch, please notify @mbaret and 176 # @Leo-arm. The hash is there to catch any changes in the behaviour of the 177 # codegen, which could come about from either a change in Support Library 178 # version or a change in the Ethos-N codegen. To update this requires running 179 # on hardware that isn't available in CI. 180 _test_image_network( 181 model_url="https://storage.googleapis.com/download.tensorflow.org/" 182 "models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip", 183 model_sub_path="detect.tflite", 184 input_dict={"normalized_input_image_tensor": (1, 300, 300, 3)}, 185 compile_hash={"29aec6b184b09454b4323271aadf89b1", "6211d96103880b016baa85e638abddef"}, 186 output_count=4, 187 host_ops=28, 188 npu_partitions=2, 189 ) 190