1 /**
2  * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <assert.h>
17 #include <inttypes.h>
18 #include <png.h>
19 #include <stdio.h>
20 #include <stdlib.h>
21 
22 #include <cstring>
23 #include <string>
24 #include <vector>
25 
26 #include "lenet_mnist.h"
27 
28 /// This is an example demonstrating how to use auto-generated bundles and
29 /// create standalone executables that can perform neural network computations.
30 /// This example loads and runs the compiled lenet_mnist network model.
31 /// This example is using the static bundle API.
32 
33 #define DEFAULT_HEIGHT 28
34 #define DEFAULT_WIDTH 28
35 #define OUTPUT_LEN 10
36 
37 //===----------------------------------------------------------------------===//
38 //                   Image processing helpers
39 //===----------------------------------------------------------------------===//
40 std::vector<std::string> inputImageFilenames;
41 
42 /// \returns the index of the element at x,y,z,w.
getXYZW(const size_t * dims,size_t x,size_t y,size_t z,size_t w)43 size_t getXYZW(const size_t *dims, size_t x, size_t y, size_t z, size_t w) {
44   return (x * dims[1] * dims[2] * dims[3]) + (y * dims[2] * dims[3]) +
45          (z * dims[3]) + w;
46 }
47 
48 /// \returns the index of the element at x,y,z.
getXYZ(const size_t * dims,size_t x,size_t y,size_t z)49 size_t getXYZ(const size_t *dims, size_t x, size_t y, size_t z) {
50   return (x * dims[1] * dims[2]) + (y * dims[2]) + z;
51 }
52 
53 /// Reads a PNG image from a file into a newly allocated memory block \p imageT
54 /// representing a WxHxNxC tensor and returns it. The client is responsible for
55 /// freeing the memory block.
readPngImage(const char * filename,std::pair<float,float> range,float * & imageT,size_t * imageDims)56 bool readPngImage(const char *filename, std::pair<float, float> range,
57                   float *&imageT, size_t *imageDims) {
58   unsigned char header[8];
59   // open file and test for it being a png.
60   FILE *fp = fopen(filename, "rb");
61   // Can't open the file.
62   if (!fp) {
63     return true;
64   }
65 
66   // Validate signature.
67   size_t fread_ret = fread(header, 1, 8, fp);
68   if (fread_ret != 8) {
69     return true;
70   }
71   if (png_sig_cmp(header, 0, 8)) {
72     return true;
73   }
74 
75   // Initialize stuff.
76   png_structp png_ptr =
77       png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
78   if (!png_ptr) {
79     return true;
80   }
81 
82   png_infop info_ptr = png_create_info_struct(png_ptr);
83   if (!info_ptr) {
84     return true;
85   }
86 
87   if (setjmp(png_jmpbuf(png_ptr))) {
88     return true;
89   }
90 
91   png_init_io(png_ptr, fp);
92   png_set_sig_bytes(png_ptr, 8);
93   png_read_info(png_ptr, info_ptr);
94 
95   size_t width = png_get_image_width(png_ptr, info_ptr);
96   size_t height = png_get_image_height(png_ptr, info_ptr);
97   int color_type = png_get_color_type(png_ptr, info_ptr);
98   int bit_depth = png_get_bit_depth(png_ptr, info_ptr);
99 
100   const bool isGray = color_type == PNG_COLOR_TYPE_GRAY;
101   const size_t numChannels = 1;
102 
103   (void)bit_depth;
104   assert(bit_depth == 8 && "Invalid image");
105   assert(isGray && "Invalid image");
106   (void)isGray;
107   bool hasAlpha = (color_type == PNG_COLOR_TYPE_RGB_ALPHA);
108 
109   int number_of_passes = png_set_interlace_handling(png_ptr);
110   (void)number_of_passes;
111   assert(number_of_passes == 1 && "Invalid image");
112 
113   png_read_update_info(png_ptr, info_ptr);
114 
115   // Error during image read.
116   if (setjmp(png_jmpbuf(png_ptr))) {
117     return true;
118   }
119 
120   auto *row_pointers = (png_bytep *)malloc(sizeof(png_bytep) * height);
121   for (size_t y = 0; y < height; y++) {
122     row_pointers[y] = (png_byte *)malloc(png_get_rowbytes(png_ptr, info_ptr));
123   }
124 
125   png_read_image(png_ptr, row_pointers);
126   png_read_end(png_ptr, info_ptr);
127 
128   imageDims[0] = width;
129   imageDims[1] = height;
130   imageDims[2] = numChannels;
131   imageT = static_cast<float *>(
132       calloc(1, width * height * numChannels * sizeof(float)));
133 
134   float scale = ((range.second - range.first) / 255.0);
135   float bias = range.first;
136 
137   for (size_t row_n = 0; row_n < height; row_n++) {
138     png_byte *row = row_pointers[row_n];
139     for (size_t col_n = 0; col_n < width; col_n++) {
140       png_byte *ptr =
141           &(row[col_n * (hasAlpha ? (numChannels + 1) : numChannels)]);
142       imageT[getXYZ(imageDims, row_n, col_n, 0)] = float(ptr[0]) * scale + bias;
143     }
144   }
145 
146   for (size_t y = 0; y < height; y++) {
147     free(row_pointers[y]);
148   }
149   free(row_pointers);
150   png_destroy_read_struct(&png_ptr, &info_ptr, (png_infopp)NULL);
151   fclose(fp);
152   printf("Loaded image: %s\n", filename);
153 
154   return false;
155 }
156 
157 /// Loads and normalizes all PNGs into a tensor memory block \p resultT in the
158 /// NCHW 1x28x28 format.
loadImagesAndPreprocess(const std::vector<std::string> & filenames,float * & resultT,size_t * resultDims)159 static void loadImagesAndPreprocess(const std::vector<std::string> &filenames,
160                                     float *&resultT, size_t *resultDims) {
161   assert(filenames.size() > 0 &&
162          "There must be at least one filename in filenames");
163   std::pair<float, float> range = std::make_pair(0., 1.0);
164   unsigned numImages = filenames.size();
165   // N x C x H x W
166   resultDims[0] = numImages;
167   resultDims[1] = 1;
168   resultDims[2] = DEFAULT_HEIGHT;
169   resultDims[3] = DEFAULT_WIDTH;
170   size_t resultSizeInBytes =
171       numImages * DEFAULT_HEIGHT * DEFAULT_WIDTH * sizeof(float);
172   resultT = static_cast<float *>(malloc(resultSizeInBytes));
173   // We iterate over all the png files, reading them all into our result tensor
174   // for processing
175   for (unsigned n = 0; n < numImages; n++) {
176     float *imageT{nullptr};
177     size_t dims[3];
178     bool loadSuccess = !readPngImage(filenames[n].c_str(), range, imageT, dims);
179     assert(loadSuccess && "Error reading input image.");
180     (void)loadSuccess;
181 
182     assert((dims[0] == DEFAULT_HEIGHT && dims[1] == DEFAULT_WIDTH) &&
183            "All images must have the same Height and Width");
184 
185     // Convert to BGR, as this is what NN is expecting.
186     for (unsigned y = 0; y < dims[1]; y++) {
187       for (unsigned x = 0; x < dims[0]; x++) {
188         resultT[getXYZW(resultDims, n, 0, x, y)] =
189             imageT[getXYZ(dims, x, y, 0)];
190       }
191     }
192   }
193   printf("Loaded images size in bytes is: %lu\n", resultSizeInBytes);
194 }
195 
196 /// Parse images file names into a vector.
parseCommandLineOptions(int argc,char ** argv)197 void parseCommandLineOptions(int argc, char **argv) {
198   int arg = 1;
199   while (arg < argc) {
200     inputImageFilenames.push_back(argv[arg++]);
201   }
202 }
203 
204 //===----------------------------------------------------------------------===//
205 //                 Wrapper code for executing a bundle
206 //===----------------------------------------------------------------------===//
207 /// Statically allocate memory for constant weights (model weights) and
208 /// initialize.
209 GLOW_MEM_ALIGN(LENET_MNIST_MEM_ALIGN)
210 uint8_t constantWeight[LENET_MNIST_CONSTANT_MEM_SIZE] = {
211 #include "lenet_mnist.weights.txt"
212 };
213 
214 /// Statically allocate memory for mutable weights (model input/output data).
215 GLOW_MEM_ALIGN(LENET_MNIST_MEM_ALIGN)
216 uint8_t mutableWeight[LENET_MNIST_MUTABLE_MEM_SIZE];
217 
218 /// Statically allocate memory for activations (model intermediate results).
219 GLOW_MEM_ALIGN(LENET_MNIST_MEM_ALIGN)
220 uint8_t activations[LENET_MNIST_ACTIVATIONS_MEM_SIZE];
221 
222 /// Bundle input data absolute address.
223 uint8_t *inputAddr = GLOW_GET_ADDR(mutableWeight, LENET_MNIST_data);
224 
225 /// Bundle output data absolute address.
226 uint8_t *outputAddr = GLOW_GET_ADDR(mutableWeight, LENET_MNIST_softmax);
227 
228 /// Copy the pre-processed images into the mutable region of the bundle.
initInputImages()229 static void initInputImages() {
230   size_t inputDims[4];
231   float *inputT{nullptr};
232   loadImagesAndPreprocess(inputImageFilenames, inputT, inputDims);
233   // Copy image data into the data input variable in the mutableWeightVars area.
234   size_t imageDataSizeInBytes =
235       inputDims[0] * inputDims[1] * inputDims[2] * inputDims[3] * sizeof(float);
236   printf("Copying image data into mutable weight vars: %lu bytes\n",
237          imageDataSizeInBytes);
238   memcpy(inputAddr, inputT, imageDataSizeInBytes);
239   free(inputT);
240 }
241 
242 /// Dump the result of the inference by looking at the results vector and
243 /// finding the index of the max element.
printResults()244 static void printResults() {
245   int maxIdx = 0;
246   float maxValue = 0;
247   float *results = (float *)(outputAddr);
248   for (int i = 0; i < OUTPUT_LEN; ++i) {
249     if (results[i] > maxValue) {
250       maxValue = results[i];
251       maxIdx = i;
252     }
253   }
254   printf("Result: %u\n", maxIdx);
255   printf("Confidence: %f\n", maxValue);
256 }
257 
main(int argc,char ** argv)258 int main(int argc, char **argv) {
259   parseCommandLineOptions(argc, argv);
260 
261   // Initialize input images.
262   initInputImages();
263 
264   // Perform the computation.
265   int errCode = lenet_mnist(constantWeight, mutableWeight, activations);
266   if (errCode != GLOW_SUCCESS) {
267     printf("Error running bundle: error code %d\n", errCode);
268   }
269 
270   // Print results.
271   printResults();
272 }
273