1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# pylint: disable=missing-docstring
19from __future__ import print_function
20
21import numpy as np
22import mxnet as mx
23
24
25def get_mnist():
26    """ Gets MNIST dataset """
27
28    np.random.seed(1234) # set seed for deterministic ordering
29    mnist_data = mx.test_utils.get_mnist()
30    X = np.concatenate([mnist_data['train_data'], mnist_data['test_data']])
31    Y = np.concatenate([mnist_data['train_label'], mnist_data['test_label']])
32    p = np.random.permutation(X.shape[0])
33    X = X[p].reshape((X.shape[0], -1)).astype(np.float32)*5
34    Y = Y[p]
35    return X, Y
36