1 /*************************************************************************** 2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht * 3 * Copyright (c) QuantStack * 4 * * 5 * Distributed under the terms of the BSD 3-Clause License. * 6 * * 7 * The full license is in the file LICENSE, distributed with this software. * 8 ****************************************************************************/ 9 10 #ifndef XTENSOR_XSHAPE_HPP 11 #define XTENSOR_XSHAPE_HPP 12 13 #include <algorithm> 14 #include <cassert> 15 #include <cstddef> 16 #include <cstdlib> 17 #include <cstring> 18 #include <initializer_list> 19 #include <iterator> 20 #include <memory> 21 22 #include "xlayout.hpp" 23 #include "xstorage.hpp" 24 #include "xtensor_forward.hpp" 25 26 namespace xt 27 { 28 template <class T> 29 using dynamic_shape = svector<T, 4>; 30 31 template <class T, std::size_t N> 32 using static_shape = std::array<T, N>; 33 34 template <std::size_t... X> 35 class fixed_shape; 36 37 using xindex = dynamic_shape<std::size_t>; 38 39 template <class S1, class S2> 40 bool same_shape(const S1& s1, const S2& s2) noexcept; 41 42 template <class U> 43 struct initializer_dimension; 44 45 template <class R, class T> 46 constexpr R shape(T t); 47 48 template<class R = std::size_t, class T, std::size_t N> 49 xt::static_shape<R, N> shape(const T(&aList)[N]); 50 51 template <class S> 52 struct static_dimension; 53 54 template <layout_type L, class S> 55 struct select_layout; 56 57 template <class... S> 58 struct promote_shape; 59 60 template <class... S> 61 struct promote_strides; 62 63 template <class S> 64 struct index_from_shape; 65 } 66 67 namespace xtl 68 { 69 namespace detail 70 { 71 template <class S> 72 struct sequence_builder; 73 74 template <std::size_t... I> 75 struct sequence_builder<xt::fixed_shape<I...>> 76 { 77 using sequence_type = xt::fixed_shape<I...>; 78 using value_type = typename sequence_type::value_type; 79 makextl::detail::sequence_builder80 inline static sequence_type make(std::size_t /*size*/) 81 { 82 return sequence_type{}; 83 } 84 makextl::detail::sequence_builder85 inline static sequence_type make(std::size_t /*size*/, value_type /*v*/) 86 { 87 return sequence_type{}; 88 } 89 }; 90 } 91 } 92 93 namespace xt 94 { 95 /************** 96 * same_shape * 97 **************/ 98 99 /** 100 * @ingroup same_shape 101 * @brief same_shape 102 * 103 * Check if two objects have the same shape. 104 * @param s1 an array 105 * @param s2 an array 106 * @return bool 107 */ 108 template <class S1, class S2> same_shape(const S1 & s1,const S2 & s2)109 inline bool same_shape(const S1& s1, const S2& s2) noexcept 110 { 111 return s1.size() == s2.size() && std::equal(s1.begin(), s1.end(), s2.begin()); 112 } 113 114 /************* 115 * has_shape * 116 *************/ 117 118 /** 119 * @ingroup has_shape 120 * @brief has_shape 121 * 122 * Check if an object has a certain shape. 123 * @param a an array 124 * @param shape the shape to test 125 * @return bool 126 */ 127 template <class E, class S> has_shape(const E & e,std::initializer_list<S> shape)128 inline bool has_shape(const E& e, std::initializer_list<S> shape) noexcept 129 { 130 return e.shape().size() == shape.size() && std::equal(e.shape().cbegin(), e.shape().cend(), shape.begin()); 131 } 132 133 /** 134 * @ingroup has_shape 135 * @brief has_shape 136 * 137 * Check if an object has a certain shape. 138 * @param a an array 139 * @param shape the shape to test 140 * @return bool 141 */ 142 template <class E, class S, class = typename std::enable_if_t<has_iterator_interface<S>::value>> has_shape(const E & e,const S & shape)143 inline bool has_shape(const E& e, const S& shape) 144 { 145 return e.shape().size() == shape.size() && std::equal(e.shape().cbegin(), e.shape().cend(), shape.begin()); 146 } 147 148 /************************* 149 * initializer_dimension * 150 *************************/ 151 152 namespace detail 153 { 154 template <class U> 155 struct initializer_depth_impl 156 { 157 static constexpr std::size_t value = 0; 158 }; 159 160 template <class T> 161 struct initializer_depth_impl<std::initializer_list<T>> 162 { 163 static constexpr std::size_t value = 1 + initializer_depth_impl<T>::value; 164 }; 165 } 166 167 template <class U> 168 struct initializer_dimension 169 { 170 static constexpr std::size_t value = detail::initializer_depth_impl<U>::value; 171 }; 172 173 /********************* 174 * initializer_shape * 175 *********************/ 176 177 namespace detail 178 { 179 template <std::size_t I> 180 struct initializer_shape_impl 181 { 182 template <class T> valuext::detail::initializer_shape_impl183 static constexpr std::size_t value(T t) 184 { 185 return t.size() == 0 ? 0 : initializer_shape_impl<I - 1>::value(*t.begin()); 186 } 187 }; 188 189 template <> 190 struct initializer_shape_impl<0> 191 { 192 template <class T> valuext::detail::initializer_shape_impl193 static constexpr std::size_t value(T t) 194 { 195 return t.size(); 196 } 197 }; 198 199 template <class R, class U, std::size_t... I> initializer_shape(U t,std::index_sequence<I...>)200 constexpr R initializer_shape(U t, std::index_sequence<I...>) 201 { 202 using size_type = typename R::value_type; 203 return {size_type(initializer_shape_impl<I>::value(t))...}; 204 } 205 } 206 207 template <class R, class T> shape(T t)208 constexpr R shape(T t) 209 { 210 return detail::initializer_shape<R, decltype(t)>(t, std::make_index_sequence<initializer_dimension<decltype(t)>::value>()); 211 } 212 213 /** @brief Generate an xt::static_shape of the given size. */ 214 template<class R, class T, std::size_t N> shape(const T (& list)[N])215 xt::static_shape<R, N> shape(const T(&list)[N]) { 216 xt::static_shape<R, N> shape; 217 std::copy(std::begin(list), std::end(list), std::begin(shape)); 218 return shape; 219 } 220 221 /******************** 222 * static_dimension * 223 ********************/ 224 225 namespace detail 226 { 227 template <class T, class E = void> 228 struct static_dimension_impl 229 { 230 static constexpr std::ptrdiff_t value = -1; 231 }; 232 233 template <class T> 234 struct static_dimension_impl<T, void_t<decltype(std::tuple_size<T>::value)>> 235 { 236 static constexpr std::ptrdiff_t value = static_cast<std::ptrdiff_t>(std::tuple_size<T>::value); 237 }; 238 } 239 240 template <class S> 241 struct static_dimension 242 { 243 static constexpr std::ptrdiff_t value = detail::static_dimension_impl<S>::value; 244 }; 245 246 /** 247 * Compute a layout based on a layout and a shape type. 248 * 249 * The main functionality of this function is that it reduces vectors to 250 * ``layout_type::any`` so that assigning a row major 1D container to another 251 * row_major container becomes free. 252 */ 253 template <layout_type L, class S> 254 struct select_layout 255 { 256 constexpr static std::ptrdiff_t static_dimension = xt::static_dimension<S>::value; 257 constexpr static bool is_any = static_dimension != -1 && static_dimension <= 1 && L != layout_type::dynamic; 258 constexpr static layout_type value = is_any ? layout_type::any : L; 259 }; 260 261 /************************************* 262 * promote_shape and promote_strides * 263 *************************************/ 264 265 namespace detail 266 { 267 template <class T1, class T2> imax(const T1 & a,const T2 & b)268 constexpr std::common_type_t<T1, T2> imax(const T1& a, const T2& b) 269 { 270 return a > b ? a : b; 271 } 272 273 // Variadic meta-function returning the maximal size of std::arrays. 274 template <class... T> 275 struct max_array_size; 276 277 template <> 278 struct max_array_size<> 279 { 280 static constexpr std::size_t value = 0; 281 }; 282 283 template <class T, class... Ts> 284 struct max_array_size<T, Ts...> : std::integral_constant<std::size_t, imax(std::tuple_size<T>::value, max_array_size<Ts...>::value)> 285 { 286 }; 287 288 // Broadcasting for fixed shapes 289 template <std::size_t IDX, std::size_t... X> 290 struct at 291 { 292 constexpr static std::size_t arr[sizeof...(X)] = {X...}; 293 constexpr static std::size_t value = (IDX < sizeof...(X)) ? arr[IDX] : 0; 294 }; 295 296 template <class S1, class S2> 297 struct broadcast_fixed_shape; 298 299 template <class IX, class A, class B> 300 struct broadcast_fixed_shape_impl; 301 302 template <std::size_t IX, class A, class B> 303 struct broadcast_fixed_shape_cmp_impl; 304 305 template <std::size_t JX, std::size_t... I, std::size_t... J> 306 struct broadcast_fixed_shape_cmp_impl<JX, fixed_shape<I...>, fixed_shape<J...>> 307 { 308 //We line the shapes up from the last index 309 //IX may underflow, thus being a very large number 310 static constexpr std::size_t IX = JX - (sizeof...(J) - sizeof...(I)); 311 312 //Out of bounds access gives value 0 313 static constexpr std::size_t I_v = at<IX, I...>::value; 314 static constexpr std::size_t J_v = at<JX, J...>::value; 315 316 // we're statically checking if the broadcast shapes are either one on either of them or equal 317 static_assert(!I_v || I_v == 1 || J_v == 1 || J_v == I_v, "broadcast shapes do not match."); 318 319 static constexpr std::size_t ordinate = (I_v > J_v) ? I_v : J_v; 320 static constexpr bool value = (I_v == J_v); 321 }; 322 323 template <std::size_t... JX, std::size_t... I, std::size_t... J> 324 struct broadcast_fixed_shape_impl<std::index_sequence<JX...>, fixed_shape<I...>, fixed_shape<J...>> 325 { 326 static_assert(sizeof... (J) >= sizeof... (I), "broadcast shapes do not match."); 327 328 using type = xt::fixed_shape<broadcast_fixed_shape_cmp_impl<JX, fixed_shape<I...>, fixed_shape<J...>>::ordinate...>; 329 static constexpr bool value = xtl::conjunction<broadcast_fixed_shape_cmp_impl<JX, fixed_shape<I...>, fixed_shape<J...>>...>::value; 330 }; 331 332 /* broadcast_fixed_shape<fixed_shape<I...>, fixed_shape<J...>> 333 * Just like a call to broadcast_shape(cont S1& input, S2& output), 334 * except that the result shape is alised as type, and the returned 335 * bool is the member value. Asserts on an illegal broadcast, including 336 * the case where pack I is strictly longer than pack J. */ 337 338 template <std::size_t... I, std::size_t... J> 339 struct broadcast_fixed_shape<fixed_shape<I...>, fixed_shape<J...>> 340 : broadcast_fixed_shape_impl<std::make_index_sequence<sizeof...(J)>, fixed_shape<I...>, fixed_shape<J...>> {}; 341 342 // Simple is_array and only_array meta-functions 343 template <class S> 344 struct is_array 345 { 346 static constexpr bool value = false; 347 }; 348 349 template <class T, std::size_t N> 350 struct is_array<std::array<T, N>> 351 { 352 static constexpr bool value = true; 353 }; 354 355 template <class S> 356 struct is_fixed : std::false_type 357 { 358 }; 359 360 template <std::size_t... N> 361 struct is_fixed<fixed_shape<N...>> 362 : std::true_type 363 { 364 }; 365 366 template <class S> 367 struct is_scalar_shape 368 { 369 static constexpr bool value = false; 370 }; 371 372 template <class T> 373 struct is_scalar_shape<std::array<T, 0>> 374 { 375 static constexpr bool value = true; 376 }; 377 378 template <class... S> 379 using only_array = xtl::conjunction<xtl::disjunction<is_array<S>, is_fixed<S>>...>; 380 381 // test that at least one argument is a fixed shape. If yes, then either argument has to be fixed or scalar 382 template <class... S> 383 using only_fixed = std::integral_constant<bool, xtl::disjunction<is_fixed<S>...>::value && 384 xtl::conjunction<xtl::disjunction<is_fixed<S>, is_scalar_shape<S>>...>::value>; 385 386 template <class... S> 387 using all_fixed = xtl::conjunction<is_fixed<S>...>; 388 389 // The promote_index meta-function returns std::vector<promoted_value_type> in the 390 // general case and an array of the promoted value type and maximal size if all 391 // arguments are of type std::array 392 393 template <class... S> 394 struct promote_array 395 { 396 using type = std::array<typename std::common_type<typename S::value_type...>::type, max_array_size<S...>::value>; 397 }; 398 399 template <> 400 struct promote_array<> 401 { 402 using type = std::array<std::size_t, 0>; 403 }; 404 405 template <class S> 406 struct filter_scalar 407 { 408 using type = S; 409 }; 410 411 template <class T> 412 struct filter_scalar<std::array<T, 0>> 413 { 414 using type = fixed_shape<1>; 415 }; 416 417 template <class S> 418 using filter_scalar_t = typename filter_scalar<S>::type; 419 420 template <class... S> 421 struct promote_fixed : promote_fixed<filter_scalar_t<S>...> {}; 422 423 template <std::size_t... I> 424 struct promote_fixed<fixed_shape<I...>> 425 { 426 using type = fixed_shape<I...>; 427 static constexpr bool value = true; 428 }; 429 430 template <std::size_t... I, std::size_t... J, class... S> 431 struct promote_fixed<fixed_shape<I...>, fixed_shape<J...>, S...> 432 { 433 private: 434 435 using intermediate = std::conditional_t< (sizeof... (I) > sizeof... (J)), 436 broadcast_fixed_shape<fixed_shape<J...>, fixed_shape<I...>>, 437 broadcast_fixed_shape<fixed_shape<I...>, fixed_shape<J...>>>; 438 using result = promote_fixed<typename intermediate::type, S...>; 439 440 public: 441 442 using type = typename result::type; 443 static constexpr bool value = xtl::conjunction<intermediate, result>::value; 444 }; 445 446 template <bool all_index, bool all_array, class... S> 447 struct select_promote_index; 448 449 template <class... S> 450 struct select_promote_index<true, true, S...> : promote_fixed<S...> {}; 451 452 template <> 453 struct select_promote_index<true, true> 454 { 455 // todo correct? used in xvectorize 456 using type = dynamic_shape<std::size_t>; 457 }; 458 459 template <class... S> 460 struct select_promote_index<false, true, S...> : promote_array<S...> {}; 461 462 template <class... S> 463 struct select_promote_index<false, false, S...> 464 { 465 using type = dynamic_shape<typename std::common_type<typename S::value_type...>::type>; 466 }; 467 468 template <class... S> 469 struct promote_index : select_promote_index<only_fixed<S...>::value, only_array<S...>::value, S...> {}; 470 471 template <class T> 472 struct index_from_shape_impl 473 { 474 using type = T; 475 }; 476 477 template <std::size_t... N> 478 struct index_from_shape_impl<fixed_shape<N...>> 479 { 480 using type = std::array<std::size_t, sizeof...(N)>; 481 }; 482 } 483 484 template <class... S> 485 struct promote_shape 486 { 487 using type = typename detail::promote_index<S...>::type; 488 }; 489 490 template <class... S> 491 using promote_shape_t = typename promote_shape<S...>::type; 492 493 template <class... S> 494 struct promote_strides 495 { 496 using type = typename detail::promote_index<S...>::type; 497 }; 498 499 template <class... S> 500 using promote_strides_t = typename promote_strides<S...>::type; 501 502 template <class S> 503 struct index_from_shape 504 { 505 using type = typename detail::index_from_shape_impl<S>::type; 506 }; 507 508 template <class S> 509 using index_from_shape_t = typename index_from_shape<S>::type; 510 511 /********************** 512 * filter_fixed_shape * 513 **********************/ 514 515 namespace detail 516 { 517 template <class S> 518 struct filter_fixed_shape_impl 519 { 520 using type = S; 521 }; 522 523 template <std::size_t... N> 524 struct filter_fixed_shape_impl<fixed_shape<N...>> 525 { 526 using type = std::array<std::size_t, sizeof...(N)>; 527 }; 528 } 529 530 template <class S> 531 struct filter_fixed_shape : detail::filter_fixed_shape_impl<S> 532 { 533 }; 534 535 template <class S> 536 using filter_fixed_shape_t = typename filter_fixed_shape<S>::type; 537 } 538 539 #endif 540