1 /***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht          *
3 * Copyright (c) QuantStack                                                 *
4 *                                                                          *
5 * Distributed under the terms of the BSD 3-Clause License.                 *
6 *                                                                          *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9 
10 #include "test_common_macros.hpp"
11 
12 #include <cstddef>
13 
14 #include "xtensor/xexpression.hpp"
15 #include "xtensor/xio.hpp"
16 #include "xtensor/xtensor.hpp"
17 #include "xtensor/xview.hpp"
18 
19     struct field_expression_tag
20     {
21     };
22 
23     template<class D>
24     class field_expression : public xt::xexpression<D> {
25     public:
26         using expression_tag = field_expression_tag;
27     };
28 
29     template<class F, class... CT>
30     class field_function : public field_expression<field_function<F, CT...>> {
31     public:
32         using self_type = field_function<F, CT...>;
33         using functor_type = std::remove_reference_t<F>;
34 
35         using expression_tag = field_expression_tag;
36 
37         template<class Func, class... CTA,
38                 class U = std::enable_if<!std::is_base_of<Func, self_type>::value>>
field_function(Func && f,CTA &&...e)39         field_function(Func &&f, CTA &&... e) noexcept
40             : m_e(std::forward<CTA>(e)...), m_f(std::forward<Func>(f))
41         {}
42 
43         template<class... T>
operator ()(const std::size_t begin,const std::size_t end) const44         auto operator()(const std::size_t begin, const std::size_t end) const
45         {
46             return evaluate(std::make_index_sequence<sizeof...(CT)>(), begin, end);
47         }
48 
49         template<std::size_t... I, class... T>
evaluate(std::index_sequence<I...>,T &&...t) const50         auto evaluate(std::index_sequence<I...>, T &&... t) const
51         {
52             return m_f(
53                 std::get<I>(m_e).operator()(std::forward<T>(t)...)...);
54         }
55 
56     private:
57         std::tuple<CT...> m_e;
58         functor_type m_f;
59     };
60 
61     namespace xt
62     {
63         namespace detail
64         {
65             template<class F, class... E>
66             struct select_xfunction_expression<field_expression_tag, F, E...>
67             {
68                 using type = field_function<F, E...>;
69             };
70         }
71     }
72 
73     // using xt::operator+;
74     // using xt::operator-;
75     // using xt::operator*;
76     // using xt::operator/;
77     // using xt::operator%;
78 
79     struct Field : public field_expression<Field>
80     {
FieldField81         Field() : m_data(std::array<std::size_t, 1>{10})
82         {}
83 
operator ()Field84         auto operator()(const std::size_t begin, const std::size_t end) const
85         {
86             return xt::view(m_data, xt::range(begin, end));
87         }
88 
operator ()Field89         auto operator()(const std::size_t begin, const std::size_t end)
90         {
91             return xt::view(m_data, xt::range(begin, end));
92         }
93 
94         template<class E>
operator =Field95         Field &operator=(const field_expression<E> &e)
96         {
97             (*this)(0, 5) = e.derived_cast()(0, 5);
98             return *this;
99         }
100 
101         xt::xtensor<double, 1> m_data;
102     };
103 
TEST(xfunc_on_xexpression,field_expression)104     TEST(xfunc_on_xexpression, field_expression)
105     {
106         Field x, y;
107         xt::xtensor<double , 1> res{{20, 20, 20, 20, 20, 0, 0, 0, 0, 0}};
108         x.m_data.fill(10);
109         y.m_data.fill(0);
110 
111         y = x + x;
112 
113         EXPECT_EQ(y.m_data, res);
114     }
115 
TEST(xfunc_on_xexpression,copy_constructor)116     TEST(xfunc_on_xexpression, copy_constructor)
117     {
118         // Compilation test only
119         // checks that there is no ambiguity among xfunction constructors
120         xt::xtensor<double, 1> x{{1, 2}}, y{{3, 4}};
121         xt::xtensor<double, 1> res{{4, 6}};
122 
123         auto expr = x + y;
124         decltype(expr) expr2{expr};
125 
126         EXPECT_EQ(xt::eval(expr2), res);
127     }
128