1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import mxnet as mx
19import numpy as np
20import scipy.ndimage
21from mxnet.test_utils import *
22from common import assertRaises, with_seed
23import shutil
24import tempfile
25import unittest
26
27from nose.tools import raises
28
29
30def _get_data(url, dirname):
31    import os, tarfile
32    download(url, dirname=dirname, overwrite=False)
33    fname = os.path.join(dirname, url.split('/')[-1])
34    tar = tarfile.open(fname)
35    source_images = [os.path.join(dirname, x.name) for x in tar.getmembers() if x.isfile()]
36    if len(source_images) < 1 or not os.path.isfile(source_images[0]):
37        # skip extracting if exists
38        tar.extractall(path=dirname)
39    tar.close()
40    return source_images
41
42def _generate_objects():
43    num = np.random.randint(1, 10)
44    xy = np.random.rand(num, 2)
45    wh = np.random.rand(num, 2) / 2
46    left = (xy[:, 0] - wh[:, 0])[:, np.newaxis]
47    right = (xy[:, 0] + wh[:, 0])[:, np.newaxis]
48    top = (xy[:, 1] - wh[:, 1])[:, np.newaxis]
49    bot = (xy[:, 1] + wh[:, 1])[:, np.newaxis]
50    boxes = np.maximum(0., np.minimum(1., np.hstack((left, top, right, bot))))
51    cid = np.random.randint(0, 20, size=num)
52    label = np.hstack((cid[:, np.newaxis], boxes)).ravel().tolist()
53    return [2, 5] + label
54
55def _test_imageiter_last_batch(imageiter_list, assert_data_shape):
56    test_iter = imageiter_list[0]
57    # test batch data shape
58    for _ in range(3):
59        for batch in test_iter:
60            assert batch.data[0].shape == assert_data_shape
61        test_iter.reset()
62    # test last batch handle(discard)
63    test_iter = imageiter_list[1]
64    i = 0
65    for batch in test_iter:
66        i += 1
67    assert i == 5
68    # test last_batch_handle(pad)
69    test_iter = imageiter_list[2]
70    i = 0
71    for batch in test_iter:
72        if i == 0:
73            first_three_data = batch.data[0][:2]
74        if i == 5:
75            last_three_data = batch.data[0][1:]
76        i += 1
77    assert i == 6
78    assert np.array_equal(first_three_data.asnumpy(), last_three_data.asnumpy())
79    # test last_batch_handle(roll_over)
80    test_iter = imageiter_list[3]
81    i = 0
82    for batch in test_iter:
83        if i == 0:
84            first_image = batch.data[0][0]
85        i += 1
86    assert i == 5
87    test_iter.reset()
88    first_batch_roll_over = test_iter.next()
89    assert np.array_equal(
90        first_batch_roll_over.data[0][1].asnumpy(), first_image.asnumpy())
91    assert first_batch_roll_over.pad == 2
92    # test iteratopr work properly after calling reset several times when last_batch_handle is roll_over
93    for _ in test_iter:
94        pass
95    test_iter.reset()
96    first_batch_roll_over_twice = test_iter.next()
97    assert np.array_equal(
98        first_batch_roll_over_twice.data[0][2].asnumpy(), first_image.asnumpy())
99    assert first_batch_roll_over_twice.pad == 1
100    # we've called next once
101    i = 1
102    for _ in test_iter:
103        i += 1
104    # test the third epoch with size 6
105    assert i == 6
106    # test shuffle option for sanity test
107    test_iter = imageiter_list[4]
108    for _ in test_iter:
109        pass
110
111
112class TestImage(unittest.TestCase):
113    IMAGES_URL = "http://data.mxnet.io/data/test_images.tar.gz"
114    IMAGES = []
115    IMAGES_DIR = None
116
117    @classmethod
118    def setupClass(cls):
119        cls.IMAGES_DIR = tempfile.mkdtemp()
120        cls.IMAGES = _get_data(cls.IMAGES_URL, cls.IMAGES_DIR)
121        print("Loaded {} images".format(len(cls.IMAGES)))
122
123    @classmethod
124    def teardownClass(cls):
125        if cls.IMAGES_DIR:
126            print("cleanup {}".format(cls.IMAGES_DIR))
127            shutil.rmtree(cls.IMAGES_DIR)
128
129    @raises(mx.base.MXNetError)
130    def test_imread_not_found(self):
131        x = mx.img.image.imread("/139810923jadjsajlskd.___adskj/blah.jpg")
132
133    def test_imread_vs_imdecode(self):
134        for img in TestImage.IMAGES:
135            with open(img, 'rb') as fp:
136                str_image = fp.read()
137                image = mx.image.imdecode(str_image, to_rgb=0)
138                image_read = mx.img.image.imread(img)
139                same(image.asnumpy(), image_read.asnumpy())
140
141    def test_imdecode(self):
142        try:
143            import cv2
144        except ImportError:
145            raise unittest.SkipTest("Unable to import cv2.")
146        for img in TestImage.IMAGES:
147            with open(img, 'rb') as fp:
148                str_image = fp.read()
149                image = mx.image.imdecode(str_image, to_rgb=0)
150            cv_image = cv2.imread(img)
151            assert_almost_equal(image.asnumpy(), cv_image)
152
153    def test_imdecode_bytearray(self):
154        try:
155            import cv2
156        except ImportError:
157            return
158        for img in TestImage.IMAGES:
159            with open(img, 'rb') as fp:
160                str_image = bytearray(fp.read())
161                image = mx.image.imdecode(str_image, to_rgb=0)
162            cv_image = cv2.imread(img)
163            assert_almost_equal(image.asnumpy(), cv_image)
164
165    @raises(mx.base.MXNetError)
166    def test_imdecode_empty_buffer(self):
167        mx.image.imdecode(b'', to_rgb=0)
168
169    @raises(mx.base.MXNetError)
170    def test_imdecode_invalid_image(self):
171        image = mx.image.imdecode(b'clearly not image content')
172        assert_equal(image, None)
173
174    def test_scale_down(self):
175        assert mx.image.scale_down((640, 480), (720, 120)) == (640, 106)
176        assert mx.image.scale_down((360, 1000), (480, 500)) == (360, 375)
177        assert mx.image.scale_down((300, 400), (0, 0)) == (0, 0)
178
179    @with_seed()
180    def test_resize_short(self):
181        try:
182            import cv2
183        except ImportError:
184            raise unittest.SkipTest("Unable to import cv2")
185        for img in TestImage.IMAGES:
186            cv_img = cv2.imread(img)
187            mx_img = mx.nd.array(cv_img[:, :, (2, 1, 0)])
188            h, w, _ = cv_img.shape
189            for _ in range(3):
190                new_size = np.random.randint(1, 1000)
191                if h > w:
192                    new_h, new_w = new_size * h // w, new_size
193                else:
194                    new_h, new_w = new_size, new_size * w // h
195                for interp in range(0, 2):
196                    # area-based/lanczos don't match with cv2?
197                    cv_resized = cv2.resize(cv_img, (new_w, new_h), interpolation=interp)
198                    mx_resized = mx.image.resize_short(mx_img, new_size, interp)
199                    assert_almost_equal(mx_resized.asnumpy()[:, :, (2, 1, 0)], cv_resized, atol=3)
200
201    @with_seed()
202    def test_imresize(self):
203        try:
204            import cv2
205        except ImportError:
206            raise unittest.SkipTest("Unable to import cv2")
207        for img in TestImage.IMAGES:
208            cv_img = cv2.imread(img)
209            mx_img = mx.nd.array(cv_img[:, :, (2, 1, 0)])
210            new_h = np.random.randint(1, 1000)
211            new_w = np.random.randint(1, 1000)
212            for interp_val in range(0, 2):
213                cv_resized = cv2.resize(cv_img, (new_w, new_h), interpolation=interp_val)
214                mx_resized = mx.image.imresize(mx_img, new_w, new_h, interp=interp_val)
215                assert_almost_equal(mx_resized.asnumpy()[:, :, (2, 1, 0)], cv_resized, atol=3)
216                out_img = mx.nd.zeros((new_h, new_w, 3), dtype=mx_img.dtype)
217                mx.image.imresize(mx_img, new_w, new_h, interp=interp_val, out=out_img)
218                assert_almost_equal(out_img.asnumpy()[:, :, (2, 1, 0)], cv_resized, atol=3)
219
220    def test_color_normalize(self):
221        for _ in range(10):
222            mean = np.random.rand(3) * 255
223            std = np.random.rand(3) + 1
224            width = np.random.randint(100, 500)
225            height = np.random.randint(100, 500)
226            src = np.random.rand(height, width, 3) * 255.
227            mx_result = mx.image.color_normalize(mx.nd.array(src),
228                mx.nd.array(mean), mx.nd.array(std))
229            assert_almost_equal(mx_result.asnumpy(), (src - mean) / std, atol=1e-3)
230
231    def test_imageiter(self):
232        im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
233        fname = './data/test_imageiter.lst'
234        file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x])
235                        for k, x in enumerate(TestImage.IMAGES)]
236        with open(fname, 'w') as f:
237            for line in file_list:
238                f.write(line + '\n')
239
240        test_list = ['imglist', 'path_imglist']
241        for dtype in ['int32', 'float32', 'int64', 'float64']:
242            for test in test_list:
243                imglist = im_list if test == 'imglist' else None
244                path_imglist = fname if test == 'path_imglist' else None
245                imageiter_list = [
246                    mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=imglist,
247                        path_imglist=path_imglist, path_root='', dtype=dtype),
248                    mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
249                        path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='discard'),
250                    mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
251                        path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad'),
252                    mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
253                        path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='roll_over'),
254                    mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, shuffle=True,
255                        path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad')
256                ]
257                _test_imageiter_last_batch(imageiter_list, (2, 3, 224, 224))
258
259    @with_seed()
260    def test_copyMakeBorder(self):
261        try:
262            import cv2
263        except ImportError:
264            raise unittest.SkipTest("Unable to import cv2")
265        for img in TestImage.IMAGES:
266            cv_img = cv2.imread(img)
267            mx_img = mx.nd.array(cv_img)
268            top = np.random.randint(1, 10)
269            bot = np.random.randint(1, 10)
270            left = np.random.randint(1, 10)
271            right = np.random.randint(1, 10)
272            new_h, new_w, _ = mx_img.shape
273            new_h += top + bot
274            new_w += left + right
275            val = [np.random.randint(1, 255)] * 3
276            for type_val in range(0, 5):
277                cv_border = cv2.copyMakeBorder(cv_img, top, bot, left, right, borderType=type_val, value=val)
278                mx_border = mx.image.copyMakeBorder(mx_img, top, bot, left, right, type=type_val, values=val)
279                assert_almost_equal(mx_border.asnumpy(), cv_border)
280                out_img = mx.nd.zeros((new_h , new_w, 3), dtype=mx_img.dtype)
281                mx.image.copyMakeBorder(mx_img, top, bot, left, right, type=type_val, values=val, out=out_img)
282                assert_almost_equal(out_img.asnumpy(), cv_border)
283
284    @with_seed()
285    def test_augmenters(self):
286        # ColorNormalizeAug
287        mean = np.random.rand(3) * 255
288        std = np.random.rand(3) + 1
289        width = np.random.randint(100, 500)
290        height = np.random.randint(100, 500)
291        src = np.random.rand(height, width, 3) * 255.
292        # We test numpy and mxnet NDArray inputs
293        color_norm_aug = mx.image.ColorNormalizeAug(mean=mx.nd.array(mean), std=std)
294        out_image = color_norm_aug(mx.nd.array(src))
295        assert_almost_equal(out_image.asnumpy(), (src - mean) / std, atol=1e-3)
296
297        # only test if all augmenters will work
298        # TODO(Joshua Zhang): verify the augmenter outputs
299        im_list = [[0, x] for x in TestImage.IMAGES]
300        test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=im_list,
301            resize=640, rand_crop=True, rand_resize=True, rand_mirror=True, mean=True,
302            std=np.array([1.1, 1.03, 1.05]), brightness=0.1, contrast=0.1, saturation=0.1,
303            hue=0.1, pca_noise=0.1, rand_gray=0.2, inter_method=10, path_root='', shuffle=True)
304        for batch in test_iter:
305            pass
306
307    def test_image_detiter(self):
308        im_list = [_generate_objects() + [x] for x in TestImage.IMAGES]
309        det_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='')
310        for _ in range(3):
311            for _ in det_iter:
312                pass
313        det_iter.reset()
314        val_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='')
315        det_iter = val_iter.sync_label_shape(det_iter)
316        assert det_iter.data_shape == val_iter.data_shape
317        assert det_iter.label_shape == val_iter.label_shape
318
319        # test batch_size is not divisible by number of images
320        det_iter = mx.image.ImageDetIter(4, (3, 300, 300), imglist=im_list, path_root='')
321        for _ in det_iter:
322            pass
323
324        # test file list with last batch handle
325        fname = './data/test_imagedetiter.lst'
326        im_list = [[k] + _generate_objects() + [x] for k, x in enumerate(TestImage.IMAGES)]
327        with open(fname, 'w') as f:
328            for line in im_list:
329                line = '\t'.join([str(k) for k in line])
330                f.write(line + '\n')
331
332        imageiter_list = [
333            mx.image.ImageDetIter(2, (3, 400, 400),
334                path_imglist=fname, path_root=''),
335            mx.image.ImageDetIter(3, (3, 400, 400),
336                path_imglist=fname, path_root='', last_batch_handle='discard'),
337            mx.image.ImageDetIter(3, (3, 400, 400),
338                path_imglist=fname, path_root='', last_batch_handle='pad'),
339            mx.image.ImageDetIter(3, (3, 400, 400),
340                path_imglist=fname, path_root='', last_batch_handle='roll_over'),
341            mx.image.ImageDetIter(3, (3, 400, 400), shuffle=True,
342                path_imglist=fname, path_root='', last_batch_handle='pad')
343        ]
344        _test_imageiter_last_batch(imageiter_list, (2, 3, 400, 400))
345
346    def test_det_augmenters(self):
347        # only test if all augmenters will work
348        # TODO(Joshua Zhang): verify the augmenter outputs
349        im_list = [_generate_objects() + [x] for x in TestImage.IMAGES]
350        det_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='',
351            resize=640, rand_crop=1, rand_pad=1, rand_gray=0.1, rand_mirror=True, mean=True,
352            std=np.array([1.1, 1.03, 1.05]), brightness=0.1, contrast=0.1, saturation=0.1,
353            pca_noise=0.1, hue=0.1, inter_method=10, min_object_covered=0.5,
354            aspect_ratio_range=(0.2, 5), area_range=(0.1, 4.0), min_eject_coverage=0.5,
355            max_attempts=50)
356        for batch in det_iter:
357            pass
358
359    @with_seed()
360    def test_random_size_crop(self):
361        # test aspect ratio within bounds
362        width = np.random.randint(100, 500)
363        height = np.random.randint(100, 500)
364        src = np.random.rand(height, width, 3) * 255.
365        ratio = (0.75, 1)
366        epsilon = 0.05
367        out, (x0, y0, new_w, new_h) = mx.image.random_size_crop(mx.nd.array(src), size=(width, height), area=0.08, ratio=ratio)
368        _, pts = mx.image.center_crop(mx.nd.array(src), size=(width, height))
369        if (x0, y0, new_w, new_h) != pts:
370            assert ratio[0] - epsilon <= float(new_w)/new_h <= ratio[1] + epsilon, \
371            'ration of new width and height out of the bound{}/{}={}'.format(new_w, new_h, float(new_w)/new_h)
372
373    @with_seed()
374    def test_imrotate(self):
375        # test correctness
376        xlin = np.expand_dims(np.linspace(0, 0.5, 30), axis=1)
377        ylin = np.expand_dims(np.linspace(0, 0.5, 60), axis=0)
378        np_img = np.expand_dims(xlin + ylin, axis=2)
379        # rotate with imrotate
380        nd_img = mx.nd.array(np_img.transpose((2, 0, 1)))  # convert to CHW
381        rot_angle = 6
382        args = {'src': nd_img, 'rotation_degrees': rot_angle, 'zoom_in': False, 'zoom_out': False}
383        nd_rot = mx.image.imrotate(**args)
384        npnd_rot = nd_rot.asnumpy().transpose((1, 2, 0))
385        # rotate with scipy
386        scipy_rot = scipy.ndimage.rotate(np_img, rot_angle, axes=(1, 0), reshape=False,
387                                         order=1, mode='constant', prefilter=False)
388        # cannot compare the edges (where image ends) because of different behavior
389        assert_almost_equal(scipy_rot[10:20, 20:40, :], npnd_rot[10:20, 20:40, :])
390
391        # test if execution raises exceptions in any allowed mode
392        # batch mode
393        img_in = mx.nd.random.uniform(0, 1, (5, 3, 30, 60), dtype=np.float32)
394        nd_rots = mx.nd.array([1, 2, 3, 4, 5], dtype=np.float32)
395        args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': False}
396        _ = mx.image.imrotate(**args)
397        args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': True}
398        _ = mx.image.imrotate(**args)
399        args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': True, 'zoom_out': False}
400        _ = mx.image.imrotate(**args)
401        # single image mode
402        nd_rots = 11
403        img_in = mx.nd.random.uniform(0, 1, (3, 30, 60), dtype=np.float32)
404        args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': False}
405        _ = mx.image.imrotate(**args)
406        args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': True}
407        _ = mx.image.imrotate(**args)
408        args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': True, 'zoom_out': False}
409        _ = mx.image.imrotate(**args)
410
411        # test if exceptions are correctly raised
412        # batch exception - zoom_in=zoom_out=True
413        img_in = mx.nd.random.uniform(0, 1, (5, 3, 30, 60), dtype=np.float32)
414        nd_rots = mx.nd.array([1, 2, 3, 4, 5], dtype=np.float32)
415        args={'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': True, 'zoom_out': True}
416        self.assertRaises(ValueError, mx.image.imrotate, **args)
417
418        # single image exception - zoom_in=zoom_out=True
419        img_in = mx.nd.random.uniform(0, 1, (3, 30, 60), dtype=np.float32)
420        nd_rots = 11
421        args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': True, 'zoom_out': True}
422        self.assertRaises(ValueError, mx.image.imrotate, **args)
423
424        # batch of images with scalar rotation
425        img_in = mx.nd.stack(nd_img, nd_img, nd_img)
426        nd_rots = 6
427        args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': False}
428        out = mx.image.imrotate(**args)
429        for img in out:
430            img = img.asnumpy().transpose((1, 2, 0))
431            assert_almost_equal(scipy_rot[10:20, 20:40, :], img[10:20, 20:40, :])
432
433        # single image exception - single image with vector rotation
434        img_in = mx.nd.random.uniform(0, 1, (3, 30, 60), dtype=np.float32)
435        nd_rots = mx.nd.array([1, 2, 3, 4, 5], dtype=np.float32)
436        args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': False}
437        self.assertRaises(TypeError, mx.image.imrotate, **args)
438
439    @with_seed()
440    def test_random_rotate(self):
441        angle_limits = [-5., 5.]
442        src_single_image = mx.nd.random.uniform(0, 1, (3, 30, 60),
443                                                dtype=np.float32)
444        out_single_image = mx.image.random_rotate(src_single_image,
445                                                  angle_limits)
446        self.assertEqual(out_single_image.shape, (3, 30, 60))
447        src_batch_image = mx.nd.stack(src_single_image,
448                                      src_single_image,
449                                      src_single_image)
450        out_batch_image = mx.image.random_rotate(src_batch_image,
451                                                 angle_limits)
452        self.assertEqual(out_batch_image.shape, (3, 3, 30, 60))
453
454
455if __name__ == '__main__':
456    import nose
457    nose.runmodule()
458