1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import sys
19import os
20import numpy as np
21import mxnet as mx
22from mxnet.test_utils import *
23curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
24sys.path.insert(0, os.path.join(curr_path, '../unittest'))
25from common import setup_module, with_seed, teardown
26from mxnet.gluon import utils
27import tarfile
28
29def _get_model():
30    if not os.path.exists('model/Inception-7-symbol.json'):
31        download('http://data.mxnet.io/models/imagenet/inception-v3.tar.gz')
32        with tarfile.open(name="inception-v3.tar.gz", mode="r:gz") as tf:
33            tf.extractall()
34
35def _dump_images(shape):
36    import skimage.io
37    import skimage.transform
38    img_list = []
39    for img in sorted(os.listdir('data/test_images/')):
40        img = skimage.io.imread('data/test_images/'+img)
41        short_egde = min(img.shape[:2])
42        yy = int((img.shape[0] - short_egde) / 2)
43        xx = int((img.shape[1] - short_egde) / 2)
44        img = img[yy : yy + short_egde, xx : xx + short_egde]
45        img = skimage.transform.resize(img, shape)
46        img_list.append(img)
47    imgs = np.asarray(img_list, dtype=np.float32).transpose((0, 3, 1, 2)) - 128
48    np.save('data/test_images_%d_%d.npy'%shape, imgs)
49
50def _get_data(shape):
51    hash_test_img = "355e15800642286e7fe607d87c38aeeab085b0cc"
52    hash_inception_v3 = "91807dfdbd336eb3b265dd62c2408882462752b9"
53    utils.download("http://data.mxnet.io/data/test_images_%d_%d.npy" % (shape),
54                   path="data/test_images_%d_%d.npy" % (shape),
55                   sha1_hash=hash_test_img)
56    utils.download("http://data.mxnet.io/data/inception-v3-dump.npz",
57                   path='data/inception-v3-dump.npz',
58                   sha1_hash=hash_inception_v3)
59
60@with_seed()
61def test_consistency(dump=False):
62    shape = (299, 299)
63    _get_model()
64    _get_data(shape)
65    if dump:
66        _dump_images(shape)
67        gt = None
68    else:
69        gt = {n: mx.nd.array(a) for n, a in np.load('data/inception-v3-dump.npz').items()}
70    data = np.load('data/test_images_%d_%d.npy'%shape)
71    sym, arg_params, aux_params = mx.model.load_checkpoint('model/Inception-7', 1)
72    arg_params['data'] = data
73    arg_params['softmax_label'] = np.random.randint(low=1, high=1000, size=(data.shape[0],))
74    ctx_list = [{'ctx': mx.gpu(0), 'data': data.shape, 'type_dict': {'data': data.dtype}},
75                {'ctx': mx.cpu(0), 'data': data.shape, 'type_dict': {'data': data.dtype}}]
76    gt = check_consistency(sym, ctx_list, arg_params=arg_params, aux_params=aux_params,
77                           rtol=1e-3, atol=1e-3, grad_req='null', raise_on_err=False, ground_truth=gt)
78    if dump:
79        np.savez('data/inception-v3-dump.npz', **{n: a.asnumpy() for n, a in gt.items()})
80
81if __name__ == '__main__':
82    test_consistency(False)
83