1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# pylint: disable=missing-docstring 19from __future__ import print_function 20 21import numpy as np 22import mxnet as mx 23 24 25def get_mnist(): 26 """ Gets MNIST dataset """ 27 28 np.random.seed(1234) # set seed for deterministic ordering 29 mnist_data = mx.test_utils.get_mnist() 30 X = np.concatenate([mnist_data['train_data'], mnist_data['test_data']]) 31 Y = np.concatenate([mnist_data['train_label'], mnist_data['test_label']]) 32 p = np.random.permutation(X.shape[0]) 33 X = X[p].reshape((X.shape[0], -1)).astype(np.float32)*5 34 Y = Y[p] 35 return X, Y 36