1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# coding: utf-8 19# pylint: disable= 20"""Dataset container.""" 21__all__ = ['MNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100', 22 'ImageRecordDataset', 'ImageFolderDataset'] 23 24import os 25import gzip 26import tarfile 27import struct 28import warnings 29import numpy as np 30 31from .. import dataset 32from ...utils import download, check_sha1, _get_repo_file_url 33from .... import nd, image, recordio, base 34from .... import numpy as _mx_np # pylint: disable=reimported 35from ....util import is_np_array 36 37 38class MNIST(dataset._DownloadedDataset): 39 """MNIST handwritten digits dataset from http://yann.lecun.com/exdb/mnist 40 41 Each sample is an image (in 3D NDArray) with shape (28, 28, 1). 42 43 Parameters 44 ---------- 45 root : str, default $MXNET_HOME/datasets/mnist 46 Path to temp folder for storing data. 47 train : bool, default True 48 Whether to load the training or testing set. 49 transform : function, default None 50 A user defined callback that transforms each sample. For example:: 51 52 transform=lambda data, label: (data.astype(np.float32)/255, label) 53 54 """ 55 def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'mnist'), 56 train=True, transform=None): 57 self._train = train 58 self._train_data = ('train-images-idx3-ubyte.gz', 59 '6c95f4b05d2bf285e1bfb0e7960c31bd3b3f8a7d') 60 self._train_label = ('train-labels-idx1-ubyte.gz', 61 '2a80914081dc54586dbdf242f9805a6b8d2a15fc') 62 self._test_data = ('t10k-images-idx3-ubyte.gz', 63 'c3a25af1f52dad7f726cce8cacb138654b760d48') 64 self._test_label = ('t10k-labels-idx1-ubyte.gz', 65 '763e7fa3757d93b0cdec073cef058b2004252c17') 66 self._namespace = 'mnist' 67 super(MNIST, self).__init__(root, transform) 68 69 def _get_data(self): 70 if self._train: 71 data, label = self._train_data, self._train_label 72 else: 73 data, label = self._test_data, self._test_label 74 75 namespace = 'gluon/dataset/'+self._namespace 76 data_file = download(_get_repo_file_url(namespace, data[0]), 77 path=self._root, 78 sha1_hash=data[1]) 79 label_file = download(_get_repo_file_url(namespace, label[0]), 80 path=self._root, 81 sha1_hash=label[1]) 82 83 with gzip.open(label_file, 'rb') as fin: 84 struct.unpack(">II", fin.read(8)) 85 label = np.frombuffer(fin.read(), dtype=np.uint8).astype(np.int32) 86 if is_np_array(): 87 label = _mx_np.array(label, dtype=label.dtype) 88 89 with gzip.open(data_file, 'rb') as fin: 90 struct.unpack(">IIII", fin.read(16)) 91 data = np.frombuffer(fin.read(), dtype=np.uint8) 92 data = data.reshape(len(label), 28, 28, 1) 93 94 array_fn = _mx_np.array if is_np_array() else nd.array 95 self._data = array_fn(data, dtype=data.dtype) 96 self._label = label 97 98 99class FashionMNIST(MNIST): 100 """A dataset of Zalando's article images consisting of fashion products, 101 a drop-in replacement of the original MNIST dataset from 102 https://github.com/zalandoresearch/fashion-mnist 103 104 Each sample is an image (in 3D NDArray) with shape (28, 28, 1). 105 106 Parameters 107 ---------- 108 root : str, default $MXNET_HOME/datasets/fashion-mnist' 109 Path to temp folder for storing data. 110 train : bool, default True 111 Whether to load the training or testing set. 112 transform : function, default None 113 A user defined callback that transforms each sample. For example:: 114 115 transform=lambda data, label: (data.astype(np.float32)/255, label) 116 117 """ 118 def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'fashion-mnist'), 119 train=True, transform=None): 120 self._train = train 121 self._train_data = ('train-images-idx3-ubyte.gz', 122 '0cf37b0d40ed5169c6b3aba31069a9770ac9043d') 123 self._train_label = ('train-labels-idx1-ubyte.gz', 124 '236021d52f1e40852b06a4c3008d8de8aef1e40b') 125 self._test_data = ('t10k-images-idx3-ubyte.gz', 126 '626ed6a7c06dd17c0eec72fa3be1740f146a2863') 127 self._test_label = ('t10k-labels-idx1-ubyte.gz', 128 '17f9ab60e7257a1620f4ad76bbbaf857c3920701') 129 self._namespace = 'fashion-mnist' 130 super(MNIST, self).__init__(root, transform) # pylint: disable=bad-super-call 131 132 133class CIFAR10(dataset._DownloadedDataset): 134 """CIFAR10 image classification dataset from https://www.cs.toronto.edu/~kriz/cifar.html 135 136 Each sample is an image (in 3D NDArray) with shape (32, 32, 3). 137 138 Parameters 139 ---------- 140 root : str, default $MXNET_HOME/datasets/cifar10 141 Path to temp folder for storing data. 142 train : bool, default True 143 Whether to load the training or testing set. 144 transform : function, default None 145 A user defined callback that transforms each sample. For example:: 146 147 transform=lambda data, label: (data.astype(np.float32)/255, label) 148 149 """ 150 def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'cifar10'), 151 train=True, transform=None): 152 self._train = train 153 self._archive_file = ('cifar-10-binary.tar.gz', 'fab780a1e191a7eda0f345501ccd62d20f7ed891') 154 self._train_data = [('data_batch_1.bin', 'aadd24acce27caa71bf4b10992e9e7b2d74c2540'), 155 ('data_batch_2.bin', 'c0ba65cce70568cd57b4e03e9ac8d2a5367c1795'), 156 ('data_batch_3.bin', '1dd00a74ab1d17a6e7d73e185b69dbf31242f295'), 157 ('data_batch_4.bin', 'aab85764eb3584312d3c7f65fd2fd016e36a258e'), 158 ('data_batch_5.bin', '26e2849e66a845b7f1e4614ae70f4889ae604628')] 159 self._test_data = [('test_batch.bin', '67eb016db431130d61cd03c7ad570b013799c88c')] 160 self._namespace = 'cifar10' 161 super(CIFAR10, self).__init__(root, transform) 162 163 def _read_batch(self, filename): 164 with open(filename, 'rb') as fin: 165 data = np.frombuffer(fin.read(), dtype=np.uint8).reshape(-1, 3072+1) 166 167 return data[:, 1:].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1), \ 168 data[:, 0].astype(np.int32) 169 170 def _get_data(self): 171 if any(not os.path.exists(path) or not check_sha1(path, sha1) 172 for path, sha1 in ((os.path.join(self._root, name), sha1) 173 for name, sha1 in self._train_data + self._test_data)): 174 namespace = 'gluon/dataset/'+self._namespace 175 filename = download(_get_repo_file_url(namespace, self._archive_file[0]), 176 path=self._root, 177 sha1_hash=self._archive_file[1]) 178 179 with tarfile.open(filename) as tar: 180 tar.extractall(self._root) 181 182 if self._train: 183 data_files = self._train_data 184 else: 185 data_files = self._test_data 186 data, label = zip(*(self._read_batch(os.path.join(self._root, name)) 187 for name, _ in data_files)) 188 data = np.concatenate(data) 189 label = np.concatenate(label) 190 191 array_fn = _mx_np.array if is_np_array() else nd.array 192 self._data = array_fn(data, dtype=data.dtype) 193 self._label = array_fn(label, dtype=label.dtype) if is_np_array() else label 194 195 196class CIFAR100(CIFAR10): 197 """CIFAR100 image classification dataset from https://www.cs.toronto.edu/~kriz/cifar.html 198 199 Each sample is an image (in 3D NDArray) with shape (32, 32, 3). 200 201 Parameters 202 ---------- 203 root : str, default $MXNET_HOME/datasets/cifar100 204 Path to temp folder for storing data. 205 fine_label : bool, default False 206 Whether to load the fine-grained (100 classes) or coarse-grained (20 super-classes) labels. 207 train : bool, default True 208 Whether to load the training or testing set. 209 transform : function, default None 210 A user defined callback that transforms each sample. For example:: 211 212 transform=lambda data, label: (data.astype(np.float32)/255, label) 213 214 """ 215 def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'cifar100'), 216 fine_label=False, train=True, transform=None): 217 self._train = train 218 self._archive_file = ('cifar-100-binary.tar.gz', 'a0bb982c76b83111308126cc779a992fa506b90b') 219 self._train_data = [('train.bin', 'e207cd2e05b73b1393c74c7f5e7bea451d63e08e')] 220 self._test_data = [('test.bin', '8fb6623e830365ff53cf14adec797474f5478006')] 221 self._fine_label = fine_label 222 self._namespace = 'cifar100' 223 super(CIFAR10, self).__init__(root, transform) # pylint: disable=bad-super-call 224 225 def _read_batch(self, filename): 226 with open(filename, 'rb') as fin: 227 data = np.frombuffer(fin.read(), dtype=np.uint8).reshape(-1, 3072+2) 228 229 return data[:, 2:].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1), \ 230 data[:, 0+self._fine_label].astype(np.int32) 231 232 233class ImageRecordDataset(dataset.RecordFileDataset): 234 """A dataset wrapping over a RecordIO file containing images. 235 236 Each sample is an image and its corresponding label. 237 238 Parameters 239 ---------- 240 filename : str 241 Path to rec file. 242 flag : {0, 1}, default 1 243 If 0, always convert images to greyscale. \ 244 If 1, always convert images to colored (RGB). 245 transform : function, default None 246 A user defined callback that transforms each sample. For example:: 247 248 transform=lambda data, label: (data.astype(np.float32)/255, label) 249 250 """ 251 def __init__(self, filename, flag=1, transform=None): 252 super(ImageRecordDataset, self).__init__(filename) 253 self._flag = flag 254 self._transform = transform 255 256 def __getitem__(self, idx): 257 record = super(ImageRecordDataset, self).__getitem__(idx) 258 header, img = recordio.unpack(record) 259 if self._transform is not None: 260 return self._transform(image.imdecode(img, self._flag), header.label) 261 return image.imdecode(img, self._flag), header.label 262 263 264class ImageFolderDataset(dataset.Dataset): 265 """A dataset for loading image files stored in a folder structure. 266 267 like:: 268 269 root/car/0001.jpg 270 root/car/xxxa.jpg 271 root/car/yyyb.jpg 272 root/bus/123.jpg 273 root/bus/023.jpg 274 root/bus/wwww.jpg 275 276 Parameters 277 ---------- 278 root : str 279 Path to root directory. 280 flag : {0, 1}, default 1 281 If 0, always convert loaded images to greyscale (1 channel). 282 If 1, always convert loaded images to colored (3 channels). 283 transform : callable, default None 284 A function that takes data and label and transforms them:: 285 286 transform = lambda data, label: (data.astype(np.float32)/255, label) 287 288 Attributes 289 ---------- 290 synsets : list 291 List of class names. `synsets[i]` is the name for the integer label `i` 292 items : list of tuples 293 List of all images in (filename, label) pairs. 294 """ 295 def __init__(self, root, flag=1, transform=None): 296 self._root = os.path.expanduser(root) 297 self._flag = flag 298 self._transform = transform 299 self._exts = ['.jpg', '.jpeg', '.png'] 300 self._list_images(self._root) 301 302 def _list_images(self, root): 303 self.synsets = [] 304 self.items = [] 305 306 for folder in sorted(os.listdir(root)): 307 path = os.path.join(root, folder) 308 if not os.path.isdir(path): 309 warnings.warn('Ignoring %s, which is not a directory.'%path, stacklevel=3) 310 continue 311 label = len(self.synsets) 312 self.synsets.append(folder) 313 for filename in sorted(os.listdir(path)): 314 filename = os.path.join(path, filename) 315 ext = os.path.splitext(filename)[1] 316 if ext.lower() not in self._exts: 317 warnings.warn('Ignoring %s of type %s. Only support %s'%( 318 filename, ext, ', '.join(self._exts))) 319 continue 320 self.items.append((filename, label)) 321 322 def __getitem__(self, idx): 323 img = image.imread(self.items[idx][0], self._flag) 324 label = self.items[idx][1] 325 if self._transform is not None: 326 return self._transform(img, label) 327 return img, label 328 329 def __len__(self): 330 return len(self.items) 331