1 //
2 // Copyright (c) 2018-2019, Cem Bassoy, cem.bassoy@gmail.com
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See
5 // accompanying file LICENSE_1_0.txt or copy at
6 // http://www.boost.org/LICENSE_1_0.txt)
7 //
8 // The authors gratefully acknowledge the support of
9 // Fraunhofer IOSB, Ettlingen, Germany
10 //
11
12 #ifndef _BOOST_UBLAS_TENSOR_EXPRESSIONS_EVALUATION_HPP_
13 #define _BOOST_UBLAS_TENSOR_EXPRESSIONS_EVALUATION_HPP_
14
15 #include <type_traits>
16 #include <stdexcept>
17
18
19 namespace boost::numeric::ublas {
20
21 template<class element_type, class storage_format, class storage_type>
22 class tensor;
23
24 template<class size_type>
25 class basic_extents;
26
27 }
28
29 namespace boost::numeric::ublas::detail {
30
31 template<class T, class D>
32 struct tensor_expression;
33
34 template<class T, class EL, class ER, class OP>
35 struct binary_tensor_expression;
36
37 template<class T, class E, class OP>
38 struct unary_tensor_expression;
39
40 }
41
42 namespace boost::numeric::ublas::detail {
43
44 template<class T, class E>
45 struct has_tensor_types
46 { static constexpr bool value = false; };
47
48 template<class T>
49 struct has_tensor_types<T,T>
50 { static constexpr bool value = true; };
51
52 template<class T, class D>
53 struct has_tensor_types<T, tensor_expression<T,D>>
54 { static constexpr bool value = std::is_same<T,D>::value || has_tensor_types<T,D>::value; };
55
56
57 template<class T, class EL, class ER, class OP>
58 struct has_tensor_types<T, binary_tensor_expression<T,EL,ER,OP>>
59 { static constexpr bool value = std::is_same<T,EL>::value || std::is_same<T,ER>::value || has_tensor_types<T,EL>::value || has_tensor_types<T,ER>::value; };
60
61 template<class T, class E, class OP>
62 struct has_tensor_types<T, unary_tensor_expression<T,E,OP>>
63 { static constexpr bool value = std::is_same<T,E>::value || has_tensor_types<T,E>::value; };
64
65 } // namespace boost::numeric::ublas::detail
66
67
68 namespace boost::numeric::ublas::detail {
69
70
71
72
73
74 /** @brief Retrieves extents of the tensor
75 *
76 */
77 template<class T, class F, class A>
retrieve_extents(tensor<T,F,A> const & t)78 auto retrieve_extents(tensor<T,F,A> const& t)
79 {
80 return t.extents();
81 }
82
83 /** @brief Retrieves extents of the tensor expression
84 *
85 * @note tensor expression must be a binary tree with at least one tensor type
86 *
87 * @returns extents of the child expression if it is a tensor or extents of one child of its child.
88 */
89 template<class T, class D>
retrieve_extents(tensor_expression<T,D> const & expr)90 auto retrieve_extents(tensor_expression<T,D> const& expr)
91 {
92 static_assert(detail::has_tensor_types<T,tensor_expression<T,D>>::value,
93 "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
94
95 auto const& cast_expr = static_cast<D const&>(expr);
96
97 if constexpr ( std::is_same<T,D>::value )
98 return cast_expr.extents();
99 else
100 return retrieve_extents(cast_expr);
101 }
102
103 /** @brief Retrieves extents of the binary tensor expression
104 *
105 * @note tensor expression must be a binary tree with at least one tensor type
106 *
107 * @returns extents of the (left and if necessary then right) child expression if it is a tensor or extents of a child of its (left and if necessary then right) child.
108 */
109 template<class T, class EL, class ER, class OP>
retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const & expr)110 auto retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& expr)
111 {
112 static_assert(detail::has_tensor_types<T,binary_tensor_expression<T,EL,ER,OP>>::value,
113 "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
114
115 if constexpr ( std::is_same<T,EL>::value )
116 return expr.el.extents();
117
118 if constexpr ( std::is_same<T,ER>::value )
119 return expr.er.extents();
120
121 else if constexpr ( detail::has_tensor_types<T,EL>::value )
122 return retrieve_extents(expr.el);
123
124 else if constexpr ( detail::has_tensor_types<T,ER>::value )
125 return retrieve_extents(expr.er);
126 }
127
128 /** @brief Retrieves extents of the binary tensor expression
129 *
130 * @note tensor expression must be a binary tree with at least one tensor type
131 *
132 * @returns extents of the child expression if it is a tensor or extents of a child of its child.
133 */
134 template<class T, class E, class OP>
retrieve_extents(unary_tensor_expression<T,E,OP> const & expr)135 auto retrieve_extents(unary_tensor_expression<T,E,OP> const& expr)
136 {
137
138 static_assert(detail::has_tensor_types<T,unary_tensor_expression<T,E,OP>>::value,
139 "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
140
141 if constexpr ( std::is_same<T,E>::value )
142 return expr.e.extents();
143
144 else if constexpr ( detail::has_tensor_types<T,E>::value )
145 return retrieve_extents(expr.e);
146 }
147
148 } // namespace boost::numeric::ublas::detail
149
150
151 ///////////////
152
153 namespace boost::numeric::ublas::detail {
154
155 template<class T, class F, class A, class S>
all_extents_equal(tensor<T,F,A> const & t,basic_extents<S> const & extents)156 auto all_extents_equal(tensor<T,F,A> const& t, basic_extents<S> const& extents)
157 {
158 return extents == t.extents();
159 }
160
161 template<class T, class D, class S>
all_extents_equal(tensor_expression<T,D> const & expr,basic_extents<S> const & extents)162 auto all_extents_equal(tensor_expression<T,D> const& expr, basic_extents<S> const& extents)
163 {
164 static_assert(detail::has_tensor_types<T,tensor_expression<T,D>>::value,
165 "Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
166 auto const& cast_expr = static_cast<D const&>(expr);
167
168
169 if constexpr ( std::is_same<T,D>::value )
170 if( extents != cast_expr.extents() )
171 return false;
172
173 if constexpr ( detail::has_tensor_types<T,D>::value )
174 if ( !all_extents_equal(cast_expr, extents))
175 return false;
176
177 return true;
178
179 }
180
181 template<class T, class EL, class ER, class OP, class S>
all_extents_equal(binary_tensor_expression<T,EL,ER,OP> const & expr,basic_extents<S> const & extents)182 auto all_extents_equal(binary_tensor_expression<T,EL,ER,OP> const& expr, basic_extents<S> const& extents)
183 {
184 static_assert(detail::has_tensor_types<T,binary_tensor_expression<T,EL,ER,OP>>::value,
185 "Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
186
187 if constexpr ( std::is_same<T,EL>::value )
188 if(extents != expr.el.extents())
189 return false;
190
191 if constexpr ( std::is_same<T,ER>::value )
192 if(extents != expr.er.extents())
193 return false;
194
195 if constexpr ( detail::has_tensor_types<T,EL>::value )
196 if(!all_extents_equal(expr.el, extents))
197 return false;
198
199 if constexpr ( detail::has_tensor_types<T,ER>::value )
200 if(!all_extents_equal(expr.er, extents))
201 return false;
202
203 return true;
204 }
205
206
207 template<class T, class E, class OP, class S>
all_extents_equal(unary_tensor_expression<T,E,OP> const & expr,basic_extents<S> const & extents)208 auto all_extents_equal(unary_tensor_expression<T,E,OP> const& expr, basic_extents<S> const& extents)
209 {
210
211 static_assert(detail::has_tensor_types<T,unary_tensor_expression<T,E,OP>>::value,
212 "Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
213
214 if constexpr ( std::is_same<T,E>::value )
215 if(extents != expr.e.extents())
216 return false;
217
218 if constexpr ( detail::has_tensor_types<T,E>::value )
219 if(!all_extents_equal(expr.e, extents))
220 return false;
221
222 return true;
223 }
224
225 } // namespace boost::numeric::ublas::detail
226
227
228 namespace boost::numeric::ublas::detail {
229
230
231 /** @brief Evaluates expression for a tensor
232 *
233 * Assigns the results of the expression to the tensor.
234 *
235 * \note Checks if shape of the tensor matches those of all tensors within the expression.
236 */
237 template<class tensor_type, class derived_type>
eval(tensor_type & lhs,tensor_expression<tensor_type,derived_type> const & expr)238 void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type> const& expr)
239 {
240 if constexpr (detail::has_tensor_types<tensor_type, tensor_expression<tensor_type,derived_type> >::value )
241 if(!detail::all_extents_equal(expr, lhs.extents() ))
242 throw std::runtime_error("Error in boost::numeric::ublas::tensor: expression contains tensors with different shapes.");
243
244 #pragma omp parallel for
245 for(auto i = 0u; i < lhs.size(); ++i)
246 lhs(i) = expr()(i);
247 }
248
249 /** @brief Evaluates expression for a tensor
250 *
251 * Applies a unary function to the results of the expressions before the assignment.
252 * Usually applied needed for unary operators such as A += C;
253 *
254 * \note Checks if shape of the tensor matches those of all tensors within the expression.
255 */
256 template<class tensor_type, class derived_type, class unary_fn>
eval(tensor_type & lhs,tensor_expression<tensor_type,derived_type> const & expr,unary_fn const fn)257 void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type> const& expr, unary_fn const fn)
258 {
259
260 if constexpr (detail::has_tensor_types< tensor_type, tensor_expression<tensor_type,derived_type> >::value )
261 if(!detail::all_extents_equal( expr, lhs.extents() ))
262 throw std::runtime_error("Error in boost::numeric::ublas::tensor: expression contains tensors with different shapes.");
263
264 #pragma omp parallel for
265 for(auto i = 0u; i < lhs.size(); ++i)
266 fn(lhs(i), expr()(i));
267 }
268
269
270
271 /** @brief Evaluates expression for a tensor
272 *
273 * Applies a unary function to the results of the expressions before the assignment.
274 * Usually applied needed for unary operators such as A += C;
275 *
276 * \note Checks if shape of the tensor matches those of all tensors within the expression.
277 */
278 template<class tensor_type, class unary_fn>
eval(tensor_type & lhs,unary_fn const fn)279 void eval(tensor_type& lhs, unary_fn const fn)
280 {
281 #pragma omp parallel for
282 for(auto i = 0u; i < lhs.size(); ++i)
283 fn(lhs(i));
284 }
285
286
287 }
288 #endif
289