1 /*************************************************************************** 2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht * 3 * Copyright (c) QuantStack * 4 * * 5 * Distributed under the terms of the BSD 3-Clause License. * 6 * * 7 * The full license is in the file LICENSE, distributed with this software. * 8 ****************************************************************************/ 9 10 #ifndef XTENSOR_BROADCAST_HPP 11 #define XTENSOR_BROADCAST_HPP 12 13 #include <algorithm> 14 #include <array> 15 #include <cstddef> 16 #include <iterator> 17 #include <numeric> 18 #include <type_traits> 19 #include <utility> 20 21 #include <xtl/xsequence.hpp> 22 23 #include "xaccessible.hpp" 24 #include "xexpression.hpp" 25 #include "xiterable.hpp" 26 #include "xscalar.hpp" 27 #include "xstrides.hpp" 28 #include "xtensor_config.hpp" 29 #include "xutils.hpp" 30 31 namespace xt 32 { 33 34 /************* 35 * broadcast * 36 *************/ 37 38 template <class E, class S> 39 auto broadcast(E&& e, const S& s); 40 41 template <class E, class I, std::size_t L> 42 auto broadcast(E&& e, const I (&s)[L]); 43 44 /************************* 45 * xbroadcast extensions * 46 *************************/ 47 48 namespace extension 49 { 50 template <class Tag, class CT, class X> 51 struct xbroadcast_base_impl; 52 53 template <class CT, class X> 54 struct xbroadcast_base_impl<xtensor_expression_tag, CT, X> 55 { 56 using type = xtensor_empty_base; 57 }; 58 59 template <class CT, class X> 60 struct xbroadcast_base 61 : xbroadcast_base_impl<xexpression_tag_t<CT>, CT, X> 62 { 63 }; 64 65 template <class CT, class X> 66 using xbroadcast_base_t = typename xbroadcast_base<CT, X>::type; 67 } 68 69 /************** 70 * xbroadcast * 71 **************/ 72 73 template <class CT, class X> 74 class xbroadcast; 75 76 template <class CT, class X> 77 struct xiterable_inner_types<xbroadcast<CT, X>> 78 { 79 using xexpression_type = std::decay_t<CT>; 80 using inner_shape_type = promote_shape_t<typename xexpression_type::shape_type, X>; 81 using const_stepper = typename xexpression_type::const_stepper; 82 using stepper = const_stepper; 83 }; 84 85 template <class CT, class X> 86 struct xcontainer_inner_types<xbroadcast<CT, X>> 87 { 88 using xexpression_type = std::decay_t<CT>; 89 using reference = typename xexpression_type::const_reference; 90 using const_reference = typename xexpression_type::const_reference; 91 using size_type = typename xexpression_type::size_type; 92 }; 93 94 /***************************** 95 * linear_begin / linear_end * 96 *****************************/ 97 98 template <class CT, class X> linear_begin(xbroadcast<CT,X> & c)99 XTENSOR_CONSTEXPR_RETURN auto linear_begin(xbroadcast<CT, X>& c) noexcept 100 { 101 return linear_begin(c.expression()); 102 } 103 104 template <class CT, class X> linear_end(xbroadcast<CT,X> & c)105 XTENSOR_CONSTEXPR_RETURN auto linear_end(xbroadcast<CT, X>& c) noexcept 106 { 107 return linear_end(c.expression()); 108 } 109 110 template <class CT, class X> linear_begin(const xbroadcast<CT,X> & c)111 XTENSOR_CONSTEXPR_RETURN auto linear_begin(const xbroadcast<CT, X>& c) noexcept 112 { 113 return linear_begin(c.expression()); 114 } 115 116 template <class CT, class X> linear_end(const xbroadcast<CT,X> & c)117 XTENSOR_CONSTEXPR_RETURN auto linear_end(const xbroadcast<CT, X>& c) noexcept 118 { 119 return linear_end(c.expression()); 120 } 121 122 /** 123 * @class xbroadcast 124 * @brief Broadcasted xexpression to a specified shape. 125 * 126 * The xbroadcast class implements the broadcasting of an \ref xexpression 127 * to a specified shape. xbroadcast is not meant to be used directly, but 128 * only with the \ref broadcast helper functions. 129 * 130 * @tparam CT the closure type of the \ref xexpression to broadcast 131 * @tparam X the type of the specified shape. 132 * 133 * @sa broadcast 134 */ 135 template <class CT, class X> 136 class xbroadcast : public xsharable_expression<xbroadcast<CT, X>>, 137 public xconst_iterable<xbroadcast<CT, X>>, 138 public xconst_accessible<xbroadcast<CT, X>>, 139 public extension::xbroadcast_base_t<CT, X> 140 { 141 public: 142 143 using self_type = xbroadcast<CT, X>; 144 using xexpression_type = std::decay_t<CT>; 145 using accessible_base = xconst_accessible<self_type>; 146 using extension_base = extension::xbroadcast_base_t<CT, X>; 147 using expression_tag = typename extension_base::expression_tag; 148 149 using inner_types = xcontainer_inner_types<self_type>; 150 using value_type = typename xexpression_type::value_type; 151 using reference = typename inner_types::reference; 152 using const_reference = typename inner_types::const_reference; 153 using pointer = typename xexpression_type::const_pointer; 154 using const_pointer = typename xexpression_type::const_pointer; 155 using size_type = typename inner_types::size_type; 156 using difference_type = typename xexpression_type::difference_type; 157 158 using iterable_base = xconst_iterable<self_type>; 159 using inner_shape_type = typename iterable_base::inner_shape_type; 160 using shape_type = inner_shape_type; 161 162 using stepper = typename iterable_base::stepper; 163 using const_stepper = typename iterable_base::const_stepper; 164 165 using bool_load_type = typename xexpression_type::bool_load_type; 166 167 static constexpr layout_type static_layout = layout_type::dynamic; 168 static constexpr bool contiguous_layout = false; 169 170 template <class CTA, class S> 171 xbroadcast(CTA&& e, const S& s); 172 173 template <class CTA> 174 xbroadcast(CTA&& e, shape_type&& s); 175 176 using accessible_base::size; 177 const inner_shape_type& shape() const noexcept; 178 layout_type layout() const noexcept; 179 bool is_contiguous() const noexcept; 180 using accessible_base::shape; 181 182 template <class... Args> 183 const_reference operator()(Args... args) const; 184 185 template <class... Args> 186 const_reference unchecked(Args... args) const; 187 188 template <class It> 189 const_reference element(It first, It last) const; 190 191 const xexpression_type& expression() const noexcept; 192 193 template <class S> 194 bool broadcast_shape(S& shape, bool reuse_cache = false) const; 195 196 template <class S> 197 bool has_linear_assign(const S& strides) const noexcept; 198 199 template <class S> 200 const_stepper stepper_begin(const S& shape) const noexcept; 201 template <class S> 202 const_stepper stepper_end(const S& shape, layout_type l) const noexcept; 203 204 template <class E, class XCT = CT, class = std::enable_if_t<xt::is_xscalar<XCT>::value>> 205 void assign_to(xexpression<E>& e) const; 206 207 template <class E> 208 using rebind_t = xbroadcast<E, X>; 209 210 template <class E> 211 rebind_t<E> build_broadcast(E&& e) const; 212 213 private: 214 215 CT m_e; 216 inner_shape_type m_shape; 217 }; 218 219 /**************************** 220 * broadcast implementation * 221 ****************************/ 222 223 /** 224 * @brief Returns an \ref xexpression broadcasting the given expression to 225 * a specified shape. 226 * 227 * @tparam e the \ref xexpression to broadcast 228 * @tparam s the specified shape to broadcast. 229 * 230 * The returned expression either hold a const reference to \p e or a copy 231 * depending on whether \p e is an lvalue or an rvalue. 232 */ 233 template <class E, class S> broadcast(E && e,const S & s)234 inline auto broadcast(E&& e, const S& s) 235 { 236 using shape_type = filter_fixed_shape_t<std::decay_t<S>>; 237 using broadcast_type = xbroadcast<const_xclosure_t<E>, shape_type>; 238 return broadcast_type(std::forward<E>(e), xtl::forward_sequence<shape_type, decltype(s)>(s)); 239 } 240 241 template <class E, class I, std::size_t L> broadcast(E && e,const I (& s)[L])242 inline auto broadcast(E&& e, const I (&s)[L]) 243 { 244 using broadcast_type = xbroadcast<const_xclosure_t<E>, std::array<std::size_t, L>>; 245 using shape_type = typename broadcast_type::shape_type; 246 return broadcast_type(std::forward<E>(e), xtl::forward_sequence<shape_type, decltype(s)>(s)); 247 } 248 249 /***************************** 250 * xbroadcast implementation * 251 *****************************/ 252 253 /** 254 * @name Constructor 255 */ 256 //@{ 257 /** 258 * Constructs an xbroadcast expression broadcasting the specified 259 * \ref xexpression to the given shape 260 * 261 * @param e the expression to broadcast 262 * @param s the shape to apply 263 */ 264 template <class CT, class X> 265 template <class CTA, class S> xbroadcast(CTA && e,const S & s)266 inline xbroadcast<CT, X>::xbroadcast(CTA&& e, const S& s) 267 : m_e(std::forward<CTA>(e)) 268 { 269 if (s.size() < m_e.dimension()) 270 { 271 XTENSOR_THROW(xt::broadcast_error, "Broadcast shape has fewer elements than original expression."); 272 } 273 xt::resize_container(m_shape, s.size()); 274 std::copy(s.begin(), s.end(), m_shape.begin()); 275 xt::broadcast_shape(m_e.shape(), m_shape); 276 } 277 278 /** 279 * Constructs an xbroadcast expression broadcasting the specified 280 * \ref xexpression to the given shape 281 * 282 * @param e the expression to broadcast 283 * @param s the shape to apply 284 */ 285 template <class CT, class X> 286 template <class CTA> xbroadcast(CTA && e,shape_type && s)287 inline xbroadcast<CT, X>::xbroadcast(CTA&& e, shape_type&& s) 288 : m_e(std::forward<CTA>(e)), m_shape(std::move(s)) 289 { 290 xt::broadcast_shape(m_e.shape(), m_shape); 291 } 292 //@} 293 294 /** 295 * @name Size and shape 296 */ 297 //@{ 298 /** 299 * Returns the shape of the expression. 300 */ 301 template <class CT, class X> shape() const302 inline auto xbroadcast<CT, X>::shape() const noexcept -> const inner_shape_type& 303 { 304 return m_shape; 305 } 306 307 /** 308 * Returns the layout_type of the expression. 309 */ 310 template <class CT, class X> layout() const311 inline layout_type xbroadcast<CT, X>::layout() const noexcept 312 { 313 return m_e.layout(); 314 } 315 316 template <class CT, class X> is_contiguous() const317 inline bool xbroadcast<CT, X>::is_contiguous() const noexcept 318 { 319 return false; 320 } 321 322 //@} 323 324 /** 325 * @name Data 326 */ 327 //@{ 328 /** 329 * Returns a constant reference to the element at the specified position in the expression. 330 * @param args a list of indices specifying the position in the function. Indices 331 * must be unsigned integers, the number of indices should be equal or greater than 332 * the number of dimensions of the expression. 333 */ 334 template <class CT, class X> 335 template <class... Args> operator ()(Args...args) const336 inline auto xbroadcast<CT, X>::operator()(Args... args) const -> const_reference 337 { 338 return m_e(args...); 339 } 340 341 /** 342 * Returns a constant reference to the element at the specified position in the expression. 343 * @param args a list of indices specifying the position in the expression. Indices 344 * must be unsigned integers, the number of indices must be equal to the number of 345 * dimensions of the expression, else the behavior is undefined. 346 * 347 * @warning This method is meant for performance, for expressions with a dynamic 348 * number of dimensions (i.e. not known at compile time). Since it may have 349 * undefined behavior (see parameters), operator() should be prefered whenever 350 * it is possible. 351 * @warning This method is NOT compatible with broadcasting, meaning the following 352 * code has undefined behavior: 353 * \code{.cpp} 354 * xt::xarray<double> a = {{0, 1}, {2, 3}}; 355 * xt::xarray<double> b = {0, 1}; 356 * auto fd = a + b; 357 * double res = fd.uncheked(0, 1); 358 * \endcode 359 */ 360 template <class CT, class X> 361 template <class... Args> unchecked(Args...args) const362 inline auto xbroadcast<CT, X>::unchecked(Args... args) const -> const_reference 363 { 364 return this->operator()(args...); 365 } 366 367 /** 368 * Returns a constant reference to the element at the specified position in the expression. 369 * @param first iterator starting the sequence of indices 370 * @param last iterator ending the sequence of indices 371 * The number of indices in the sequence should be equal to or greater 372 * than the number of dimensions of the function. 373 */ 374 template <class CT, class X> 375 template <class It> element(It,It last) const376 inline auto xbroadcast<CT, X>::element(It, It last) const -> const_reference 377 { 378 return m_e.element(last - this->dimension(), last); 379 } 380 381 /** 382 * Returns a constant reference to the underlying expression of the broadcast expression. 383 */ 384 template <class CT, class X> expression() const385 inline auto xbroadcast<CT, X>::expression() const noexcept -> const xexpression_type& 386 { 387 return m_e; 388 } 389 //@} 390 391 /** 392 * @name Broadcasting 393 */ 394 //@{ 395 /** 396 * Broadcast the shape of the function to the specified parameter. 397 * @param shape the result shape 398 * @param reuse_cache parameter for internal optimization 399 * @return a boolean indicating whether the broadcasting is trivial 400 */ 401 template <class CT, class X> 402 template <class S> broadcast_shape(S & shape,bool) const403 inline bool xbroadcast<CT, X>::broadcast_shape(S& shape, bool) const 404 { 405 return xt::broadcast_shape(m_shape, shape); 406 } 407 408 /** 409 * Checks whether the xbroadcast can be linearly assigned to an expression 410 * with the specified strides. 411 * @return a boolean indicating whether a linear assign is possible 412 */ 413 template <class CT, class X> 414 template <class S> has_linear_assign(const S & strides) const415 inline bool xbroadcast<CT, X>::has_linear_assign(const S& strides) const noexcept 416 { 417 return this->dimension() == m_e.dimension() && 418 std::equal(m_shape.cbegin(), m_shape.cend(), m_e.shape().cbegin()) && 419 m_e.has_linear_assign(strides); 420 } 421 //@} 422 423 template <class CT, class X> 424 template <class S> stepper_begin(const S & shape) const425 inline auto xbroadcast<CT, X>::stepper_begin(const S& shape) const noexcept -> const_stepper 426 { 427 // Could check if (broadcastable(shape, m_shape) 428 return m_e.stepper_begin(shape); 429 } 430 431 template <class CT, class X> 432 template <class S> stepper_end(const S & shape,layout_type l) const433 inline auto xbroadcast<CT, X>::stepper_end(const S& shape, layout_type l) const noexcept -> const_stepper 434 { 435 // Could check if (broadcastable(shape, m_shape) 436 return m_e.stepper_end(shape, l); 437 } 438 439 template <class CT, class X> 440 template <class E, class XCT, class> assign_to(xexpression<E> & e) const441 inline void xbroadcast<CT, X>::assign_to(xexpression<E>& e) const 442 { 443 auto& ed = e.derived_cast(); 444 ed.resize(m_shape); 445 std::fill(ed.begin(), ed.end(), m_e()); 446 } 447 448 template <class CT, class X> 449 template <class E> build_broadcast(E && e) const450 inline auto xbroadcast<CT, X>::build_broadcast(E&& e) const -> rebind_t<E> 451 { 452 return rebind_t<E>(std::forward<E>(e), inner_shape_type(m_shape)); 453 } 454 } 455 456 #endif 457