1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file im2rec.cc
22  * \brief convert images into image recordio format
23  *  Image Record Format: zeropad[64bit] imid[64bit] img-binary-content
24  *  The 64bit zero pad was reserved for future purposes
25  *
26  *  Image List Format: unique-image-index label[s] path-to-image
27  * \sa dmlc/recordio.h
28  */
29 #include <cctype>
30 #include <cstring>
31 #include <string>
32 #include <vector>
33 #include <iomanip>
34 #include <sstream>
35 #include <dmlc/base.h>
36 #include <dmlc/io.h>
37 #include <dmlc/timer.h>
38 #include <dmlc/logging.h>
39 #include <dmlc/recordio.h>
40 #include <opencv2/opencv.hpp>
41 #include "../src/io/opencv_compatibility.h"
42 #include "../src/io/image_recordio.h"
43 #include <random>
44 /*!
45  *\brief get interpolation method with given inter_method, 0-CV_INTER_NN 1-CV_INTER_LINEAR 2-CV_INTER_CUBIC
46  *\ 3-CV_INTER_AREA 4-CV_INTER_LANCZOS4 9-AUTO(cubic for enlarge, area for shrink, bilinear for others) 10-RAND(0-4)
47  */
GetInterMethod(int inter_method,int old_width,int old_height,int new_width,int new_height,std::mt19937 & prnd)48 int GetInterMethod(int inter_method, int old_width, int old_height, int new_width, int new_height, std::mt19937& prnd) {
49     if (inter_method == 9) {
50         if (new_width > old_width && new_height > old_height) {
51             return 2;  // CV_INTER_CUBIC for enlarge
52         } else if (new_width <old_width && new_height < old_height) {
53             return 3;  // CV_INTER_AREA for shrink
54         } else {
55             return 1;  // CV_INTER_LINEAR for others
56         }
57     } else if (inter_method == 10) {
58         std::uniform_int_distribution<size_t> rand_uniform_int(0, 4);
59         return rand_uniform_int(prnd);
60     } else {
61         return inter_method;
62     }
63 }
main(int argc,char * argv[])64 int main(int argc, char *argv[]) {
65   if (argc < 4) {
66     printf("Usage: <image.lst> <image_root_dir> <output.rec> [additional parameters in form key=value]\n"\
67            "Possible additional parameters:\n"\
68            "\tcolor=USE_COLOR[default=1] Force color (1), gray image (0) or keep source unchanged (-1).\n"\
69            "\tresize=newsize resize the shorter edge of image to the newsize, original images will be packed by default\n"\
70            "\tlabel_width=WIDTH[default=1] specify the label_width in the list, by default set to 1\n"\
71            "\tpack_label=PACK_LABEL[default=0] whether to also pack multi dimenional label in the record file\n"\
72            "\tnsplit=NSPLIT[default=1] used for part generation, logically split the image.list to NSPLIT parts by position\n"\
73            "\tpart=PART[default=0] used for part generation, pack the images from the specific part in image.list\n"\
74            "\tcenter_crop=CENTER_CROP[default=0] specify whether to crop the center image to make it square.\n"\
75            "\tquality=QUALITY[default=95] JPEG quality for encoding (1-100, default: 95) or PNG compression for encoding (1-9, default: 3).\n"\
76            "\tencoding=ENCODING[default='.jpg'] Encoding type. Can be '.jpg' or '.png'\n"\
77            "\tinter_method=INTER_METHOD[default=1] NN(0) BILINEAR(1) CUBIC(2) AREA(3) LANCZOS4(4) AUTO(9) RAND(10).\n"\
78            "\tunchanged=UNCHANGED[default=0] Keep the original image encoding, size and color. If set to 1, it will ignore the others parameters.\n");
79     return 0;
80   }
81   int label_width = 1;
82   int pack_label = 0;
83   int new_size = -1;
84   int nsplit = 1;
85   int partid = 0;
86   int center_crop = 0;
87   int quality = 95;
88   int color_mode = CV_LOAD_IMAGE_COLOR;
89   int unchanged = 0;
90   int inter_method = CV_INTER_LINEAR;
91   std::string encoding(".jpg");
92   for (int i = 4; i < argc; ++i) {
93     char key[128], val[128];
94     int effct_len = 0;
95 
96 #ifdef _MSC_VER
97     effct_len = sscanf_s(argv[i], "%[^=]=%s", key, sizeof(key), val, sizeof(val));
98 #else
99     effct_len = sscanf(argv[i], "%[^=]=%s", key, val);
100 #endif
101 
102     if (effct_len == 2) {
103       if (!strcmp(key, "resize")) new_size = atoi(val);
104       if (!strcmp(key, "label_width")) label_width = atoi(val);
105       if (!strcmp(key, "pack_label")) pack_label = atoi(val);
106       if (!strcmp(key, "nsplit")) nsplit = atoi(val);
107       if (!strcmp(key, "part")) partid = atoi(val);
108       if (!strcmp(key, "center_crop")) center_crop = atoi(val);
109       if (!strcmp(key, "quality")) quality = atoi(val);
110       if (!strcmp(key, "color")) color_mode = atoi(val);
111       if (!strcmp(key, "encoding")) encoding = std::string(val);
112       if (!strcmp(key, "unchanged")) unchanged = atoi(val);
113       if (!strcmp(key, "inter_method")) inter_method = atoi(val);
114     }
115   }
116   // Check parameters ranges
117   if (color_mode != -1 && color_mode != 0 && color_mode != 1) {
118     LOG(FATAL) << "Color mode must be -1, 0 or 1.";
119   }
120   if (encoding != std::string(".jpg") && encoding != std::string(".png")) {
121     LOG(FATAL) << "Encoding mode must be .jpg or .png.";
122   }
123   if (label_width <= 1 && pack_label) {
124     LOG(FATAL) << "pack_label can only be used when label_width > 1";
125   }
126   if (new_size > 0) {
127     LOG(INFO) << "New Image Size: Short Edge " << new_size;
128   } else {
129     LOG(INFO) << "Keep origin image size";
130   }
131   if (center_crop) {
132     LOG(INFO) << "Center cropping to square";
133   }
134   if (color_mode == 0) {
135     LOG(INFO) << "Use gray images";
136   }
137   if (color_mode == -1) {
138     LOG(INFO) << "Keep original color mode";
139   }
140   LOG(INFO) << "Encoding is " << encoding;
141 
142   if (encoding == std::string(".png") && quality > 9) {
143       quality = 3;
144   }
145   if (inter_method != 1) {
146       switch (inter_method) {
147         case 0:
148             LOG(INFO) << "Use inter_method CV_INTER_NN";
149             break;
150         case 2:
151             LOG(INFO) << "Use inter_method CV_INTER_CUBIC";
152             break;
153         case 3:
154             LOG(INFO) << "Use inter_method CV_INTER_AREA";
155             break;
156         case 4:
157             LOG(INFO) << "Use inter_method CV_INTER_LANCZOS4";
158             break;
159         case 9:
160             LOG(INFO) << "Use inter_method mod auto(cubic for enlarge, area for shrink)";
161             break;
162         case 10:
163             LOG(INFO) << "Use inter_method mod rand(nn/bilinear/cubic/area/lanczos4)";
164            break;
165         default:
166             LOG(INFO) << "Unkown inter_method";
167             return 0;
168       }
169   }
170   std::random_device rd;
171   std::mt19937 prnd(rd());
172   using namespace dmlc;
173   const static size_t kBufferSize = 1 << 20UL;
174   std::string root = argv[2];
175   mxnet::io::ImageRecordIO rec;
176   size_t imcnt = 0;
177   double tstart = dmlc::GetTime();
178   dmlc::InputSplit *flist = dmlc::InputSplit::
179       Create(argv[1], partid, nsplit, "text");
180   std::ostringstream os;
181   if (nsplit == 1) {
182     os << argv[3];
183   } else {
184     os << argv[3] << ".part" << std::setw(3) << std::setfill('0') << partid;
185   }
186   LOG(INFO) << "Write to output: " << os.str();
187   dmlc::Stream *fo = dmlc::Stream::Create(os.str().c_str(), "w");
188   LOG(INFO) << "Output: " << os.str();
189   dmlc::RecordIOWriter writer(fo);
190   std::string fname, path, blob;
191   std::vector<unsigned char> decode_buf;
192   std::vector<unsigned char> encode_buf;
193   std::vector<int> encode_params;
194   if (encoding == std::string(".png")) {
195       encode_params.push_back(CV_IMWRITE_PNG_COMPRESSION);
196       encode_params.push_back(quality);
197       LOG(INFO) << "PNG encoding compression: " << quality;
198   } else {
199       encode_params.push_back(CV_IMWRITE_JPEG_QUALITY);
200       encode_params.push_back(quality);
201       LOG(INFO) << "JPEG encoding quality: " << quality;
202   }
203   dmlc::InputSplit::Blob line;
204   std::vector<float> label_buf(label_width, 0.f);
205 
206   while (flist->NextRecord(&line)) {
207     std::string sline(static_cast<char*>(line.dptr), line.size);
208     std::istringstream is(sline);
209     if (!(is >> rec.header.image_id[0] >> rec.header.label)) continue;
210     label_buf[0] = rec.header.label;
211     for (int k = 1; k < label_width; ++k) {
212       CHECK(is >> label_buf[k])
213           << "Invalid ImageList, did you provide the correct label_width?";
214     }
215     if (pack_label) rec.header.flag = label_width;
216     rec.SaveHeader(&blob);
217     if (pack_label) {
218       size_t bsize = blob.size();
219       blob.resize(bsize + label_buf.size()*sizeof(float));
220       memcpy(BeginPtr(blob) + bsize,
221              BeginPtr(label_buf), label_buf.size()*sizeof(float));
222     }
223     CHECK(std::getline(is, fname));
224     // eliminate invalid chars in the end
225     while (fname.length() != 0 &&
226            (isspace(*fname.rbegin()) || !isprint(*fname.rbegin()))) {
227       fname.resize(fname.length() - 1);
228     }
229     // eliminate invalid chars in beginning.
230     const char *p = fname.c_str();
231     while (isspace(*p)) ++p;
232     path = root + p;
233     // use "r" is equal to rb in dmlc::Stream
234     dmlc::Stream *fi = dmlc::Stream::Create(path.c_str(), "r");
235     decode_buf.clear();
236     size_t imsize = 0;
237     while (true) {
238       decode_buf.resize(imsize + kBufferSize);
239       size_t nread = fi->Read(BeginPtr(decode_buf) + imsize, kBufferSize);
240       imsize += nread;
241       decode_buf.resize(imsize);
242       if (nread != kBufferSize) break;
243     }
244     delete fi;
245 
246 
247     if (unchanged != 1) {
248       cv::Mat img = cv::imdecode(decode_buf, color_mode);
249       CHECK(img.data != NULL) << "OpenCV decode fail:" << path;
250       cv::Mat res = img;
251       if (new_size > 0) {
252         if (center_crop) {
253           if (img.rows > img.cols) {
254             int margin = (img.rows - img.cols)/2;
255             img = img(cv::Range(margin, margin+img.cols), cv::Range(0, img.cols));
256           } else {
257             int margin = (img.cols - img.rows)/2;
258             img = img(cv::Range(0, img.rows), cv::Range(margin, margin + img.rows));
259           }
260         }
261         int interpolation_method = 1;
262         if (img.rows > img.cols) {
263             if (img.cols != new_size) {
264                 interpolation_method = GetInterMethod(inter_method, img.cols, img.rows, new_size, img.rows * new_size / img.cols, prnd);
265                 cv::resize(img, res, cv::Size(new_size, img.rows * new_size / img.cols), 0, 0, interpolation_method);
266             } else {
267                 res = img.clone();
268             }
269         } else {
270             if (img.rows != new_size) {
271                 interpolation_method = GetInterMethod(inter_method, img.cols, img.rows, new_size * img.cols / img.rows, new_size, prnd);
272                 cv::resize(img, res, cv::Size(new_size * img.cols / img.rows, new_size), 0, 0, interpolation_method);
273             } else {
274                 res = img.clone();
275             }
276         }
277       }
278       encode_buf.clear();
279       CHECK(cv::imencode(encoding, res, encode_buf, encode_params));
280 
281       // write buffer
282       size_t bsize = blob.size();
283       blob.resize(bsize + encode_buf.size());
284       memcpy(BeginPtr(blob) + bsize,
285              BeginPtr(encode_buf), encode_buf.size());
286     } else {
287       size_t bsize = blob.size();
288       blob.resize(bsize + decode_buf.size());
289       memcpy(BeginPtr(blob) + bsize,
290              BeginPtr(decode_buf), decode_buf.size());
291     }
292     writer.WriteRecord(BeginPtr(blob), blob.size());
293     // write header
294     ++imcnt;
295     if (imcnt % 1000 == 0) {
296       LOG(INFO) << imcnt << " images processed, " << GetTime() - tstart << " sec elapsed";
297     }
298   }
299   LOG(INFO) << "Total: " << imcnt << " images processed, " << GetTime() - tstart << " sec elapsed";
300   delete fo;
301   delete flist;
302   return 0;
303 }
304