1import os
2
3import numpy as np
4import imageio
5from skimage import data_dir
6from skimage.io.collection import ImageCollection, MultiImage, alphanumeric_key
7from skimage.io import reset_plugins
8
9from skimage._shared import testing
10from skimage._shared.testing import assert_equal, assert_allclose, TestCase
11
12
13def test_string_split():
14    test_string = 'z23a'
15    test_str_result = ['z', 23, 'a']
16    assert_equal(alphanumeric_key(test_string), test_str_result)
17
18
19def test_string_sort():
20    filenames = ['f9.10.png', 'f9.9.png', 'f10.10.png', 'f10.9.png',
21                 'e9.png', 'e10.png', 'em.png']
22    expected_filenames = ['e9.png', 'e10.png', 'em.png', 'f9.9.png',
23                          'f9.10.png', 'f10.9.png', 'f10.10.png']
24    sorted_filenames = sorted(filenames, key=alphanumeric_key)
25    assert_equal(expected_filenames, sorted_filenames)
26
27def test_imagecollection_input():
28    """Test function for ImageCollection. The new behavior (implemented
29    in 0.16) allows the `pattern` argument to accept a list of strings
30    as the input.
31
32    Notes
33    -----
34        If correct, `images` will receive three images.
35    """
36    # Ensure that these images are part of the legacy datasets
37    # this means they will always be available in the user's install
38    # regarless of the availability of pooch
39    pattern = [os.path.join(data_dir, pic)
40               for pic in ['coffee.png',
41                           'chessboard_GRAY.png',
42                           'rocket.jpg']]
43    images = ImageCollection(pattern)
44    assert len(images) == 3
45
46
47class TestImageCollection(TestCase):
48    pattern = [os.path.join(data_dir, pic)
49               for pic in ['brick.png', 'color.png']]
50
51    pattern_matched = [os.path.join(data_dir, pic)
52                       for pic in ['brick.png', 'moon.png']]
53
54    def setUp(self):
55        reset_plugins()
56        # Generic image collection with images of different shapes.
57        self.images = ImageCollection(self.pattern)
58        # Image collection with images having shapes that match.
59        self.images_matched = ImageCollection(self.pattern_matched)
60        # Same images as a collection of frames
61        self.frames_matched = MultiImage(self.pattern_matched)
62
63    def test_len(self):
64        assert len(self.images) == 2
65
66    def test_getitem(self):
67        num = len(self.images)
68        for i in range(-num, num):
69            assert isinstance(self.images[i], np.ndarray)
70        assert_allclose(self.images[0],
71                        self.images[-num])
72
73        def return_img(n):
74            return self.images[n]
75        with testing.raises(IndexError):
76            return_img(num)
77        with testing.raises(IndexError):
78            return_img(-num - 1)
79
80    def test_slicing(self):
81        assert type(self.images[:]) is ImageCollection
82        assert len(self.images[:]) == 2
83        assert len(self.images[:1]) == 1
84        assert len(self.images[1:]) == 1
85        assert_allclose(self.images[0], self.images[:1][0])
86        assert_allclose(self.images[1], self.images[1:][0])
87        assert_allclose(self.images[1], self.images[::-1][0])
88        assert_allclose(self.images[0], self.images[::-1][1])
89
90    def test_files_property(self):
91        assert isinstance(self.images.files, list)
92
93        def set_files(f):
94            self.images.files = f
95        with testing.raises(AttributeError):
96            set_files('newfiles')
97
98    def test_custom_load_func_w_kwarg(self):
99        load_pattern = os.path.join(data_dir, 'no_time_for_that_tiny.gif')
100
101        def load_fn(f, step):
102            vid = imageio.get_reader(f)
103            seq = [v for v in vid.iter_data()]
104            return seq[::step]
105
106        ic = ImageCollection(load_pattern, load_func=load_fn, step=3)
107        # Each file should map to one image (array).
108        assert len(ic) == 1
109        # GIF file has 24 frames, so 24 / 3 equals 8.
110        assert len(ic[0]) == 8
111
112    def test_custom_load_func(self):
113
114        def load_fn(x):
115            return x
116
117        ic = ImageCollection(os.pathsep.join(self.pattern), load_func=load_fn)
118        assert_equal(ic[0], self.pattern[0])
119
120    def test_concatenate(self):
121        array = self.images_matched.concatenate()
122        expected_shape = (len(self.images_matched),) + self.images[0].shape
123        assert_equal(array.shape, expected_shape)
124
125    def test_concatenate_mismatched_image_shapes(self):
126        with testing.raises(ValueError):
127            self.images.concatenate()
128
129    def test_multiimage_imagecollection(self):
130        assert_equal(self.images_matched[0], self.frames_matched[0])
131        assert_equal(self.images_matched[1], self.frames_matched[1])
132