1 /***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht          *
3 * Copyright (c) QuantStack                                                 *
4 *                                                                          *
5 * Distributed under the terms of the BSD 3-Clause License.                 *
6 *                                                                          *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9 
10 #ifndef XTENSOR_REDUCER_HPP
11 #define XTENSOR_REDUCER_HPP
12 
13 #include <algorithm>
14 #include <cstddef>
15 #include <initializer_list>
16 #include <iterator>
17 #include <stdexcept>
18 #include <tuple>
19 #include <type_traits>
20 #include <utility>
21 
22 #include <xtl/xfunctional.hpp>
23 #include <xtl/xsequence.hpp>
24 
25 #include "xaccessible.hpp"
26 #include "xbuilder.hpp"
27 #include "xeval.hpp"
28 #include "xexpression.hpp"
29 #include "xgenerator.hpp"
30 #include "xiterable.hpp"
31 #include "xtensor_config.hpp"
32 #include "xutils.hpp"
33 
34 namespace xt
35 {
36     template <template <class...> class A, class... AX, class X,
37               XTL_REQUIRES(is_evaluation_strategy<AX>..., is_evaluation_strategy<X>)>
operator |(const A<AX...> & args,const A<X> & rhs)38     auto operator|(const A<AX...>& args, const A<X>& rhs)
39     {
40         return std::tuple_cat(args, rhs);
41     }
42 
43     struct keep_dims_type : xt::detail::option_base {};
44     constexpr auto keep_dims = std::tuple<keep_dims_type>{};
45 
46     template <class T = double>
47     struct xinitial : xt::detail::option_base
48     {
xinitialxt::xinitial49         constexpr xinitial(T val)
50             : m_val(val)
51         {
52         }
53 
valuext::xinitial54         constexpr T value() const { return m_val; }
55         T m_val;
56     };
57 
58     template <class T>
initial(T val)59     constexpr auto initial(T val)
60     {
61         return std::make_tuple(xinitial<T>(val));
62     }
63 
64     template <std::ptrdiff_t I, class T, class Tuple>
65     struct tuple_idx_of_impl;
66 
67     template <std::ptrdiff_t I, class T>
68     struct tuple_idx_of_impl<I, T, std::tuple<>>
69     {
70         static constexpr std::ptrdiff_t value = -1;
71     };
72 
73     template <std::ptrdiff_t I, class T, class... Types>
74     struct tuple_idx_of_impl<I, T, std::tuple<T, Types...>>
75     {
76         static constexpr std::ptrdiff_t value = I;
77     };
78 
79     template <std::ptrdiff_t I, class T, class U, class... Types>
80     struct tuple_idx_of_impl<I, T, std::tuple<U, Types...>>
81     {
82         static constexpr std::ptrdiff_t value = tuple_idx_of_impl<I + 1, T, std::tuple<Types...>>::value;
83     };
84 
85     template <class S, class... X>
86     struct decay_all;
87 
88     template <template <class...> class S, class... X>
89     struct decay_all<S<X...>>
90     {
91         using type = S<std::decay_t<X>...>;
92     };
93 
94     template <class T, class Tuple>
95     struct tuple_idx_of
96     {
97         static constexpr std::ptrdiff_t value = tuple_idx_of_impl<0, std::decay_t<T>, typename decay_all<Tuple>::type>::value;
98     };
99 
100     template <class R, class T>
101     struct reducer_options
102     {
103         template <class X>
104         struct initial_tester : std::false_type {};
105 
106         template <class X>
107         struct initial_tester<xinitial<X>> : std::true_type {};
108 
109         // Workaround for Apple because tuple_cat is buggy!
110         template <class X>
111         struct initial_tester<const xinitial<X>> : std::true_type {};
112 
113         using d_t = std::decay_t<T>;
114 
115         static constexpr std::size_t initial_val_idx = xtl::mpl::find_if<initial_tester, d_t>::value;
116         reducer_options() = default;
117 
reducer_optionsxt::reducer_options118         reducer_options(const T& tpl)
119         {
120             xtl::mpl::static_if<initial_val_idx != std::tuple_size<T>::value>([this, &tpl](auto no_compile) {
121                     // use no_compile to prevent compilation if initial_val_idx is out of bounds!
122                     this->initial_value = no_compile(std::get<initial_val_idx != std::tuple_size<T>::value ? initial_val_idx : 0>(tpl)).value();
123                 },
124                 [](auto /*np_compile*/){}
125             );
126         }
127 
128         using evaluation_strategy = std::conditional_t<tuple_idx_of<xt::evaluation_strategy::immediate_type, d_t>::value != -1,
129                                                                     xt::evaluation_strategy::immediate_type,
130                                                                     xt::evaluation_strategy::lazy_type>;
131 
132         using keep_dims = std::conditional_t<tuple_idx_of<xt::keep_dims_type, d_t>::value != -1,
133                                              std::true_type,
134                                              std::false_type>;
135 
136         constexpr static bool has_initial_value = initial_val_idx != std::tuple_size<d_t>::value;
137 
138         R initial_value;
139 
140         template <class NR>
141         using rebind_t = reducer_options<NR, T>;
142 
143         template <class NR>
rebindxt::reducer_options144         auto rebind(NR initial, const reducer_options<R, T>&) const
145         {
146             reducer_options<NR, T> res;
147             res.initial_value = initial;
148             return res;
149         }
150     };
151 
152     template <class T>
153     struct is_reducer_options_impl : std::false_type
154     {
155     };
156 
157     template <class... X>
158     struct is_reducer_options_impl<std::tuple<X...>> : std::true_type
159     {
160     };
161 
162     template <class T>
163     struct is_reducer_options : is_reducer_options_impl<std::decay_t<T>>
164     {
165     };
166 
167     /**********
168      * reduce *
169      **********/
170 
171 #define DEFAULT_STRATEGY_REDUCERS std::tuple<evaluation_strategy::lazy_type>
172 
173     template <class ST, class X, class KD = std::false_type>
174     struct xreducer_shape_type;
175 
176     template <class S1, class S2>
177     struct fixed_xreducer_shape_type;
178 
179     namespace detail
180     {
181         template <class O, class RS, class R, class E, class AX>
shape_computation(RS & result_shape,R & result,E & expr,const AX & axes,std::enable_if_t<!detail::is_fixed<RS>::value,int>=0)182         inline void shape_computation(RS& result_shape, R& result, E& expr,
183                                       const AX& axes, std::enable_if_t<!detail::is_fixed<RS>::value, int> = 0)
184         {
185             if (typename O::keep_dims())
186             {
187                 resize_container(result_shape, expr.dimension());
188                 for (std::size_t i = 0; i < expr.dimension(); ++i)
189                 {
190                     if (std::find(axes.begin(), axes.end(), i) == axes.end())
191                     {
192                         // i not in axes!
193                         result_shape[i] = expr.shape()[i];
194                     }
195                     else
196                     {
197                         result_shape[i] = 1;
198                     }
199                 }
200             }
201             else
202             {
203                 resize_container(result_shape, expr.dimension() - axes.size());
204                 for (std::size_t i = 0, idx = 0; i < expr.dimension(); ++i)
205                 {
206                     if (std::find(axes.begin(), axes.end(), i) == axes.end())
207                     {
208                         // i not in axes!
209                         result_shape[idx] = expr.shape()[i];
210                         ++idx;
211                     }
212                 }
213             }
214             result.resize(result_shape, expr.layout());
215         }
216 
217         // skip shape computation if already done at compile time
218         template <class O, class RS, class R,class S, class AX>
shape_computation(RS &,R &,const S &,const AX &,std::enable_if_t<detail::is_fixed<RS>::value,int>=0)219         inline void shape_computation(RS&, R&, const S&,
220                                       const AX&, std::enable_if_t<detail::is_fixed<RS>::value, int> = 0)
221         {
222         }
223     }
224 
225     template <class F, class E, class R,
226               XTL_REQUIRES(std::is_convertible<typename E::value_type, typename R::value_type>)>
copy_to_reduced(F &,const E & e,R & result)227     inline void copy_to_reduced(F&, const E& e, R& result)
228     {
229         if (e.layout() == layout_type::row_major)
230         {
231             std::copy(e.template cbegin<layout_type::row_major>(),
232                       e.template cend<layout_type::row_major>(),
233                       result.data());
234         }
235         else
236         {
237             std::copy(e.template cbegin<layout_type::column_major>(),
238                       e.template cend<layout_type::column_major>(),
239                       result.data());
240         }
241     }
242 
243     template <class F, class E, class R,
244               XTL_REQUIRES(xtl::negation<std::is_convertible<typename E::value_type, typename R::value_type>>)>
copy_to_reduced(F & f,const E & e,R & result)245     inline void copy_to_reduced(F& f, const E& e, R& result)
246     {
247         if (e.layout() == layout_type::row_major)
248         {
249             std::transform(e.template cbegin<layout_type::row_major>(),
250                            e.template cend<layout_type::row_major>(),
251                            result.data(),
252                            f);
253         }
254         else
255         {
256             std::transform(e.template cbegin<layout_type::column_major>(),
257                            e.template cend<layout_type::column_major>(),
258                            result.data(),
259                            f);
260         }
261     }
262 
263     template <class F, class E, class X, class O>
reduce_immediate(F && f,E && e,X && axes,O && raw_options)264     inline auto reduce_immediate(F&& f, E&& e, X&& axes, O&& raw_options)
265     {
266         using reduce_functor_type = typename std::decay_t<F>::reduce_functor_type;
267         using init_functor_type = typename std::decay_t<F>::init_functor_type;
268         using expr_value_type = typename std::decay_t<E>::value_type;
269         using result_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(std::declval<init_functor_type>()(), std::declval<expr_value_type>()))>;
270 
271         using options_t = reducer_options<result_type, std::decay_t<O>>;
272         options_t options(raw_options);
273 
274         using shape_type = typename xreducer_shape_type<typename std::decay_t<E>::shape_type, std::decay_t<X>, typename options_t::keep_dims>::type;
275         using result_container_type = typename detail::xtype_for_shape<shape_type>::template type<result_type, std::decay_t<E>::static_layout>;
276         result_container_type result;
277 
278         // retrieve functors from triple struct
279         auto reduce_fct = xt::get<0>(f);
280         auto init_fct = xt::get<1>(f);
281         auto merge_fct = xt::get<2>(f);
282 
283         if (axes.size() == 0)
284         {
285             result.resize(e.shape(), e.layout());
286             auto cpf = [&reduce_fct, &init_fct](const auto& v) {return reduce_fct(static_cast<result_type>(init_fct()), v); };
287             copy_to_reduced(cpf, e, result);
288             return result;
289         }
290 
291         shape_type result_shape{};
292         dynamic_shape<std::size_t> iter_shape = xtl::forward_sequence<dynamic_shape<std::size_t>, decltype(e.shape())>(e.shape());
293         dynamic_shape<std::size_t> iter_strides(e.dimension());
294 
295         // using std::less_equal is counter-intuitive, but as the standard says (24.4.5):
296         // A sequence is sorted with respect to a comparator comp if for any iterator i pointing to the sequence and any non-negative integer n
297         // such that i + n is a valid iterator pointing to an element of the sequence, comp(*(i + n), *i) == false.
298         // Therefore less_equal is required to detect duplicates.
299         if (!std::is_sorted(axes.cbegin(), axes.cend(), std::less_equal<>()))
300         {
301             XTENSOR_THROW(std::runtime_error, "Reducing axes should be sorted and should not contain duplicates");
302         }
303         if (axes.size() != 0 && axes[axes.size() - 1] > e.dimension() - 1)
304         {
305             XTENSOR_THROW(std::runtime_error,
306                           "Axis " + std::to_string(axes[axes.size() - 1]) +
307                           " out of bounds for reduction.");
308         }
309 
310         detail::shape_computation<options_t>(result_shape, result, e, axes);
311 
312         // Fast track for complete reduction
313         if (e.dimension() == axes.size())
314         {
315             result_type tmp = options_t::has_initial_value ? options.initial_value : init_fct();
316             result.data()[0] = std::accumulate(e.storage().begin(), e.storage().end(), tmp, reduce_fct);
317             return result;
318         }
319 
320         std::size_t leading_ax = axes[(e.layout() == layout_type::row_major) ? axes.size() - 1 : 0];
321         auto strides_finder = e.strides().begin() + static_cast<std::ptrdiff_t>(leading_ax);
322         // The computed strides contain "0" where the shape is 1 -- therefore find the next none-zero number
323         std::size_t inner_stride = static_cast<std::size_t>(*strides_finder);
324         auto iter_bound = e.layout() == layout_type::row_major ? e.strides().begin() : (e.strides().end() - 1);
325         while (inner_stride == 0 && strides_finder != iter_bound)
326         {
327             (e.layout() == layout_type::row_major) ? --strides_finder : ++strides_finder;
328             inner_stride = static_cast<std::size_t>(*strides_finder);
329         }
330 
331         if (inner_stride == 0)
332         {
333             auto cpf = [&reduce_fct, &init_fct](const auto& v) {return reduce_fct(static_cast<result_type>(init_fct()), v); };
334             copy_to_reduced(cpf, e, result);
335             return result;
336         }
337 
338         std::size_t inner_loop_size = static_cast<std::size_t>(inner_stride);
339         std::size_t outer_loop_size = e.shape()[leading_ax];
340 
341         // The following code merges reduction axes "at the end" (or the beginning for col_major)
342         // together by increasing the size of the outer loop where appropriate
343         auto merge_loops = [&outer_loop_size, &e](auto it, auto end) {
344             auto last_ax = *it;
345             ++it;
346             for (; it != end; ++it)
347             {
348                 // note that we check is_sorted, so this condition is valid
349                 if (std::abs(std::ptrdiff_t(*it) - std::ptrdiff_t(last_ax)) == 1)
350                 {
351                     last_ax = *it;
352                     outer_loop_size *= e.shape()[last_ax];
353                 }
354             }
355             return last_ax;
356         };
357 
358         for (std::size_t i = 0, idx = 0; i < e.dimension(); ++i)
359         {
360             if (std::find(axes.begin(), axes.end(), i) == axes.end())
361             {
362                 // i not in axes!
363                 iter_strides[i] = static_cast<std::size_t>(result.strides()[typename options_t::keep_dims() ? i : idx]);
364                 ++idx;
365             }
366         }
367 
368         if (e.layout() == layout_type::row_major)
369         {
370             std::size_t last_ax = merge_loops(axes.rbegin(), axes.rend());
371 
372             iter_shape.erase(iter_shape.begin() + std::ptrdiff_t(last_ax), iter_shape.end());
373             iter_strides.erase(iter_strides.begin() + std::ptrdiff_t(last_ax), iter_strides.end());
374         }
375         else if (e.layout() == layout_type::column_major)
376         {
377             // we got column_major here
378             std::size_t last_ax = merge_loops(axes.begin(), axes.end());
379 
380             // erasing the front vs the back
381             iter_shape.erase(iter_shape.begin(), iter_shape.begin() + std::ptrdiff_t(last_ax + 1));
382             iter_strides.erase(iter_strides.begin(), iter_strides.begin() + std::ptrdiff_t(last_ax + 1));
383 
384             // and reversing, to make it work with the same next_idx function
385             std::reverse(iter_shape.begin(), iter_shape.end());
386             std::reverse(iter_strides.begin(), iter_strides.end());
387         }
388         else
389         {
390             XTENSOR_THROW(std::runtime_error, "Layout not supported in immediate reduction.");
391         }
392 
393         xindex temp_idx(iter_shape.size());
394         auto next_idx = [&iter_shape, &iter_strides, &temp_idx]() {
395             std::size_t i = iter_shape.size();
396             for (; i > 0; --i)
397             {
398                 if (std::ptrdiff_t(temp_idx[i - 1]) >= std::ptrdiff_t(iter_shape[i - 1]) - 1)
399                 {
400                     temp_idx[i - 1] = 0;
401                 }
402                 else
403                 {
404                     temp_idx[i - 1]++;
405                     break;
406                 }
407             }
408 
409             return std::make_pair(i == 0,
410                                   std::inner_product(temp_idx.begin(), temp_idx.end(),
411                                                      iter_strides.begin(), std::ptrdiff_t(0)));
412         };
413 
414         auto begin = e.data();
415         auto out = result.data();
416         auto out_begin = result.data();
417 
418         std::ptrdiff_t next_stride = 0;
419 
420         std::pair<bool, std::ptrdiff_t> idx_res(false, 0);
421 
422         // Remark: eventually some modifications here to make conditions faster where merge + accumulate is the
423         // same function (e.g. check std::is_same<decltype(merge_fct), decltype(reduce_fct)>::value) ...
424 
425         auto merge_border = out;
426         bool merge = false;
427 
428         // TODO there could be some performance gain by removing merge checking
429         //      when axes.size() == 1 and even next_idx could be removed for something simpler (next_stride always the same)
430         //      best way to do this would be to create a function that takes (begin, out, outer_loop_size, inner_loop_size, next_idx_lambda)
431         // Decide if going about it row-wise or col-wise
432         if (inner_stride == 1)
433         {
434             while (idx_res.first != true)
435             {
436                 // for unknown reasons it's much faster to use a temporary variable and
437                 // std::accumulate here -- probably some cache behavior
438                 result_type tmp = init_fct();
439                 tmp = std::accumulate(begin , begin + outer_loop_size, tmp, reduce_fct);
440 
441                 // use merge function if necessary
442                 *out = merge ? merge_fct(*out, tmp) : tmp;
443 
444                 begin += outer_loop_size;
445 
446                 idx_res = next_idx();
447                 next_stride = idx_res.second;
448                 out = out_begin + next_stride;
449 
450                 if (out > merge_border)
451                 {
452                     // looped over once
453                     merge = false;
454                     merge_border = out;
455                 }
456                 else
457                 {
458                     merge = true;
459                 }
460             };
461         }
462         else
463         {
464             while (idx_res.first != true)
465             {
466                 std::transform(out, out + inner_loop_size, begin, out,
467                                [merge, &init_fct, &reduce_fct](auto&& v1, auto&& v2) {
468                                     return merge ?
469                                         reduce_fct(v1, v2) :
470                                         // cast because return type of identity function is not upcasted
471                                         reduce_fct(static_cast<result_type>(init_fct()), v2);
472                                });
473 
474                 begin += inner_stride;
475                 for (std::size_t i = 1; i < outer_loop_size; ++i)
476                 {
477                     std::transform(out, out + inner_loop_size, begin, out, reduce_fct);
478                     begin += inner_stride;
479                 }
480 
481                 idx_res = next_idx();
482                 next_stride = idx_res.second;
483                 out = out_begin + next_stride;
484 
485                 if (out > merge_border)
486                 {
487                     // looped over once
488                     merge = false;
489                     merge_border = out;
490                 }
491                 else
492                 {
493                     merge = true;
494                 }
495             };
496         }
497         if (options_t::has_initial_value)
498         {
499             std::transform(result.data(), result.data() + result.size(), result.data(),
500                            [&merge_fct, &options](auto&& v) { return merge_fct(v, options.initial_value); });
501         }
502         return result;
503     }
504 
505 
506     /*********************
507      * xreducer functors *
508      *********************/
509 
510     template <class T>
511     struct const_value
512     {
513         using value_type = T;
514 
515         constexpr const_value() = default;
516 
const_valuext::const_value517         constexpr const_value(T t)
518             : m_value(t)
519         {
520         }
521 
operator ()xt::const_value522         constexpr T operator()() const
523         {
524             return m_value;
525         }
526 
527         template <class NT>
528         using rebind_t = const_value<NT>;
529 
530         template <class NT>
531         const_value<NT> rebind() const;
532 
533         T m_value;
534     };
535 
536     namespace detail
537     {
538         template <class T, bool B>
539         struct evaluated_value_type
540         {
541             using type = T;
542         };
543 
544         template <class T>
545         struct evaluated_value_type<T, true>
546         {
547             using type = typename std::decay_t<decltype(xt::eval(std::declval<T>()))>;
548         };
549 
550         template <class T, bool B>
551         using evaluated_value_type_t = typename evaluated_value_type<T, B>::type;
552     }
553 
554     template <class REDUCE_FUNC, class INIT_FUNC = const_value<long int>, class MERGE_FUNC = REDUCE_FUNC>
555     struct xreducer_functors
556         : public std::tuple<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>
557     {
558         using self_type = xreducer_functors<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>;
559         using base_type = std::tuple<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>;
560         using reduce_functor_type = REDUCE_FUNC;
561         using init_functor_type = INIT_FUNC;
562         using merge_functor_type = MERGE_FUNC;
563         using init_value_type = typename init_functor_type::value_type;
564 
xreducer_functorsxt::xreducer_functors565         xreducer_functors()
566             : base_type()
567         {
568         }
569 
570         template <class RF>
xreducer_functorsxt::xreducer_functors571         xreducer_functors(RF&& reduce_func)
572             : base_type(std::forward<RF>(reduce_func), INIT_FUNC(), reduce_func)
573         {
574         }
575 
576         template <class RF, class IF>
xreducer_functorsxt::xreducer_functors577         xreducer_functors(RF&& reduce_func, IF&& init_func)
578             : base_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func), reduce_func)
579         {
580         }
581 
582         template <class RF, class IF, class MF>
xreducer_functorsxt::xreducer_functors583         xreducer_functors(RF&& reduce_func, IF&& init_func, MF&& merge_func)
584             : base_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func), std::forward<MF>(merge_func))
585         {
586         }
587 
get_reducext::xreducer_functors588         reduce_functor_type get_reduce() const
589         {
590             return std::get<0>(upcast());
591         }
592 
get_initxt::xreducer_functors593         init_functor_type get_init() const
594         {
595             return std::get<1>(upcast());
596         }
597 
get_mergext::xreducer_functors598         merge_functor_type get_merge() const
599         {
600             return std::get<2>(upcast());
601         }
602 
603         template<class NT>
604         using rebind_t = xreducer_functors<REDUCE_FUNC, const_value<NT>, MERGE_FUNC>;
605 
606         template<class NT>
rebindxt::xreducer_functors607         rebind_t<NT> rebind()
608         {
609             return make_xreducer_functor(get_reduce(), get_init().template rebind<NT>(), get_merge());
610         }
611 
612     private:
613 
614         // Workaround for clang-cl
upcastxt::xreducer_functors615         const base_type& upcast() const
616         {
617             return static_cast<const base_type&>(*this);
618         }
619     };
620 
621     template <class RF>
make_xreducer_functor(RF && reduce_func)622     auto make_xreducer_functor(RF&& reduce_func)
623     {
624         using reducer_type = xreducer_functors<std::remove_reference_t<RF>>;
625         return reducer_type(std::forward<RF>(reduce_func));
626     }
627 
628     template <class RF, class IF>
make_xreducer_functor(RF && reduce_func,IF && init_func)629     auto make_xreducer_functor(RF&& reduce_func, IF&& init_func)
630     {
631         using reducer_type = xreducer_functors<std::remove_reference_t<RF>, std::remove_reference_t<IF>>;
632         return reducer_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func));
633     }
634 
635     template <class RF, class IF, class MF>
make_xreducer_functor(RF && reduce_func,IF && init_func,MF && merge_func)636     auto make_xreducer_functor(RF&& reduce_func, IF&& init_func, MF&& merge_func)
637     {
638         using reducer_type = xreducer_functors<std::remove_reference_t<RF>,
639                                                std::remove_reference_t<IF>,
640                                                std::remove_reference_t<MF>>;
641         return reducer_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func), std::forward<MF>(merge_func));
642     }
643 
644     /**********************
645      * xreducer extension *
646      **********************/
647 
648     namespace extension
649     {
650         template <class Tag, class F, class CT, class X, class O>
651         struct xreducer_base_impl;
652 
653         template <class F, class CT, class X, class O>
654         struct xreducer_base_impl<xtensor_expression_tag, F, CT, X, O>
655         {
656             using type = xtensor_empty_base;
657         };
658 
659         template <class F, class CT, class X, class O>
660         struct xreducer_base
661             : xreducer_base_impl<xexpression_tag_t<CT>, F, CT, X, O>
662         {
663         };
664 
665         template <class F, class CT, class X, class O>
666         using xreducer_base_t = typename xreducer_base<F, CT, X, O>::type;
667     }
668 
669     /************
670      * xreducer *
671      ************/
672 
673     template <class F, class CT, class X, class O>
674     class xreducer;
675 
676     template <class F, class CT, class X, class O>
677     class xreducer_stepper;
678 
679     template <class F, class CT, class X, class O>
680     struct xiterable_inner_types<xreducer<F, CT, X, O>>
681     {
682         using xexpression_type = std::decay_t<CT>;
683         using inner_shape_type = typename xreducer_shape_type<typename xexpression_type::shape_type, std::decay_t<X>, typename O::keep_dims>::type;
684         using const_stepper = xreducer_stepper<F, CT, X, O>;
685         using stepper = const_stepper;
686     };
687 
688     template <class F, class CT, class X, class O>
689     struct xcontainer_inner_types<xreducer<F, CT, X, O>>
690     {
691         using xexpression_type = std::decay_t<CT>;
692         using reduce_functor_type = typename std::decay_t<F>::reduce_functor_type;
693         using init_functor_type = typename std::decay_t<F>::init_functor_type;
694         using merge_functor_type = typename std::decay_t<F>::merge_functor_type;
695         using substepper_type = typename xexpression_type::const_stepper;
696         using raw_value_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
697                                                             std::declval<init_functor_type>()(),
698                                                             *std::declval<substepper_type>())
699                                                     )>;
700         using value_type = typename detail::evaluated_value_type_t<raw_value_type, is_xexpression<raw_value_type>::value>;
701 
702         using reference = value_type;
703         using const_reference = value_type;
704         using size_type = typename xexpression_type::size_type;
705     };
706 
707     template <class T>
708     struct select_dim_mapping_type
709     {
710         using type = T;
711     };
712 
713     template <std::size_t... I>
714     struct select_dim_mapping_type<fixed_shape<I...>>
715     {
716         using type = std::array<std::size_t, sizeof...(I)>;
717     };
718 
719     /**
720      * @class xreducer
721      * @brief Reducing function operating over specified axes.
722      *
723      * The xreducer class implements an \ref xexpression applying
724      * a reducing function to an \ref xexpression over the specified
725      * axes.
726      *
727      * @tparam F a tuple of functors (class \ref xreducer_functors or compatible)
728      * @tparam CT the closure type of the \ref xexpression to reduce
729      * @tparam X the list of axes
730      *
731      * The reducer's result_type is deduced from the result type of function
732      * <tt>F::reduce_functor_type</tt> when called with elements of the expression @tparam CT.
733      *
734      * @sa reduce
735      */
736     template <class F, class CT, class X, class O>
737     class xreducer : public xsharable_expression<xreducer<F, CT, X, O>>,
738                      public xconst_iterable<xreducer<F, CT, X, O>>,
739                      public xaccessible<xreducer<F, CT, X, O>>,
740                      public extension::xreducer_base_t<F, CT, X, O>
741     {
742     public:
743 
744         using self_type = xreducer<F, CT, X, O>;
745         using inner_types = xcontainer_inner_types<self_type>;
746 
747         using reduce_functor_type = typename inner_types::reduce_functor_type;
748         using init_functor_type = typename inner_types::init_functor_type;
749         using merge_functor_type = typename inner_types::merge_functor_type;
750         using xreducer_functors_type = xreducer_functors<reduce_functor_type, init_functor_type, merge_functor_type>;
751 
752         using xexpression_type = typename inner_types::xexpression_type;
753         using axes_type = X;
754 
755         using extension_base = extension::xreducer_base_t<F, CT, X, O>;
756         using expression_tag = typename extension_base::expression_tag;
757 
758         using substepper_type = typename inner_types::substepper_type;
759         using value_type = typename inner_types::value_type;
760         using reference = typename inner_types::reference;
761         using const_reference = typename inner_types::const_reference;
762         using pointer = value_type*;
763         using const_pointer = const value_type*;
764 
765         using size_type = typename inner_types::size_type;
766         using difference_type = typename xexpression_type::difference_type;
767 
768         using iterable_base = xconst_iterable<self_type>;
769         using inner_shape_type = typename iterable_base::inner_shape_type;
770         using shape_type = inner_shape_type;
771 
772         using dim_mapping_type = typename select_dim_mapping_type<inner_shape_type>::type;
773 
774         using stepper = typename iterable_base::stepper;
775         using const_stepper = typename iterable_base::const_stepper;
776         using bool_load_type = typename xexpression_type::bool_load_type;
777 
778         static constexpr layout_type static_layout = layout_type::dynamic;
779         static constexpr bool contiguous_layout = false;
780 
781         template <class Func, class CTA, class AX, class OX>
782         xreducer(Func&& func, CTA&& e, AX&& axes, OX&& options);
783 
784         const inner_shape_type& shape() const noexcept;
785         layout_type layout() const noexcept;
786         bool is_contiguous() const noexcept;
787 
788         template <class... Args>
789         const_reference operator()(Args... args) const;
790         template <class... Args>
791         const_reference unchecked(Args... args) const;
792 
793         template <class It>
794         const_reference element(It first, It last) const;
795 
796         const xexpression_type& expression() const noexcept;
797 
798         template <class S>
799         bool broadcast_shape(S& shape, bool reuse_cache = false) const;
800 
801         template <class S>
802         bool has_linear_assign(const S& strides) const noexcept;
803 
804         template <class S>
805         const_stepper stepper_begin(const S& shape) const noexcept;
806         template <class S>
807         const_stepper stepper_end(const S& shape, layout_type) const noexcept;
808 
809         template <class E, class Func = F, class Opts = O>
810         using rebind_t = xreducer<Func, E, X, Opts>;
811 
812         template <class E>
813         rebind_t<E> build_reducer(E&& e) const;
814 
815         template <class E, class Func, class Opts>
816         rebind_t<E, Func, Opts> build_reducer(E&& e, Func&& func, Opts&& opts) const;
817 
functors() const818         xreducer_functors_type functors() const
819         {
820             return xreducer_functors_type(m_reduce, m_init, m_merge);  // TODO: understand why make_xreducer_functor is throwing an error
821         }
822 
options() const823         const O& options() const
824         {
825             return m_options;
826         }
827     private:
828 
829         CT m_e;
830         reduce_functor_type m_reduce;
831         init_functor_type m_init;
832         merge_functor_type m_merge;
833         axes_type m_axes;
834         inner_shape_type m_shape;
835         dim_mapping_type m_dim_mapping;
836         O m_options;
837 
838         friend class xreducer_stepper<F, CT, X, O>;
839     };
840 
841     /*************************
842      * reduce implementation *
843      *************************/
844 
845     namespace detail
846     {
847         template <class F, class E, class X, class O>
reduce_impl(F && f,E && e,X && axes,evaluation_strategy::lazy_type,O && options)848         inline auto reduce_impl(F&& f, E&& e, X&& axes, evaluation_strategy::lazy_type, O&& options)
849         {
850             decltype(auto) normalized_axes = normalize_axis(e, std::forward<X>(axes));
851 
852             using reduce_functor_type = typename std::decay_t<F>::reduce_functor_type;
853             using init_functor_type = typename std::decay_t<F>::init_functor_type;
854             using value_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
855                 std::declval<init_functor_type>()(),
856                 *std::declval<typename std::decay_t<E>::const_stepper>()))>;
857             using evaluated_value_type = evaluated_value_type_t<value_type, is_xexpression<value_type>::value>;
858 
859             using reducer_type = xreducer<F, const_xclosure_t<E>, xtl::const_closure_type_t<decltype(normalized_axes)>, reducer_options<evaluated_value_type, std::decay_t<O>>>;
860             return reducer_type(std::forward<F>(f), std::forward<E>(e), std::forward<decltype(normalized_axes)>(normalized_axes), std::forward<O>(options));
861         }
862 
863 
864         template <class F, class E, class X, class O>
reduce_impl(F && f,E && e,X && axes,evaluation_strategy::immediate_type,O && options)865         inline auto reduce_impl(F&& f, E&& e, X&& axes, evaluation_strategy::immediate_type, O&& options)
866         {
867             decltype(auto) normalized_axes = normalize_axis(e, std::forward<X>(axes));
868             return reduce_immediate(std::forward<F>(f),
869                                     eval(std::forward<E>(e)),
870                                     std::forward<decltype(normalized_axes)>(normalized_axes),
871                                     std::forward<O>(options)
872             );
873         }
874     }
875 
876 #define DEFAULT_STRATEGY_REDUCERS std::tuple<evaluation_strategy::lazy_type>
877 
878     namespace detail
879     {
880         template <class T>
881         struct is_xreducer_functors_impl : std::false_type
882         {
883         };
884 
885         template <class RF, class IF, class MF>
886         struct is_xreducer_functors_impl<xreducer_functors<RF, IF, MF>>
887             : std::true_type
888         {
889         };
890 
891         template <class T>
892         using is_xreducer_functors = is_xreducer_functors_impl<std::decay_t<T>>;
893     }
894 
895     /**
896      * @brief Returns an \ref xexpression applying the specified reducing
897      * function to an expression over the given axes.
898      *
899      * @param f the reducing function to apply.
900      * @param e the \ref xexpression to reduce.
901      * @param axes the list of axes.
902      * @param options evaluation strategy to use (lazy (default), or immediate)
903      *
904      * The returned expression either hold a const reference to \p e or a copy
905      * depending on whether \p e is an lvalue or an rvalue.
906      */
907 
908     template <class F, class E, class X, class EVS = DEFAULT_STRATEGY_REDUCERS,
909               XTL_REQUIRES(xtl::negation<is_reducer_options<X>>,
910                            detail::is_xreducer_functors<F>)>
reduce(F && f,E && e,X && axes,EVS && options=EVS ())911     inline auto reduce(F&& f, E&& e, X&& axes, EVS&& options = EVS())
912     {
913 
914         return detail::reduce_impl(std::forward<F>(f),
915                                    std::forward<E>(e),
916                                    std::forward<X>(axes),
917                                    typename reducer_options<int, EVS>::evaluation_strategy{},
918                                    std::forward<EVS>(options)
919         );
920     }
921 
922     template <class F, class E, class X, class EVS = DEFAULT_STRATEGY_REDUCERS,
923               XTL_REQUIRES(xtl::negation<is_reducer_options<X>>,
924                            xtl::negation<detail::is_xreducer_functors<F>>)>
reduce(F && f,E && e,X && axes,EVS && options=EVS ())925     inline auto reduce(F&& f, E&& e, X&& axes, EVS&& options = EVS())
926     {
927         return reduce(make_xreducer_functor(std::forward<F>(f)), std::forward<E>(e),
928                       std::forward<X>(axes), std::forward<EVS>(options));
929     }
930 
931     template <class F, class E, class EVS = DEFAULT_STRATEGY_REDUCERS,
932               XTL_REQUIRES(is_reducer_options<EVS>,
933                            detail::is_xreducer_functors<F>)>
reduce(F && f,E && e,EVS && options=EVS ())934     inline auto reduce(F&& f, E&& e, EVS&& options = EVS())
935     {
936         xindex_type_t<typename std::decay_t<E>::shape_type> ar;
937         resize_container(ar, e.dimension());
938         std::iota(ar.begin(), ar.end(), 0);
939         return detail::reduce_impl(std::forward<F>(f),
940                                    std::forward<E>(e),
941                                    std::move(ar),
942                                    typename reducer_options<int, std::decay_t<EVS>>::evaluation_strategy{},
943                                    std::forward<EVS>(options)
944         );
945     }
946 
947     template <class F, class E, class EVS = DEFAULT_STRATEGY_REDUCERS,
948               XTL_REQUIRES(is_reducer_options<EVS>,
949                            xtl::negation<detail::is_xreducer_functors<F>>)>
reduce(F && f,E && e,EVS && options=EVS ())950     inline auto reduce(F&& f, E&& e, EVS&& options = EVS())
951     {
952         return reduce(make_xreducer_functor(std::forward<F>(f)), std::forward<E>(e), std::forward<EVS>(options));
953     }
954 
955     template <class F, class E, class I, std::size_t N, class EVS = DEFAULT_STRATEGY_REDUCERS,
956               XTL_REQUIRES(detail::is_xreducer_functors<F>)>
reduce(F && f,E && e,const I (& axes)[N],EVS options=EVS ())957     inline auto reduce(F&& f, E&& e, const I (&axes)[N], EVS options = EVS())
958     {
959         using axes_type = std::array<std::size_t, N>;
960         auto ax = xt::forward_normalize<axes_type>(e, axes);
961         return detail::reduce_impl(std::forward<F>(f), std::forward<E>(e), std::move(ax),
962                                    typename reducer_options<int, EVS>::evaluation_strategy{},
963                                    options);
964     }
965     template <class F, class E, class I, std::size_t N, class EVS = DEFAULT_STRATEGY_REDUCERS,
966               XTL_REQUIRES(xtl::negation<detail::is_xreducer_functors<F>>)>
reduce(F && f,E && e,const I (& axes)[N],EVS options=EVS ())967     inline auto reduce(F&& f, E&& e, const I (&axes)[N], EVS options = EVS())
968     {
969         return reduce(make_xreducer_functor(std::forward<F>(f)), std::forward<E>(e), axes, options);
970     }
971 
972     /********************
973      * xreducer_stepper *
974      ********************/
975 
976     template <class F, class CT, class X, class O>
977     class xreducer_stepper
978     {
979     public:
980 
981         using self_type = xreducer_stepper<F, CT, X, O>;
982         using xreducer_type = xreducer<F, CT, X, O>;
983 
984         using value_type = typename xreducer_type::value_type;
985         using reference = typename xreducer_type::value_type;
986         using pointer = typename xreducer_type::const_pointer;
987         using size_type = typename xreducer_type::size_type;
988         using difference_type = typename xreducer_type::difference_type;
989 
990         using xexpression_type = typename xreducer_type::xexpression_type;
991         using substepper_type = typename xexpression_type::const_stepper;
992         using shape_type = typename xreducer_type::shape_type;
993 
994         xreducer_stepper(const xreducer_type& red, size_type offset, bool end = false,
995                          layout_type l = default_assignable_layout(xexpression_type::static_layout));
996 
997         reference operator*() const;
998 
999         void step(size_type dim);
1000         void step_back(size_type dim);
1001         void step(size_type dim, size_type n);
1002         void step_back(size_type dim, size_type n);
1003         void reset(size_type dim);
1004         void reset_back(size_type dim);
1005 
1006         void to_begin();
1007         void to_end(layout_type l);
1008 
1009     private:
1010 
1011         reference initial_value() const;
1012         reference aggregate(size_type dim) const;
1013         reference aggregate_impl(size_type dim, /*keep_dims=*/ std::false_type) const;
1014         reference aggregate_impl(size_type dim, /*keep_dims=*/ std::true_type) const;
1015 
1016         substepper_type get_substepper_begin() const;
1017         size_type get_dim(size_type dim) const noexcept;
1018         size_type shape(size_type i) const noexcept;
1019         size_type axis(size_type i) const noexcept;
1020 
1021         const xreducer_type* m_reducer;
1022         size_type m_offset;
1023         mutable substepper_type m_stepper;
1024     };
1025 
1026     /******************
1027      * xreducer utils *
1028      ******************/
1029 
1030     namespace detail
1031     {
1032         template <std::size_t X, std::size_t... I>
1033         struct in
1034         {
1035             constexpr static bool value = xtl::disjunction<std::integral_constant<bool, X == I>...>::value;
1036         };
1037 
1038         template <std::size_t Z, class S1, class S2, class R>
1039         struct fixed_xreducer_shape_type_impl;
1040 
1041         template <std::size_t Z, std::size_t... I, std::size_t... J, std::size_t... R>
1042         struct fixed_xreducer_shape_type_impl<Z, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>
1043         {
1044             using type = std::conditional_t<in<Z, J...>::value,
1045                                             typename fixed_xreducer_shape_type_impl<Z - 1, fixed_shape<I...>, fixed_shape<J...>,
1046                                                                                     fixed_shape<R...>>::type,
1047                                             typename fixed_xreducer_shape_type_impl<Z - 1, fixed_shape<I...>, fixed_shape<J...>,
1048                                                                                     fixed_shape<detail::at<Z, I...>::value, R...>>::type>;
1049         };
1050 
1051         template <std::size_t... I, std::size_t... J, std::size_t... R>
1052         struct fixed_xreducer_shape_type_impl<0, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>
1053         {
1054             using type = std::conditional_t<in<0, J...>::value,
1055                                             fixed_shape<R...>,
1056                                             fixed_shape<detail::at<0, I...>::value, R...>>;
1057         };
1058 
1059         /***************************
1060          * helper for return types *
1061          ***************************/
1062 
1063         template <class T>
1064         struct xreducer_size_type
1065         {
1066             using type = std::size_t;
1067         };
1068 
1069         template <class T>
1070         using xreducer_size_type_t = typename xreducer_size_type<T>::type;
1071 
1072 
1073         template <class T>
1074         struct xreducer_temporary_type
1075         {
1076             using type = T;
1077         };
1078 
1079         template <class T>
1080         using xreducer_temporary_type_t = typename xreducer_temporary_type<T>::type;
1081 
1082         /********************************
1083          * Default const_value rebinder *
1084          ********************************/
1085 
1086         template <class T, class U>
1087         struct const_value_rebinder
1088         {
runxt::detail::const_value_rebinder1089             static const_value<U> run(const const_value<T>& t)
1090             {
1091                 return const_value<U>(t.m_value);
1092             }
1093         };
1094     }
1095 
1096     /*******************************************
1097      * Init functor const_value implementation *
1098      *******************************************/
1099 
1100     template <class T>
1101     template <class NT>
rebind() const1102     const_value<NT> const_value<T>::rebind() const
1103     {
1104         return detail::const_value_rebinder<T, NT>::run(*this);
1105     }
1106 
1107     /*****************************
1108      * fixed_xreducer_shape_type *
1109      *****************************/
1110 
1111     template <class S1, class S2>
1112     struct fixed_xreducer_shape_type;
1113 
1114     template <std::size_t... I, std::size_t... J>
1115     struct fixed_xreducer_shape_type<fixed_shape<I...>, fixed_shape<J...>>
1116     {
1117         using type = typename detail::fixed_xreducer_shape_type_impl<sizeof...(I) - 1,
1118                                                                      fixed_shape<I...>,
1119                                                                      fixed_shape<J...>,
1120                                                                      fixed_shape<>>::type;
1121     };
1122 
1123 
1124     // meta-function returning the shape type for an xreducer
1125     template <class ST, class X, class O>
1126     struct xreducer_shape_type
1127     {
1128         using type = promote_shape_t<ST, std::decay_t<X>>;
1129     };
1130 
1131     template <class I1, std::size_t N1, class I2, std::size_t N2>
1132     struct xreducer_shape_type<std::array<I1, N1>, std::array<I2, N2>, std::true_type>
1133     {
1134         using type = std::array<I2, N1>;
1135     };
1136 
1137     template <class I1, std::size_t N1, class I2, std::size_t N2>
1138     struct xreducer_shape_type<std::array<I1, N1>, std::array<I2, N2>, std::false_type>
1139     {
1140         using type = std::array<I2, N1 - N2>;
1141     };
1142 
1143     template <std::size_t... I, class I2, std::size_t N2>
1144     struct xreducer_shape_type<fixed_shape<I...>, std::array<I2, N2>, std::false_type>
1145     {
1146         using type = std::conditional_t<sizeof...(I) == N2,
1147                                         fixed_shape<>,
1148                                         std::array<I2, sizeof...(I) - N2>>;
1149     };
1150 
1151     namespace detail
1152     {
1153         template <class S1, class S2> struct ixconcat;
1154 
1155         template<class T, T... I1, T... I2>
1156         struct ixconcat<std::integer_sequence<T, I1...>, std::integer_sequence<T, I2...>>
1157         {
1158             using type = std::integer_sequence<T, I1..., I2...>;
1159         };
1160 
1161         template <class T, T X, std::size_t N>
1162         struct repeat_integer_sequence
1163         {
1164             using type = typename ixconcat<std::integer_sequence<T, X>, typename repeat_integer_sequence<T, X, N - 1>::type>::type;
1165         };
1166 
1167         template <class T, T X>
1168         struct repeat_integer_sequence<T, X, 0>
1169         {
1170             using type = std::integer_sequence<T>;
1171         };
1172 
1173         template <class T, T X>
1174         struct repeat_integer_sequence<T, X, 2>
1175         {
1176             using type = std::integer_sequence<T, X, X>;
1177         };
1178 
1179         template <class T, T X>
1180         struct repeat_integer_sequence<T, X, 1>
1181         {
1182             using type = std::integer_sequence<T, X>;
1183         };
1184     }
1185 
1186     template <std::size_t... I, class I2, std::size_t N2>
1187     struct xreducer_shape_type<fixed_shape<I...>, std::array<I2, N2>, std::true_type>
1188     {
1189         template <std::size_t... X>
get_typext::xreducer_shape_type1190         constexpr static auto get_type(std::index_sequence<X...>) {
1191             return fixed_shape<X...>{};
1192         }
1193 
1194         // if all axes reduced
1195         using type = std::conditional_t<sizeof...(I) == N2,
1196                                         decltype(get_type(typename detail::repeat_integer_sequence<std::size_t, std::size_t(1), N2>::type{})),
1197                                         std::array<I2, sizeof...(I)>>;
1198     };
1199 
1200     // Note adding "A" to prevent compilation in case nothing else matches
1201     template <std::size_t... I, std::size_t... J, class O>
1202     struct xreducer_shape_type<fixed_shape<I...>, fixed_shape<J...>, O>
1203     {
1204         using type = typename fixed_xreducer_shape_type<fixed_shape<I...>, fixed_shape<J...>>::type;
1205     };
1206 
1207     namespace detail
1208     {
1209         template <class S, class E, class X, class M>
shape_and_mapping_computation(S & shape,E & e,const X & axes,M & mapping,std::false_type)1210         inline void shape_and_mapping_computation(S& shape, E& e, const X& axes, M& mapping,
1211                                                   std::false_type)
1212         {
1213 
1214             auto first = e.shape().begin();
1215             auto last = e.shape().end();
1216             auto exclude_it = axes.begin();
1217 
1218             using value_type = typename S::value_type;
1219             using difference_type = typename S::difference_type;
1220             auto d_first = shape.begin();
1221             auto map_first = mapping.begin();
1222 
1223             auto iter = first;
1224             while (iter != last && exclude_it != axes.end())
1225             {
1226                 auto diff = std::distance(first, iter);
1227                 if (diff != difference_type(*exclude_it))
1228                 {
1229                     *d_first++ = *iter++;
1230                     *map_first++ = value_type(diff);
1231                 }
1232                 else
1233                 {
1234                     ++iter;
1235                     ++exclude_it;
1236                 }
1237             }
1238 
1239             auto diff = std::distance(first, iter);
1240             auto end  = std::distance(iter, last);
1241             std::iota(map_first, map_first + end, diff);
1242             std::copy(iter, last, d_first);
1243         }
1244 
1245         template <class S, class E, class X, class M>
shape_and_mapping_computation_keep_dim(S & shape,E & e,const X & axes,M & mapping,std::false_type)1246         inline void shape_and_mapping_computation_keep_dim(S& shape, E& e, const X& axes, M& mapping,
1247                                                            std::false_type)
1248         {
1249             for (std::size_t i = 0; i < e.dimension(); ++i)
1250             {
1251                 if (std::find(axes.cbegin(), axes.cend(), i) == axes.cend())
1252                 {
1253                     // i not in axes!
1254                     shape[i] = e.shape()[i];
1255                 }
1256                 else
1257                 {
1258                     shape[i] = 1;
1259                 }
1260             }
1261             std::iota(mapping.begin(), mapping.end(), 0);
1262         }
1263 
1264         template <class S, class E, class X, class M>
shape_and_mapping_computation(S &,E &,const X &,M &,std::true_type)1265         inline void shape_and_mapping_computation(S&, E&, const X&, M&, std::true_type)
1266         {
1267         }
1268 
1269 
1270         template <class S, class E, class X, class M>
shape_and_mapping_computation_keep_dim(S &,E &,const X &,M &,std::true_type)1271         inline void shape_and_mapping_computation_keep_dim(S&, E&, const X&, M&, std::true_type)
1272         {
1273         }
1274     }
1275 
1276     /***************************
1277      * xreducer implementation *
1278      ***************************/
1279 
1280     /**
1281      * @name Constructor
1282      */
1283     //@{
1284     /**
1285      * Constructs an xreducer expression applying the specified
1286      * function to the given expression over the given axes.
1287      *
1288      * @param func the function to apply
1289      * @param e the expression to reduce
1290      * @param axes the axes along which the reduction is performed
1291      */
1292     template <class F, class CT, class X, class O>
1293     template <class Func, class CTA, class AX, class OX>
xreducer(Func && func,CTA && e,AX && axes,OX && options)1294     inline xreducer<F, CT, X, O>::xreducer(Func&& func, CTA&& e, AX&& axes, OX&& options)
1295         : m_e(std::forward<CTA>(e))
1296         , m_reduce(xt::get<0>(func))
1297         , m_init(xt::get<1>(func))
1298         , m_merge(xt::get<2>(func))
1299         , m_axes(std::forward<AX>(axes))
1300         , m_shape(xtl::make_sequence<inner_shape_type>(typename O::keep_dims() ? m_e.dimension() : m_e.dimension() - m_axes.size(), 0))
1301         , m_dim_mapping(xtl::make_sequence<dim_mapping_type>(typename O::keep_dims() ? m_e.dimension() : m_e.dimension() - m_axes.size(), 0))
1302         , m_options(std::forward<OX>(options))
1303     {
1304         // using std::less_equal is counter-intuitive, but as the standard says (24.4.5):
1305         // A sequence is sorted with respect to a comparator comp if for any iterator i pointing to the sequence and any non-negative integer n
1306         // such that i + n is a valid iterator pointing to an element of the sequence, comp(*(i + n), *i) == false.
1307         // Therefore less_equal is required to detect duplicates.
1308         if (!std::is_sorted(m_axes.cbegin(), m_axes.cend(), std::less_equal<>()))
1309         {
1310             XTENSOR_THROW(std::runtime_error, "Reducing axes should be sorted and should not contain duplicates");
1311         }
1312         if (m_axes.size() != 0 && m_axes[m_axes.size() - 1] > m_e.dimension() - 1)
1313         {
1314             XTENSOR_THROW(std::runtime_error,
1315                           "Axis " + std::to_string(m_axes[m_axes.size() - 1]) +
1316                           " out of bounds for reduction.");
1317         }
1318 
1319         if (!typename O::keep_dims())
1320         {
1321             detail::shape_and_mapping_computation(m_shape, m_e, m_axes, m_dim_mapping, detail::is_fixed<shape_type>{});
1322         }
1323         else
1324         {
1325             detail::shape_and_mapping_computation_keep_dim(m_shape, m_e, m_axes, m_dim_mapping, detail::is_fixed<shape_type>{});
1326         }
1327     }
1328     //@}
1329 
1330     /**
1331      * @name Size and shape
1332      */
1333 
1334     /**
1335      * Returns the shape of the expression.
1336      */
1337     template <class F, class CT, class X, class O>
shape() const1338     inline auto xreducer<F, CT, X, O>::shape() const noexcept -> const inner_shape_type&
1339     {
1340         return m_shape;
1341     }
1342 
1343     /**
1344      * Returns the shape of the expression.
1345      */
1346     template <class F, class CT, class X, class O>
layout() const1347     inline layout_type xreducer<F, CT, X, O>::layout() const noexcept
1348     {
1349         return static_layout;
1350     }
1351 
1352     template <class F, class CT, class X, class O>
is_contiguous() const1353     inline bool xreducer<F, CT, X, O>::is_contiguous() const noexcept
1354     {
1355         return false;
1356     }
1357 
1358     //@}
1359 
1360     /**
1361      * @name Data
1362      */
1363     /**
1364      * Returns a constant reference to the element at the specified position in the reducer.
1365      * @param args a list of indices specifying the position in the reducer. Indices
1366      * must be unsigned integers, the number of indices should be equal or greater than
1367      * the number of dimensions of the reducer.
1368      */
1369     template <class F, class CT, class X, class O>
1370     template <class... Args>
operator ()(Args...args) const1371     inline auto xreducer<F, CT, X, O>::operator()(Args... args) const -> const_reference
1372     {
1373         XTENSOR_TRY(check_index(shape(), args...));
1374         XTENSOR_CHECK_DIMENSION(shape(), args...);
1375         std::array<std::size_t, sizeof...(Args)> arg_array = {{static_cast<std::size_t>(args)...}};
1376         return element(arg_array.cbegin(), arg_array.cend());
1377     }
1378 
1379     /**
1380      * Returns a constant reference to the element at the specified position in the reducer.
1381      * @param args a list of indices specifying the position in the reducer. Indices
1382      * must be unsigned integers, the number of indices must be equal to the number of
1383      * dimensions of the reducer, else the behavior is undefined.
1384      *
1385      * @warning This method is meant for performance, for expressions with a dynamic
1386      * number of dimensions (i.e. not known at compile time). Since it may have
1387      * undefined behavior (see parameters), operator() should be prefered whenever
1388      * it is possible.
1389      * @warning This method is NOT compatible with broadcasting, meaning the following
1390      * code has undefined behavior:
1391      * \code{.cpp}
1392      * xt::xarray<double> a = {{0, 1}, {2, 3}};
1393      * xt::xarray<double> b = {0, 1};
1394      * auto fd = a + b;
1395      * double res = fd.uncheked(0, 1);
1396      * \endcode
1397      */
1398     template <class F, class CT, class X, class O>
1399     template <class... Args>
unchecked(Args...args) const1400     inline auto xreducer<F, CT, X, O>::unchecked(Args... args) const -> const_reference
1401     {
1402         std::array<std::size_t, sizeof...(Args)> arg_array = { { static_cast<std::size_t>(args)... } };
1403         return element(arg_array.cbegin(), arg_array.cend());
1404     }
1405 
1406     /**
1407      * Returns a constant reference to the element at the specified position in the reducer.
1408      * @param first iterator starting the sequence of indices
1409      * @param last iterator ending the sequence of indices
1410      * The number of indices in the sequence should be equal to or greater
1411      * than the number of dimensions of the reducer.
1412      */
1413     template <class F, class CT, class X, class O>
1414     template <class It>
element(It first,It last) const1415     inline auto xreducer<F, CT, X, O>::element(It first, It last) const -> const_reference
1416     {
1417         XTENSOR_TRY(check_element_index(shape(), first, last));
1418         auto stepper = const_stepper(*this, 0);
1419         if (first != last)
1420         {
1421             size_type dim = 0;
1422             // drop left most elements
1423             auto size = std::ptrdiff_t(this->dimension()) - std::distance(first, last);
1424             auto begin = first - size;
1425             while (begin != last)
1426             {
1427                 if (begin < first)
1428                 {
1429                     stepper.step(dim++, std::size_t(0));
1430                     begin++;
1431                 }
1432                 else
1433                 {
1434                     stepper.step(dim++, std::size_t(*begin++));
1435                 }
1436             }
1437         }
1438         return *stepper;
1439     }
1440 
1441     /**
1442      * Returns a constant reference to the underlying expression of the reducer.
1443      */
1444     template <class F, class CT, class X, class O>
expression() const1445     inline auto xreducer<F, CT, X, O>::expression() const noexcept -> const xexpression_type&
1446     {
1447         return m_e;
1448     }
1449     //@}
1450 
1451     /**
1452      * @name Broadcasting
1453      */
1454     //@{
1455     /**
1456      * Broadcast the shape of the reducer to the specified parameter.
1457      * @param shape the result shape
1458      * @param reuse_cache parameter for internal optimization
1459      * @return a boolean indicating whether the broadcasting is trivial
1460      */
1461     template <class F, class CT, class X, class O>
1462     template <class S>
broadcast_shape(S & shape,bool) const1463     inline bool xreducer<F, CT, X, O>::broadcast_shape(S& shape, bool) const
1464     {
1465         return xt::broadcast_shape(m_shape, shape);
1466     }
1467 
1468     /**
1469     * Checks whether the xreducer can be linearly assigned to an expression
1470     * with the specified strides.
1471     * @return a boolean indicating whether a linear assign is possible
1472     */
1473     template <class F, class CT, class X, class O>
1474     template <class S>
has_linear_assign(const S &) const1475     inline bool xreducer<F, CT, X, O>::has_linear_assign(const S& /*strides*/) const noexcept
1476     {
1477         return false;
1478     }
1479     //@}
1480 
1481     template <class F, class CT, class X, class O>
1482     template <class S>
stepper_begin(const S & shape) const1483     inline auto xreducer<F, CT, X, O>::stepper_begin(const S& shape) const noexcept -> const_stepper
1484     {
1485         size_type offset = shape.size() - this->dimension();
1486         return const_stepper(*this, offset);
1487     }
1488 
1489     template <class F, class CT, class X, class O>
1490     template <class S>
stepper_end(const S & shape,layout_type l) const1491     inline auto xreducer<F, CT, X, O>::stepper_end(const S& shape, layout_type l) const noexcept -> const_stepper
1492     {
1493         size_type offset = shape.size() - this->dimension();
1494         return const_stepper(*this, offset, true, l);
1495     }
1496 
1497     template <class F, class CT, class X, class O>
1498     template <class E>
build_reducer(E && e) const1499     inline auto xreducer<F, CT, X, O>::build_reducer(E&& e) const -> rebind_t<E>
1500     {
1501         return rebind_t<E>(std::make_tuple(m_reduce, m_init, m_merge), std::forward<E>(e), axes_type(m_axes), m_options);
1502     }
1503 
1504     template <class F, class CT, class X, class O>
1505     template <class E, class Func, class Opts>
build_reducer(E && e,Func && func,Opts && opts) const1506     inline auto xreducer<F, CT, X, O>::build_reducer(E&& e, Func&& func, Opts&& opts) const -> rebind_t<E, Func, Opts>
1507     {
1508         return rebind_t<E, Func, Opts>(std::forward<Func>(func), std::forward<E>(e), axes_type(m_axes), std::forward<Opts>(opts));
1509     }
1510 
1511     /***********************************
1512      * xreducer_stepper implementation *
1513      ***********************************/
1514 
1515     template <class F, class CT, class X, class O>
xreducer_stepper(const xreducer_type & red,size_type offset,bool end,layout_type l)1516     inline xreducer_stepper<F, CT, X, O>::xreducer_stepper(const xreducer_type& red, size_type offset, bool end, layout_type l)
1517         : m_reducer(&red), m_offset(offset),
1518           m_stepper(get_substepper_begin())
1519     {
1520         if (end)
1521         {
1522             to_end(l);
1523         }
1524     }
1525 
1526     template <class F, class CT, class X, class O>
operator *() const1527     inline auto xreducer_stepper<F, CT, X, O>::operator*() const -> reference
1528     {
1529         reference r = aggregate(0);
1530         return r;
1531     }
1532 
1533     template <class F, class CT, class X, class O>
step(size_type dim)1534     inline void xreducer_stepper<F, CT, X, O>::step(size_type dim)
1535     {
1536         if (dim >= m_offset)
1537         {
1538             m_stepper.step(get_dim(dim - m_offset));
1539         }
1540     }
1541 
1542     template <class F, class CT, class X, class O>
step_back(size_type dim)1543     inline void xreducer_stepper<F, CT, X, O>::step_back(size_type dim)
1544     {
1545         if (dim >= m_offset)
1546         {
1547             m_stepper.step_back(get_dim(dim - m_offset));
1548         }
1549     }
1550 
1551     template <class F, class CT, class X, class O>
step(size_type dim,size_type n)1552     inline void xreducer_stepper<F, CT, X, O>::step(size_type dim, size_type n)
1553     {
1554         if (dim >= m_offset)
1555         {
1556             m_stepper.step(get_dim(dim - m_offset), n);
1557         }
1558     }
1559 
1560     template <class F, class CT, class X, class O>
step_back(size_type dim,size_type n)1561     inline void xreducer_stepper<F, CT, X, O>::step_back(size_type dim, size_type n)
1562     {
1563         if (dim >= m_offset)
1564         {
1565             m_stepper.step_back(get_dim(dim - m_offset), n);
1566         }
1567     }
1568 
1569     template <class F, class CT, class X, class O>
reset(size_type dim)1570     inline void xreducer_stepper<F, CT, X, O>::reset(size_type dim)
1571     {
1572         if (dim >= m_offset)
1573         {
1574             // Because the reducer uses `reset` to reset the non-reducing axes,
1575             // we need to prevent that here for the KD case where.
1576             if (typename O::keep_dims() && std::binary_search(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim - m_offset))
1577             {
1578                 // If keep dim activated, and dim is in the axes, do nothing!
1579                 return;
1580             }
1581             m_stepper.reset(get_dim(dim - m_offset));
1582         }
1583     }
1584 
1585     template <class F, class CT, class X, class O>
reset_back(size_type dim)1586     inline void xreducer_stepper<F, CT, X, O>::reset_back(size_type dim)
1587     {
1588         if (dim >= m_offset)
1589         {
1590             // Note that for *not* KD this is not going to do anything
1591             if (typename O::keep_dims() && std::binary_search(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim - m_offset))
1592             {
1593                 // If keep dim activated, and dim is in the axes, do nothing!
1594                 return;
1595             }
1596             m_stepper.reset_back(get_dim(dim - m_offset));
1597         }
1598     }
1599 
1600     template <class F, class CT, class X, class O>
to_begin()1601     inline void xreducer_stepper<F, CT, X, O>::to_begin()
1602     {
1603         m_stepper.to_begin();
1604     }
1605 
1606     template <class F, class CT, class X, class O>
to_end(layout_type l)1607     inline void xreducer_stepper<F, CT, X, O>::to_end(layout_type l)
1608     {
1609         m_stepper.to_end(l);
1610     }
1611 
1612     template <class F, class CT, class X, class O>
initial_value() const1613     inline auto xreducer_stepper<F, CT, X, O>::initial_value() const -> reference
1614     {
1615         return O::has_initial_value ? m_reducer->m_options.initial_value : static_cast<reference>(m_reducer->m_init());
1616     }
1617 
1618     template <class F, class CT, class X, class O>
aggregate(size_type dim) const1619     inline auto xreducer_stepper<F, CT, X, O>::aggregate(size_type dim) const -> reference
1620     {
1621         reference res;
1622         if (m_reducer->m_e.size() == size_type(0))
1623         {
1624             res = initial_value();
1625         }
1626         else if (m_reducer->m_e.shape().empty() || m_reducer->m_axes.size() == 0)
1627         {
1628             res = m_reducer->m_reduce(initial_value(), *m_stepper);
1629         }
1630         else
1631         {
1632             res = aggregate_impl(dim, typename O::keep_dims());
1633             if (O::has_initial_value && dim == 0)
1634             {
1635                 res = m_reducer->m_merge(m_reducer->m_options.initial_value, res);
1636             }
1637         }
1638         return res;
1639     }
1640 
1641     template <class F, class CT, class X, class O>
aggregate_impl(size_type dim,std::false_type) const1642     inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::false_type) const -> reference
1643     {
1644         // reference can be std::array, hence the {} initializer
1645         reference res = {};
1646         size_type index = axis(dim);
1647         size_type size = shape(index);
1648         if (dim != m_reducer->m_axes.size() - 1)
1649         {
1650             res = aggregate_impl(dim + 1, typename O::keep_dims());
1651             for (size_type i = 1; i != size; ++i)
1652             {
1653                 m_stepper.step(index);
1654                 res = m_reducer->m_merge(res, aggregate_impl(dim + 1, typename O::keep_dims()));
1655             }
1656         }
1657         else
1658         {
1659             res = static_cast<reference>(m_reducer->m_init());
1660             for (size_type i = 0; i != size; ++i, m_stepper.step(index))
1661             {
1662                 res = m_reducer->m_reduce(res, *m_stepper);
1663             }
1664             m_stepper.step_back(index);
1665         }
1666         m_stepper.reset(index);
1667         return res;
1668     }
1669 
1670     template <class F, class CT, class X, class O>
aggregate_impl(size_type dim,std::true_type) const1671     inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::true_type) const -> reference
1672     {
1673         // reference can be std::array, hence the {} initializer
1674         reference res = {};
1675         auto ax_it = std::find(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim);
1676         if (ax_it != m_reducer->m_axes.end())
1677         {
1678             size_type index = dim;
1679             size_type size = m_reducer->m_e.shape()[index];
1680             if (ax_it != m_reducer->m_axes.end() - 1 && size != 0)
1681             {
1682                 res = aggregate_impl(dim + 1, typename O::keep_dims());
1683                 for (size_type i = 1; i != size; ++i)
1684                 {
1685                     m_stepper.step(index);
1686                     res = m_reducer->m_merge(res, aggregate_impl(dim + 1, typename O::keep_dims()));
1687                 }
1688             }
1689             else
1690             {
1691                 res = m_reducer->m_init();
1692                 for (size_type i = 0; i != size; ++i, m_stepper.step(index))
1693                 {
1694                     res = m_reducer->m_reduce(res, *m_stepper);
1695                 }
1696                 m_stepper.step_back(index);
1697             }
1698             m_stepper.reset(index);
1699         }
1700         else
1701         {
1702             if  (dim < m_reducer->m_e.dimension())
1703             {
1704                 res = aggregate_impl(dim + 1, typename O::keep_dims());
1705             }
1706         }
1707         return res;
1708     }
1709 
1710 
1711     template <class F, class CT, class X, class O>
get_substepper_begin() const1712     inline auto xreducer_stepper<F, CT, X, O>::get_substepper_begin() const -> substepper_type
1713     {
1714         return m_reducer->m_e.stepper_begin(m_reducer->m_e.shape());
1715     }
1716 
1717     template <class F, class CT, class X, class O>
get_dim(size_type dim) const1718     inline auto xreducer_stepper<F, CT, X, O>::get_dim(size_type dim) const noexcept -> size_type
1719     {
1720         return m_reducer->m_dim_mapping[dim];
1721     }
1722 
1723     template <class F, class CT, class X, class O>
shape(size_type i) const1724     inline auto xreducer_stepper<F, CT, X, O>::shape(size_type i) const noexcept -> size_type
1725     {
1726         return m_reducer->m_e.shape()[i];
1727     }
1728 
1729     template <class F, class CT, class X, class O>
axis(size_type i) const1730     inline auto xreducer_stepper<F, CT, X, O>::axis(size_type i) const noexcept -> size_type
1731     {
1732         return m_reducer->m_axes[i];
1733     }
1734 }
1735 
1736 #endif
1737