1 #include <thread>
2 #include <iomanip>
3 #include <CLI11.hpp>
4 #include "dnn.hpp"
5 
6 // Function: measure_time_taskflow
measure_time_taskflow(unsigned num_epochs,unsigned num_threads)7 std::chrono::milliseconds measure_time_taskflow(
8   unsigned num_epochs,
9   unsigned num_threads
10 ) {
11   auto dnn {build_dnn(num_epochs)};
12   auto t1 = std::chrono::high_resolution_clock::now();
13   run_taskflow(dnn, num_threads);
14   auto t2 = std::chrono::high_resolution_clock::now();
15   return std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1);
16 }
17 
18 // Function: measure_time_omp
measure_time_omp(unsigned num_epochs,unsigned num_threads)19 std::chrono::milliseconds measure_time_omp(
20   unsigned num_epochs,
21   unsigned num_threads
22 ) {
23   auto dnn {build_dnn(num_epochs)};
24   auto t1 = std::chrono::high_resolution_clock::now();
25   run_omp(dnn, num_threads);
26   auto t2 = std::chrono::high_resolution_clock::now();
27   return std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1);
28 }
29 
30 // Function: measure_time_tbb
measure_time_tbb(unsigned num_epochs,unsigned num_threads)31 std::chrono::milliseconds measure_time_tbb(
32   unsigned num_epochs,
33   unsigned num_threads
34 ) {
35   auto dnn {build_dnn(num_epochs)};
36   auto t1 = std::chrono::high_resolution_clock::now();
37   run_tbb(dnn, num_threads);
38   auto t2 = std::chrono::high_resolution_clock::now();
39   return std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1);
40 }
41 
42 // Procedure
mnist(const std::string & model,const unsigned min_epochs,const unsigned max_epochs,const unsigned num_threads,const unsigned num_rounds)43 void mnist(
44   const std::string& model,
45   const unsigned min_epochs,
46   const unsigned max_epochs,
47   const unsigned num_threads,
48   const unsigned num_rounds
49 ) {
50 
51   std::cout << std::setw(12) << "epochs"
52             << std::setw(12) << "runtime"
53             << std::endl;
54 
55   for(unsigned epochs=min_epochs; epochs <= max_epochs; epochs += 10) {
56 
57     double runtime  {0.0};
58 
59     for(unsigned i=0; i<num_rounds; i++) {
60 
61       if(model == "tf") {
62         runtime += measure_time_taskflow(epochs, num_threads).count();
63       }
64       else if(model == "tbb") {
65         runtime += measure_time_tbb(epochs, num_threads).count();
66       }
67       else if(model == "omp") {
68         runtime += measure_time_omp(epochs, num_threads).count();
69       }
70       else assert(false);
71 
72       std::cout << std::setw(12) << epochs
73                 << std::setw(12) << runtime / num_rounds / 1e3
74                 << std::endl;
75     }
76   }
77 
78 }
79 
80 // Function: main
main(int argc,char * argv[])81 int main(int argc, char *argv[]){
82 
83   CLI::App app{"DNN Training on MNIST Dataset"};
84 
85   unsigned num_threads {1};
86   app.add_option("-t,--num_threads", num_threads, "number of threads (default=1)");
87 
88   unsigned max_epochs {100};
89   app.add_option("-E,--max_epochs", max_epochs, "max number of epochs (default=100)");
90 
91   unsigned min_epochs {10};
92   app.add_option("-e,--min_epochs", min_epochs, "min number of epochs (default=10)");
93 
94   unsigned num_rounds {1};
95   app.add_option("-r,--num_rounds", num_rounds, "number of rounds (default=1)");
96 
97   std::string model = "tf";
98   app.add_option("-m,--model", model, "model name tbb|omp|tf (default=tf)")
99      ->check([] (const std::string& m) {
100         if(m != "tbb" && m != "omp" && m != "tf") {
101           return "model name should be \"tbb\", \"omp\", or \"tf\"";
102         }
103         return "";
104      });
105 
106   CLI11_PARSE(app, argc, argv);
107 
108   std::cout << "model=" << model << ' '
109             << "num_threads=" << num_threads << ' '
110             << "num_rounds=" << num_rounds << ' '
111             << "min_epochs=" << min_epochs << ' '
112             << "max_epochs=" << max_epochs << ' '
113             << std::endl;
114 
115   mnist(model, min_epochs, max_epochs, num_threads, num_rounds);
116 
117   return EXIT_SUCCESS;
118 }
119 
120 
121 
122 
123