1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import os
19import mxnet as mx
20import zipfile
21
22def get_mnist(data_dir):
23    if not os.path.isdir(data_dir):
24        os.system("mkdir " + data_dir)
25    os.chdir(data_dir)
26    if (not os.path.exists('train-images-idx3-ubyte')) or \
27       (not os.path.exists('train-labels-idx1-ubyte')) or \
28       (not os.path.exists('t10k-images-idx3-ubyte')) or \
29       (not os.path.exists('t10k-labels-idx1-ubyte')):
30        import urllib, zipfile
31        zippath = os.path.join(os.getcwd(), "mnist.zip")
32        mx.test_utils.download("http://data.mxnet.io/mxnet/data/mnist.zip", zippath)
33        zf = zipfile.ZipFile(zippath, "r")
34        zf.extractall()
35        zf.close()
36        os.remove(zippath)
37    os.chdir("..")
38
39def get_cifar10(data_dir):
40    if not os.path.isdir(data_dir):
41        os.system("mkdir " + data_dir)
42    cwd = os.path.abspath(os.getcwd())
43    os.chdir(data_dir)
44    if (not os.path.exists('train.rec')) or \
45       (not os.path.exists('test.rec')) :
46        import urllib, zipfile, glob
47        dirname = os.getcwd()
48        zippath = os.path.join(dirname, "cifar10.zip")
49        mx.test_utils.download("http://data.mxnet.io/mxnet/data/cifar10.zip", zippath)
50        zf = zipfile.ZipFile(zippath, "r")
51        zf.extractall()
52        zf.close()
53        os.remove(zippath)
54        for f in glob.glob(os.path.join(dirname, "cifar", "*")):
55            name = f.split(os.path.sep)[-1]
56            os.rename(f, os.path.join(dirname, name))
57        os.rmdir(os.path.join(dirname, "cifar"))
58    os.chdir(cwd)
59
60def get_cifar10_iterator(args, kv):
61    data_shape = (3, 28, 28)
62    data_dir = args.data_dir
63    if os.name == "nt":
64        data_dir = data_dir[:-1] + "\\"
65    if '://' not in args.data_dir:
66        get_cifar10(data_dir)
67
68    train = mx.io.ImageRecordIter(
69        path_imgrec = os.path.join(data_dir, "train.rec"),
70        mean_img    = os.path.join(data_dir, "mean.bin"),
71        data_shape  = data_shape,
72        batch_size  = args.batch_size,
73        rand_crop   = True,
74        rand_mirror = True,
75        num_parts   = kv.num_workers,
76        part_index  = kv.rank)
77
78    val = mx.io.ImageRecordIter(
79        path_imgrec = os.path.join(data_dir, "test.rec"),
80        mean_img    = os.path.join(data_dir, "mean.bin"),
81        rand_crop   = False,
82        rand_mirror = False,
83        data_shape  = data_shape,
84        batch_size  = args.batch_size,
85        num_parts   = kv.num_workers,
86        part_index  = kv.rank)
87
88    return (train, val)
89