1 //
2 //  MnistDataset.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/11/15.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "MnistDataset.hpp"
10 #include <string.h>
11 #include <fstream>
12 #include <string>
13 namespace MNN {
14 namespace Train {
15 
16 // referenced from pytorch C++ frontend mnist.cpp
17 // https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/src/data/datasets/mnist.cpp
18 const int32_t kTrainSize          = 60000;
19 const int32_t kTestSize           = 10000;
20 const int32_t kImageMagicNumber   = 2051;
21 const int32_t kTargetMagicNumber  = 2049;
22 const int32_t kImageRows          = 28;
23 const int32_t kImageColumns       = 28;
24 const char* kTrainImagesFilename  = "train-images-idx3-ubyte";
25 const char* kTrainTargetsFilename = "train-labels-idx1-ubyte";
26 const char* kTestImagesFilename   = "t10k-images-idx3-ubyte";
27 const char* kTestTargetsFilename  = "t10k-labels-idx1-ubyte";
28 
check_is_little_endian()29 bool check_is_little_endian() {
30     const uint32_t word = 1;
31     return reinterpret_cast<const uint8_t*>(&word)[0] == 1;
32 }
33 
flip_endianness(uint32_t value)34 constexpr uint32_t flip_endianness(uint32_t value) {
35     return ((value & 0xffu) << 24u) | ((value & 0xff00u) << 8u) | ((value & 0xff0000u) >> 8u) |
36            ((value & 0xff000000u) >> 24u);
37 }
38 
read_int32(std::ifstream & stream)39 uint32_t read_int32(std::ifstream& stream) {
40     static const bool is_little_endian = check_is_little_endian();
41     uint32_t value;
42     stream.read(reinterpret_cast<char*>(&value), sizeof value);
43     return is_little_endian ? flip_endianness(value) : value;
44 }
45 
expect_int32(std::ifstream & stream,uint32_t expected)46 uint32_t expect_int32(std::ifstream& stream, uint32_t expected) {
47     const auto value = read_int32(stream);
48     // clang-format off
49     MNN_ASSERT(value == expected);
50     // clang-format on
51     return value;
52 }
53 
join_paths(std::string head,const std::string & tail)54 std::string join_paths(std::string head, const std::string& tail) {
55     if (head.back() != '/') {
56         head.push_back('/');
57     }
58     head += tail;
59     return head;
60 }
61 
read_images(const std::string & root,bool train)62 VARP read_images(const std::string& root, bool train) {
63     const auto path = join_paths(root, train ? kTrainImagesFilename : kTestImagesFilename);
64     std::ifstream images(path, std::ios::binary);
65     if (!images.is_open()) {
66         MNN_PRINT("Error opening images file at %s", path.c_str());
67         MNN_ASSERT(false);
68     }
69 
70     const auto count = train ? kTrainSize : kTestSize;
71 
72     // From http://yann.lecun.com/exdb/mnist/
73     expect_int32(images, kImageMagicNumber);
74     expect_int32(images, count);
75     expect_int32(images, kImageRows);
76     expect_int32(images, kImageColumns);
77 
78     std::vector<int> dims = {count, 1, kImageRows, kImageColumns};
79     int length            = 1;
80     for (int i = 0; i < dims.size(); ++i) {
81         length *= dims[i];
82     }
83     auto data = _Input(dims, NCHW, halide_type_of<uint8_t>());
84     images.read(reinterpret_cast<char*>(data->writeMap<uint8_t>()), length);
85     return data;
86 }
87 
read_targets(const std::string & root,bool train)88 VARP read_targets(const std::string& root, bool train) {
89     const auto path = join_paths(root, train ? kTrainTargetsFilename : kTestTargetsFilename);
90     std::ifstream targets(path, std::ios::binary);
91     if (!targets.is_open()) {
92         MNN_PRINT("Error opening images file at %s", path.c_str());
93         MNN_ASSERT(false);
94     }
95 
96     const auto count = train ? kTrainSize : kTestSize;
97 
98     expect_int32(targets, kTargetMagicNumber);
99     expect_int32(targets, count);
100 
101     std::vector<int> dims = {count};
102     int length            = 1;
103     for (int i = 0; i < dims.size(); ++i) {
104         length *= dims[i];
105     }
106     auto labels = _Input(dims, NCHW, halide_type_of<uint8_t>());
107     targets.read(reinterpret_cast<char*>(labels->writeMap<uint8_t>()), length);
108 
109     return labels;
110 }
111 
MnistDataset(const std::string root,Mode mode)112 MnistDataset::MnistDataset(const std::string root, Mode mode)
113     : mImages(read_images(root, mode == Mode::TRAIN)), mLabels(read_targets(root, mode == Mode::TRAIN)) {
114     mImagePtr  = mImages->readMap<uint8_t>();
115     mLabelsPtr = mLabels->readMap<uint8_t>();
116 }
117 
get(size_t index)118 Example MnistDataset::get(size_t index) {
119     auto data  = _Input({1, kImageRows, kImageColumns}, NCHW, halide_type_of<uint8_t>());
120     auto label = _Input({}, NCHW, halide_type_of<uint8_t>());
121 
122     auto dataPtr = mImagePtr + index * kImageRows * kImageColumns;
123     ::memcpy(data->writeMap<uint8_t>(), dataPtr, kImageRows * kImageColumns);
124 
125     auto labelPtr = mLabelsPtr + index;
126     ::memcpy(label->writeMap<uint8_t>(), labelPtr, 1);
127 
128     auto returnIndex = _Const(index);
129     // return the index for test
130     return {{data, returnIndex}, {label}};
131 }
132 
size()133 size_t MnistDataset::size() {
134     return mImages->getInfo()->dim[0];
135 }
136 
images()137 const VARP MnistDataset::images() {
138     return mImages;
139 }
140 
labels()141 const VARP MnistDataset::labels() {
142     return mLabels;
143 }
144 
create(const std::string path,Mode mode)145 DatasetPtr MnistDataset::create(const std::string path, Mode mode) {
146     DatasetPtr res;
147     res.mDataset.reset(new MnistDataset(path, mode));
148     return res;
149 }
150 }
151 }
152