1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17import os
18import pytest
19import tarfile
20
21import numpy as np
22
23from PIL import Image
24
25from tvm.driver import tvmc
26
27from tvm.contrib.download import download_testdata
28
29# Support functions
30
31
32def download_and_untar(model_url, model_sub_path, temp_dir):
33    model_tar_name = os.path.basename(model_url)
34    model_path = download_testdata(model_url, model_tar_name, module=["tvmc"])
35
36    if model_path.endswith("tgz") or model_path.endswith("gz"):
37        tar = tarfile.open(model_path)
38        tar.extractall(path=temp_dir)
39        tar.close()
40
41    return os.path.join(temp_dir, model_sub_path)
42
43
44def get_sample_compiled_module(target_dir):
45    """Support function that returns a TFLite compiled module"""
46    base_url = "https://storage.googleapis.com/download.tensorflow.org/models"
47    model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz"
48    model_file = download_and_untar(
49        "{}/{}".format(base_url, model_url),
50        "mobilenet_v1_1.0_224_quant.tflite",
51        temp_dir=target_dir,
52    )
53
54    return tvmc.compiler.compile_model(model_file, target="llvm")
55
56
57# PyTest fixtures
58
59
60@pytest.fixture(scope="session")
61def tflite_mobilenet_v1_1_quant(tmpdir_factory):
62    base_url = "https://storage.googleapis.com/download.tensorflow.org/models"
63    model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz"
64    model_file = download_and_untar(
65        "{}/{}".format(base_url, model_url),
66        "mobilenet_v1_1.0_224_quant.tflite",
67        temp_dir=tmpdir_factory.mktemp("data"),
68    )
69
70    return model_file
71
72
73@pytest.fixture(scope="session")
74def pb_mobilenet_v1_1_quant(tmpdir_factory):
75    base_url = "https://storage.googleapis.com/download.tensorflow.org/models"
76    model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz"
77    model_file = download_and_untar(
78        "{}/{}".format(base_url, model_url),
79        "mobilenet_v1_1.0_224_frozen.pb",
80        temp_dir=tmpdir_factory.mktemp("data"),
81    )
82
83    return model_file
84
85
86@pytest.fixture(scope="session")
87def keras_resnet50(tmpdir_factory):
88    try:
89        from tensorflow.keras.applications.resnet50 import ResNet50
90    except ImportError:
91        # not all environments provide TensorFlow, so skip this fixture
92        # if that is that case.
93        return ""
94
95    model_file_name = "{}/{}".format(tmpdir_factory.mktemp("data"), "resnet50.h5")
96    model = ResNet50(include_top=True, weights="imagenet", input_shape=(224, 224, 3), classes=1000)
97    model.save(model_file_name)
98
99    return model_file_name
100
101
102@pytest.fixture(scope="session")
103def onnx_resnet50():
104    base_url = "https://github.com/onnx/models/raw/master/vision/classification/resnet/model"
105    file_to_download = "resnet50-v2-7.onnx"
106    model_file = download_testdata(
107        "{}/{}".format(base_url, file_to_download), file_to_download, module=["tvmc"]
108    )
109
110    return model_file
111
112
113@pytest.fixture(scope="session")
114def tflite_compiled_module_as_tarfile(tmpdir_factory):
115
116    # Not all CI environments will have TFLite installed
117    # so we need to safely skip this fixture that will
118    # crash the tests that rely on it.
119    # As this is a pytest.fixture, we cannot take advantage
120    # of pytest.importorskip. Using the block below instead.
121    try:
122        import tflite
123    except ImportError:
124        print("Cannot import tflite, which is required by tflite_compiled_module_as_tarfile.")
125        return ""
126
127    target_dir = tmpdir_factory.mktemp("data")
128    graph, lib, params, _ = get_sample_compiled_module(target_dir)
129
130    module_file = os.path.join(target_dir, "mock.tar")
131    tvmc.compiler.save_module(module_file, graph, lib, params)
132
133    return module_file
134
135
136@pytest.fixture(scope="session")
137def imagenet_cat(tmpdir_factory):
138    tmpdir_name = tmpdir_factory.mktemp("data")
139    cat_file_name = "imagenet_cat.npz"
140
141    cat_url = "https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true"
142    image_path = download_testdata(cat_url, "inputs", module=["tvmc"])
143    resized_image = Image.open(image_path).resize((224, 224))
144    image_data = np.asarray(resized_image).astype("float32")
145    image_data = np.expand_dims(image_data, axis=0)
146
147    cat_file_full_path = os.path.join(tmpdir_name, cat_file_name)
148    np.savez(cat_file_full_path, input=image_data)
149
150    return cat_file_full_path
151