1 /**
2  * @file core/hpt/deduce_hp_types.hpp
3  * @author Kirill Mishchenko
4  *
5  * Tools to deduce types of hyper-parameters from types of arguments in the
6  * Optimize method in HyperParameterTuner.
7  *
8  * mlpack is free software; you may redistribute it and/or modify it under the
9  * terms of the 3-clause BSD license.  You should have received a copy of the
10  * 3-clause BSD license along with mlpack.  If not, see
11  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12  */
13 #ifndef MLPACK_CORE_HPT_DEDUCE_HP_TYPES_HPP
14 #define MLPACK_CORE_HPT_DEDUCE_HP_TYPES_HPP
15 
16 #include <mlpack/core/hpt/fixed.hpp>
17 
18 namespace mlpack {
19 namespace hpt {
20 
21 /**
22  * A type function for deducing types of hyper-parameters from types of
23  * arguments in the Optimize method in HyperParameterTuner.
24  *
25  * We start by putting all types of the arguments into Args, and then process
26  * each of them one by one and put results into the internal struct
27  * ResultHolder. By the end Args become empty, while ResultHolder holds the
28  * tuple type of hyper-parameters.
29  *
30  * Here we declare and define DeduceHyperParameterTypes for the end phase when
31  * Args are empty (all argument types have been processed).
32  */
33 template<typename... Args>
34 struct DeduceHyperParameterTypes
35 {
36   template<typename... HPTypes>
37   struct ResultHolder
38   {
39     using TupleType = std::tuple<HPTypes...>;
40   };
41 };
42 
43 /**
44  * Defining DeduceHyperParameterTypes for the case when not all argument types
45  * have been processed, and the next one (T) is a collection type or an
46  * arithmetic type.
47  */
48 template<typename T, typename... Args>
49 struct DeduceHyperParameterTypes<T, Args...>
50 {
51   /**
52    * A type function to deduce the result hyper-parameter type for ArgumentType.
53    */
54   template<typename ArgumentType,
55            bool IsArithmetic = std::is_arithmetic<ArgumentType>::value>
56   struct ResultHPType;
57 
58   template<typename ArithmeticType>
59   struct ResultHPType<ArithmeticType, true>
60   {
61     using Type = ArithmeticType;
62   };
63 
64   /**
65    * A type function to check whether Type is a collection type (for that it
66    * should define value_type).
67    */
68   template<typename Type>
69   struct IsCollectionType
70   {
71     using Yes = char[1];
72     using No = char[2];
73 
74     template<typename TypeToCheck>
75     static Yes& Check(typename TypeToCheck::value_type*);
76     template<typename>
77     static No& Check(...);
78 
79     static const bool value  =
80       sizeof(decltype(Check<Type>(0))) == sizeof(Yes);
81   };
82 
83   template<typename CollectionType>
84   struct ResultHPType<CollectionType, false>
85   {
86     static_assert(IsCollectionType<CollectionType>::value,
87         "One of the passed arguments is neither of an arithmetic type, nor of "
88         "a collection type, nor fixed with the Fixed function.");
89 
90     using Type = typename CollectionType::value_type;
91   };
92 
93   template<typename... HPTypes>
94   struct ResultHolder
95   {
96     using TupleType = typename DeduceHyperParameterTypes<Args...>::template
97         ResultHolder<HPTypes..., typename ResultHPType<T>::Type>::TupleType;
98   };
99 
100   using TupleType = typename ResultHolder<>::TupleType;
101 };
102 
103 /**
104  * Defining DeduceHyperParameterTypes for the case when not all argument types
105  * have been processed, and the next one is the type of an argument that should
106  * be fixed.
107  */
108 template<typename T, typename... Args>
109 struct DeduceHyperParameterTypes<PreFixedArg<T>, Args...>
110 {
111   template<typename... HPTypes>
112   struct ResultHolder
113   {
114     using TupleType = typename DeduceHyperParameterTypes<Args...>::template
115         ResultHolder<HPTypes...>::TupleType;
116   };
117 
118   using TupleType = typename ResultHolder<>::TupleType;
119 };
120 
121 /**
122  * A short alias for deducing types of hyper-parameters from types of arguments
123  * in the Optimize method in HyperParameterTuner.
124  */
125 template<typename... Args>
126 using TupleOfHyperParameters =
127     typename DeduceHyperParameterTypes<Args...>::TupleType;
128 
129 } // namespace hpt
130 } // namespace mlpack
131 
132 #endif
133