1 /***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht          *
3 * Copyright (c) QuantStack                                                 *
4 *                                                                          *
5 * Distributed under the terms of the BSD 3-Clause License.                 *
6 *                                                                          *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9 
10 #ifndef XTENSOR_ZARRAY_WRAPPER_HPP
11 #define XTENSOR_ZARRAY_WRAPPER_HPP
12 
13 #include "zarray_impl.hpp"
14 
15 namespace xt
16 {
17     /******************
18      * zarray_wrapper *
19      ******************/
20 
21     template <class CTE>
22     class zarray_wrapper : public ztyped_array<typename std::decay_t<CTE>::value_type>
23     {
24     public:
25 
26         using self_type = zarray_wrapper;
27         using value_type = typename std::decay_t<CTE>::value_type;
28         using base_type = ztyped_array<value_type>;
29         using shape_type = typename base_type::shape_type;
30         using slice_vector = typename base_type::slice_vector;
31 
32         template <class E>
33         zarray_wrapper(E&& e);
34 
35         virtual ~zarray_wrapper() = default;
36 
37         bool is_array() const override;
38         bool is_chunked() const override;
39 
40         xarray<value_type>& get_array() override;
41         const xarray<value_type>& get_array() const override;
42         xarray<value_type> get_chunk(const slice_vector& slices) const override;
43 
44         self_type* clone() const override;
45         std::ostream& print(std::ostream& out) const override;
46 
47         zarray_impl* strided_view(slice_vector& slices) override;
48 
49         const nlohmann::json& get_metadata() const override;
50         void set_metadata(const nlohmann::json& metadata) override;
51         std::size_t dimension() const override;
52         const shape_type& shape() const override;
53         void reshape(const shape_type&) override;
54         void reshape(shape_type&&) override;
55         void resize(const shape_type&) override;
56         void resize(shape_type&&) override;
57         bool broadcast_shape(shape_type& shape, bool reuse_cache = 0) const override;
58 
59     private:
60 
61         zarray_wrapper(const zarray_wrapper&) = default;
62 
63         CTE m_array;
64         nlohmann::json m_metadata;
65     };
66 
67     /*********************************
68      * zarray_wrapper implementation *
69      *********************************/
70 
71     namespace detail
72     {
73         template <class T>
74         struct zarray_wrapper_helper
75         {
get_arrayxt::detail::zarray_wrapper_helper76             static inline xarray<T>& get_array(xarray<T>& ar)
77             {
78                 return ar;
79             }
80 
get_arrayxt::detail::zarray_wrapper_helper81             static inline xarray<T>& get_array(const xarray<T>&)
82             {
83                 throw std::runtime_error("Cannot return non const array from const array");
84             }
85 
86             template <class S>
reshapext::detail::zarray_wrapper_helper87             static inline void reshape(xarray<T>& ar, S&& shape)
88             {
89                 ar.reshape(std::forward<S>(shape));
90             }
91 
92             template <class S>
reshapext::detail::zarray_wrapper_helper93             static inline void reshape(const xarray<T>&, S&&)
94             {
95                 throw std::runtime_error("Cannot reshape const array");
96             }
97 
98             template <class S>
resizext::detail::zarray_wrapper_helper99             static inline void resize(xarray<T>& ar, S&& shape)
100             {
101                 ar.resize(std::forward<S>(shape));
102             }
103 
104             template <class S>
resizext::detail::zarray_wrapper_helper105             static inline void resize(const xarray<T>&, S&&)
106             {
107                 throw std::runtime_error("Cannot resize const array");
108             }
109         };
110     }
111 
112     template <class CTE>
113     template <class E>
zarray_wrapper(E && e)114     inline zarray_wrapper<CTE>::zarray_wrapper(E&& e)
115         : base_type()
116         , m_array(std::forward<E>(e))
117     {
118         detail::set_data_type<value_type>(m_metadata);
119     }
120 
121     template <class CTE>
is_array() const122     bool zarray_wrapper<CTE>::is_array() const
123     {
124         return true;
125     }
126 
127     template <class CTE>
is_chunked() const128     bool zarray_wrapper<CTE>::is_chunked() const
129     {
130         return false;
131     }
132 
133     template <class CTE>
get_array()134     auto zarray_wrapper<CTE>::get_array() -> xarray<value_type>&
135     {
136         return detail::zarray_wrapper_helper<value_type>::get_array(m_array);
137     }
138 
139     template <class CTE>
get_array() const140     auto zarray_wrapper<CTE>::get_array() const -> const xarray<value_type>&
141     {
142         return m_array;
143     }
144 
145     template <class CTE>
get_chunk(const slice_vector & slices) const146     auto zarray_wrapper<CTE>::get_chunk(const slice_vector& slices) const -> xarray<value_type>
147     {
148         return xt::strided_view(m_array, slices);
149     }
150 
151     template <class CTE>
clone() const152     auto zarray_wrapper<CTE>::clone() const -> self_type*
153     {
154         return new self_type(*this);
155     }
156 
157     template <class CTE>
print(std::ostream & out) const158     std::ostream& zarray_wrapper<CTE>::print(std::ostream& out) const
159     {
160         return out << m_array;
161     }
162 
163     template <class CTE>
strided_view(slice_vector & slices)164     zarray_impl* zarray_wrapper<CTE>::strided_view(slice_vector& slices)
165     {
166         auto e = xt::strided_view(m_array, slices);
167         return detail::build_zarray(std::move(e));
168     }
169 
170     template <class CTE>
get_metadata() const171     auto zarray_wrapper<CTE>::get_metadata() const -> const nlohmann::json&
172     {
173         return m_metadata;
174     }
175 
176     template <class CTE>
set_metadata(const nlohmann::json & metadata)177     void zarray_wrapper<CTE>::set_metadata(const nlohmann::json& metadata)
178     {
179         m_metadata = metadata;
180     }
181 
182     template <class CTE>
dimension() const183     std::size_t zarray_wrapper<CTE>::dimension() const
184     {
185         return m_array.dimension();
186     }
187 
188     template <class CTE>
shape() const189     auto zarray_wrapper<CTE>::shape() const -> const shape_type&
190     {
191         return m_array.shape();
192     }
193 
194     template <class CTE>
reshape(const shape_type & shape)195     void zarray_wrapper<CTE>::reshape(const shape_type& shape)
196     {
197         detail::zarray_wrapper_helper<value_type>::reshape(m_array, shape);
198     }
199 
200     template <class CTE>
reshape(shape_type && shape)201     void zarray_wrapper<CTE>::reshape(shape_type&& shape)
202     {
203         detail::zarray_wrapper_helper<value_type>::reshape(m_array, std::move(shape));
204     }
205 
206     template <class CTE>
resize(const shape_type & shape)207     void zarray_wrapper<CTE>::resize(const shape_type& shape)
208     {
209         detail::zarray_wrapper_helper<value_type>::resize(m_array, shape);
210     }
211 
212     template <class CTE>
resize(shape_type && shape)213     void zarray_wrapper<CTE>::resize(shape_type&& shape)
214     {
215         detail::zarray_wrapper_helper<value_type>::resize(m_array, std::move(shape));
216     }
217 
218     template <class CTE>
broadcast_shape(shape_type & shape,bool reuse_cache) const219     bool zarray_wrapper<CTE>::broadcast_shape(shape_type& shape, bool reuse_cache) const
220     {
221         return m_array.broadcast_shape(shape, reuse_cache);
222     }
223 }
224 
225 #endif
226 
227