1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# coding: utf-8
19# pylint: disable=
20"""Dataset container."""
21__all__ = ['MNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100',
22           'ImageRecordDataset', 'ImageFolderDataset']
23
24import os
25import gzip
26import tarfile
27import struct
28import warnings
29import numpy as np
30
31from .. import dataset
32from ...utils import download, check_sha1, _get_repo_file_url
33from .... import nd, image, recordio, base
34from .... import numpy as _mx_np  # pylint: disable=reimported
35from ....util import is_np_array
36
37
38class MNIST(dataset._DownloadedDataset):
39    """MNIST handwritten digits dataset from http://yann.lecun.com/exdb/mnist
40
41    Each sample is an image (in 3D NDArray) with shape (28, 28, 1).
42
43    Parameters
44    ----------
45    root : str, default $MXNET_HOME/datasets/mnist
46        Path to temp folder for storing data.
47    train : bool, default True
48        Whether to load the training or testing set.
49    transform : function, default None
50        A user defined callback that transforms each sample. For example::
51
52            transform=lambda data, label: (data.astype(np.float32)/255, label)
53
54    """
55    def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'mnist'),
56                 train=True, transform=None):
57        self._train = train
58        self._train_data = ('train-images-idx3-ubyte.gz',
59                            '6c95f4b05d2bf285e1bfb0e7960c31bd3b3f8a7d')
60        self._train_label = ('train-labels-idx1-ubyte.gz',
61                             '2a80914081dc54586dbdf242f9805a6b8d2a15fc')
62        self._test_data = ('t10k-images-idx3-ubyte.gz',
63                           'c3a25af1f52dad7f726cce8cacb138654b760d48')
64        self._test_label = ('t10k-labels-idx1-ubyte.gz',
65                            '763e7fa3757d93b0cdec073cef058b2004252c17')
66        self._namespace = 'mnist'
67        super(MNIST, self).__init__(root, transform)
68
69    def _get_data(self):
70        if self._train:
71            data, label = self._train_data, self._train_label
72        else:
73            data, label = self._test_data, self._test_label
74
75        namespace = 'gluon/dataset/'+self._namespace
76        data_file = download(_get_repo_file_url(namespace, data[0]),
77                             path=self._root,
78                             sha1_hash=data[1])
79        label_file = download(_get_repo_file_url(namespace, label[0]),
80                              path=self._root,
81                              sha1_hash=label[1])
82
83        with gzip.open(label_file, 'rb') as fin:
84            struct.unpack(">II", fin.read(8))
85            label = np.frombuffer(fin.read(), dtype=np.uint8).astype(np.int32)
86            if is_np_array():
87                label = _mx_np.array(label, dtype=label.dtype)
88
89        with gzip.open(data_file, 'rb') as fin:
90            struct.unpack(">IIII", fin.read(16))
91            data = np.frombuffer(fin.read(), dtype=np.uint8)
92            data = data.reshape(len(label), 28, 28, 1)
93
94        array_fn = _mx_np.array if is_np_array() else nd.array
95        self._data = array_fn(data, dtype=data.dtype)
96        self._label = label
97
98
99class FashionMNIST(MNIST):
100    """A dataset of Zalando's article images consisting of fashion products,
101    a drop-in replacement of the original MNIST dataset from
102    https://github.com/zalandoresearch/fashion-mnist
103
104    Each sample is an image (in 3D NDArray) with shape (28, 28, 1).
105
106    Parameters
107    ----------
108    root : str, default $MXNET_HOME/datasets/fashion-mnist'
109        Path to temp folder for storing data.
110    train : bool, default True
111        Whether to load the training or testing set.
112    transform : function, default None
113        A user defined callback that transforms each sample. For example::
114
115            transform=lambda data, label: (data.astype(np.float32)/255, label)
116
117    """
118    def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'fashion-mnist'),
119                 train=True, transform=None):
120        self._train = train
121        self._train_data = ('train-images-idx3-ubyte.gz',
122                            '0cf37b0d40ed5169c6b3aba31069a9770ac9043d')
123        self._train_label = ('train-labels-idx1-ubyte.gz',
124                             '236021d52f1e40852b06a4c3008d8de8aef1e40b')
125        self._test_data = ('t10k-images-idx3-ubyte.gz',
126                           '626ed6a7c06dd17c0eec72fa3be1740f146a2863')
127        self._test_label = ('t10k-labels-idx1-ubyte.gz',
128                            '17f9ab60e7257a1620f4ad76bbbaf857c3920701')
129        self._namespace = 'fashion-mnist'
130        super(MNIST, self).__init__(root, transform) # pylint: disable=bad-super-call
131
132
133class CIFAR10(dataset._DownloadedDataset):
134    """CIFAR10 image classification dataset from https://www.cs.toronto.edu/~kriz/cifar.html
135
136    Each sample is an image (in 3D NDArray) with shape (32, 32, 3).
137
138    Parameters
139    ----------
140    root : str, default $MXNET_HOME/datasets/cifar10
141        Path to temp folder for storing data.
142    train : bool, default True
143        Whether to load the training or testing set.
144    transform : function, default None
145        A user defined callback that transforms each sample. For example::
146
147            transform=lambda data, label: (data.astype(np.float32)/255, label)
148
149    """
150    def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'cifar10'),
151                 train=True, transform=None):
152        self._train = train
153        self._archive_file = ('cifar-10-binary.tar.gz', 'fab780a1e191a7eda0f345501ccd62d20f7ed891')
154        self._train_data = [('data_batch_1.bin', 'aadd24acce27caa71bf4b10992e9e7b2d74c2540'),
155                            ('data_batch_2.bin', 'c0ba65cce70568cd57b4e03e9ac8d2a5367c1795'),
156                            ('data_batch_3.bin', '1dd00a74ab1d17a6e7d73e185b69dbf31242f295'),
157                            ('data_batch_4.bin', 'aab85764eb3584312d3c7f65fd2fd016e36a258e'),
158                            ('data_batch_5.bin', '26e2849e66a845b7f1e4614ae70f4889ae604628')]
159        self._test_data = [('test_batch.bin', '67eb016db431130d61cd03c7ad570b013799c88c')]
160        self._namespace = 'cifar10'
161        super(CIFAR10, self).__init__(root, transform)
162
163    def _read_batch(self, filename):
164        with open(filename, 'rb') as fin:
165            data = np.frombuffer(fin.read(), dtype=np.uint8).reshape(-1, 3072+1)
166
167        return data[:, 1:].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1), \
168               data[:, 0].astype(np.int32)
169
170    def _get_data(self):
171        if any(not os.path.exists(path) or not check_sha1(path, sha1)
172               for path, sha1 in ((os.path.join(self._root, name), sha1)
173                                  for name, sha1 in self._train_data + self._test_data)):
174            namespace = 'gluon/dataset/'+self._namespace
175            filename = download(_get_repo_file_url(namespace, self._archive_file[0]),
176                                path=self._root,
177                                sha1_hash=self._archive_file[1])
178
179            with tarfile.open(filename) as tar:
180                tar.extractall(self._root)
181
182        if self._train:
183            data_files = self._train_data
184        else:
185            data_files = self._test_data
186        data, label = zip(*(self._read_batch(os.path.join(self._root, name))
187                            for name, _ in data_files))
188        data = np.concatenate(data)
189        label = np.concatenate(label)
190
191        array_fn = _mx_np.array if is_np_array() else nd.array
192        self._data = array_fn(data, dtype=data.dtype)
193        self._label = array_fn(label, dtype=label.dtype) if is_np_array() else label
194
195
196class CIFAR100(CIFAR10):
197    """CIFAR100 image classification dataset from https://www.cs.toronto.edu/~kriz/cifar.html
198
199    Each sample is an image (in 3D NDArray) with shape (32, 32, 3).
200
201    Parameters
202    ----------
203    root : str, default $MXNET_HOME/datasets/cifar100
204        Path to temp folder for storing data.
205    fine_label : bool, default False
206        Whether to load the fine-grained (100 classes) or coarse-grained (20 super-classes) labels.
207    train : bool, default True
208        Whether to load the training or testing set.
209    transform : function, default None
210        A user defined callback that transforms each sample. For example::
211
212            transform=lambda data, label: (data.astype(np.float32)/255, label)
213
214    """
215    def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'cifar100'),
216                 fine_label=False, train=True, transform=None):
217        self._train = train
218        self._archive_file = ('cifar-100-binary.tar.gz', 'a0bb982c76b83111308126cc779a992fa506b90b')
219        self._train_data = [('train.bin', 'e207cd2e05b73b1393c74c7f5e7bea451d63e08e')]
220        self._test_data = [('test.bin', '8fb6623e830365ff53cf14adec797474f5478006')]
221        self._fine_label = fine_label
222        self._namespace = 'cifar100'
223        super(CIFAR10, self).__init__(root, transform) # pylint: disable=bad-super-call
224
225    def _read_batch(self, filename):
226        with open(filename, 'rb') as fin:
227            data = np.frombuffer(fin.read(), dtype=np.uint8).reshape(-1, 3072+2)
228
229        return data[:, 2:].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1), \
230               data[:, 0+self._fine_label].astype(np.int32)
231
232
233class ImageRecordDataset(dataset.RecordFileDataset):
234    """A dataset wrapping over a RecordIO file containing images.
235
236    Each sample is an image and its corresponding label.
237
238    Parameters
239    ----------
240    filename : str
241        Path to rec file.
242    flag : {0, 1}, default 1
243        If 0, always convert images to greyscale. \
244        If 1, always convert images to colored (RGB).
245    transform : function, default None
246        A user defined callback that transforms each sample. For example::
247
248            transform=lambda data, label: (data.astype(np.float32)/255, label)
249
250    """
251    def __init__(self, filename, flag=1, transform=None):
252        super(ImageRecordDataset, self).__init__(filename)
253        self._flag = flag
254        self._transform = transform
255
256    def __getitem__(self, idx):
257        record = super(ImageRecordDataset, self).__getitem__(idx)
258        header, img = recordio.unpack(record)
259        if self._transform is not None:
260            return self._transform(image.imdecode(img, self._flag), header.label)
261        return image.imdecode(img, self._flag), header.label
262
263
264class ImageFolderDataset(dataset.Dataset):
265    """A dataset for loading image files stored in a folder structure.
266
267    like::
268
269        root/car/0001.jpg
270        root/car/xxxa.jpg
271        root/car/yyyb.jpg
272        root/bus/123.jpg
273        root/bus/023.jpg
274        root/bus/wwww.jpg
275
276    Parameters
277    ----------
278    root : str
279        Path to root directory.
280    flag : {0, 1}, default 1
281        If 0, always convert loaded images to greyscale (1 channel).
282        If 1, always convert loaded images to colored (3 channels).
283    transform : callable, default None
284        A function that takes data and label and transforms them::
285
286            transform = lambda data, label: (data.astype(np.float32)/255, label)
287
288    Attributes
289    ----------
290    synsets : list
291        List of class names. `synsets[i]` is the name for the integer label `i`
292    items : list of tuples
293        List of all images in (filename, label) pairs.
294    """
295    def __init__(self, root, flag=1, transform=None):
296        self._root = os.path.expanduser(root)
297        self._flag = flag
298        self._transform = transform
299        self._exts = ['.jpg', '.jpeg', '.png']
300        self._list_images(self._root)
301
302    def _list_images(self, root):
303        self.synsets = []
304        self.items = []
305
306        for folder in sorted(os.listdir(root)):
307            path = os.path.join(root, folder)
308            if not os.path.isdir(path):
309                warnings.warn('Ignoring %s, which is not a directory.'%path, stacklevel=3)
310                continue
311            label = len(self.synsets)
312            self.synsets.append(folder)
313            for filename in sorted(os.listdir(path)):
314                filename = os.path.join(path, filename)
315                ext = os.path.splitext(filename)[1]
316                if ext.lower() not in self._exts:
317                    warnings.warn('Ignoring %s of type %s. Only support %s'%(
318                        filename, ext, ', '.join(self._exts)))
319                    continue
320                self.items.append((filename, label))
321
322    def __getitem__(self, idx):
323        img = image.imread(self.items[idx][0], self._flag)
324        label = self.items[idx][1]
325        if self._transform is not None:
326            return self._transform(img, label)
327        return img, label
328
329    def __len__(self):
330        return len(self.items)
331