1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*
21  * This example demonstrates image classification workflow with pre-trained models using MXNet C++ API.
22  * The example performs following tasks.
23  * 1. Load the pre-trained model.
24  * 2. Load the parameters of pre-trained model.
25  * 3. Load the inference dataset and create a new ImageRecordIter.
26  * 4. Run the forward pass and obtain throughput & accuracy.
27  */
28 #ifndef _WIN32
29 #include <sys/time.h>
30 #endif
31 #include <fstream>
32 #include <iostream>
33 #include <map>
34 #include <chrono>
35 #include <string>
36 #include <vector>
37 #include <random>
38 #include <type_traits>
39 #include <opencv2/opencv.hpp>
40 #include "mxnet/c_api.h"
41 #include "mxnet/tuple.h"
42 #include "mxnet-cpp/MxNetCpp.h"
43 #include "mxnet-cpp/initializer.h"
44 
45 using namespace mxnet::cpp;
46 
ms_now()47 double ms_now() {
48   double ret;
49 #ifdef _WIN32
50   auto timePoint = std::chrono::high_resolution_clock::now().time_since_epoch();
51   ret = std::chrono::duration<double, std::milli>(timePoint).count();
52 #else
53   struct timeval time;
54   gettimeofday(&time, nullptr);
55   ret = 1e+3 * time.tv_sec + 1e-3 * time.tv_usec;
56 #endif
57   return ret;
58 }
59 
60 
61 // define the data type for NDArray, aliged with the definition in mshadow/base.h
62 enum TypeFlag {
63   kFloat32 = 0,
64   kFloat64 = 1,
65   kFloat16 = 2,
66   kUint8 = 3,
67   kInt32 = 4,
68   kInt8  = 5,
69   kInt64 = 6,
70 };
71 
72 /*
73  * class Predictor
74  *
75  * This class encapsulates the functionality to load the model, prepare dataset and run the forward pass.
76  */
77 
78 class Predictor {
79  public:
Predictor()80     Predictor() {}
81     Predictor(const std::string& model_json_file,
82               const std::string& model_params_file,
83               const Shape& input_shape,
84               bool use_gpu,
85               bool enable_tensorrt,
86               const std::string& dataset,
87               const int data_nthreads,
88               const std::string& data_layer_type,
89               const std::vector<float>& rgb_mean,
90               const std::vector<float>& rgb_std,
91               int shuffle_chunk_seed,
92               int seed, bool benchmark);
93     void BenchmarkScore(int num_inference_batches);
94     void Score(int num_skipped_batches, int num_inference_batches);
95     ~Predictor();
96 
97  private:
98     bool CreateImageRecordIter();
99     bool AdvanceDataIter(int skipped_batches);
100     void LoadModel(const std::string& model_json_file);
101     void LoadParameters(const std::string& model_parameters_file);
102     void SplitParamMap(const std::map<std::string, NDArray> &paramMap,
103         std::map<std::string, NDArray> *argParamInTargetContext,
104         std::map<std::string, NDArray> *auxParamInTargetContext,
105         Context targetContext);
106     void ConvertParamMapToTargetContext(const std::map<std::string, NDArray> &paramMap,
107         std::map<std::string, NDArray> *paramMapInTargetContext,
108         Context targetContext);
109     void InitParameters();
110 
FileExists(const std::string & name)111     inline bool FileExists(const std::string &name) {
112       std::ifstream fhandle(name.c_str());
113       return fhandle.good();
114     }
115     int GetDataLayerType();
116 
117     std::map<std::string, NDArray> args_map_;
118     std::map<std::string, NDArray> aux_map_;
119     Symbol net_;
120     Executor *executor_;
121     Shape input_shape_;
122     Context global_ctx_ = Context::cpu();
123 
124     MXDataIter *val_iter_;
125     bool use_gpu_;
126     bool enable_tensorrt_;
127     std::string dataset_;
128     int data_nthreads_;
129     std::string data_layer_type_;
130     std::vector<float> rgb_mean_;
131     std::vector<float> rgb_std_;
132     int shuffle_chunk_seed_;
133     int seed_;
134     bool benchmark_;
135 };
136 
137 
138 /*
139  * The constructor takes following parameters as input:
140  * 1. model_json_file:  The model in json formatted file.
141  * 2. model_params_file: File containing model parameters
142  * 3. input_shape: Shape of input data to the model. Since this class will be running one inference at a time,
143  *                 the input shape is required to be in format Shape(1, number_of_channels, height, width)
144  *                 The input image will be resized to (height x width) size before running the inference.
145  * 4. use_gpu: determine if run inference on GPU
146  * 5. enable_tensorrt: determine if enable TensorRT
147  * 6. dataset: data file (.rec) to be used for inference
148  * 7. data_nthreads: number of threads for data loading
149  * 8. data_layer_type: data type for data layer
150  * 9. rgb_mean: mean value to be subtracted on R/G/B channel
151  * 10. rgb_std: standard deviation on R/G/B channel
152  * 11. shuffle_chunk_seed: shuffling chunk seed
153  * 12. seed: shuffling seed
154  * 13. benchmark: use dummy data for inference
155  *
156  * The constructor will:
157  *  1. Create ImageRecordIter based on the given dataset file.
158  *  2. Load the model and parameter files.
159  *  3. Infer and construct NDArrays according to the input argument and create an executor.
160  */
Predictor(const std::string & model_json_file,const std::string & model_params_file,const Shape & input_shape,bool use_gpu,bool enable_tensorrt,const std::string & dataset,const int data_nthreads,const std::string & data_layer_type,const std::vector<float> & rgb_mean,const std::vector<float> & rgb_std,int shuffle_chunk_seed,int seed,bool benchmark)161 Predictor::Predictor(const std::string& model_json_file,
162                      const std::string& model_params_file,
163                      const Shape& input_shape,
164                      bool use_gpu,
165                      bool enable_tensorrt,
166                      const std::string& dataset,
167                      const int data_nthreads,
168                      const std::string& data_layer_type,
169                      const std::vector<float>& rgb_mean,
170                      const std::vector<float>& rgb_std,
171                      int shuffle_chunk_seed,
172                      int seed, bool benchmark)
173     : input_shape_(input_shape),
174       use_gpu_(use_gpu),
175       enable_tensorrt_(enable_tensorrt),
176       dataset_(dataset),
177       data_nthreads_(data_nthreads),
178       data_layer_type_(data_layer_type),
179       rgb_mean_(rgb_mean),
180       rgb_std_(rgb_std),
181       shuffle_chunk_seed_(shuffle_chunk_seed),
182       seed_(seed),
183       benchmark_(benchmark) {
184   if (use_gpu) {
185     global_ctx_ = Context::gpu();
186   }
187 
188   // initilize data iterator
189   if (!benchmark_ && !CreateImageRecordIter()) {
190     LG << "Error: failed to create ImageRecordIter";
191     throw std::runtime_error("ImageRecordIter cannot be created");
192   }
193 
194   // Load the model
195   LoadModel(model_json_file);
196   // Initilize the parameters
197   // benchmark=true && model_params_file.empty(), randomly initialize parameters
198   // else, load parameters
199   if (benchmark_ && model_params_file.empty()) {
200     InitParameters();
201   } else {
202     LoadParameters(model_params_file);
203   }
204 
205   int dtype = GetDataLayerType();
206   if (dtype == -1) {
207     throw std::runtime_error("Unsupported data layer type...");
208   }
209   args_map_["data"] = NDArray(input_shape_, global_ctx_, false, dtype);
210   Shape label_shape(input_shape_[0]);
211   args_map_["softmax_label"] = NDArray(label_shape, global_ctx_, false);
212   std::vector<NDArray> arg_arrays;
213   std::vector<NDArray> grad_arrays;
214   std::vector<OpReqType> grad_reqs;
215   std::vector<NDArray> aux_arrays;
216 
217   // infer and create ndarrays according to the given input ndarrays.
218   net_.InferExecutorArrays(global_ctx_, &arg_arrays, &grad_arrays, &grad_reqs,
219                            &aux_arrays, args_map_, std::map<std::string, NDArray>(),
220                            std::map<std::string, OpReqType>(), aux_map_);
221   for (auto& i : grad_reqs) i = OpReqType::kNullOp;
222 
223   // Create an executor after binding the model to input parameters.
224   executor_ = new Executor(net_, global_ctx_, arg_arrays, grad_arrays, grad_reqs, aux_arrays);
225 }
226 
227 /*
228  * The following function is used to get the data layer type for input data
229  */
GetDataLayerType()230 int Predictor::GetDataLayerType() {
231   int ret_type = -1;
232   if (data_layer_type_ == "float32") {
233     ret_type = kFloat32;
234   } else if (data_layer_type_ == "int8") {
235     ret_type = kInt8;
236   } else if (data_layer_type_ == "uint8") {
237     ret_type = kUint8;
238   } else {
239     LG << "Unsupported data layer type " << data_layer_type_ << "..."
240        << "Please use one of {float32, int8, uint8}";
241   }
242   return ret_type;
243 }
244 
245 /*
246  * create a new ImageRecordIter according to the given parameters
247  */
CreateImageRecordIter()248 bool Predictor::CreateImageRecordIter() {
249   val_iter_ = new MXDataIter("ImageRecordIter");
250   if (!FileExists(dataset_)) {
251     LG << "Error: " << dataset_ << " must be provided";
252     return false;
253   }
254 
255   std::vector<index_t> shape_vec;
256   for (index_t i = 1; i < input_shape_.ndim(); i++)
257     shape_vec.push_back(input_shape_[i]);
258   mxnet::TShape data_shape(shape_vec.begin(), shape_vec.end());
259 
260   // set image record parser parameters
261   val_iter_->SetParam("path_imgrec", dataset_);
262   val_iter_->SetParam("label_width", 1);
263   val_iter_->SetParam("data_shape", data_shape);
264   val_iter_->SetParam("preprocess_threads", data_nthreads_);
265   val_iter_->SetParam("shuffle_chunk_seed", shuffle_chunk_seed_);
266 
267   // set Batch parameters
268   val_iter_->SetParam("batch_size", input_shape_[0]);
269 
270   // image record parameters
271   val_iter_->SetParam("shuffle", true);
272   val_iter_->SetParam("seed", seed_);
273 
274   // set normalize parameters
275   val_iter_->SetParam("mean_r", rgb_mean_[0]);
276   val_iter_->SetParam("mean_g", rgb_mean_[1]);
277   val_iter_->SetParam("mean_b", rgb_mean_[2]);
278   val_iter_->SetParam("std_r", rgb_std_[0]);
279   val_iter_->SetParam("std_g", rgb_std_[1]);
280   val_iter_->SetParam("std_b", rgb_std_[2]);
281 
282   // set prefetcher parameters
283   if (use_gpu_) {
284     val_iter_->SetParam("ctx", "gpu");
285   } else {
286     val_iter_->SetParam("ctx", "cpu");
287   }
288   val_iter_->SetParam("dtype", data_layer_type_);
289 
290   val_iter_->CreateDataIter();
291   return true;
292 }
293 
294 /*
295  * The following function loads the model from json file.
296  */
LoadModel(const std::string & model_json_file)297 void Predictor::LoadModel(const std::string& model_json_file) {
298   if (!FileExists(model_json_file)) {
299     LG << "Model file " << model_json_file << " does not exist";
300     throw std::runtime_error("Model file does not exist");
301   }
302   LG << "Loading the model from " << model_json_file << std::endl;
303   net_ = Symbol::Load(model_json_file);
304   if (enable_tensorrt_) {
305     net_ = net_.GetBackendSymbol("TensorRT");
306   }
307 }
308 
309 /*
310  * The following function loads the model parameters.
311  */
LoadParameters(const std::string & model_parameters_file)312 void Predictor::LoadParameters(const std::string& model_parameters_file) {
313   if (!FileExists(model_parameters_file)) {
314     LG << "Parameter file " << model_parameters_file << " does not exist";
315     throw std::runtime_error("Model parameters does not exist");
316   }
317   LG << "Loading the model parameters from " << model_parameters_file << std::endl;
318   std::map<std::string, NDArray> parameters;
319   NDArray::Load(model_parameters_file, 0, &parameters);
320   if (enable_tensorrt_) {
321     std::map<std::string, NDArray> intermediate_args_map;
322     std::map<std::string, NDArray> intermediate_aux_map;
323     SplitParamMap(parameters, &intermediate_args_map, &intermediate_aux_map, Context::cpu());
324     contrib::InitTensorRTParams(net_, &intermediate_args_map, &intermediate_aux_map);
325     ConvertParamMapToTargetContext(intermediate_args_map, &args_map_, global_ctx_);
326     ConvertParamMapToTargetContext(intermediate_aux_map, &aux_map_, global_ctx_);
327   } else {
328     SplitParamMap(parameters, &args_map_, &aux_map_, global_ctx_);
329   }
330   /*WaitAll is need when we copy data between GPU and the main memory*/
331   NDArray::WaitAll();
332 }
333 
334 /*
335  * The following function split loaded param map into arg parm
336  *   and aux param with target context
337  */
SplitParamMap(const std::map<std::string,NDArray> & paramMap,std::map<std::string,NDArray> * argParamInTargetContext,std::map<std::string,NDArray> * auxParamInTargetContext,Context targetContext)338 void Predictor::SplitParamMap(const std::map<std::string, NDArray> &paramMap,
339     std::map<std::string, NDArray> *argParamInTargetContext,
340     std::map<std::string, NDArray> *auxParamInTargetContext,
341     Context targetContext) {
342   for (const auto& pair : paramMap) {
343     std::string type = pair.first.substr(0, 4);
344     std::string name = pair.first.substr(4);
345     if (type == "arg:") {
346       (*argParamInTargetContext)[name] = pair.second.Copy(targetContext);
347     } else if (type == "aux:") {
348       (*auxParamInTargetContext)[name] = pair.second.Copy(targetContext);
349     }
350   }
351 }
352 
353 /*
354  * The following function copy the param map into the target context
355  */
ConvertParamMapToTargetContext(const std::map<std::string,NDArray> & paramMap,std::map<std::string,NDArray> * paramMapInTargetContext,Context targetContext)356 void Predictor::ConvertParamMapToTargetContext(const std::map<std::string, NDArray> &paramMap,
357     std::map<std::string, NDArray> *paramMapInTargetContext,
358     Context targetContext) {
359   for (const auto& pair : paramMap) {
360     (*paramMapInTargetContext)[pair.first] = pair.second.Copy(targetContext);
361   }
362 }
363 
364 /*
365  * The following function randomly initializes the parameters when benchmark_ is true.
366  */
InitParameters()367 void Predictor::InitParameters() {
368   std::vector<mx_uint> data_shape;
369   for (index_t i = 0; i < input_shape_.ndim(); i++) {
370     data_shape.push_back(input_shape_[i]);
371   }
372 
373   std::map<std::string, std::vector<mx_uint> > arg_shapes;
374   std::vector<std::vector<mx_uint> > aux_shapes, in_shapes, out_shapes;
375   arg_shapes["data"] = data_shape;
376   net_.InferShape(arg_shapes, &in_shapes, &aux_shapes, &out_shapes);
377 
378   // initializer to call
379   Xavier xavier(Xavier::uniform, Xavier::avg, 2.0f);
380 
381   auto arg_name_list = net_.ListArguments();
382   for (index_t i = 0; i < in_shapes.size(); i++) {
383     const auto &shape = in_shapes[i];
384     const auto &arg_name = arg_name_list[i];
385     int paramType = kFloat32;
386     if (Initializer::StringEndWith(arg_name, "weight_quantize") ||
387         Initializer::StringEndWith(arg_name, "bias_quantize")) {
388       paramType = kInt8;
389     }
390     NDArray tmp_arr(shape, global_ctx_, false, paramType);
391     xavier(arg_name, &tmp_arr);
392     args_map_[arg_name] = tmp_arr.Copy(global_ctx_);
393   }
394 
395   auto aux_name_list = net_.ListAuxiliaryStates();
396   for (index_t i = 0; i < aux_shapes.size(); i++) {
397     const auto &shape = aux_shapes[i];
398     const auto &aux_name = aux_name_list[i];
399     NDArray tmp_arr(shape, global_ctx_, false);
400     xavier(aux_name, &tmp_arr);
401     aux_map_[aux_name] = tmp_arr.Copy(global_ctx_);
402   }
403   /*WaitAll is need when we copy data between GPU and the main memory*/
404   NDArray::WaitAll();
405 }
406 
407 /*
408  * The following function runs the forward pass on the model
409  * and use dummy data for benchmark.
410  */
BenchmarkScore(int num_inference_batches)411 void Predictor::BenchmarkScore(int num_inference_batches) {
412   // Create dummy data
413   std::vector<float> dummy_data(input_shape_.Size());
414   std::default_random_engine generator;
415   std::uniform_real_distribution<float> val(0.0f, 1.0f);
416   for (size_t i = 0; i < static_cast<size_t>(input_shape_.Size()); ++i) {
417     dummy_data[i] = static_cast<float>(val(generator));
418   }
419   executor_->arg_dict()["data"].SyncCopyFromCPU(
420         dummy_data.data(),
421         input_shape_.Size());
422   NDArray::WaitAll();
423 
424   LG << "Running the forward pass on model to evaluate the performance..";
425 
426   // warm up.
427   for (int i = 0; i < 5; i++) {
428     executor_->Forward(false);
429     NDArray::WaitAll();
430   }
431 
432   // Run the forward pass.
433   double ms = ms_now();
434   for (int i = 0; i < num_inference_batches; i++) {
435     executor_->Forward(false);
436     NDArray::WaitAll();
437   }
438   ms = ms_now() - ms;
439   LG << " benchmark completed!";
440   LG << " batch size: " << input_shape_[0] << " num batch: " << num_inference_batches
441      << " throughput: " << 1000.0 * input_shape_[0] * num_inference_batches / ms
442      << " imgs/s latency:" << ms / input_shape_[0] / num_inference_batches << " ms";
443 }
444 
445 /*
446  * \param skipped_batches skip the first number of batches
447  *
448  */
AdvanceDataIter(int skipped_batches)449 bool Predictor::AdvanceDataIter(int skipped_batches) {
450   assert(skipped_batches >= 0);
451   if (skipped_batches == 0) return true;
452   int skipped_count = 0;
453   while (val_iter_->Next()) {
454     if (++skipped_count >= skipped_batches) break;
455   }
456   if (skipped_count != skipped_batches) return false;
457   return true;
458 }
459 
460 /*
461  * The following function runs the forward pass on the model
462  * and use real data for testing accuracy and performance.
463  */
Score(int num_skipped_batches,int num_inference_batches)464 void Predictor::Score(int num_skipped_batches, int num_inference_batches) {
465   // Create metrics
466   Accuracy val_acc;
467 
468   val_iter_->Reset();
469   val_acc.Reset();
470   int nBatch = 0;
471 
472   if (!AdvanceDataIter(num_skipped_batches)) {
473     LG << "skipped batches should less than total batches!";
474     return;
475   }
476 
477   double ms = ms_now();
478   while (val_iter_->Next()) {
479     auto data_batch = val_iter_->GetDataBatch();
480     data_batch.data.CopyTo(&args_map_["data"]);
481     data_batch.label.CopyTo(&args_map_["softmax_label"]);
482     NDArray::WaitAll();
483 
484     // running on forward pass
485     executor_->Forward(false);
486     NDArray::WaitAll();
487     val_acc.Update(data_batch.label, executor_->outputs[0]);
488 
489     if (++nBatch >= num_inference_batches) {
490       break;
491     }
492   }
493   ms = ms_now() - ms;
494   auto args_name = net_.ListArguments();
495   LG << "INFO:" << "Dataset for inference: " << dataset_;
496   LG << "INFO:" << "label_name = " << args_name[args_name.size()-1];
497   LG << "INFO:" << "rgb_mean: " << "(" << rgb_mean_[0] << ", " << rgb_mean_[1]
498      << ", " << rgb_mean_[2] << ")";
499   LG << "INFO:" << "rgb_std: " << "(" << rgb_std_[0] << ", " << rgb_std_[1]
500      << ", " << rgb_std_[2] << ")";
501   LG << "INFO:" << "Image shape: " << "(" << input_shape_[1] << ", "
502      << input_shape_[2] << ", " << input_shape_[3] << ")";
503   LG << "INFO:" << "Finished inference with: " << nBatch * input_shape_[0]
504      << " images ";
505   LG << "INFO:" << "Batch size = " << input_shape_[0] << " for inference";
506   LG << "INFO:" << "Accuracy: " << val_acc.Get();
507   LG << "INFO:" << "Throughput: " << (1000.0 * nBatch * input_shape_[0] / ms)
508      << " images per second";
509 }
510 
~Predictor()511 Predictor::~Predictor() {
512   if (executor_) {
513     delete executor_;
514   }
515   if (!benchmark_ && val_iter_) {
516     delete val_iter_;
517   }
518   MXNotifyShutdown();
519 }
520 
521 /*
522  * Convert the input string of number into the vector.
523  */
524 template<typename T>
createVectorFromString(const std::string & input_string)525 std::vector<T> createVectorFromString(const std::string& input_string) {
526   std::vector<T> dst_vec;
527   char *p_next;
528   T elem;
529   bool bFloat = std::is_same<T, float>::value;
530   if (!bFloat) {
531     elem = strtol(input_string.c_str(), &p_next, 10);
532   } else {
533     elem = strtof(input_string.c_str(), &p_next);
534   }
535 
536   dst_vec.push_back(elem);
537   while (*p_next) {
538     if (!bFloat) {
539       elem = strtol(p_next, &p_next, 10);
540     } else {
541       elem = strtof(p_next, &p_next);
542     }
543     dst_vec.push_back(elem);
544   }
545   return dst_vec;
546 }
547 
printUsage()548 void printUsage() {
549     std::cout << "Usage:" << std::endl;
550     std::cout << "imagenet_inference --symbol_file <model symbol file in json format>" << std::endl
551               << "--params_file <model params file> " << std::endl
552               << "--dataset <dataset used to run inference> " << std::endl
553               << "--data_nthreads <default: 60> " << std::endl
554               << "--input_shape <shape of input image e.g \"3 224 224\">] " << std::endl
555               << "--rgb_mean <mean value to be subtracted on RGB channel e.g \"0 0 0\">"
556               << std::endl
557               << "--rgb_std <standard deviation on R/G/B channel. e.g \"1 1 1\"> " << std::endl
558               << "--batch_size <number of images per batch> " << std::endl
559               << "--num_skipped_batches <skip the number of batches for inference> " << std::endl
560               << "--num_inference_batches <number of batches used for inference> " << std::endl
561               << "--data_layer_type <default: \"float32\" "
562               << "choices: [\"float32\",\"int8\",\"uint8\"]>" << std::endl
563               << "--gpu  <whether to run inference on GPU, default: false>" << std::endl
564               << "--enableTRT  <whether to run inference with TensorRT, "
565               << "default: false>" << std::endl
566               << "--benchmark <whether to use dummy data to run inference, default: false>"
567               << std::endl;
568 }
569 
main(int argc,char ** argv)570 int main(int argc, char** argv) {
571   std::string model_file_json;
572   std::string model_file_params;
573   std::string dataset("");
574   std::string input_rgb_mean("0 0 0");
575   std::string input_rgb_std("1 1 1");
576   bool use_gpu = false;
577   bool enable_tensorrt = false;
578   bool benchmark = false;
579   int batch_size = 64;
580   int num_skipped_batches = 0;
581   int num_inference_batches = 100;
582   std::string data_layer_type("float32");
583   std::string input_shape("3 224 224");
584   int seed = 48564309;
585   int shuffle_chunk_seed = 3982304;
586   int data_nthreads = 60;
587 
588   int index = 1;
589   while (index < argc) {
590     if (strcmp("--symbol_file", argv[index]) == 0) {
591       index++;
592       model_file_json = (index < argc ? argv[index]:"");
593     } else if (strcmp("--params_file", argv[index]) == 0) {
594       index++;
595       model_file_params = (index < argc ? argv[index]:"");
596     } else if (strcmp("--dataset", argv[index]) == 0) {
597       index++;
598       dataset = (index < argc ? argv[index]:dataset);
599     } else if (strcmp("--data_nthreads", argv[index]) == 0) {
600       index++;
601       data_nthreads = strtol(argv[index], nullptr, 10);
602     } else if (strcmp("--input_shape", argv[index]) == 0) {
603       index++;
604       input_shape = (index < argc ? argv[index]:input_shape);
605     } else if (strcmp("--rgb_mean", argv[index]) == 0) {
606       index++;
607       input_rgb_mean = (index < argc ? argv[index]:input_rgb_mean);
608     } else if (strcmp("--rgb_std", argv[index]) == 0) {
609       index++;
610       input_rgb_std = (index < argc ? argv[index]:input_rgb_std);
611     } else if (strcmp("--batch_size", argv[index]) == 0) {
612       index++;
613       batch_size = strtol(argv[index], nullptr, 10);
614     }  else if (strcmp("--num_skipped_batches", argv[index]) == 0) {
615       index++;
616       num_skipped_batches = strtol(argv[index], nullptr, 10);
617     }  else if (strcmp("--num_inference_batches", argv[index]) == 0) {
618       index++;
619       num_inference_batches = strtol(argv[index], nullptr, 10);
620     } else if (strcmp("--data_layer_type", argv[index]) == 0) {
621       index++;
622       data_layer_type = (index < argc ? argv[index]:data_layer_type);
623     } else if (strcmp("--gpu", argv[index]) == 0) {
624       use_gpu = true;
625     } else if (strcmp("--enableTRT", argv[index]) == 0) {
626       use_gpu = true;
627       enable_tensorrt = true;
628     } else if (strcmp("--benchmark", argv[index]) == 0) {
629       benchmark = true;
630     } else if (strcmp("--help", argv[index]) == 0) {
631       printUsage();
632       return 0;
633     }
634     index++;
635   }
636 
637   if (model_file_json.empty()
638       || (!benchmark && model_file_params.empty())
639       || (enable_tensorrt && model_file_params.empty())) {
640     LG << "ERROR: Model details such as symbol, param files are not specified";
641     printUsage();
642     return 1;
643   }
644   std::vector<index_t> input_dimensions = createVectorFromString<index_t>(input_shape);
645   input_dimensions.insert(input_dimensions.begin(), batch_size);
646   Shape input_data_shape(input_dimensions);
647 
648   std::vector<float> rgb_mean = createVectorFromString<float>(input_rgb_mean);
649   std::vector<float> rgb_std = createVectorFromString<float>(input_rgb_std);
650 
651   // Initialize the predictor object
652   Predictor predict(model_file_json, model_file_params, input_data_shape, use_gpu, enable_tensorrt,
653                     dataset, data_nthreads, data_layer_type, rgb_mean, rgb_std, shuffle_chunk_seed,
654                     seed, benchmark);
655 
656   if (benchmark) {
657     predict.BenchmarkScore(num_inference_batches);
658   } else {
659     predict.Score(num_skipped_batches, num_inference_batches);
660   }
661   return 0;
662 }
663