1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import sys 19import os 20import numpy as np 21import mxnet as mx 22from mxnet.test_utils import * 23curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) 24sys.path.insert(0, os.path.join(curr_path, '../unittest')) 25from common import setup_module, with_seed, teardown 26from mxnet.gluon import utils 27import tarfile 28 29def _get_model(): 30 if not os.path.exists('model/Inception-7-symbol.json'): 31 download('http://data.mxnet.io/models/imagenet/inception-v3.tar.gz') 32 with tarfile.open(name="inception-v3.tar.gz", mode="r:gz") as tf: 33 tf.extractall() 34 35def _dump_images(shape): 36 import skimage.io 37 import skimage.transform 38 img_list = [] 39 for img in sorted(os.listdir('data/test_images/')): 40 img = skimage.io.imread('data/test_images/'+img) 41 short_egde = min(img.shape[:2]) 42 yy = int((img.shape[0] - short_egde) / 2) 43 xx = int((img.shape[1] - short_egde) / 2) 44 img = img[yy : yy + short_egde, xx : xx + short_egde] 45 img = skimage.transform.resize(img, shape) 46 img_list.append(img) 47 imgs = np.asarray(img_list, dtype=np.float32).transpose((0, 3, 1, 2)) - 128 48 np.save('data/test_images_%d_%d.npy'%shape, imgs) 49 50def _get_data(shape): 51 hash_test_img = "355e15800642286e7fe607d87c38aeeab085b0cc" 52 hash_inception_v3 = "91807dfdbd336eb3b265dd62c2408882462752b9" 53 utils.download("http://data.mxnet.io/data/test_images_%d_%d.npy" % (shape), 54 path="data/test_images_%d_%d.npy" % (shape), 55 sha1_hash=hash_test_img) 56 utils.download("http://data.mxnet.io/data/inception-v3-dump.npz", 57 path='data/inception-v3-dump.npz', 58 sha1_hash=hash_inception_v3) 59 60@with_seed() 61def test_consistency(dump=False): 62 shape = (299, 299) 63 _get_model() 64 _get_data(shape) 65 if dump: 66 _dump_images(shape) 67 gt = None 68 else: 69 gt = {n: mx.nd.array(a) for n, a in np.load('data/inception-v3-dump.npz').items()} 70 data = np.load('data/test_images_%d_%d.npy'%shape) 71 sym, arg_params, aux_params = mx.model.load_checkpoint('model/Inception-7', 1) 72 arg_params['data'] = data 73 arg_params['softmax_label'] = np.random.randint(low=1, high=1000, size=(data.shape[0],)) 74 ctx_list = [{'ctx': mx.gpu(0), 'data': data.shape, 'type_dict': {'data': data.dtype}}, 75 {'ctx': mx.cpu(0), 'data': data.shape, 'type_dict': {'data': data.dtype}}] 76 gt = check_consistency(sym, ctx_list, arg_params=arg_params, aux_params=aux_params, 77 rtol=1e-3, atol=1e-3, grad_req='null', raise_on_err=False, ground_truth=gt) 78 if dump: 79 np.savez('data/inception-v3-dump.npz', **{n: a.asnumpy() for n, a in gt.items()}) 80 81if __name__ == '__main__': 82 test_consistency(False) 83