1 /*************************************************************************** 2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht * 3 * Copyright (c) QuantStack * 4 * * 5 * Distributed under the terms of the BSD 3-Clause License. * 6 * * 7 * The full license is in the file LICENSE, distributed with this software. * 8 ****************************************************************************/ 9 10 #ifndef XTENSOR_ZARRAY_WRAPPER_HPP 11 #define XTENSOR_ZARRAY_WRAPPER_HPP 12 13 #include "zarray_impl.hpp" 14 15 namespace xt 16 { 17 /****************** 18 * zarray_wrapper * 19 ******************/ 20 21 template <class CTE> 22 class zarray_wrapper : public ztyped_array<typename std::decay_t<CTE>::value_type> 23 { 24 public: 25 26 using self_type = zarray_wrapper; 27 using value_type = typename std::decay_t<CTE>::value_type; 28 using base_type = ztyped_array<value_type>; 29 using shape_type = typename base_type::shape_type; 30 using slice_vector = typename base_type::slice_vector; 31 32 template <class E> 33 zarray_wrapper(E&& e); 34 35 virtual ~zarray_wrapper() = default; 36 37 bool is_array() const override; 38 bool is_chunked() const override; 39 40 xarray<value_type>& get_array() override; 41 const xarray<value_type>& get_array() const override; 42 xarray<value_type> get_chunk(const slice_vector& slices) const override; 43 44 self_type* clone() const override; 45 std::ostream& print(std::ostream& out) const override; 46 47 zarray_impl* strided_view(slice_vector& slices) override; 48 49 const nlohmann::json& get_metadata() const override; 50 void set_metadata(const nlohmann::json& metadata) override; 51 std::size_t dimension() const override; 52 const shape_type& shape() const override; 53 void reshape(const shape_type&) override; 54 void reshape(shape_type&&) override; 55 void resize(const shape_type&) override; 56 void resize(shape_type&&) override; 57 bool broadcast_shape(shape_type& shape, bool reuse_cache = 0) const override; 58 59 private: 60 61 zarray_wrapper(const zarray_wrapper&) = default; 62 63 CTE m_array; 64 nlohmann::json m_metadata; 65 }; 66 67 /********************************* 68 * zarray_wrapper implementation * 69 *********************************/ 70 71 namespace detail 72 { 73 template <class T> 74 struct zarray_wrapper_helper 75 { get_arrayxt::detail::zarray_wrapper_helper76 static inline xarray<T>& get_array(xarray<T>& ar) 77 { 78 return ar; 79 } 80 get_arrayxt::detail::zarray_wrapper_helper81 static inline xarray<T>& get_array(const xarray<T>&) 82 { 83 throw std::runtime_error("Cannot return non const array from const array"); 84 } 85 86 template <class S> reshapext::detail::zarray_wrapper_helper87 static inline void reshape(xarray<T>& ar, S&& shape) 88 { 89 ar.reshape(std::forward<S>(shape)); 90 } 91 92 template <class S> reshapext::detail::zarray_wrapper_helper93 static inline void reshape(const xarray<T>&, S&&) 94 { 95 throw std::runtime_error("Cannot reshape const array"); 96 } 97 98 template <class S> resizext::detail::zarray_wrapper_helper99 static inline void resize(xarray<T>& ar, S&& shape) 100 { 101 ar.resize(std::forward<S>(shape)); 102 } 103 104 template <class S> resizext::detail::zarray_wrapper_helper105 static inline void resize(const xarray<T>&, S&&) 106 { 107 throw std::runtime_error("Cannot resize const array"); 108 } 109 }; 110 } 111 112 template <class CTE> 113 template <class E> zarray_wrapper(E && e)114 inline zarray_wrapper<CTE>::zarray_wrapper(E&& e) 115 : base_type() 116 , m_array(std::forward<E>(e)) 117 { 118 detail::set_data_type<value_type>(m_metadata); 119 } 120 121 template <class CTE> is_array() const122 bool zarray_wrapper<CTE>::is_array() const 123 { 124 return true; 125 } 126 127 template <class CTE> is_chunked() const128 bool zarray_wrapper<CTE>::is_chunked() const 129 { 130 return false; 131 } 132 133 template <class CTE> get_array()134 auto zarray_wrapper<CTE>::get_array() -> xarray<value_type>& 135 { 136 return detail::zarray_wrapper_helper<value_type>::get_array(m_array); 137 } 138 139 template <class CTE> get_array() const140 auto zarray_wrapper<CTE>::get_array() const -> const xarray<value_type>& 141 { 142 return m_array; 143 } 144 145 template <class CTE> get_chunk(const slice_vector & slices) const146 auto zarray_wrapper<CTE>::get_chunk(const slice_vector& slices) const -> xarray<value_type> 147 { 148 return xt::strided_view(m_array, slices); 149 } 150 151 template <class CTE> clone() const152 auto zarray_wrapper<CTE>::clone() const -> self_type* 153 { 154 return new self_type(*this); 155 } 156 157 template <class CTE> print(std::ostream & out) const158 std::ostream& zarray_wrapper<CTE>::print(std::ostream& out) const 159 { 160 return out << m_array; 161 } 162 163 template <class CTE> strided_view(slice_vector & slices)164 zarray_impl* zarray_wrapper<CTE>::strided_view(slice_vector& slices) 165 { 166 auto e = xt::strided_view(m_array, slices); 167 return detail::build_zarray(std::move(e)); 168 } 169 170 template <class CTE> get_metadata() const171 auto zarray_wrapper<CTE>::get_metadata() const -> const nlohmann::json& 172 { 173 return m_metadata; 174 } 175 176 template <class CTE> set_metadata(const nlohmann::json & metadata)177 void zarray_wrapper<CTE>::set_metadata(const nlohmann::json& metadata) 178 { 179 m_metadata = metadata; 180 } 181 182 template <class CTE> dimension() const183 std::size_t zarray_wrapper<CTE>::dimension() const 184 { 185 return m_array.dimension(); 186 } 187 188 template <class CTE> shape() const189 auto zarray_wrapper<CTE>::shape() const -> const shape_type& 190 { 191 return m_array.shape(); 192 } 193 194 template <class CTE> reshape(const shape_type & shape)195 void zarray_wrapper<CTE>::reshape(const shape_type& shape) 196 { 197 detail::zarray_wrapper_helper<value_type>::reshape(m_array, shape); 198 } 199 200 template <class CTE> reshape(shape_type && shape)201 void zarray_wrapper<CTE>::reshape(shape_type&& shape) 202 { 203 detail::zarray_wrapper_helper<value_type>::reshape(m_array, std::move(shape)); 204 } 205 206 template <class CTE> resize(const shape_type & shape)207 void zarray_wrapper<CTE>::resize(const shape_type& shape) 208 { 209 detail::zarray_wrapper_helper<value_type>::resize(m_array, shape); 210 } 211 212 template <class CTE> resize(shape_type && shape)213 void zarray_wrapper<CTE>::resize(shape_type&& shape) 214 { 215 detail::zarray_wrapper_helper<value_type>::resize(m_array, std::move(shape)); 216 } 217 218 template <class CTE> broadcast_shape(shape_type & shape,bool reuse_cache) const219 bool zarray_wrapper<CTE>::broadcast_shape(shape_type& shape, bool reuse_cache) const 220 { 221 return m_array.broadcast_shape(shape, reuse_cache); 222 } 223 } 224 225 #endif 226 227