1 /*************************************************************************** 2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht * 3 * Copyright (c) QuantStack * 4 * * 5 * Distributed under the terms of the BSD 3-Clause License. * 6 * * 7 * The full license is in the file LICENSE, distributed with this software. * 8 ****************************************************************************/ 9 10 #include "test_common_macros.hpp" 11 12 #include <cstddef> 13 14 #include "xtensor/xexpression.hpp" 15 #include "xtensor/xio.hpp" 16 #include "xtensor/xtensor.hpp" 17 #include "xtensor/xview.hpp" 18 19 struct field_expression_tag 20 { 21 }; 22 23 template<class D> 24 class field_expression : public xt::xexpression<D> { 25 public: 26 using expression_tag = field_expression_tag; 27 }; 28 29 template<class F, class... CT> 30 class field_function : public field_expression<field_function<F, CT...>> { 31 public: 32 using self_type = field_function<F, CT...>; 33 using functor_type = std::remove_reference_t<F>; 34 35 using expression_tag = field_expression_tag; 36 37 template<class Func, class... CTA, 38 class U = std::enable_if<!std::is_base_of<Func, self_type>::value>> field_function(Func && f,CTA &&...e)39 field_function(Func &&f, CTA &&... e) noexcept 40 : m_e(std::forward<CTA>(e)...), m_f(std::forward<Func>(f)) 41 {} 42 43 template<class... T> operator ()(const std::size_t begin,const std::size_t end) const44 auto operator()(const std::size_t begin, const std::size_t end) const 45 { 46 return evaluate(std::make_index_sequence<sizeof...(CT)>(), begin, end); 47 } 48 49 template<std::size_t... I, class... T> evaluate(std::index_sequence<I...>,T &&...t) const50 auto evaluate(std::index_sequence<I...>, T &&... t) const 51 { 52 return m_f( 53 std::get<I>(m_e).operator()(std::forward<T>(t)...)...); 54 } 55 56 private: 57 std::tuple<CT...> m_e; 58 functor_type m_f; 59 }; 60 61 namespace xt 62 { 63 namespace detail 64 { 65 template<class F, class... E> 66 struct select_xfunction_expression<field_expression_tag, F, E...> 67 { 68 using type = field_function<F, E...>; 69 }; 70 } 71 } 72 73 // using xt::operator+; 74 // using xt::operator-; 75 // using xt::operator*; 76 // using xt::operator/; 77 // using xt::operator%; 78 79 struct Field : public field_expression<Field> 80 { FieldField81 Field() : m_data(std::array<std::size_t, 1>{10}) 82 {} 83 operator ()Field84 auto operator()(const std::size_t begin, const std::size_t end) const 85 { 86 return xt::view(m_data, xt::range(begin, end)); 87 } 88 operator ()Field89 auto operator()(const std::size_t begin, const std::size_t end) 90 { 91 return xt::view(m_data, xt::range(begin, end)); 92 } 93 94 template<class E> operator =Field95 Field &operator=(const field_expression<E> &e) 96 { 97 (*this)(0, 5) = e.derived_cast()(0, 5); 98 return *this; 99 } 100 101 xt::xtensor<double, 1> m_data; 102 }; 103 TEST(xfunc_on_xexpression,field_expression)104 TEST(xfunc_on_xexpression, field_expression) 105 { 106 Field x, y; 107 xt::xtensor<double , 1> res{{20, 20, 20, 20, 20, 0, 0, 0, 0, 0}}; 108 x.m_data.fill(10); 109 y.m_data.fill(0); 110 111 y = x + x; 112 113 EXPECT_EQ(y.m_data, res); 114 } 115 TEST(xfunc_on_xexpression,copy_constructor)116 TEST(xfunc_on_xexpression, copy_constructor) 117 { 118 // Compilation test only 119 // checks that there is no ambiguity among xfunction constructors 120 xt::xtensor<double, 1> x{{1, 2}}, y{{3, 4}}; 121 xt::xtensor<double, 1> res{{4, 6}}; 122 123 auto expr = x + y; 124 decltype(expr) expr2{expr}; 125 126 EXPECT_EQ(xt::eval(expr2), res); 127 } 128