1 //
2 // MnistDataset.cpp
3 // MNN
4 //
5 // Created by MNN on 2019/11/15.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "MnistDataset.hpp"
10 #include <string.h>
11 #include <fstream>
12 #include <string>
13 namespace MNN {
14 namespace Train {
15
16 // referenced from pytorch C++ frontend mnist.cpp
17 // https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/src/data/datasets/mnist.cpp
18 const int32_t kTrainSize = 60000;
19 const int32_t kTestSize = 10000;
20 const int32_t kImageMagicNumber = 2051;
21 const int32_t kTargetMagicNumber = 2049;
22 const int32_t kImageRows = 28;
23 const int32_t kImageColumns = 28;
24 const char* kTrainImagesFilename = "train-images-idx3-ubyte";
25 const char* kTrainTargetsFilename = "train-labels-idx1-ubyte";
26 const char* kTestImagesFilename = "t10k-images-idx3-ubyte";
27 const char* kTestTargetsFilename = "t10k-labels-idx1-ubyte";
28
check_is_little_endian()29 bool check_is_little_endian() {
30 const uint32_t word = 1;
31 return reinterpret_cast<const uint8_t*>(&word)[0] == 1;
32 }
33
flip_endianness(uint32_t value)34 constexpr uint32_t flip_endianness(uint32_t value) {
35 return ((value & 0xffu) << 24u) | ((value & 0xff00u) << 8u) | ((value & 0xff0000u) >> 8u) |
36 ((value & 0xff000000u) >> 24u);
37 }
38
read_int32(std::ifstream & stream)39 uint32_t read_int32(std::ifstream& stream) {
40 static const bool is_little_endian = check_is_little_endian();
41 uint32_t value;
42 stream.read(reinterpret_cast<char*>(&value), sizeof value);
43 return is_little_endian ? flip_endianness(value) : value;
44 }
45
expect_int32(std::ifstream & stream,uint32_t expected)46 uint32_t expect_int32(std::ifstream& stream, uint32_t expected) {
47 const auto value = read_int32(stream);
48 // clang-format off
49 MNN_ASSERT(value == expected);
50 // clang-format on
51 return value;
52 }
53
join_paths(std::string head,const std::string & tail)54 std::string join_paths(std::string head, const std::string& tail) {
55 if (head.back() != '/') {
56 head.push_back('/');
57 }
58 head += tail;
59 return head;
60 }
61
read_images(const std::string & root,bool train)62 VARP read_images(const std::string& root, bool train) {
63 const auto path = join_paths(root, train ? kTrainImagesFilename : kTestImagesFilename);
64 std::ifstream images(path, std::ios::binary);
65 if (!images.is_open()) {
66 MNN_PRINT("Error opening images file at %s", path.c_str());
67 MNN_ASSERT(false);
68 }
69
70 const auto count = train ? kTrainSize : kTestSize;
71
72 // From http://yann.lecun.com/exdb/mnist/
73 expect_int32(images, kImageMagicNumber);
74 expect_int32(images, count);
75 expect_int32(images, kImageRows);
76 expect_int32(images, kImageColumns);
77
78 std::vector<int> dims = {count, 1, kImageRows, kImageColumns};
79 int length = 1;
80 for (int i = 0; i < dims.size(); ++i) {
81 length *= dims[i];
82 }
83 auto data = _Input(dims, NCHW, halide_type_of<uint8_t>());
84 images.read(reinterpret_cast<char*>(data->writeMap<uint8_t>()), length);
85 return data;
86 }
87
read_targets(const std::string & root,bool train)88 VARP read_targets(const std::string& root, bool train) {
89 const auto path = join_paths(root, train ? kTrainTargetsFilename : kTestTargetsFilename);
90 std::ifstream targets(path, std::ios::binary);
91 if (!targets.is_open()) {
92 MNN_PRINT("Error opening images file at %s", path.c_str());
93 MNN_ASSERT(false);
94 }
95
96 const auto count = train ? kTrainSize : kTestSize;
97
98 expect_int32(targets, kTargetMagicNumber);
99 expect_int32(targets, count);
100
101 std::vector<int> dims = {count};
102 int length = 1;
103 for (int i = 0; i < dims.size(); ++i) {
104 length *= dims[i];
105 }
106 auto labels = _Input(dims, NCHW, halide_type_of<uint8_t>());
107 targets.read(reinterpret_cast<char*>(labels->writeMap<uint8_t>()), length);
108
109 return labels;
110 }
111
MnistDataset(const std::string root,Mode mode)112 MnistDataset::MnistDataset(const std::string root, Mode mode)
113 : mImages(read_images(root, mode == Mode::TRAIN)), mLabels(read_targets(root, mode == Mode::TRAIN)) {
114 mImagePtr = mImages->readMap<uint8_t>();
115 mLabelsPtr = mLabels->readMap<uint8_t>();
116 }
117
get(size_t index)118 Example MnistDataset::get(size_t index) {
119 auto data = _Input({1, kImageRows, kImageColumns}, NCHW, halide_type_of<uint8_t>());
120 auto label = _Input({}, NCHW, halide_type_of<uint8_t>());
121
122 auto dataPtr = mImagePtr + index * kImageRows * kImageColumns;
123 ::memcpy(data->writeMap<uint8_t>(), dataPtr, kImageRows * kImageColumns);
124
125 auto labelPtr = mLabelsPtr + index;
126 ::memcpy(label->writeMap<uint8_t>(), labelPtr, 1);
127
128 auto returnIndex = _Const(index);
129 // return the index for test
130 return {{data, returnIndex}, {label}};
131 }
132
size()133 size_t MnistDataset::size() {
134 return mImages->getInfo()->dim[0];
135 }
136
images()137 const VARP MnistDataset::images() {
138 return mImages;
139 }
140
labels()141 const VARP MnistDataset::labels() {
142 return mLabels;
143 }
144
create(const std::string path,Mode mode)145 DatasetPtr MnistDataset::create(const std::string path, Mode mode) {
146 DatasetPtr res;
147 res.mDataset.reset(new MnistDataset(path, mode));
148 return res;
149 }
150 }
151 }
152