1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17import os 18import pytest 19import tarfile 20 21import numpy as np 22 23from PIL import Image 24 25from tvm.driver import tvmc 26 27from tvm.contrib.download import download_testdata 28 29# Support functions 30 31 32def download_and_untar(model_url, model_sub_path, temp_dir): 33 model_tar_name = os.path.basename(model_url) 34 model_path = download_testdata(model_url, model_tar_name, module=["tvmc"]) 35 36 if model_path.endswith("tgz") or model_path.endswith("gz"): 37 tar = tarfile.open(model_path) 38 tar.extractall(path=temp_dir) 39 tar.close() 40 41 return os.path.join(temp_dir, model_sub_path) 42 43 44def get_sample_compiled_module(target_dir): 45 """Support function that returns a TFLite compiled module""" 46 base_url = "https://storage.googleapis.com/download.tensorflow.org/models" 47 model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz" 48 model_file = download_and_untar( 49 "{}/{}".format(base_url, model_url), 50 "mobilenet_v1_1.0_224_quant.tflite", 51 temp_dir=target_dir, 52 ) 53 54 return tvmc.compiler.compile_model(model_file, target="llvm") 55 56 57# PyTest fixtures 58 59 60@pytest.fixture(scope="session") 61def tflite_mobilenet_v1_1_quant(tmpdir_factory): 62 base_url = "https://storage.googleapis.com/download.tensorflow.org/models" 63 model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz" 64 model_file = download_and_untar( 65 "{}/{}".format(base_url, model_url), 66 "mobilenet_v1_1.0_224_quant.tflite", 67 temp_dir=tmpdir_factory.mktemp("data"), 68 ) 69 70 return model_file 71 72 73@pytest.fixture(scope="session") 74def pb_mobilenet_v1_1_quant(tmpdir_factory): 75 base_url = "https://storage.googleapis.com/download.tensorflow.org/models" 76 model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz" 77 model_file = download_and_untar( 78 "{}/{}".format(base_url, model_url), 79 "mobilenet_v1_1.0_224_frozen.pb", 80 temp_dir=tmpdir_factory.mktemp("data"), 81 ) 82 83 return model_file 84 85 86@pytest.fixture(scope="session") 87def keras_resnet50(tmpdir_factory): 88 try: 89 from tensorflow.keras.applications.resnet50 import ResNet50 90 except ImportError: 91 # not all environments provide TensorFlow, so skip this fixture 92 # if that is that case. 93 return "" 94 95 model_file_name = "{}/{}".format(tmpdir_factory.mktemp("data"), "resnet50.h5") 96 model = ResNet50(include_top=True, weights="imagenet", input_shape=(224, 224, 3), classes=1000) 97 model.save(model_file_name) 98 99 return model_file_name 100 101 102@pytest.fixture(scope="session") 103def onnx_resnet50(): 104 base_url = "https://github.com/onnx/models/raw/master/vision/classification/resnet/model" 105 file_to_download = "resnet50-v2-7.onnx" 106 model_file = download_testdata( 107 "{}/{}".format(base_url, file_to_download), file_to_download, module=["tvmc"] 108 ) 109 110 return model_file 111 112 113@pytest.fixture(scope="session") 114def tflite_compiled_module_as_tarfile(tmpdir_factory): 115 116 # Not all CI environments will have TFLite installed 117 # so we need to safely skip this fixture that will 118 # crash the tests that rely on it. 119 # As this is a pytest.fixture, we cannot take advantage 120 # of pytest.importorskip. Using the block below instead. 121 try: 122 import tflite 123 except ImportError: 124 print("Cannot import tflite, which is required by tflite_compiled_module_as_tarfile.") 125 return "" 126 127 target_dir = tmpdir_factory.mktemp("data") 128 graph, lib, params, _ = get_sample_compiled_module(target_dir) 129 130 module_file = os.path.join(target_dir, "mock.tar") 131 tvmc.compiler.save_module(module_file, graph, lib, params) 132 133 return module_file 134 135 136@pytest.fixture(scope="session") 137def imagenet_cat(tmpdir_factory): 138 tmpdir_name = tmpdir_factory.mktemp("data") 139 cat_file_name = "imagenet_cat.npz" 140 141 cat_url = "https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true" 142 image_path = download_testdata(cat_url, "inputs", module=["tvmc"]) 143 resized_image = Image.open(image_path).resize((224, 224)) 144 image_data = np.asarray(resized_image).astype("float32") 145 image_data = np.expand_dims(image_data, axis=0) 146 147 cat_file_full_path = os.path.join(tmpdir_name, cat_file_name) 148 np.savez(cat_file_full_path, input=image_data) 149 150 return cat_file_full_path 151