1 /*
2     Copyright (c) 2005-2021 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 #ifndef TBB_examples_utility_H
18 #define TBB_examples_utility_H
19 
20 #include <cassert>
21 #include <cstring>
22 #include <cstdlib>
23 
24 #include <utility>
25 #include <string>
26 #include <vector>
27 #include <map>
28 #include <set>
29 #include <algorithm>
30 #include <sstream>
31 #include <numeric>
32 #include <stdexcept>
33 #include <memory>
34 #include <iostream>
35 // TBB headers should not be used, as some examples may need to be built without TBB.
36 
37 namespace utility {
38 namespace internal {
39 
40 //TODO: add tcs
41 template <class dest_type>
string_to(std::string const & s,dest_type & result)42 dest_type& string_to(std::string const& s, dest_type& result) {
43     std::stringstream stream(s);
44     stream >> result;
45     if ((!stream) || (stream.fail())) {
46         throw std::invalid_argument("error converting string '" + std::string(s) + "'");
47     }
48     return result;
49 }
50 
51 template <class dest_type>
string_to(std::string const & s)52 dest_type string_to(std::string const& s) {
53     dest_type result;
54     return string_to(s, result);
55 }
56 
57 template <typename>
58 struct is_bool {
valueutility::internal::is_bool59     static bool value() {
60         return false;
61     }
62 };
63 template <>
64 struct is_bool<bool> {
valueutility::internal::is_bool65     static bool value() {
66         return true;
67     }
68 };
69 
70 class type_base {
71     type_base& operator=(const type_base&);
72 
73 public:
74     const std::string name;
75     const std::string description;
76 
type_base(std::string a_name,std::string a_description)77     type_base(std::string a_name, std::string a_description)
78             : name(a_name),
79               description(a_description) {}
80     virtual void parse_and_store(const std::string& s) = 0;
81     virtual std::string value() const = 0;
82     virtual std::unique_ptr<type_base> clone() const = 0;
~type_base()83     virtual ~type_base() {}
84 };
85 template <typename type>
86 class type_impl : public type_base {
87 private:
type_impl(const type_impl & src)88     type_impl(const type_impl& src)
89             : type_base(src.name, src.description),
90               target(src.target),
91               validating_function(src.validating_function) {}
92     type_impl& operator=(const type_impl&);
93     typedef bool (*validating_function_type)(const type&);
94     type& target;
95     validating_function_type validating_function;
96 
97 public:
type_impl(std::string a_name,std::string a_description,type & a_target,validating_function_type a_validating_function=nullptr)98     type_impl(std::string a_name,
99               std::string a_description,
100               type& a_target,
101               validating_function_type a_validating_function = nullptr)
102             : type_base(a_name, a_description),
103               target(a_target),
104               validating_function(a_validating_function){};
parse_and_store(const std::string & s)105     void parse_and_store(const std::string& s) /*override*/ {
106         try {
107             const bool is_bool = internal::is_bool<type>::value();
108             if (is_bool && s.empty()) {
109                 //to avoid directly assigning true
110                 //(as it will impose additional layer of indirection)
111                 //so, simply pass it as string
112                 internal::string_to("1", target);
113             }
114             else {
115                 internal::string_to(s, target);
116             }
117         }
118         catch (std::invalid_argument& e) {
119             std::stringstream str;
120             str << "'" << s << "' is incorrect input for argument '" << name << "'"
121                 << " (" << e.what() << ")";
122             throw std::invalid_argument(str.str());
123         }
124         if (validating_function) {
125             if (!((validating_function)(target))) {
126                 std::stringstream str;
127                 str << "'" << target << "' is invalid value for argument '" << name << "'";
128                 throw std::invalid_argument(str.str());
129             }
130         }
131     }
132     template <typename t>
is_null_c_str(t &)133     static bool is_null_c_str(t&) {
134         return false;
135     }
is_null_c_str(char * s)136     static bool is_null_c_str(char* s) {
137         return s == nullptr;
138     }
value() const139     std::string value() const /*override*/ {
140         std::stringstream str;
141         if (!is_null_c_str(target))
142             str << target;
143         return str.str();
144     }
clone() const145     std::unique_ptr<type_base> clone() const /*override*/ {
146         return std::unique_ptr<type_base>(new type_impl(*this));
147     }
148 };
149 
150 class argument {
151 private:
152     std::unique_ptr<type_base> p_type;
153     bool matched_;
154 
155 public:
argument(argument const & other)156     argument(argument const& other)
157             : p_type(other.p_type.get() ? (other.p_type->clone()).release() : nullptr),
158               matched_(other.matched_) {}
operator =(argument a)159     argument& operator=(argument a) {
160         this->swap(a);
161         return *this;
162     }
swap(argument & other)163     void swap(argument& other) {
164         std::swap(p_type, other.p_type);
165         std::swap(matched_, other.matched_);
166     }
167     template <class type>
argument(std::string a_name,std::string a_description,type & dest,bool (* a_validating_function)(const type &)=nullptr)168     argument(std::string a_name,
169              std::string a_description,
170              type& dest,
171              bool (*a_validating_function)(const type&) = nullptr)
172             : p_type(new type_impl<type>(a_name, a_description, dest, a_validating_function)),
173               matched_(false) {}
value() const174     std::string value() const {
175         return p_type->value();
176     }
name() const177     std::string name() const {
178         return p_type->name;
179     }
description() const180     std::string description() const {
181         return p_type->description;
182     }
parse_and_store(const std::string & s)183     void parse_and_store(const std::string& s) {
184         p_type->parse_and_store(s);
185         matched_ = true;
186     }
is_matched() const187     bool is_matched() const {
188         return matched_;
189     }
190 };
191 } // namespace internal
192 
193 class cli_argument_pack {
194     typedef std::map<std::string, internal::argument> args_map_type;
195     typedef std::vector<std::string> args_display_order_type;
196     typedef std::vector<std::string> positional_arg_names_type;
197 
198 private:
199     args_map_type args_map;
200     args_display_order_type args_display_order;
201     positional_arg_names_type positional_arg_names;
202     std::set<std::string> bool_args_names;
203 
204 private:
add_arg(internal::argument const & a)205     void add_arg(internal::argument const& a) {
206         std::pair<args_map_type::iterator, bool> result =
207             args_map.insert(std::make_pair(a.name(), a));
208         if (!result.second) {
209             throw std::invalid_argument("argument with name: '" + a.name() +
210                                         "' already registered");
211         }
212         args_display_order.push_back(a.name());
213     }
214 
215 public:
216     template <typename type>
arg(type & dest,std::string const & name,std::string const & description,bool (* validate)(const type &)=nullptr)217     cli_argument_pack& arg(type& dest,
218                            std::string const& name,
219                            std::string const& description,
220                            bool (*validate)(const type&) = nullptr) {
221         internal::argument a(name, description, dest, validate);
222         add_arg(a);
223         if (internal::is_bool<type>::value()) {
224             bool_args_names.insert(name);
225         }
226         return *this;
227     }
228 
229     //Positional means that argument name can be omitted in actual CL
230     //only key to match values for parameters with
231     template <typename type>
positional_arg(type & dest,std::string const & name,std::string const & description,bool (* validate)(const type &)=nullptr)232     cli_argument_pack& positional_arg(type& dest,
233                                       std::string const& name,
234                                       std::string const& description,
235                                       bool (*validate)(const type&) = nullptr) {
236         internal::argument a(name, description, dest, validate);
237         add_arg(a);
238         if (internal::is_bool<type>::value()) {
239             bool_args_names.insert(name);
240         }
241         positional_arg_names.push_back(name);
242         return *this;
243     }
244 
parse(std::size_t argc,char const * argv[])245     void parse(std::size_t argc, char const* argv[]) {
246         {
247             std::size_t current_positional_index = 0;
248             for (std::size_t j = 1; j < argc; j++) {
249                 internal::argument* pa = nullptr;
250                 std::string argument_value;
251 
252                 const char* const begin = argv[j];
253                 const char* const end = begin + std::strlen(argv[j]);
254 
255                 const char* const assign_sign = std::find(begin, end, '=');
256 
257                 struct throw_unknown_parameter {
258                     static void _(std::string const& location) {
259                         throw std::invalid_argument(std::string("unknown parameter starting at:'") +
260                                                     location + "'");
261                     }
262                 };
263                 //first try to interpret it like parameter=value string
264                 if (assign_sign != end) {
265                     std::string name_found = std::string(begin, assign_sign);
266                     args_map_type::iterator it = args_map.find(name_found);
267 
268                     if (it != args_map.end()) {
269                         pa = &((*it).second);
270                         argument_value = std::string(assign_sign + 1, end);
271                     }
272                     else {
273                         throw_unknown_parameter::_(argv[j]);
274                     }
275                 }
276                 //then see is it a named flag
277                 else {
278                     args_map_type::iterator it = args_map.find(argv[j]);
279                     if (it != args_map.end()) {
280                         pa = &((*it).second);
281                         argument_value = "";
282                     }
283                     //then try it as positional argument without name specified
284                     else if (current_positional_index < positional_arg_names.size()) {
285                         std::stringstream str(argv[j]);
286                         args_map_type::iterator found_positional_arg =
287                             args_map.find(positional_arg_names.at(current_positional_index));
288                         //TODO: probably use of smarter assert would help here
289                         assert(
290                             found_positional_arg !=
291                             args_map
292                                 .end() /*&&"positional_arg_names and args_map are out of sync"*/);
293                         if (found_positional_arg == args_map.end()) {
294                             throw std::logic_error(
295                                 "positional_arg_names and args_map are out of sync");
296                         }
297                         pa = &((*found_positional_arg).second);
298                         argument_value = argv[j];
299 
300                         current_positional_index++;
301                     }
302                     else {
303                         //TODO: add tc to check
304                         throw_unknown_parameter::_(argv[j]);
305                     }
306                 }
307                 assert(pa);
308                 if (pa->is_matched()) {
309                     throw std::invalid_argument(std::string("several values specified for: '") +
310                                                 pa->name() + "' argument");
311                 }
312                 pa->parse_and_store(argument_value);
313             }
314         }
315     }
usage_string(const std::string & binary_name) const316     std::string usage_string(const std::string& binary_name) const {
317         std::string command_line_params;
318         std::string summary_description;
319 
320         for (args_display_order_type::const_iterator it = args_display_order.begin();
321              it != args_display_order.end();
322              ++it) {
323             const bool is_bool = (0 != bool_args_names.count((*it)));
324             args_map_type::const_iterator argument_it = args_map.find(*it);
325             //TODO: probably use of smarter assert would help here
326             assert(argument_it !=
327                    args_map.end() /*&&"args_display_order and args_map are out of sync"*/);
328             if (argument_it == args_map.end()) {
329                 throw std::logic_error("args_display_order and args_map are out of sync");
330             }
331             const internal::argument& a = (*argument_it).second;
332             command_line_params += " [" + a.name() + (is_bool ? "" : "=value") + "]";
333             summary_description +=
334                 " " + a.name() + " - " + a.description() + " (" + a.value() + ")" + "\n";
335         }
336 
337         std::string positional_arg_cl;
338         for (positional_arg_names_type::const_iterator it = positional_arg_names.begin();
339              it != positional_arg_names.end();
340              ++it) {
341             positional_arg_cl += " [" + (*it);
342         }
343         for (std::size_t i = 0; i < positional_arg_names.size(); ++i) {
344             positional_arg_cl += "]";
345         }
346         command_line_params += positional_arg_cl;
347         std::stringstream str;
348         str << " Program usage is:"
349             << "\n"
350             << " " << binary_name << command_line_params << "\n"
351             << "\n"
352             << " where:"
353             << "\n"
354             << summary_description;
355         return str.str();
356     }
357 }; // class cli_argument_pack
358 
359 namespace internal {
360 template <typename T>
is_power_of_2(T val)361 bool is_power_of_2(T val) {
362     std::size_t intval = std::size_t(val);
363     return (intval & (intval - 1)) == std::size_t(0);
364 }
step_function_plus(int previous,double step)365 int step_function_plus(int previous, double step) {
366     return static_cast<int>(previous + step);
367 }
step_function_multiply(int previous,double multiply)368 int step_function_multiply(int previous, double multiply) {
369     return static_cast<int>(previous * multiply);
370 }
371 // "Power-of-2 ladder": nsteps is the desired number of steps between any subsequent powers of 2.
372 // The actual step is the quotient of the nearest smaller power of 2 divided by that number (but at least 1).
373 // E.g., '1:32:#4' means 1,2,3,4,5,6,7,8,10,12,14,16,20,24,28,32
step_function_power2_ladder(int previous,double nsteps)374 int step_function_power2_ladder(int previous, double nsteps) {
375     int steps = int(nsteps);
376     assert(is_power_of_2(steps)); // must be a power of 2
377     // The actual step is 1 until the value is twice as big as nsteps
378     if (previous < 2 * steps)
379         return previous + 1;
380     // calculate the previous power of 2
381     int prev_power2 = previous / 2; // start with half the given value
382     int rshift = 1; // and with the shift of 1;
383     while (int shifted =
384                prev_power2 >> rshift) { // shift the value right; while the result is non-zero,
385         prev_power2 |= shifted; //   add the bits set in 'shifted';
386         rshift <<= 1; //   double the shift, as twice as many top bits are set;
387     } // repeat.
388     ++prev_power2; // all low bits set; now it's just one less than the desired power of 2
389     assert(is_power_of_2(prev_power2));
390     assert((prev_power2 <= previous) && (2 * prev_power2 > previous));
391     // The actual step value is the previous power of 2 divided by steps
392     return previous + (prev_power2 / steps);
393 }
394 typedef int (*step_function_ptr_type)(int, double);
395 
396 struct step_function_descriptor {
397     char mnemonic;
398     step_function_ptr_type function;
399 
400 public:
step_function_descriptorutility::internal::step_function_descriptor401     step_function_descriptor(char a_mnemonic, step_function_ptr_type a_function)
402             : mnemonic(a_mnemonic),
403               function(a_function) {}
404 
405 private:
406     void operator=(step_function_descriptor const&);
407 };
408 step_function_descriptor step_function_descriptors[] = {
409     step_function_descriptor('*', step_function_multiply),
410     step_function_descriptor('+', step_function_plus),
411     step_function_descriptor('#', step_function_power2_ladder)
412 };
413 
414 template <typename T, std::size_t N>
array_length(const T (&)[N])415 inline std::size_t array_length(const T (&)[N]) {
416     return N;
417 }
418 
419 struct thread_range_step {
420     step_function_ptr_type step_function;
421     double step_function_argument;
422 
thread_range_steputility::internal::thread_range_step423     thread_range_step(step_function_ptr_type step_function_, double step_function_argument_)
424             : step_function(step_function_),
425               step_function_argument(step_function_argument_) {
426         if (!step_function_)
427             throw std::invalid_argument(
428                 "step_function for thread range step should not be nullptr");
429     }
operator ()utility::internal::thread_range_step430     int operator()(int previous) const {
431         assert(0 <= previous); // test 0<=first and loop discipline
432         const int ret = step_function(previous, step_function_argument);
433         assert(previous < ret);
434         return ret;
435     }
operator >>(std::istream & input_stream,thread_range_step & step)436     friend std::istream& operator>>(std::istream& input_stream, thread_range_step& step) {
437         char function_char;
438         double function_argument;
439         input_stream >> function_char >> function_argument;
440         std::size_t i = 0;
441         while ((i < array_length(step_function_descriptors)) &&
442                (step_function_descriptors[i].mnemonic != function_char))
443             ++i;
444         if (i >= array_length(step_function_descriptors)) {
445             throw std::invalid_argument("unknown step function mnemonic: " +
446                                         std::string(1, function_char));
447         }
448         else if ((function_char == '#') && !is_power_of_2(function_argument)) {
449             throw std::invalid_argument("the argument of # should be a power of 2");
450         }
451         step.step_function = step_function_descriptors[i].function;
452         step.step_function_argument = function_argument;
453         return input_stream;
454     }
455 };
456 } // namespace internal
457 
458 struct thread_number_range {
459     int (*auto_number_of_threads)();
460     int first; // 0<=first (0 can be used as a special value)
461     int last; // first<=last
462 
463     ::utility::internal::thread_range_step step;
464 
thread_number_rangeutility::thread_number_range465     thread_number_range(
466         int (*auto_number_of_threads_)(),
467         int low_ = 1,
468         int high_ = -1,
469         ::utility::internal::thread_range_step step_ =
470             ::utility::internal::thread_range_step(::utility::internal::step_function_power2_ladder,
471                                                    4))
472             : auto_number_of_threads(auto_number_of_threads_),
473               first(low_),
474               last((high_ > -1) ? high_ : auto_number_of_threads_()),
475               step(step_) {
476         if (first < 0) {
477             throw std::invalid_argument("negative value not allowed");
478         }
479         if (first > last) {
480             throw std::invalid_argument("decreasing sequence not allowed");
481         }
482     }
operator >>(std::istream & i,thread_number_range & range)483     friend std::istream& operator>>(std::istream& i, thread_number_range& range) {
484         try {
485             std::string s;
486             i >> s;
487             struct string_to_number_of_threads {
488                 int auto_value;
489                 string_to_number_of_threads(int auto_value_) : auto_value(auto_value_) {}
490                 int operator()(const std::string& value) const {
491                     return (value == "auto") ? auto_value : internal::string_to<int>(value);
492                 }
493             };
494             string_to_number_of_threads string_to_number_of_threads(range.auto_number_of_threads());
495             int low, high;
496             std::size_t colon = s.find(':');
497             if (colon == std::string::npos) {
498                 low = high = string_to_number_of_threads(s);
499             }
500             else {
501                 //it is a range
502                 std::size_t second_colon = s.find(':', colon + 1);
503 
504                 low = string_to_number_of_threads(std::string(s, 0, colon)); //not copying the colon
505                 high = string_to_number_of_threads(
506                     std::string(s, colon + 1, second_colon - (colon + 1))); //not copying the colons
507                 if (second_colon != std::string::npos) {
508                     internal::string_to(std::string(s, second_colon + 1), range.step);
509                 }
510             }
511             range = thread_number_range(range.auto_number_of_threads, low, high, range.step);
512         }
513         catch (std::invalid_argument&) {
514             i.setstate(std::ios::failbit);
515             throw;
516         }
517         return i;
518     }
operator <<(std::ostream & o,thread_number_range const & range)519     friend std::ostream& operator<<(std::ostream& o, thread_number_range const& range) {
520         using namespace internal;
521         std::size_t i = 0;
522         for (; i < array_length(step_function_descriptors) &&
523                step_function_descriptors[i].function != range.step.step_function;
524              ++i) {
525         }
526         if (i >= array_length(step_function_descriptors)) {
527             throw std::invalid_argument("unknown step function for thread range");
528         }
529         o << range.first << ":" << range.last << ":" << step_function_descriptors[i].mnemonic
530           << range.step.step_function_argument;
531         return o;
532     }
533 }; // struct thread_number_range
534 //TODO: fix unused warning here
535 static const char* thread_number_range_desc =
536     "number of threads to use; a range of the form low[:high[:(+|*|#)step]],"
537     "\n\twhere low and optional high are non-negative integers or 'auto' for the default choice,"
538     "\n\tand optional step expression specifies how thread numbers are chosen within the range.";
539 
report_elapsed_time(double seconds)540 inline void report_elapsed_time(double seconds) {
541     std::cout << "elapsed time : " << seconds << " seconds"
542               << "\n";
543 }
544 
report_skipped()545 inline void report_skipped() {
546     std::cout << "skip"
547               << "\n";
548 }
549 
parse_cli_arguments(int argc,const char * argv[],utility::cli_argument_pack cli_pack)550 inline void parse_cli_arguments(int argc, const char* argv[], utility::cli_argument_pack cli_pack) {
551     bool show_help = false;
552     cli_pack.arg(show_help, "-h", "show this message");
553 
554     bool invalid_input = false;
555     try {
556         cli_pack.parse(argc, argv);
557     }
558     catch (std::exception& e) {
559         std::cerr << "error occurred while parsing command line."
560                   << "\n"
561                   << "error text: " << e.what() << "\n"
562                   << std::flush;
563         invalid_input = true;
564     }
565     if (show_help || invalid_input) {
566         std::cout << cli_pack.usage_string(argv[0]) << std::flush;
567         std::exit(0);
568     }
569 }
parse_cli_arguments(int argc,char * argv[],utility::cli_argument_pack cli_pack)570 inline void parse_cli_arguments(int argc, char* argv[], utility::cli_argument_pack cli_pack) {
571     parse_cli_arguments(argc, const_cast<const char**>(argv), cli_pack);
572 }
573 } // namespace utility
574 
575 #endif /* TBB_examples_utility_H */
576