1 /*************************************************************************** 2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht * 3 * Copyright (c) QuantStack * 4 * * 5 * Distributed under the terms of the BSD 3-Clause License. * 6 * * 7 * The full license is in the file LICENSE, distributed with this software. * 8 ****************************************************************************/ 9 10 #ifndef XTENSOR_GENERATOR_HPP 11 #define XTENSOR_GENERATOR_HPP 12 13 #include <algorithm> 14 #include <cstddef> 15 #include <numeric> 16 #include <tuple> 17 #include <type_traits> 18 #include <utility> 19 20 #include <xtl/xsequence.hpp> 21 22 #include "xaccessible.hpp" 23 #include "xexpression.hpp" 24 #include "xiterable.hpp" 25 #include "xstrides.hpp" 26 #include "xutils.hpp" 27 #include "xstrided_view.hpp" 28 29 namespace xt 30 { 31 32 /************************ 33 * xgenerator extension * 34 ************************/ 35 36 namespace extension 37 { 38 template <class Tag, class F, class R, class S> 39 struct xgenerator_base_impl; 40 41 template <class F, class R, class S> 42 struct xgenerator_base_impl<xtensor_expression_tag, F, R, S> 43 { 44 using type = xtensor_empty_base; 45 }; 46 47 template <class F, class R, class S> 48 struct xgenerator_base : xgenerator_base_impl<xexpression_tag_t<R>, F, R, S> 49 { 50 }; 51 52 template <class F, class R, class S> 53 using xgenerator_base_t = typename xgenerator_base<F, R, S>::type; 54 } 55 56 /************** 57 * xgenerator * 58 **************/ 59 60 template <class F, class R, class S> 61 class xgenerator; 62 63 template <class C, class R, class S> 64 struct xiterable_inner_types<xgenerator<C, R, S>> 65 { 66 using inner_shape_type = S; 67 using const_stepper = xindexed_stepper<xgenerator<C, R, S>, true>; 68 using stepper = const_stepper; 69 }; 70 71 template <class C, class R, class S> 72 struct xcontainer_inner_types<xgenerator<C, R, S>> 73 { 74 using reference = R; 75 using const_reference = R; 76 using size_type = std::size_t; 77 }; 78 79 /** 80 * @class xgenerator 81 * @brief Multidimensional function operating on indices. 82 * 83 * The xgenerator class implements a multidimensional function, 84 * generating a value from the supplied indices. 85 * 86 * @tparam F the function type 87 * @tparam R the return type of the function 88 * @tparam S the shape type of the generator 89 */ 90 template <class F, class R, class S> 91 class xgenerator : public xsharable_expression<xgenerator<F, R, S>>, 92 public xconst_iterable<xgenerator<F, R, S>>, 93 public xconst_accessible<xgenerator<F, R, S>>, 94 public extension::xgenerator_base_t<F, R, S> 95 { 96 public: 97 98 using self_type = xgenerator<F, R, S>; 99 using functor_type = typename std::remove_reference<F>::type; 100 101 using accessible_base = xconst_accessible<self_type>; 102 using extension_base = extension::xgenerator_base_t<F, R, S>; 103 using expression_tag = typename extension_base::expression_tag; 104 105 using inner_types = xcontainer_inner_types<self_type>; 106 using value_type = R; 107 using reference = typename inner_types::reference; 108 using const_reference = typename inner_types::const_reference; 109 using pointer = value_type*; 110 using const_pointer = const value_type*; 111 using size_type = typename inner_types::size_type; 112 using difference_type = std::ptrdiff_t; 113 114 using iterable_base = xconst_iterable<self_type>; 115 using inner_shape_type = typename iterable_base::inner_shape_type; 116 using shape_type = inner_shape_type; 117 118 using stepper = typename iterable_base::stepper; 119 using const_stepper = typename iterable_base::const_stepper; 120 121 using bool_load_type = xt::bool_load_type<R>; 122 123 static constexpr layout_type static_layout = layout_type::dynamic; 124 static constexpr bool contiguous_layout = false; 125 126 template <class Func> 127 xgenerator(Func&& f, const S& shape) noexcept; 128 129 const inner_shape_type& shape() const noexcept; 130 layout_type layout() const noexcept; 131 bool is_contiguous() const noexcept; 132 using accessible_base::shape; 133 134 template <class... Args> 135 const_reference operator()(Args... args) const; 136 template <class... Args> 137 const_reference unchecked(Args... args) const; 138 139 template <class It> 140 const_reference element(It first, It last) const; 141 142 template <class O> 143 bool broadcast_shape(O& shape, bool reuse_cache = false) const; 144 145 template <class O> 146 bool has_linear_assign(const O& /*strides*/) const noexcept; 147 148 template <class O> 149 const_stepper stepper_begin(const O& shape) const noexcept; 150 template <class O> 151 const_stepper stepper_end(const O& shape, layout_type) const noexcept; 152 153 template <class E, class FE = F, class = std::enable_if_t<has_assign_to<E, FE>::value>> 154 void assign_to(xexpression<E>& e) const noexcept; 155 156 const functor_type& functor() const noexcept; 157 158 template <class OR, class OF> 159 using rebind_t = xgenerator<OF, OR, S>; 160 161 template <class OR, class OF> 162 rebind_t<OR, OF> build_generator(OF&& func) const; 163 164 template <class O = xt::dynamic_shape<typename shape_type::value_type>> 165 auto reshape(O&& shape) const &; 166 167 template <class O = xt::dynamic_shape<typename shape_type::value_type>> 168 auto reshape(O&& shape) &&; 169 170 template <class T> 171 auto reshape(std::initializer_list<T> shape) const &; 172 173 template <class T> 174 auto reshape(std::initializer_list<T> shape) &&; 175 176 private: 177 178 template <class O> 179 decltype(auto) compute_shape(O&& shape, std::false_type /*signed*/) const; 180 181 template <class O> 182 auto compute_shape(O&& shape, std::true_type /*signed*/) const; 183 184 template <class T> 185 auto compute_shape(std::initializer_list<T> shape) const; 186 187 template <std::size_t dim> 188 void adapt_index() const; 189 190 template <std::size_t dim, class I, class... Args> 191 void adapt_index(I& arg, Args&... args) const; 192 193 functor_type m_f; 194 inner_shape_type m_shape; 195 }; 196 197 /***************************** 198 * xgenerator implementation * 199 *****************************/ 200 201 /** 202 * @name Constructor 203 */ 204 //@{ 205 /** 206 * Constructs an xgenerator applying the specified function over the 207 * given shape. 208 * @param f the function to apply 209 * @param shape the shape of the xgenerator 210 */ 211 template <class F, class R, class S> 212 template <class Func> xgenerator(Func && f,const S & shape)213 inline xgenerator<F, R, S>::xgenerator(Func&& f, const S& shape) noexcept 214 : m_f(std::forward<Func>(f)), m_shape(shape) 215 { 216 } 217 //@} 218 219 /** 220 * @name Size and shape 221 */ 222 //@{ 223 /** 224 * Returns the shape of the xgenerator. 225 */ 226 template <class F, class R, class S> shape() const227 inline auto xgenerator<F, R, S>::shape() const noexcept -> const inner_shape_type& 228 { 229 return m_shape; 230 } 231 232 template <class F, class R, class S> layout() const233 inline layout_type xgenerator<F, R, S>::layout() const noexcept 234 { 235 return static_layout; 236 } 237 238 template <class F, class R, class S> is_contiguous() const239 inline bool xgenerator<F, R, S>::is_contiguous() const noexcept 240 { 241 return false; 242 } 243 244 //@} 245 246 /** 247 * @name Data 248 */ 249 /** 250 * Returns the evaluated element at the specified position in the function. 251 * @param args a list of indices specifying the position in the function. Indices 252 * must be unsigned integers, the number of indices should be equal or greater than 253 * the number of dimensions of the function. 254 */ 255 template <class F, class R, class S> 256 template <class... Args> operator ()(Args...args) const257 inline auto xgenerator<F, R, S>::operator()(Args... args) const -> const_reference 258 { 259 XTENSOR_TRY(check_index(shape(), args...)); 260 adapt_index<0>(args...); 261 return m_f(args...); 262 } 263 264 /** 265 * Returns a constant reference to the element at the specified position in the expression. 266 * @param args a list of indices specifying the position in the expression. Indices 267 * must be unsigned integers, the number of indices must be equal to the number of 268 * dimensions of the expression, else the behavior is undefined. 269 * 270 * @warning This method is meant for performance, for expressions with a dynamic 271 * number of dimensions (i.e. not known at compile time). Since it may have 272 * undefined behavior (see parameters), operator() should be prefered whenever 273 * it is possible. 274 * @warning This method is NOT compatible with broadcasting, meaning the following 275 * code has undefined behavior: 276 * \code{.cpp} 277 * xt::xarray<double> a = {{0, 1}, {2, 3}}; 278 * xt::xarray<double> b = {0, 1}; 279 * auto fd = a + b; 280 * double res = fd.uncheked(0, 1); 281 * \endcode 282 */ 283 template <class F, class R, class S> 284 template <class... Args> unchecked(Args...args) const285 inline auto xgenerator<F, R, S>::unchecked(Args... args) const -> const_reference 286 { 287 return m_f(args...); 288 } 289 290 /** 291 * Returns a constant reference to the element at the specified position in the function. 292 * @param first iterator starting the sequence of indices 293 * @param last iterator ending the sequence of indices 294 * The number of indices in the sequence should be equal to or greater 295 * than the number of dimensions of the container. 296 */ 297 template <class F, class R, class S> 298 template <class It> element(It first,It last) const299 inline auto xgenerator<F, R, S>::element(It first, It last) const -> const_reference 300 { 301 using bounded_iterator = xbounded_iterator<It, typename shape_type::const_iterator>; 302 XTENSOR_TRY(check_element_index(shape(), first, last)); 303 return m_f.element(bounded_iterator(first, shape().cbegin()), bounded_iterator(last, shape().cend())); 304 } 305 //@} 306 307 /** 308 * @name Broadcasting 309 */ 310 //@{ 311 /** 312 * Broadcast the shape of the function to the specified parameter. 313 * @param shape the result shape 314 * @param reuse_cache parameter for internal optimization 315 * @return a boolean indicating whether the broadcasting is trivial 316 */ 317 template <class F, class R, class S> 318 template <class O> broadcast_shape(O & shape,bool) const319 inline bool xgenerator<F, R, S>::broadcast_shape(O& shape, bool) const 320 { 321 return xt::broadcast_shape(m_shape, shape); 322 } 323 324 /** 325 * Checks whether the xgenerator can be linearly assigned to an expression 326 * with the specified strides. 327 * @return a boolean indicating whether a linear assign is possible 328 */ 329 template <class F, class R, class S> 330 template <class O> has_linear_assign(const O &) const331 inline bool xgenerator<F, R, S>::has_linear_assign(const O& /*strides*/) const noexcept 332 { 333 return false; 334 } 335 //@} 336 337 template <class F, class R, class S> 338 template <class O> stepper_begin(const O & shape) const339 inline auto xgenerator<F, R, S>::stepper_begin(const O& shape) const noexcept -> const_stepper 340 { 341 size_type offset = shape.size() - this->dimension(); 342 return const_stepper(this, offset); 343 } 344 345 template <class F, class R, class S> 346 template <class O> stepper_end(const O & shape,layout_type) const347 inline auto xgenerator<F, R, S>::stepper_end(const O& shape, layout_type) const noexcept -> const_stepper 348 { 349 size_type offset = shape.size() - this->dimension(); 350 return const_stepper(this, offset, true); 351 } 352 353 template <class F, class R, class S> 354 template <class E, class, class> assign_to(xexpression<E> & e) const355 inline void xgenerator<F, R, S>::assign_to(xexpression<E>& e) const noexcept 356 { 357 e.derived_cast().resize(m_shape); 358 m_f.assign_to(e); 359 } 360 361 template <class F, class R, class S> functor() const362 inline auto xgenerator<F, R, S>::functor() const noexcept -> const functor_type& 363 { 364 return m_f; 365 } 366 367 template <class F, class R, class S> 368 template <class OR, class OF> build_generator(OF && func) const369 inline auto xgenerator<F, R, S>::build_generator(OF&& func) const -> rebind_t<OR, OF> 370 { 371 return rebind_t<OR, OF>(std::move(func), shape_type(m_shape)); 372 } 373 374 /** 375 * Reshapes the generator and keeps old elements. The `shape` argument can have one of its value 376 * equal to `-1`, in this case the value is inferred from the number of elements in the generator 377 * and the remaining values in the `shape`. 378 * \code{.cpp} 379 * auto a = xt::arange<double>(50).reshape({-1, 10}); 380 * //a.shape() is {5, 10} 381 * \endcode 382 * @param shape the new shape (has to have same number of elements as the original genrator) 383 */ 384 template <class F, class R, class S> 385 template <class O> reshape(O && shape) const386 inline auto xgenerator<F, R, S>::reshape(O&& shape) const & 387 { 388 return reshape_view(*this, compute_shape(shape, xtl::is_signed<typename std::decay_t<O>::value_type>())); 389 } 390 391 template <class F, class R, class S> 392 template <class O> reshape(O && shape)393 inline auto xgenerator<F, R, S>::reshape(O&& shape) && 394 { 395 return reshape_view(std::move(*this), compute_shape(shape, xtl::is_signed<typename std::decay_t<O>::value_type>())); 396 } 397 398 template <class F, class R, class S> 399 template <class T> reshape(std::initializer_list<T> shape) const400 inline auto xgenerator<F, R, S>::reshape(std::initializer_list<T> shape) const & 401 { 402 return reshape_view(*this, compute_shape(shape)); 403 } 404 405 template <class F, class R, class S> 406 template <class T> reshape(std::initializer_list<T> shape)407 inline auto xgenerator<F, R, S>::reshape(std::initializer_list<T> shape) && 408 { 409 return reshape_view(std::move(*this), compute_shape(shape)); 410 } 411 412 template <class F, class R, class S> 413 template <class O> compute_shape(O && shape,std::false_type) const414 inline decltype(auto) xgenerator<F, R, S>::compute_shape(O&& shape, std::false_type) const 415 { 416 return xtl::forward_sequence<xt::dynamic_shape<typename shape_type::value_type>, O>(shape); 417 } 418 419 template <class F, class R, class S> 420 template <class O> compute_shape(O && shape,std::true_type) const421 inline auto xgenerator<F, R, S>::compute_shape(O&& shape, std::true_type) const 422 { 423 using vtype = typename shape_type::value_type; 424 xt::dynamic_shape<vtype> sh(shape.size()); 425 using int_type = typename std::decay_t<O>::value_type; 426 int_type accumulator(1); 427 std::size_t neg_idx = 0; 428 std::size_t i = 0; 429 for(std::size_t j = 0; j != shape.size(); ++j, ++i) 430 { 431 auto dim = shape[j]; 432 if(dim < 0) 433 { 434 XTENSOR_ASSERT(dim == -1 && !neg_idx); 435 neg_idx = i; 436 } 437 else 438 { 439 sh[j] = static_cast<vtype>(dim); 440 } 441 accumulator *= dim; 442 } 443 if(accumulator < 0) 444 { 445 sh[neg_idx] = this->size() / static_cast<size_type>(std::make_unsigned_t<int_type>(std::abs(accumulator))); 446 } 447 return sh; 448 } 449 450 template <class F, class R, class S> 451 template <class T> compute_shape(std::initializer_list<T> shape) const452 inline auto xgenerator<F, R, S>::compute_shape(std::initializer_list<T> shape) const 453 { 454 using sh_type = xt::dynamic_shape<T>; 455 sh_type sh = xtl::make_sequence<sh_type>(shape.size()); 456 std::copy(shape.begin(), shape.end(), sh.begin()); 457 return compute_shape(std::move(sh), xtl::is_signed<T>()); 458 } 459 460 template <class F, class R, class S> 461 template <std::size_t dim> adapt_index() const462 inline void xgenerator<F, R, S>::adapt_index() const 463 { 464 } 465 466 template <class F, class R, class S> 467 template <std::size_t dim, class I, class... Args> adapt_index(I & arg,Args &...args) const468 inline void xgenerator<F, R, S>::adapt_index(I& arg, Args&... args) const 469 { 470 using tmp_value_type = typename decltype(m_shape)::value_type; 471 if (sizeof...(Args) + 1 > m_shape.size()) 472 { 473 adapt_index<dim>(args...); 474 } 475 else 476 { 477 if (static_cast<tmp_value_type>(arg) >= m_shape[dim] && m_shape[dim] == 1) 478 { 479 arg = 0; 480 } 481 adapt_index<dim + 1>(args...); 482 } 483 } 484 485 namespace detail 486 { 487 template <class Functor, class I, std::size_t L> make_xgenerator(Functor && f,const I (& shape)[L])488 inline auto make_xgenerator(Functor&& f, const I (&shape)[L]) noexcept 489 { 490 using shape_type = std::array<std::size_t, L>; 491 using type = xgenerator<Functor, typename Functor::value_type, shape_type>; 492 return type(std::forward<Functor>(f), xtl::forward_sequence<shape_type, decltype(shape)>(shape)); 493 } 494 495 template <class Functor, class S> make_xgenerator(Functor && f,S && shape)496 inline auto make_xgenerator(Functor&& f, S&& shape) noexcept 497 { 498 using type = xgenerator<Functor, typename Functor::value_type, std::decay_t<S>>; 499 return type(std::forward<Functor>(f), std::forward<S>(shape)); 500 } 501 } 502 } 503 504 #endif 505