1 /**
2  * @file arma_traits.hpp
3  * @author Ryan Curtin
4  * @author Marcus Edel
5  *
6  * Some traits used for template metaprogramming (SFINAE) with Armadillo types.
7  *
8  * ensmallen is free software; you may redistribute it and/or modify it under
9  * the terms of the 3-clause BSD license.  You should have received a copy of
10  * the 3-clause BSD license along with ensmallen.  If not, see
11  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12  */
13 #ifndef ENSMALLEN_UTILITY_ARMA_TRAITS_HPP
14 #define ENSMALLEN_UTILITY_ARMA_TRAITS_HPP
15 
16 namespace ens {
17 
18 // Structs have public members by default (that's why they are chosen over
19 // classes).
20 
21 /**
22  * If value == true, then MatType is some sort of Armadillo vector or subview.
23  * You might use this struct like this:
24  *
25  * @code
26  * // Only accepts VecTypes that are actually Armadillo vector types.
27  * template<typename MatType>
28  * void Function(const MatType& argumentA,
29  *               typename std::enable_if_t<IsArmaType<MatType>::value>* = 0);
30  * @endcode
31  *
32  * The use of the enable_if_t object allows the compiler to instantiate
33  * Function() only if VecType is one of the Armadillo vector types.  It has a
34  * default argument because it isn't meant to be used in either the function
35  * call or the function body.
36  */
37 template<typename MatType>
38 struct IsArmaType
39 {
40   const static bool value = false;
41 };
42 
43 // Commenting out the first template per case, because
44 // Visual Studio doesn't like this instantiaion pattern (error C2910).
45 // template<>
46 template<typename eT>
47 struct IsArmaType<arma::Col<eT> >
48 {
49   const static bool value = true;
50 };
51 
52 // template<>
53 template<typename eT>
54 struct IsArmaType<arma::SpCol<eT> >
55 {
56   const static bool value = true;
57 };
58 
59 // template<>
60 template<typename eT>
61 struct IsArmaType<arma::Row<eT> >
62 {
63   const static bool value = true;
64 };
65 
66 // template<>
67 template<typename eT>
68 struct IsArmaType<arma::SpRow<eT> >
69 {
70   const static bool value = true;
71 };
72 
73 // template<>
74 template<typename eT>
75 struct IsArmaType<arma::subview<eT> >
76 {
77   const static bool value = true;
78 };
79 
80 // template<>
81 template<typename eT>
82 struct IsArmaType<arma::subview_col<eT> >
83 {
84   const static bool value = true;
85 };
86 
87 // template<>
88 template<typename eT>
89 struct IsArmaType<arma::subview_row<eT> >
90 {
91   const static bool value = true;
92 };
93 
94 // template<>
95 template<typename eT>
96 struct IsArmaType<arma::SpSubview<eT> >
97 {
98   const static bool value = true;
99 };
100 
101 
102 #if ((ARMA_VERSION_MAJOR >= 10) || \
103     ((ARMA_VERSION_MAJOR == 9) && (ARMA_VERSION_MINOR >= 869)))
104 
105   // Armadillo 9.869+ has SpSubview_col and SpSubview_row
106 
107   template<typename eT>
108   struct IsArmaType<arma::SpSubview_col<eT> >
109   {
110     const static bool value = true;
111   };
112 
113   template<typename eT>
114   struct IsArmaType<arma::SpSubview_row<eT> >
115   {
116     const static bool value = true;
117   };
118 
119 #endif
120 
121 
122 // template<>
123 template<typename eT>
124 struct IsArmaType<arma::Mat<eT> >
125 {
126   const static bool value = true;
127 };
128 
129 // template<>
130 template<typename eT>
131 struct IsArmaType<arma::SpMat<eT> >
132 {
133   const static bool value = true;
134 };
135 
136 // template<>
137 template<typename eT>
138 struct IsArmaType<arma::Cube<eT> >
139 {
140   const static bool value = true;
141 };
142 
143 // template<>
144 template<typename eT>
145 struct IsArmaType<arma::subview_cube<eT> >
146 {
147   const static bool value = true;
148 };
149 
150 
151 template <int N, typename... T>
152 struct tuple_element;
153 
154 template <typename T0, typename... T>
155 struct tuple_element<0, T0, T...> {
156     typedef T0 type;
157 };
158 template <int N, typename T0, typename... T>
159 struct tuple_element<N, T0, T...> {
160     typedef typename tuple_element<N-1, T...>::type type;
161 };
162 
163 } // namespace ens
164 
165 #endif
166