1 /***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht          *
3 * Copyright (c) QuantStack                                                 *
4 *                                                                          *
5 * Distributed under the terms of the BSD 3-Clause License.                 *
6 *                                                                          *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9 
10 #ifndef XTENSOR_XSHAPE_HPP
11 #define XTENSOR_XSHAPE_HPP
12 
13 #include <algorithm>
14 #include <cassert>
15 #include <cstddef>
16 #include <cstdlib>
17 #include <cstring>
18 #include <initializer_list>
19 #include <iterator>
20 #include <memory>
21 
22 #include "xlayout.hpp"
23 #include "xstorage.hpp"
24 #include "xtensor_forward.hpp"
25 
26 namespace xt
27 {
28     template <class T>
29     using dynamic_shape = svector<T, 4>;
30 
31     template <class T, std::size_t N>
32     using static_shape = std::array<T, N>;
33 
34     template <std::size_t... X>
35     class fixed_shape;
36 
37     using xindex = dynamic_shape<std::size_t>;
38 
39     template <class S1, class S2>
40     bool same_shape(const S1& s1, const S2& s2) noexcept;
41 
42     template <class U>
43     struct initializer_dimension;
44 
45     template <class R, class T>
46     constexpr R shape(T t);
47 
48     template<class R = std::size_t, class T, std::size_t N>
49     xt::static_shape<R, N> shape(const T(&aList)[N]);
50 
51     template <class S>
52     struct static_dimension;
53 
54     template <layout_type L, class S>
55     struct select_layout;
56 
57     template <class... S>
58     struct promote_shape;
59 
60     template <class... S>
61     struct promote_strides;
62 
63     template <class S>
64     struct index_from_shape;
65 }
66 
67 namespace xtl
68 {
69     namespace detail
70     {
71         template <class S>
72         struct sequence_builder;
73 
74         template <std::size_t... I>
75         struct sequence_builder<xt::fixed_shape<I...>>
76         {
77             using sequence_type = xt::fixed_shape<I...>;
78             using value_type = typename sequence_type::value_type;
79 
makextl::detail::sequence_builder80             inline static sequence_type make(std::size_t /*size*/)
81             {
82                 return sequence_type{};
83             }
84 
makextl::detail::sequence_builder85             inline static sequence_type make(std::size_t /*size*/, value_type /*v*/)
86             {
87                 return sequence_type{};
88             }
89         };
90     }
91 }
92 
93 namespace xt
94 {
95     /**************
96      * same_shape *
97      **************/
98 
99     /**
100     * @ingroup same_shape
101     * @brief same_shape
102     *
103     * Check if two objects have the same shape.
104     * @param s1 an array
105     * @param s2 an array
106     * @return bool
107     */
108     template <class S1, class S2>
same_shape(const S1 & s1,const S2 & s2)109     inline bool same_shape(const S1& s1, const S2& s2) noexcept
110     {
111         return s1.size() == s2.size() && std::equal(s1.begin(), s1.end(), s2.begin());
112     }
113 
114     /*************
115      * has_shape *
116      *************/
117 
118     /**
119     * @ingroup has_shape
120     * @brief has_shape
121     *
122     * Check if an object has a certain shape.
123     * @param a an array
124     * @param shape the shape to test
125     * @return bool
126     */
127     template <class E, class S>
has_shape(const E & e,std::initializer_list<S> shape)128     inline bool has_shape(const E& e, std::initializer_list<S> shape) noexcept
129     {
130         return e.shape().size() == shape.size() && std::equal(e.shape().cbegin(), e.shape().cend(), shape.begin());
131     }
132 
133     /**
134     * @ingroup has_shape
135     * @brief has_shape
136     *
137     * Check if an object has a certain shape.
138     * @param a an array
139     * @param shape the shape to test
140     * @return bool
141     */
142     template <class E, class S, class = typename std::enable_if_t<has_iterator_interface<S>::value>>
has_shape(const E & e,const S & shape)143     inline bool has_shape(const E& e, const S& shape)
144     {
145         return e.shape().size() == shape.size() && std::equal(e.shape().cbegin(), e.shape().cend(), shape.begin());
146     }
147 
148     /*************************
149      * initializer_dimension *
150      *************************/
151 
152     namespace detail
153     {
154         template <class U>
155         struct initializer_depth_impl
156         {
157             static constexpr std::size_t value = 0;
158         };
159 
160         template <class T>
161         struct initializer_depth_impl<std::initializer_list<T>>
162         {
163             static constexpr std::size_t value = 1 + initializer_depth_impl<T>::value;
164         };
165     }
166 
167     template <class U>
168     struct initializer_dimension
169     {
170         static constexpr std::size_t value = detail::initializer_depth_impl<U>::value;
171     };
172 
173     /*********************
174      * initializer_shape *
175      *********************/
176 
177     namespace detail
178     {
179         template <std::size_t I>
180         struct initializer_shape_impl
181         {
182             template <class T>
valuext::detail::initializer_shape_impl183             static constexpr std::size_t value(T t)
184             {
185                 return t.size() == 0 ? 0 : initializer_shape_impl<I - 1>::value(*t.begin());
186             }
187         };
188 
189         template <>
190         struct initializer_shape_impl<0>
191         {
192             template <class T>
valuext::detail::initializer_shape_impl193             static constexpr std::size_t value(T t)
194             {
195                 return t.size();
196             }
197         };
198 
199         template <class R, class U, std::size_t... I>
initializer_shape(U t,std::index_sequence<I...>)200         constexpr R initializer_shape(U t, std::index_sequence<I...>)
201         {
202             using size_type = typename R::value_type;
203             return {size_type(initializer_shape_impl<I>::value(t))...};
204         }
205     }
206 
207     template <class R, class T>
shape(T t)208     constexpr R shape(T t)
209     {
210         return detail::initializer_shape<R, decltype(t)>(t, std::make_index_sequence<initializer_dimension<decltype(t)>::value>());
211     }
212 
213     /** @brief Generate an xt::static_shape of the given size. */
214     template<class R, class T, std::size_t N>
shape(const T (& list)[N])215     xt::static_shape<R, N> shape(const T(&list)[N]) {
216         xt::static_shape<R, N> shape;
217         std::copy(std::begin(list), std::end(list), std::begin(shape));
218         return shape;
219     }
220 
221     /********************
222      * static_dimension *
223      ********************/
224 
225     namespace detail
226     {
227         template <class T, class E = void>
228         struct static_dimension_impl
229         {
230             static constexpr std::ptrdiff_t value = -1;
231         };
232 
233         template <class T>
234         struct static_dimension_impl<T, void_t<decltype(std::tuple_size<T>::value)>>
235         {
236             static constexpr std::ptrdiff_t value = static_cast<std::ptrdiff_t>(std::tuple_size<T>::value);
237         };
238     }
239 
240     template <class S>
241     struct static_dimension
242     {
243         static constexpr std::ptrdiff_t value = detail::static_dimension_impl<S>::value;
244     };
245 
246     /**
247      * Compute a layout based on a layout and a shape type.
248      *
249      * The main functionality of this function is that it reduces vectors to
250      * ``layout_type::any`` so that assigning a row major 1D container to another
251      * row_major container becomes free.
252      */
253     template <layout_type L, class S>
254     struct select_layout
255     {
256         constexpr static std::ptrdiff_t static_dimension = xt::static_dimension<S>::value;
257         constexpr static bool is_any = static_dimension != -1 && static_dimension <= 1 && L != layout_type::dynamic;
258         constexpr static layout_type value = is_any ? layout_type::any : L;
259     };
260 
261     /*************************************
262      * promote_shape and promote_strides *
263      *************************************/
264 
265     namespace detail
266     {
267         template <class T1, class T2>
imax(const T1 & a,const T2 & b)268         constexpr std::common_type_t<T1, T2> imax(const T1& a, const T2& b)
269         {
270             return a > b ? a : b;
271         }
272 
273         // Variadic meta-function returning the maximal size of std::arrays.
274         template <class... T>
275         struct max_array_size;
276 
277         template <>
278         struct max_array_size<>
279         {
280             static constexpr std::size_t value = 0;
281         };
282 
283         template <class T, class... Ts>
284         struct max_array_size<T, Ts...> : std::integral_constant<std::size_t, imax(std::tuple_size<T>::value, max_array_size<Ts...>::value)>
285         {
286         };
287 
288         // Broadcasting for fixed shapes
289         template <std::size_t IDX, std::size_t... X>
290         struct at
291         {
292             constexpr static std::size_t arr[sizeof...(X)] = {X...};
293             constexpr static std::size_t value = (IDX < sizeof...(X)) ? arr[IDX] : 0;
294         };
295 
296         template <class S1, class S2>
297         struct broadcast_fixed_shape;
298 
299         template <class IX, class A, class B>
300         struct broadcast_fixed_shape_impl;
301 
302         template <std::size_t IX, class A, class B>
303         struct broadcast_fixed_shape_cmp_impl;
304 
305         template <std::size_t JX, std::size_t... I, std::size_t... J>
306         struct broadcast_fixed_shape_cmp_impl<JX, fixed_shape<I...>, fixed_shape<J...>>
307         {
308             //We line the shapes up from the last index
309             //IX may underflow, thus being a very large number
310             static constexpr std::size_t IX = JX - (sizeof...(J) - sizeof...(I));
311 
312             //Out of bounds access gives value 0
313             static constexpr std::size_t I_v = at<IX, I...>::value;
314             static constexpr std::size_t J_v = at<JX, J...>::value;
315 
316             // we're statically checking if the broadcast shapes are either one on either of them or equal
317             static_assert(!I_v ||  I_v == 1 || J_v == 1 || J_v == I_v, "broadcast shapes do not match.");
318 
319             static constexpr std::size_t ordinate = (I_v > J_v) ? I_v : J_v;
320             static constexpr bool value = (I_v == J_v);
321         };
322 
323         template <std::size_t... JX, std::size_t... I, std::size_t... J>
324         struct broadcast_fixed_shape_impl<std::index_sequence<JX...>, fixed_shape<I...>, fixed_shape<J...>>
325         {
326             static_assert(sizeof... (J) >= sizeof... (I), "broadcast shapes do not match.");
327 
328             using type = xt::fixed_shape<broadcast_fixed_shape_cmp_impl<JX, fixed_shape<I...>, fixed_shape<J...>>::ordinate...>;
329             static constexpr bool value = xtl::conjunction<broadcast_fixed_shape_cmp_impl<JX, fixed_shape<I...>, fixed_shape<J...>>...>::value;
330         };
331 
332         /* broadcast_fixed_shape<fixed_shape<I...>, fixed_shape<J...>>
333          * Just like a call to broadcast_shape(cont S1& input, S2& output),
334          * except that the result shape is alised as type, and the returned
335          * bool is the member value. Asserts on an illegal broadcast, including
336          * the case where pack I is strictly longer than pack J. */
337 
338         template <std::size_t... I, std::size_t... J>
339         struct broadcast_fixed_shape<fixed_shape<I...>, fixed_shape<J...>>
340             : broadcast_fixed_shape_impl<std::make_index_sequence<sizeof...(J)>, fixed_shape<I...>, fixed_shape<J...>> {};
341 
342         // Simple is_array and only_array meta-functions
343         template <class S>
344         struct is_array
345         {
346             static constexpr bool value = false;
347         };
348 
349         template <class T, std::size_t N>
350         struct is_array<std::array<T, N>>
351         {
352             static constexpr bool value = true;
353         };
354 
355         template <class S>
356         struct is_fixed : std::false_type
357         {
358         };
359 
360         template <std::size_t... N>
361         struct is_fixed<fixed_shape<N...>>
362             : std::true_type
363         {
364         };
365 
366         template <class S>
367         struct is_scalar_shape
368         {
369             static constexpr bool value = false;
370         };
371 
372         template <class T>
373         struct is_scalar_shape<std::array<T, 0>>
374         {
375             static constexpr bool value = true;
376         };
377 
378         template <class... S>
379         using only_array = xtl::conjunction<xtl::disjunction<is_array<S>, is_fixed<S>>...>;
380 
381         // test that at least one argument is a fixed shape. If yes, then either argument has to be fixed or scalar
382         template <class... S>
383         using only_fixed = std::integral_constant<bool, xtl::disjunction<is_fixed<S>...>::value &&
384                                                         xtl::conjunction<xtl::disjunction<is_fixed<S>, is_scalar_shape<S>>...>::value>;
385 
386         template <class... S>
387         using all_fixed = xtl::conjunction<is_fixed<S>...>;
388 
389         // The promote_index meta-function returns std::vector<promoted_value_type> in the
390         // general case and an array of the promoted value type and maximal size if all
391         // arguments are of type std::array
392 
393         template <class... S>
394         struct promote_array
395         {
396             using type = std::array<typename std::common_type<typename S::value_type...>::type, max_array_size<S...>::value>;
397         };
398 
399         template <>
400         struct promote_array<>
401         {
402             using type = std::array<std::size_t, 0>;
403         };
404 
405         template <class S>
406         struct filter_scalar
407         {
408             using type = S;
409         };
410 
411         template <class T>
412         struct filter_scalar<std::array<T, 0>>
413         {
414             using type = fixed_shape<1>;
415         };
416 
417         template <class S>
418         using filter_scalar_t = typename filter_scalar<S>::type;
419 
420         template <class... S>
421         struct promote_fixed : promote_fixed<filter_scalar_t<S>...> {};
422 
423         template <std::size_t... I>
424         struct promote_fixed<fixed_shape<I...>>
425         {
426             using type = fixed_shape<I...>;
427             static constexpr bool value = true;
428         };
429 
430         template <std::size_t... I, std::size_t... J, class... S>
431         struct promote_fixed<fixed_shape<I...>, fixed_shape<J...>, S...>
432         {
433         private:
434 
435             using intermediate = std::conditional_t< (sizeof... (I) > sizeof... (J)),
436                 broadcast_fixed_shape<fixed_shape<J...>, fixed_shape<I...>>,
437                 broadcast_fixed_shape<fixed_shape<I...>, fixed_shape<J...>>>;
438             using result = promote_fixed<typename intermediate::type, S...>;
439 
440         public:
441 
442             using type = typename result::type;
443             static constexpr bool value = xtl::conjunction<intermediate, result>::value;
444         };
445 
446         template <bool all_index, bool all_array, class... S>
447         struct select_promote_index;
448 
449         template <class... S>
450         struct select_promote_index<true, true, S...> : promote_fixed<S...> {};
451 
452         template <>
453         struct select_promote_index<true, true>
454         {
455             // todo correct? used in xvectorize
456             using type = dynamic_shape<std::size_t>;
457         };
458 
459         template <class... S>
460         struct select_promote_index<false, true, S...> : promote_array<S...> {};
461 
462         template <class... S>
463         struct select_promote_index<false, false, S...>
464         {
465             using type = dynamic_shape<typename std::common_type<typename S::value_type...>::type>;
466         };
467 
468         template <class... S>
469         struct promote_index : select_promote_index<only_fixed<S...>::value, only_array<S...>::value, S...> {};
470 
471         template <class T>
472         struct index_from_shape_impl
473         {
474             using type = T;
475         };
476 
477         template <std::size_t... N>
478         struct index_from_shape_impl<fixed_shape<N...>>
479         {
480             using type = std::array<std::size_t, sizeof...(N)>;
481         };
482     }
483 
484     template <class... S>
485     struct promote_shape
486     {
487         using type = typename detail::promote_index<S...>::type;
488     };
489 
490     template <class... S>
491     using promote_shape_t = typename promote_shape<S...>::type;
492 
493     template <class... S>
494     struct promote_strides
495     {
496         using type = typename detail::promote_index<S...>::type;
497     };
498 
499     template <class... S>
500     using promote_strides_t = typename promote_strides<S...>::type;
501 
502     template <class S>
503     struct index_from_shape
504     {
505         using type = typename detail::index_from_shape_impl<S>::type;
506     };
507 
508     template <class S>
509     using index_from_shape_t = typename index_from_shape<S>::type;
510 
511     /**********************
512      * filter_fixed_shape *
513      **********************/
514 
515     namespace detail
516     {
517         template <class S>
518         struct filter_fixed_shape_impl
519         {
520             using type = S;
521         };
522 
523         template <std::size_t... N>
524         struct filter_fixed_shape_impl<fixed_shape<N...>>
525         {
526             using type = std::array<std::size_t, sizeof...(N)>;
527         };
528     }
529 
530     template <class S>
531     struct filter_fixed_shape : detail::filter_fixed_shape_impl<S>
532     {
533     };
534 
535     template <class S>
536     using filter_fixed_shape_t = typename filter_fixed_shape<S>::type;
537 }
538 
539 #endif
540