1 /***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht          *
3 * Copyright (c) QuantStack                                                 *
4 *                                                                          *
5 * Distributed under the terms of the BSD 3-Clause License.                 *
6 *                                                                          *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9 
10 #ifndef XTENSOR_BROADCAST_HPP
11 #define XTENSOR_BROADCAST_HPP
12 
13 #include <algorithm>
14 #include <array>
15 #include <cstddef>
16 #include <iterator>
17 #include <numeric>
18 #include <type_traits>
19 #include <utility>
20 
21 #include <xtl/xsequence.hpp>
22 
23 #include "xaccessible.hpp"
24 #include "xexpression.hpp"
25 #include "xiterable.hpp"
26 #include "xscalar.hpp"
27 #include "xstrides.hpp"
28 #include "xtensor_config.hpp"
29 #include "xutils.hpp"
30 
31 namespace xt
32 {
33 
34     /*************
35      * broadcast *
36      *************/
37 
38     template <class E, class S>
39     auto broadcast(E&& e, const S& s);
40 
41     template <class E, class I, std::size_t L>
42     auto broadcast(E&& e, const I (&s)[L]);
43 
44     /*************************
45      * xbroadcast extensions *
46      *************************/
47 
48     namespace extension
49     {
50         template <class Tag, class CT, class X>
51         struct xbroadcast_base_impl;
52 
53         template <class CT, class X>
54         struct xbroadcast_base_impl<xtensor_expression_tag, CT, X>
55         {
56             using type = xtensor_empty_base;
57         };
58 
59         template <class CT, class X>
60         struct xbroadcast_base
61             : xbroadcast_base_impl<xexpression_tag_t<CT>, CT, X>
62         {
63         };
64 
65         template <class CT, class X>
66         using xbroadcast_base_t = typename xbroadcast_base<CT, X>::type;
67     }
68 
69     /**************
70      * xbroadcast *
71      **************/
72 
73     template <class CT, class X>
74     class xbroadcast;
75 
76     template <class CT, class X>
77     struct xiterable_inner_types<xbroadcast<CT, X>>
78     {
79         using xexpression_type = std::decay_t<CT>;
80         using inner_shape_type = promote_shape_t<typename xexpression_type::shape_type, X>;
81         using const_stepper = typename xexpression_type::const_stepper;
82         using stepper = const_stepper;
83     };
84 
85     template <class CT, class X>
86     struct xcontainer_inner_types<xbroadcast<CT, X>>
87     {
88         using xexpression_type = std::decay_t<CT>;
89         using reference = typename xexpression_type::const_reference;
90         using const_reference = typename xexpression_type::const_reference;
91         using size_type = typename xexpression_type::size_type;
92     };
93 
94     /*****************************
95      * linear_begin / linear_end *
96      *****************************/
97 
98     template <class CT, class X>
linear_begin(xbroadcast<CT,X> & c)99     XTENSOR_CONSTEXPR_RETURN auto linear_begin(xbroadcast<CT, X>& c) noexcept
100     {
101         return linear_begin(c.expression());
102     }
103 
104     template <class CT, class X>
linear_end(xbroadcast<CT,X> & c)105     XTENSOR_CONSTEXPR_RETURN auto linear_end(xbroadcast<CT, X>& c) noexcept
106     {
107         return linear_end(c.expression());
108     }
109 
110     template <class CT, class X>
linear_begin(const xbroadcast<CT,X> & c)111     XTENSOR_CONSTEXPR_RETURN auto linear_begin(const xbroadcast<CT, X>& c) noexcept
112     {
113         return linear_begin(c.expression());
114     }
115 
116     template <class CT, class X>
linear_end(const xbroadcast<CT,X> & c)117     XTENSOR_CONSTEXPR_RETURN auto linear_end(const xbroadcast<CT, X>& c) noexcept
118     {
119         return linear_end(c.expression());
120     }
121 
122     /**
123      * @class xbroadcast
124      * @brief Broadcasted xexpression to a specified shape.
125      *
126      * The xbroadcast class implements the broadcasting of an \ref xexpression
127      * to a specified shape. xbroadcast is not meant to be used directly, but
128      * only with the \ref broadcast helper functions.
129      *
130      * @tparam CT the closure type of the \ref xexpression to broadcast
131      * @tparam X the type of the specified shape.
132      *
133      * @sa broadcast
134      */
135     template <class CT, class X>
136     class xbroadcast : public xsharable_expression<xbroadcast<CT, X>>,
137                        public xconst_iterable<xbroadcast<CT, X>>,
138                        public xconst_accessible<xbroadcast<CT, X>>,
139                        public extension::xbroadcast_base_t<CT, X>
140     {
141     public:
142 
143         using self_type = xbroadcast<CT, X>;
144         using xexpression_type = std::decay_t<CT>;
145         using accessible_base = xconst_accessible<self_type>;
146         using extension_base = extension::xbroadcast_base_t<CT, X>;
147         using expression_tag = typename extension_base::expression_tag;
148 
149         using inner_types = xcontainer_inner_types<self_type>;
150         using value_type = typename xexpression_type::value_type;
151         using reference = typename inner_types::reference;
152         using const_reference = typename inner_types::const_reference;
153         using pointer = typename xexpression_type::const_pointer;
154         using const_pointer = typename xexpression_type::const_pointer;
155         using size_type = typename inner_types::size_type;
156         using difference_type = typename xexpression_type::difference_type;
157 
158         using iterable_base = xconst_iterable<self_type>;
159         using inner_shape_type = typename iterable_base::inner_shape_type;
160         using shape_type = inner_shape_type;
161 
162         using stepper = typename iterable_base::stepper;
163         using const_stepper = typename iterable_base::const_stepper;
164 
165         using bool_load_type = typename xexpression_type::bool_load_type;
166 
167         static constexpr layout_type static_layout = layout_type::dynamic;
168         static constexpr bool contiguous_layout = false;
169 
170         template <class CTA, class S>
171         xbroadcast(CTA&& e, const S& s);
172 
173         template <class CTA>
174         xbroadcast(CTA&& e, shape_type&& s);
175 
176         using accessible_base::size;
177         const inner_shape_type& shape() const noexcept;
178         layout_type layout() const noexcept;
179         bool is_contiguous() const noexcept;
180         using accessible_base::shape;
181 
182         template <class... Args>
183         const_reference operator()(Args... args) const;
184 
185         template <class... Args>
186         const_reference unchecked(Args... args) const;
187 
188         template <class It>
189         const_reference element(It first, It last) const;
190 
191         const xexpression_type& expression() const noexcept;
192 
193         template <class S>
194         bool broadcast_shape(S& shape, bool reuse_cache = false) const;
195 
196         template <class S>
197         bool has_linear_assign(const S& strides) const noexcept;
198 
199         template <class S>
200         const_stepper stepper_begin(const S& shape) const noexcept;
201         template <class S>
202         const_stepper stepper_end(const S& shape, layout_type l) const noexcept;
203 
204         template <class E, class XCT = CT, class = std::enable_if_t<xt::is_xscalar<XCT>::value>>
205         void assign_to(xexpression<E>& e) const;
206 
207         template <class E>
208         using rebind_t = xbroadcast<E, X>;
209 
210         template <class E>
211         rebind_t<E> build_broadcast(E&& e) const;
212 
213     private:
214 
215         CT m_e;
216         inner_shape_type m_shape;
217     };
218 
219     /****************************
220      * broadcast implementation *
221      ****************************/
222 
223     /**
224      * @brief Returns an \ref xexpression broadcasting the given expression to
225      * a specified shape.
226      *
227      * @tparam e the \ref xexpression to broadcast
228      * @tparam s the specified shape to broadcast.
229      *
230      * The returned expression either hold a const reference to \p e or a copy
231      * depending on whether \p e is an lvalue or an rvalue.
232      */
233     template <class E, class S>
broadcast(E && e,const S & s)234     inline auto broadcast(E&& e, const S& s)
235     {
236         using shape_type = filter_fixed_shape_t<std::decay_t<S>>;
237         using broadcast_type = xbroadcast<const_xclosure_t<E>, shape_type>;
238         return broadcast_type(std::forward<E>(e), xtl::forward_sequence<shape_type, decltype(s)>(s));
239     }
240 
241     template <class E, class I, std::size_t L>
broadcast(E && e,const I (& s)[L])242     inline auto broadcast(E&& e, const I (&s)[L])
243     {
244         using broadcast_type = xbroadcast<const_xclosure_t<E>, std::array<std::size_t, L>>;
245         using shape_type = typename broadcast_type::shape_type;
246         return broadcast_type(std::forward<E>(e), xtl::forward_sequence<shape_type, decltype(s)>(s));
247     }
248 
249     /*****************************
250      * xbroadcast implementation *
251      *****************************/
252 
253     /**
254      * @name Constructor
255      */
256     //@{
257     /**
258      * Constructs an xbroadcast expression broadcasting the specified
259      * \ref xexpression to the given shape
260      *
261      * @param e the expression to broadcast
262      * @param s the shape to apply
263      */
264     template <class CT, class X>
265     template <class CTA, class S>
xbroadcast(CTA && e,const S & s)266     inline xbroadcast<CT, X>::xbroadcast(CTA&& e, const S& s)
267         : m_e(std::forward<CTA>(e))
268     {
269         if (s.size() < m_e.dimension())
270         {
271             XTENSOR_THROW(xt::broadcast_error, "Broadcast shape has fewer elements than original expression.");
272         }
273         xt::resize_container(m_shape, s.size());
274         std::copy(s.begin(), s.end(), m_shape.begin());
275         xt::broadcast_shape(m_e.shape(), m_shape);
276     }
277 
278     /**
279      * Constructs an xbroadcast expression broadcasting the specified
280      * \ref xexpression to the given shape
281      *
282      * @param e the expression to broadcast
283      * @param s the shape to apply
284      */
285     template <class CT, class X>
286     template <class CTA>
xbroadcast(CTA && e,shape_type && s)287     inline xbroadcast<CT, X>::xbroadcast(CTA&& e, shape_type&& s)
288         : m_e(std::forward<CTA>(e)), m_shape(std::move(s))
289     {
290         xt::broadcast_shape(m_e.shape(), m_shape);
291     }
292     //@}
293 
294     /**
295      * @name Size and shape
296      */
297     //@{
298     /**
299      * Returns the shape of the expression.
300      */
301     template <class CT, class X>
shape() const302     inline auto xbroadcast<CT, X>::shape() const noexcept -> const inner_shape_type&
303     {
304         return m_shape;
305     }
306 
307     /**
308      * Returns the layout_type of the expression.
309      */
310     template <class CT, class X>
layout() const311     inline layout_type xbroadcast<CT, X>::layout() const noexcept
312     {
313         return m_e.layout();
314     }
315 
316     template <class CT, class X>
is_contiguous() const317     inline bool xbroadcast<CT, X>::is_contiguous() const noexcept
318     {
319         return false;
320     }
321 
322     //@}
323 
324     /**
325      * @name Data
326      */
327     //@{
328     /**
329      * Returns a constant reference to the element at the specified position in the expression.
330      * @param args a list of indices specifying the position in the function. Indices
331      * must be unsigned integers, the number of indices should be equal or greater than
332      * the number of dimensions of the expression.
333      */
334     template <class CT, class X>
335     template <class... Args>
operator ()(Args...args) const336     inline auto xbroadcast<CT, X>::operator()(Args... args) const -> const_reference
337     {
338         return m_e(args...);
339     }
340 
341     /**
342      * Returns a constant reference to the element at the specified position in the expression.
343      * @param args a list of indices specifying the position in the expression. Indices
344      * must be unsigned integers, the number of indices must be equal to the number of
345      * dimensions of the expression, else the behavior is undefined.
346      *
347      * @warning This method is meant for performance, for expressions with a dynamic
348      * number of dimensions (i.e. not known at compile time). Since it may have
349      * undefined behavior (see parameters), operator() should be prefered whenever
350      * it is possible.
351      * @warning This method is NOT compatible with broadcasting, meaning the following
352      * code has undefined behavior:
353      * \code{.cpp}
354      * xt::xarray<double> a = {{0, 1}, {2, 3}};
355      * xt::xarray<double> b = {0, 1};
356      * auto fd = a + b;
357      * double res = fd.uncheked(0, 1);
358      * \endcode
359      */
360     template <class CT, class X>
361     template <class... Args>
unchecked(Args...args) const362     inline auto xbroadcast<CT, X>::unchecked(Args... args) const -> const_reference
363     {
364         return this->operator()(args...);
365     }
366 
367     /**
368      * Returns a constant reference to the element at the specified position in the expression.
369      * @param first iterator starting the sequence of indices
370      * @param last iterator ending the sequence of indices
371      * The number of indices in the sequence should be equal to or greater
372      * than the number of dimensions of the function.
373      */
374     template <class CT, class X>
375     template <class It>
element(It,It last) const376     inline auto xbroadcast<CT, X>::element(It, It last) const -> const_reference
377     {
378         return m_e.element(last - this->dimension(), last);
379     }
380 
381     /**
382      * Returns a constant reference to the underlying expression of the broadcast expression.
383      */
384     template <class CT, class X>
expression() const385     inline auto xbroadcast<CT, X>::expression() const noexcept -> const xexpression_type&
386     {
387         return m_e;
388     }
389     //@}
390 
391     /**
392      * @name Broadcasting
393      */
394     //@{
395     /**
396      * Broadcast the shape of the function to the specified parameter.
397      * @param shape the result shape
398      * @param reuse_cache parameter for internal optimization
399      * @return a boolean indicating whether the broadcasting is trivial
400      */
401     template <class CT, class X>
402     template <class S>
broadcast_shape(S & shape,bool) const403     inline bool xbroadcast<CT, X>::broadcast_shape(S& shape, bool) const
404     {
405         return xt::broadcast_shape(m_shape, shape);
406     }
407 
408     /**
409      * Checks whether the xbroadcast can be linearly assigned to an expression
410      * with the specified strides.
411      * @return a boolean indicating whether a linear assign is possible
412      */
413     template <class CT, class X>
414     template <class S>
has_linear_assign(const S & strides) const415     inline bool xbroadcast<CT, X>::has_linear_assign(const S& strides) const noexcept
416     {
417         return this->dimension() == m_e.dimension() &&
418             std::equal(m_shape.cbegin(), m_shape.cend(), m_e.shape().cbegin()) &&
419             m_e.has_linear_assign(strides);
420     }
421     //@}
422 
423     template <class CT, class X>
424     template <class S>
stepper_begin(const S & shape) const425     inline auto xbroadcast<CT, X>::stepper_begin(const S& shape) const noexcept -> const_stepper
426     {
427         // Could check if (broadcastable(shape, m_shape)
428         return m_e.stepper_begin(shape);
429     }
430 
431     template <class CT, class X>
432     template <class S>
stepper_end(const S & shape,layout_type l) const433     inline auto xbroadcast<CT, X>::stepper_end(const S& shape, layout_type l) const noexcept -> const_stepper
434     {
435         // Could check if (broadcastable(shape, m_shape)
436         return m_e.stepper_end(shape, l);
437     }
438 
439     template <class CT, class X>
440     template <class E, class XCT, class>
assign_to(xexpression<E> & e) const441     inline void xbroadcast<CT, X>::assign_to(xexpression<E>& e) const
442     {
443         auto& ed = e.derived_cast();
444         ed.resize(m_shape);
445         std::fill(ed.begin(), ed.end(), m_e());
446     }
447 
448     template <class CT, class X>
449     template <class E>
build_broadcast(E && e) const450     inline auto xbroadcast<CT, X>::build_broadcast(E&& e) const -> rebind_t<E>
451     {
452         return rebind_t<E>(std::forward<E>(e), inner_shape_type(m_shape));
453     }
454 }
455 
456 #endif
457