1 //
2 //  Copyright (c) 2018-2019, Cem Bassoy, cem.bassoy@gmail.com
3 //
4 //  Distributed under the Boost Software License, Version 1.0. (See
5 //  accompanying file LICENSE_1_0.txt or copy at
6 //  http://www.boost.org/LICENSE_1_0.txt)
7 //
8 //  The authors gratefully acknowledge the support of
9 //  Fraunhofer IOSB, Ettlingen, Germany
10 //
11 
12 #ifndef _BOOST_UBLAS_TENSOR_EXPRESSIONS_EVALUATION_HPP_
13 #define _BOOST_UBLAS_TENSOR_EXPRESSIONS_EVALUATION_HPP_
14 
15 #include <type_traits>
16 #include <stdexcept>
17 
18 
19 namespace boost::numeric::ublas {
20 
21 template<class element_type, class storage_format, class storage_type>
22 class tensor;
23 
24 template<class size_type>
25 class basic_extents;
26 
27 }
28 
29 namespace boost::numeric::ublas::detail {
30 
31 template<class T, class D>
32 struct tensor_expression;
33 
34 template<class T, class EL, class ER, class OP>
35 struct binary_tensor_expression;
36 
37 template<class T, class E, class OP>
38 struct unary_tensor_expression;
39 
40 }
41 
42 namespace boost::numeric::ublas::detail {
43 
44 template<class T, class E>
45 struct has_tensor_types
46 { static constexpr bool value = false; };
47 
48 template<class T>
49 struct has_tensor_types<T,T>
50 { static constexpr bool value = true; };
51 
52 template<class T, class D>
53 struct has_tensor_types<T, tensor_expression<T,D>>
54 { static constexpr bool value = std::is_same<T,D>::value || has_tensor_types<T,D>::value; };
55 
56 
57 template<class T, class EL, class ER, class OP>
58 struct has_tensor_types<T, binary_tensor_expression<T,EL,ER,OP>>
59 { static constexpr bool value = std::is_same<T,EL>::value || std::is_same<T,ER>::value || has_tensor_types<T,EL>::value || has_tensor_types<T,ER>::value;  };
60 
61 template<class T, class E, class OP>
62 struct has_tensor_types<T, unary_tensor_expression<T,E,OP>>
63 { static constexpr bool value = std::is_same<T,E>::value || has_tensor_types<T,E>::value; };
64 
65 } // namespace boost::numeric::ublas::detail
66 
67 
68 namespace boost::numeric::ublas::detail {
69 
70 
71 
72 
73 
74 /** @brief Retrieves extents of the tensor
75  *
76 */
77 template<class T, class F, class A>
retrieve_extents(tensor<T,F,A> const & t)78 auto retrieve_extents(tensor<T,F,A> const& t)
79 {
80 	return t.extents();
81 }
82 
83 /** @brief Retrieves extents of the tensor expression
84  *
85  * @note tensor expression must be a binary tree with at least one tensor type
86  *
87  * @returns extents of the child expression if it is a tensor or extents of one child of its child.
88 */
89 template<class T, class D>
retrieve_extents(tensor_expression<T,D> const & expr)90 auto retrieve_extents(tensor_expression<T,D> const& expr)
91 {
92 	static_assert(detail::has_tensor_types<T,tensor_expression<T,D>>::value,
93 	              "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
94 
95 	auto const& cast_expr = static_cast<D const&>(expr);
96 
97 	if constexpr ( std::is_same<T,D>::value )
98 	    return cast_expr.extents();
99 	else
100 	return retrieve_extents(cast_expr);
101 }
102 
103 /** @brief Retrieves extents of the binary tensor expression
104  *
105  * @note tensor expression must be a binary tree with at least one tensor type
106  *
107  * @returns extents of the (left and if necessary then right) child expression if it is a tensor or extents of a child of its (left and if necessary then right) child.
108 */
109 template<class T, class EL, class ER, class OP>
retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const & expr)110 auto retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& expr)
111 {
112 	static_assert(detail::has_tensor_types<T,binary_tensor_expression<T,EL,ER,OP>>::value,
113 	              "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
114 
115 	if constexpr ( std::is_same<T,EL>::value )
116 	    return expr.el.extents();
117 
118 	if constexpr ( std::is_same<T,ER>::value )
119 	    return expr.er.extents();
120 
121 	else if constexpr ( detail::has_tensor_types<T,EL>::value )
122 	    return retrieve_extents(expr.el);
123 
124 	else if constexpr ( detail::has_tensor_types<T,ER>::value  )
125 	    return retrieve_extents(expr.er);
126 }
127 
128 /** @brief Retrieves extents of the binary tensor expression
129  *
130  * @note tensor expression must be a binary tree with at least one tensor type
131  *
132  * @returns extents of the child expression if it is a tensor or extents of a child of its child.
133 */
134 template<class T, class E, class OP>
retrieve_extents(unary_tensor_expression<T,E,OP> const & expr)135 auto retrieve_extents(unary_tensor_expression<T,E,OP> const& expr)
136 {
137 
138 	static_assert(detail::has_tensor_types<T,unary_tensor_expression<T,E,OP>>::value,
139 	              "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
140 
141 	if constexpr ( std::is_same<T,E>::value )
142 	    return expr.e.extents();
143 
144 	else if constexpr ( detail::has_tensor_types<T,E>::value  )
145 	    return retrieve_extents(expr.e);
146 }
147 
148 } // namespace boost::numeric::ublas::detail
149 
150 
151 ///////////////
152 
153 namespace boost::numeric::ublas::detail {
154 
155 template<class T, class F, class A, class S>
all_extents_equal(tensor<T,F,A> const & t,basic_extents<S> const & extents)156 auto all_extents_equal(tensor<T,F,A> const& t, basic_extents<S> const& extents)
157 {
158 	return extents == t.extents();
159 }
160 
161 template<class T, class D, class S>
all_extents_equal(tensor_expression<T,D> const & expr,basic_extents<S> const & extents)162 auto all_extents_equal(tensor_expression<T,D> const& expr, basic_extents<S> const& extents)
163 {
164 	static_assert(detail::has_tensor_types<T,tensor_expression<T,D>>::value,
165 	              "Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
166 	auto const& cast_expr = static_cast<D const&>(expr);
167 
168 
169 	if constexpr ( std::is_same<T,D>::value )
170 	    if( extents != cast_expr.extents() )
171 	    return false;
172 
173 	if constexpr ( detail::has_tensor_types<T,D>::value )
174 	    if ( !all_extents_equal(cast_expr, extents))
175 	    return false;
176 
177 	return true;
178 
179 }
180 
181 template<class T, class EL, class ER, class OP, class S>
all_extents_equal(binary_tensor_expression<T,EL,ER,OP> const & expr,basic_extents<S> const & extents)182 auto all_extents_equal(binary_tensor_expression<T,EL,ER,OP> const& expr, basic_extents<S> const& extents)
183 {
184 	static_assert(detail::has_tensor_types<T,binary_tensor_expression<T,EL,ER,OP>>::value,
185 	              "Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
186 
187 	if constexpr ( std::is_same<T,EL>::value )
188 	    if(extents !=  expr.el.extents())
189 	    return false;
190 
191 	if constexpr ( std::is_same<T,ER>::value )
192 	    if(extents != expr.er.extents())
193 	    return false;
194 
195 	if constexpr ( detail::has_tensor_types<T,EL>::value )
196 	    if(!all_extents_equal(expr.el, extents))
197 	    return false;
198 
199 	if constexpr ( detail::has_tensor_types<T,ER>::value )
200 	    if(!all_extents_equal(expr.er, extents))
201 	    return false;
202 
203 	return true;
204 }
205 
206 
207 template<class T, class E, class OP, class S>
all_extents_equal(unary_tensor_expression<T,E,OP> const & expr,basic_extents<S> const & extents)208 auto all_extents_equal(unary_tensor_expression<T,E,OP> const& expr, basic_extents<S> const& extents)
209 {
210 
211 	static_assert(detail::has_tensor_types<T,unary_tensor_expression<T,E,OP>>::value,
212 	              "Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
213 
214 	if constexpr ( std::is_same<T,E>::value )
215 	    if(extents != expr.e.extents())
216 	    return false;
217 
218 	if constexpr ( detail::has_tensor_types<T,E>::value )
219 	    if(!all_extents_equal(expr.e, extents))
220 	    return false;
221 
222 	return true;
223 }
224 
225 } // namespace boost::numeric::ublas::detail
226 
227 
228 namespace boost::numeric::ublas::detail {
229 
230 
231 /** @brief Evaluates expression for a tensor
232  *
233  * Assigns the results of the expression to the tensor.
234  *
235  * \note Checks if shape of the tensor matches those of all tensors within the expression.
236 */
237 template<class tensor_type, class derived_type>
eval(tensor_type & lhs,tensor_expression<tensor_type,derived_type> const & expr)238 void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type> const& expr)
239 {
240 	if constexpr (detail::has_tensor_types<tensor_type, tensor_expression<tensor_type,derived_type> >::value )
241 	    if(!detail::all_extents_equal(expr, lhs.extents() ))
242 	    throw std::runtime_error("Error in boost::numeric::ublas::tensor: expression contains tensors with different shapes.");
243 
244 #pragma omp parallel for
245 	for(auto i = 0u; i < lhs.size(); ++i)
246 		lhs(i) = expr()(i);
247 }
248 
249 /** @brief Evaluates expression for a tensor
250  *
251  * Applies a unary function to the results of the expressions before the assignment.
252  * Usually applied needed for unary operators such as A += C;
253  *
254  * \note Checks if shape of the tensor matches those of all tensors within the expression.
255 */
256 template<class tensor_type, class derived_type, class unary_fn>
eval(tensor_type & lhs,tensor_expression<tensor_type,derived_type> const & expr,unary_fn const fn)257 void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type> const& expr, unary_fn const fn)
258 {
259 
260 	if constexpr (detail::has_tensor_types< tensor_type, tensor_expression<tensor_type,derived_type> >::value )
261 	    if(!detail::all_extents_equal( expr, lhs.extents() ))
262 	    throw std::runtime_error("Error in boost::numeric::ublas::tensor: expression contains tensors with different shapes.");
263 
264 #pragma omp parallel for
265 	for(auto i = 0u; i < lhs.size(); ++i)
266 		fn(lhs(i), expr()(i));
267 }
268 
269 
270 
271 /** @brief Evaluates expression for a tensor
272  *
273  * Applies a unary function to the results of the expressions before the assignment.
274  * Usually applied needed for unary operators such as A += C;
275  *
276  * \note Checks if shape of the tensor matches those of all tensors within the expression.
277 */
278 template<class tensor_type, class unary_fn>
eval(tensor_type & lhs,unary_fn const fn)279 void eval(tensor_type& lhs, unary_fn const fn)
280 {
281 #pragma omp parallel for
282 	for(auto i = 0u; i < lhs.size(); ++i)
283 		fn(lhs(i));
284 }
285 
286 
287 }
288 #endif
289