1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import os 19import mxnet as mx 20import zipfile 21 22def get_mnist(data_dir): 23 if not os.path.isdir(data_dir): 24 os.system("mkdir " + data_dir) 25 os.chdir(data_dir) 26 if (not os.path.exists('train-images-idx3-ubyte')) or \ 27 (not os.path.exists('train-labels-idx1-ubyte')) or \ 28 (not os.path.exists('t10k-images-idx3-ubyte')) or \ 29 (not os.path.exists('t10k-labels-idx1-ubyte')): 30 import urllib, zipfile 31 zippath = os.path.join(os.getcwd(), "mnist.zip") 32 mx.test_utils.download("http://data.mxnet.io/mxnet/data/mnist.zip", zippath) 33 zf = zipfile.ZipFile(zippath, "r") 34 zf.extractall() 35 zf.close() 36 os.remove(zippath) 37 os.chdir("..") 38 39def get_cifar10(data_dir): 40 if not os.path.isdir(data_dir): 41 os.system("mkdir " + data_dir) 42 cwd = os.path.abspath(os.getcwd()) 43 os.chdir(data_dir) 44 if (not os.path.exists('train.rec')) or \ 45 (not os.path.exists('test.rec')) : 46 import urllib, zipfile, glob 47 dirname = os.getcwd() 48 zippath = os.path.join(dirname, "cifar10.zip") 49 mx.test_utils.download("http://data.mxnet.io/mxnet/data/cifar10.zip", zippath) 50 zf = zipfile.ZipFile(zippath, "r") 51 zf.extractall() 52 zf.close() 53 os.remove(zippath) 54 for f in glob.glob(os.path.join(dirname, "cifar", "*")): 55 name = f.split(os.path.sep)[-1] 56 os.rename(f, os.path.join(dirname, name)) 57 os.rmdir(os.path.join(dirname, "cifar")) 58 os.chdir(cwd) 59 60def get_cifar10_iterator(args, kv): 61 data_shape = (3, 28, 28) 62 data_dir = args.data_dir 63 if os.name == "nt": 64 data_dir = data_dir[:-1] + "\\" 65 if '://' not in args.data_dir: 66 get_cifar10(data_dir) 67 68 train = mx.io.ImageRecordIter( 69 path_imgrec = os.path.join(data_dir, "train.rec"), 70 mean_img = os.path.join(data_dir, "mean.bin"), 71 data_shape = data_shape, 72 batch_size = args.batch_size, 73 rand_crop = True, 74 rand_mirror = True, 75 num_parts = kv.num_workers, 76 part_index = kv.rank) 77 78 val = mx.io.ImageRecordIter( 79 path_imgrec = os.path.join(data_dir, "test.rec"), 80 mean_img = os.path.join(data_dir, "mean.bin"), 81 rand_crop = False, 82 rand_mirror = False, 83 data_shape = data_shape, 84 batch_size = args.batch_size, 85 num_parts = kv.num_workers, 86 part_index = kv.rank) 87 88 return (train, val) 89