1 /***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht          *
3 * Copyright (c) QuantStack                                                 *
4 *                                                                          *
5 * Distributed under the terms of the BSD 3-Clause License.                 *
6 *                                                                          *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
13 #include <algorithm>
14 #include <cstddef>
15 #include <numeric>
16 #include <tuple>
17 #include <type_traits>
18 #include <utility>
20 #include <xtl/xsequence.hpp>
22 #include "xaccessible.hpp"
23 #include "xexpression.hpp"
24 #include "xiterable.hpp"
25 #include "xstrides.hpp"
26 #include "xutils.hpp"
27 #include "xstrided_view.hpp"
29 namespace xt
30 {
32     /************************
33      * xgenerator extension *
34      ************************/
36     namespace extension
37     {
38         template <class Tag, class F, class R, class S>
39         struct xgenerator_base_impl;
41         template <class F, class R, class S>
42         struct xgenerator_base_impl<xtensor_expression_tag, F, R, S>
43         {
44             using type = xtensor_empty_base;
45         };
47         template <class F, class R, class S>
48         struct xgenerator_base : xgenerator_base_impl<xexpression_tag_t<R>, F, R, S>
49         {
50         };
52         template <class F, class R, class S>
53         using xgenerator_base_t = typename xgenerator_base<F, R, S>::type;
54     }
56     /**************
57      * xgenerator *
58      **************/
60     template <class F, class R, class S>
61     class xgenerator;
63     template <class C, class R, class S>
64     struct xiterable_inner_types<xgenerator<C, R, S>>
65     {
66         using inner_shape_type = S;
67         using const_stepper = xindexed_stepper<xgenerator<C, R, S>, true>;
68         using stepper = const_stepper;
69     };
71     template <class C, class R, class S>
72     struct xcontainer_inner_types<xgenerator<C, R, S>>
73     {
74         using reference = R;
75         using const_reference = R;
76         using size_type = std::size_t;
77     };
79     /**
80      * @class xgenerator
81      * @brief Multidimensional function operating on indices.
82      *
83      * The xgenerator class implements a multidimensional function,
84      * generating a value from the supplied indices.
85      *
86      * @tparam F the function type
87      * @tparam R the return type of the function
88      * @tparam S the shape type of the generator
89      */
90     template <class F, class R, class S>
91     class xgenerator : public xsharable_expression<xgenerator<F, R, S>>,
92                        public xconst_iterable<xgenerator<F, R, S>>,
93                        public xconst_accessible<xgenerator<F, R, S>>,
94                        public extension::xgenerator_base_t<F, R, S>
95     {
96     public:
98         using self_type = xgenerator<F, R, S>;
99         using functor_type = typename std::remove_reference<F>::type;
101         using accessible_base = xconst_accessible<self_type>;
102         using extension_base = extension::xgenerator_base_t<F, R, S>;
103         using expression_tag = typename extension_base::expression_tag;
105         using inner_types = xcontainer_inner_types<self_type>;
106         using value_type = R;
107         using reference = typename inner_types::reference;
108         using const_reference = typename inner_types::const_reference;
109         using pointer = value_type*;
110         using const_pointer = const value_type*;
111         using size_type = typename inner_types::size_type;
112         using difference_type = std::ptrdiff_t;
114         using iterable_base = xconst_iterable<self_type>;
115         using inner_shape_type = typename iterable_base::inner_shape_type;
116         using shape_type = inner_shape_type;
118         using stepper = typename iterable_base::stepper;
119         using const_stepper = typename iterable_base::const_stepper;
121         using bool_load_type = xt::bool_load_type<R>;
123         static constexpr layout_type static_layout = layout_type::dynamic;
124         static constexpr bool contiguous_layout = false;
126         template <class Func>
127         xgenerator(Func&& f, const S& shape) noexcept;
129         const inner_shape_type& shape() const noexcept;
130         layout_type layout() const noexcept;
131         bool is_contiguous() const noexcept;
132         using accessible_base::shape;
134         template <class... Args>
135         const_reference operator()(Args... args) const;
136         template <class... Args>
137         const_reference unchecked(Args... args) const;
139         template <class It>
140         const_reference element(It first, It last) const;
142         template <class O>
143         bool broadcast_shape(O& shape, bool reuse_cache = false) const;
145         template <class O>
146         bool has_linear_assign(const O& /*strides*/) const noexcept;
148         template <class O>
149         const_stepper stepper_begin(const O& shape) const noexcept;
150         template <class O>
151         const_stepper stepper_end(const O& shape, layout_type) const noexcept;
153         template <class E, class FE = F, class = std::enable_if_t<has_assign_to<E, FE>::value>>
154         void assign_to(xexpression<E>& e) const noexcept;
156         const functor_type& functor() const noexcept;
158         template <class OR, class OF>
159         using rebind_t = xgenerator<OF, OR, S>;
161         template <class OR, class OF>
162         rebind_t<OR, OF> build_generator(OF&& func) const;
164         template <class O = xt::dynamic_shape<typename shape_type::value_type>>
165         auto reshape(O&& shape) const &;
167         template <class O = xt::dynamic_shape<typename shape_type::value_type>>
168         auto reshape(O&& shape) &&;
170         template <class T>
171         auto reshape(std::initializer_list<T> shape) const &;
173         template <class T>
174         auto reshape(std::initializer_list<T> shape) &&;
176     private:
178         template <class O>
179         decltype(auto) compute_shape(O&& shape, std::false_type /*signed*/) const;
181         template <class O>
182         auto compute_shape(O&& shape, std::true_type /*signed*/) const;
184         template <class T>
185         auto compute_shape(std::initializer_list<T> shape) const;
187         template <std::size_t dim>
188         void adapt_index() const;
190         template <std::size_t dim, class I, class... Args>
191         void adapt_index(I& arg, Args&... args) const;
193         functor_type m_f;
194         inner_shape_type m_shape;
195     };
197     /*****************************
198      * xgenerator implementation *
199      *****************************/
201     /**
202      * @name Constructor
203      */
204     //@{
205     /**
206      * Constructs an xgenerator applying the specified function over the
207      * given shape.
208      * @param f the function to apply
209      * @param shape the shape of the xgenerator
210      */
211     template <class F, class R, class S>
212     template <class Func>
xgenerator(Func && f,const S & shape)213     inline xgenerator<F, R, S>::xgenerator(Func&& f, const S& shape) noexcept
214         : m_f(std::forward<Func>(f)), m_shape(shape)
215     {
216     }
217     //@}
219     /**
220      * @name Size and shape
221      */
222     //@{
223     /**
224      * Returns the shape of the xgenerator.
225      */
226     template <class F, class R, class S>
shape() const227     inline auto xgenerator<F, R, S>::shape() const noexcept -> const inner_shape_type&
228     {
229         return m_shape;
230     }
232     template <class F, class R, class S>
layout() const233     inline layout_type xgenerator<F, R, S>::layout() const noexcept
234     {
235         return static_layout;
236     }
238     template <class F, class R, class S>
is_contiguous() const239     inline bool xgenerator<F, R, S>::is_contiguous() const noexcept
240     {
241         return false;
242     }
244     //@}
246     /**
247      * @name Data
248      */
249     /**
250      * Returns the evaluated element at the specified position in the function.
251      * @param args a list of indices specifying the position in the function. Indices
252      * must be unsigned integers, the number of indices should be equal or greater than
253      * the number of dimensions of the function.
254      */
255     template <class F, class R, class S>
256     template <class... Args>
operator ()(Args...args) const257     inline auto xgenerator<F, R, S>::operator()(Args... args) const -> const_reference
258     {
259         XTENSOR_TRY(check_index(shape(), args...));
260         adapt_index<0>(args...);
261         return m_f(args...);
262     }
264     /**
265      * Returns a constant reference to the element at the specified position in the expression.
266      * @param args a list of indices specifying the position in the expression. Indices
267      * must be unsigned integers, the number of indices must be equal to the number of
268      * dimensions of the expression, else the behavior is undefined.
269      *
270      * @warning This method is meant for performance, for expressions with a dynamic
271      * number of dimensions (i.e. not known at compile time). Since it may have
272      * undefined behavior (see parameters), operator() should be prefered whenever
273      * it is possible.
274      * @warning This method is NOT compatible with broadcasting, meaning the following
275      *  code has undefined behavior:
276      * \code{.cpp}
277      * xt::xarray<double> a = {{0, 1}, {2, 3}};
278      * xt::xarray<double> b = {0, 1};
279      * auto fd = a + b;
280      * double res = fd.uncheked(0, 1);
281      * \endcode
282      */
283     template <class F, class R, class S>
284     template <class... Args>
unchecked(Args...args) const285     inline auto xgenerator<F, R, S>::unchecked(Args... args) const -> const_reference
286     {
287         return m_f(args...);
288     }
290     /**
291      * Returns a constant reference to the element at the specified position in the function.
292      * @param first iterator starting the sequence of indices
293      * @param last iterator ending the sequence of indices
294      * The number of indices in the sequence should be equal to or greater
295      * than the number of dimensions of the container.
296      */
297     template <class F, class R, class S>
298     template <class It>
element(It first,It last) const299     inline auto xgenerator<F, R, S>::element(It first, It last) const -> const_reference
300     {
301         using bounded_iterator = xbounded_iterator<It, typename shape_type::const_iterator>;
302         XTENSOR_TRY(check_element_index(shape(), first, last));
303         return m_f.element(bounded_iterator(first, shape().cbegin()), bounded_iterator(last, shape().cend()));
304     }
305     //@}
307     /**
308      * @name Broadcasting
309      */
310     //@{
311     /**
312      * Broadcast the shape of the function to the specified parameter.
313      * @param shape the result shape
314      * @param reuse_cache parameter for internal optimization
315      * @return a boolean indicating whether the broadcasting is trivial
316      */
317     template <class F, class R, class S>
318     template <class O>
broadcast_shape(O & shape,bool) const319     inline bool xgenerator<F, R, S>::broadcast_shape(O& shape, bool) const
320     {
321         return xt::broadcast_shape(m_shape, shape);
322     }
324     /**
325      * Checks whether the xgenerator can be linearly assigned to an expression
326      * with the specified strides.
327      * @return a boolean indicating whether a linear assign is possible
328      */
329     template <class F, class R, class S>
330     template <class O>
has_linear_assign(const O &) const331     inline bool xgenerator<F, R, S>::has_linear_assign(const O& /*strides*/) const noexcept
332     {
333         return false;
334     }
335     //@}
337     template <class F, class R, class S>
338     template <class O>
stepper_begin(const O & shape) const339     inline auto xgenerator<F, R, S>::stepper_begin(const O& shape) const noexcept -> const_stepper
340     {
341         size_type offset = shape.size() - this->dimension();
342         return const_stepper(this, offset);
343     }
345     template <class F, class R, class S>
346     template <class O>
stepper_end(const O & shape,layout_type) const347     inline auto xgenerator<F, R, S>::stepper_end(const O& shape, layout_type) const noexcept -> const_stepper
348     {
349         size_type offset = shape.size() - this->dimension();
350         return const_stepper(this, offset, true);
351     }
353     template <class F, class R, class S>
354     template <class E, class, class>
assign_to(xexpression<E> & e) const355     inline void xgenerator<F, R, S>::assign_to(xexpression<E>& e) const noexcept
356     {
357         e.derived_cast().resize(m_shape);
358         m_f.assign_to(e);
359     }
361     template <class F, class R, class S>
functor() const362     inline auto xgenerator<F, R, S>::functor() const noexcept -> const functor_type&
363     {
364         return m_f;
365     }
367     template <class F, class R, class S>
368     template <class OR, class OF>
build_generator(OF && func) const369     inline auto xgenerator<F, R, S>::build_generator(OF&& func) const -> rebind_t<OR, OF>
370     {
371         return rebind_t<OR, OF>(std::move(func), shape_type(m_shape));
372     }
374     /**
375      * Reshapes the generator and keeps old elements. The `shape` argument can have one of its value
376      * equal to `-1`, in this case the value is inferred from the number of elements in the generator
377      * and the remaining values in the `shape`.
378      * \code{.cpp}
379      * auto a = xt::arange<double>(50).reshape({-1, 10});
380      * //a.shape() is {5, 10}
381      * \endcode
382      * @param shape the new shape (has to have same number of elements as the original genrator)
383      */
384     template <class F, class R, class S>
385     template <class O>
reshape(O && shape) const386     inline auto xgenerator<F, R, S>::reshape(O&& shape) const &
387     {
388         return reshape_view(*this, compute_shape(shape, xtl::is_signed<typename std::decay_t<O>::value_type>()));
389     }
391     template <class F, class R, class S>
392     template <class O>
reshape(O && shape)393     inline auto xgenerator<F, R, S>::reshape(O&& shape) &&
394     {
395         return reshape_view(std::move(*this), compute_shape(shape, xtl::is_signed<typename std::decay_t<O>::value_type>()));
396     }
398     template <class F, class R, class S>
399     template <class T>
reshape(std::initializer_list<T> shape) const400     inline auto xgenerator<F, R, S>::reshape(std::initializer_list<T> shape) const &
401     {
402         return reshape_view(*this, compute_shape(shape));
403     }
405     template <class F, class R, class S>
406     template <class T>
reshape(std::initializer_list<T> shape)407     inline auto xgenerator<F, R, S>::reshape(std::initializer_list<T> shape) &&
408     {
409         return reshape_view(std::move(*this), compute_shape(shape));
410     }
412     template <class F, class R, class S>
413     template <class O>
compute_shape(O && shape,std::false_type) const414     inline decltype(auto) xgenerator<F, R, S>::compute_shape(O&& shape, std::false_type) const
415     {
416         return xtl::forward_sequence<xt::dynamic_shape<typename shape_type::value_type>, O>(shape);
417     }
419     template <class F, class R, class S>
420     template <class O>
compute_shape(O && shape,std::true_type) const421     inline auto xgenerator<F, R, S>::compute_shape(O&& shape, std::true_type) const
422     {
423         using vtype = typename shape_type::value_type;
424         xt::dynamic_shape<vtype> sh(shape.size());
425         using int_type = typename std::decay_t<O>::value_type;
426         int_type accumulator(1);
427         std::size_t neg_idx = 0;
428         std::size_t i = 0;
429         for(std::size_t j = 0; j != shape.size(); ++j, ++i)
430         {
431             auto dim = shape[j];
432             if(dim < 0)
433             {
434                 XTENSOR_ASSERT(dim == -1 && !neg_idx);
435                 neg_idx = i;
436             }
437             else
438             {
439                 sh[j] = static_cast<vtype>(dim);
440             }
441             accumulator *= dim;
442         }
443         if(accumulator < 0)
444         {
445             sh[neg_idx] = this->size() / static_cast<size_type>(std::make_unsigned_t<int_type>(std::abs(accumulator)));
446         }
447         return sh;
448     }
450     template <class F, class R, class S>
451     template <class T>
compute_shape(std::initializer_list<T> shape) const452     inline auto xgenerator<F, R, S>::compute_shape(std::initializer_list<T> shape) const
453     {
454         using sh_type = xt::dynamic_shape<T>;
455         sh_type sh = xtl::make_sequence<sh_type>(shape.size());
456         std::copy(shape.begin(), shape.end(), sh.begin());
457         return compute_shape(std::move(sh), xtl::is_signed<T>());
458     }
460     template <class F, class R, class S>
461     template <std::size_t dim>
adapt_index() const462     inline void xgenerator<F, R, S>::adapt_index() const
463     {
464     }
466     template <class F, class R, class S>
467     template <std::size_t dim, class I, class... Args>
adapt_index(I & arg,Args &...args) const468     inline void xgenerator<F, R, S>::adapt_index(I& arg, Args&... args) const
469     {
470         using tmp_value_type = typename decltype(m_shape)::value_type;
471         if (sizeof...(Args) + 1 > m_shape.size())
472         {
473             adapt_index<dim>(args...);
474         }
475         else
476         {
477             if (static_cast<tmp_value_type>(arg) >= m_shape[dim] && m_shape[dim] == 1)
478             {
479                 arg = 0;
480             }
481             adapt_index<dim + 1>(args...);
482         }
483     }
485     namespace detail
486     {
487         template <class Functor, class I, std::size_t L>
make_xgenerator(Functor && f,const I (& shape)[L])488         inline auto make_xgenerator(Functor&& f, const I (&shape)[L]) noexcept
489         {
490             using shape_type = std::array<std::size_t, L>;
491             using type = xgenerator<Functor, typename Functor::value_type, shape_type>;
492             return type(std::forward<Functor>(f), xtl::forward_sequence<shape_type, decltype(shape)>(shape));
493         }
495         template <class Functor, class S>
make_xgenerator(Functor && f,S && shape)496         inline auto make_xgenerator(Functor&& f, S&& shape) noexcept
497         {
498             using type = xgenerator<Functor, typename Functor::value_type, std::decay_t<S>>;
499             return type(std::forward<Functor>(f), std::forward<S>(shape));
500         }
501     }
502 }
504 #endif