1 /** @addtogroup expressions
2  *  @{
3  */
4 /*
5   Copyright (C) 2016 D Levin (https://www.kfrlib.com)
6   This file is part of KFR
7 
8   KFR is free software: you can redistribute it and/or modify
9   it under the terms of the GNU General Public License as published by
10   the Free Software Foundation, either version 2 of the License, or
11   (at your option) any later version.
12 
13   KFR is distributed in the hope that it will be useful,
14   but WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16   GNU General Public License for more details.
17 
18   You should have received a copy of the GNU General Public License
19   along with KFR.
20 
21   If GPL is not suitable for your project, you must purchase a commercial license to use KFR.
22   Buying a commercial license is mandatory as soon as you develop commercial activities without
23   disclosing the source code of your own applications.
24   See https://www.kfrlib.com for details.
25  */
26 #pragma once
27 
28 #include "../simd/platform.hpp"
29 #include "../simd/shuffle.hpp"
30 #include "../simd/types.hpp"
31 #include "../simd/vec.hpp"
32 
33 #include <tuple>
34 #ifdef KFR_STD_COMPLEX
35 #include <complex>
36 #endif
37 
38 CMT_PRAGMA_GNU(GCC diagnostic push)
39 CMT_PRAGMA_GNU(GCC diagnostic ignored "-Wshadow")
40 CMT_PRAGMA_GNU(GCC diagnostic ignored "-Wparentheses")
41 
42 namespace kfr
43 {
44 
45 #ifdef KFR_STD_COMPLEX
46 
47 template <typename T>
48 using complex = std::complex<T>;
49 
50 #else
51 #ifndef KFR_CUSTOM_COMPLEX
52 
53 template <typename>
54 struct complex;
55 #endif
56 #endif
57 
58 constexpr size_t inout_context_size = 16;
59 
60 struct coutput_context
61 {
62     pconstvoid data[inout_context_size];
63 };
64 
65 struct cinput_context
66 {
67     pconstvoid data[inout_context_size];
68 };
69 
70 using coutput_t = const coutput_context*;
71 using cinput_t  = const cinput_context*;
72 
73 constexpr cinput_t cinput   = nullptr;
74 constexpr coutput_t coutput = nullptr;
75 
76 constexpr size_t infinite_size = static_cast<size_t>(-1);
77 
size_add(size_t x,size_t y)78 CMT_INTRINSIC constexpr size_t size_add(size_t x, size_t y)
79 {
80     return (x == infinite_size || y == infinite_size) ? infinite_size : x + y;
81 }
82 
size_sub(size_t x,size_t y)83 CMT_INTRINSIC constexpr size_t size_sub(size_t x, size_t y)
84 {
85     return (x == infinite_size || y == infinite_size) ? infinite_size : (x > y ? x - y : 0);
86 }
87 
size_min(size_t x)88 CMT_INTRINSIC constexpr size_t size_min(size_t x) CMT_NOEXCEPT { return x; }
89 
90 template <typename... Ts>
size_min(size_t x,size_t y,Ts...rest)91 CMT_INTRINSIC constexpr size_t size_min(size_t x, size_t y, Ts... rest) CMT_NOEXCEPT
92 {
93     return size_min(x < y ? x : y, rest...);
94 }
95 
96 /// @brief Base class of all input expressoins
97 struct input_expression
98 {
sizekfr::input_expression99     KFR_MEM_INTRINSIC constexpr static size_t size() CMT_NOEXCEPT { return infinite_size; }
100 
101     constexpr static bool is_incremental = false;
102 
begin_blockkfr::input_expression103     KFR_MEM_INTRINSIC constexpr void begin_block(cinput_t, size_t) const {}
end_blockkfr::input_expression104     KFR_MEM_INTRINSIC constexpr void end_block(cinput_t, size_t) const {}
105 };
106 
107 /// @brief Base class of all output expressoins
108 struct output_expression
109 {
sizekfr::output_expression110     KFR_MEM_INTRINSIC constexpr static size_t size() CMT_NOEXCEPT { return infinite_size; }
111 
112     constexpr static bool is_incremental = false;
113 
begin_blockkfr::output_expression114     KFR_MEM_INTRINSIC constexpr void begin_block(coutput_t, size_t) const {}
end_blockkfr::output_expression115     KFR_MEM_INTRINSIC constexpr void end_block(coutput_t, size_t) const {}
116 };
117 
118 /// @brief Check if the type argument is an input expression
119 template <typename E>
120 constexpr inline bool is_input_expression = is_base_of<input_expression, decay<E>>;
121 
122 /// @brief Check if the type arguments are an input expressions
123 template <typename... Es>
124 constexpr inline bool is_input_expressions = (is_base_of<input_expression, decay<Es>> || ...);
125 
126 /// @brief Check if the type argument is an output expression
127 template <typename E>
128 constexpr inline bool is_output_expression = is_base_of<output_expression, decay<E>>;
129 
130 /// @brief Check if the type arguments are an output expressions
131 template <typename... Es>
132 constexpr inline bool is_output_expressions = (is_base_of<output_expression, decay<Es>> || ...);
133 
134 /// @brief Check if the type argument is a number or a vector of numbers
135 template <typename T>
136 constexpr inline bool is_numeric = is_number<deep_subtype<T>>;
137 
138 /// @brief Check if the type arguments are a numbers or a vectors of numbers
139 template <typename... Ts>
140 constexpr inline bool is_numeric_args = (is_numeric<Ts> && ...);
141 
142 inline namespace CMT_ARCH_NAME
143 {
144 
145 #ifdef KFR_TESTING
146 namespace internal
147 {
148 template <typename T, size_t N, typename Fn>
get_fn_value(size_t index,Fn && fn)149 inline vec<T, N> get_fn_value(size_t index, Fn&& fn)
150 {
151     return apply(fn, enumerate<size_t, N>() + index);
152 }
153 } // namespace internal
154 
155 template <typename E, typename Fn>
test_expression(const E & expr,size_t size,Fn && fn,const char * expression=nullptr)156 void test_expression(const E& expr, size_t size, Fn&& fn, const char* expression = nullptr)
157 {
158     using T                  = value_type_of<E>;
159     ::testo::test_case* test = ::testo::active_test();
160     auto&& c                 = ::testo::make_comparison();
161     test->check(c <= expr.size() == size, expression);
162     if (expr.size() != size)
163         return;
164     size                     = size_min(size, 200);
165     constexpr size_t maxsize = 2 + ilog2(vector_width<T> * 2);
166     size_t g                 = 1;
167     for (size_t i = 0; i < size;)
168     {
169         const size_t next_size = std::min(prev_poweroftwo(size - i), g);
170         g *= 2;
171         if (g > (1 << (maxsize - 1)))
172             g = 1;
173 
174         cswitch(csize<1> << csizeseq<maxsize>, next_size, [&](auto x) {
175             constexpr size_t nsize = val_of(decltype(x)());
176             ::testo::scope s(as_string("i = ", i, " width = ", nsize));
177             test->check(c <= get_elements(expr, cinput, i, vec_shape<T, nsize>()) ==
178                             internal::get_fn_value<T, nsize>(i, fn),
179                         expression);
180         });
181         i += next_size;
182     }
183 }
184 #define TESTO_CHECK_EXPRESSION(expr, size, ...) ::kfr::test_expression(expr, size, __VA_ARGS__, #expr)
185 
186 #ifndef TESTO_NO_SHORT_MACROS
187 #define CHECK_EXPRESSION TESTO_CHECK_EXPRESSION
188 #endif
189 #endif
190 
191 namespace internal
192 {
193 
194 template <typename T, typename Fn>
195 struct expression_lambda : input_expression
196 {
197     using value_type = T;
expression_lambdakfr::CMT_ARCH_NAME::internal::expression_lambda198     KFR_MEM_INTRINSIC expression_lambda(Fn&& fn) : fn(std::move(fn)) {}
199 
200     template <size_t N, KFR_ENABLE_IF(N&& is_callable<Fn, cinput_t, size_t, vec_shape<T, N>>)>
get_elements(const expression_lambda & self,cinput_t cinput,size_t index,vec_shape<T,N> y)201     KFR_INTRINSIC friend vec<T, N> get_elements(const expression_lambda& self, cinput_t cinput, size_t index,
202                                                 vec_shape<T, N> y)
203     {
204         return self.fn(cinput, index, y);
205     }
206 
207     template <size_t N, KFR_ENABLE_IF(N&& is_callable<Fn, size_t>)>
get_elements(const expression_lambda & self,cinput_t,size_t index,vec_shape<T,N>)208     KFR_INTRINSIC friend vec<T, N> get_elements(const expression_lambda& self, cinput_t, size_t index,
209                                                 vec_shape<T, N>)
210     {
211         return apply(self.fn, enumerate<size_t, N>() + index);
212     }
213     template <size_t N, KFR_ENABLE_IF(N&& is_callable<Fn>)>
get_elements(const expression_lambda & self,cinput_t,size_t,vec_shape<T,N>)214     KFR_INTRINSIC friend vec<T, N> get_elements(const expression_lambda& self, cinput_t, size_t,
215                                                 vec_shape<T, N>)
216     {
217         return apply<N>(self.fn);
218     }
219 
220     Fn fn;
221 };
222 } // namespace internal
223 
224 template <typename T, typename Fn>
lambda(Fn && fn)225 internal::expression_lambda<T, decay<Fn>> lambda(Fn&& fn)
226 {
227     return internal::expression_lambda<T, decay<Fn>>(std::move(fn));
228 }
229 
230 namespace internal
231 {
232 
233 template <typename T, typename = void>
234 struct is_infinite_impl : std::false_type
235 {
236 };
237 
238 template <typename T>
239 struct is_infinite_impl<T, void_t<decltype(T::size())>>
240     : std::integral_constant<bool, T::size() == infinite_size>
241 {
242 };
243 } // namespace internal
244 
245 template <typename T>
246 constexpr inline bool is_infinite = internal::is_infinite_impl<T>::value;
247 
248 namespace internal
249 {
250 
251 template <typename... Args>
252 struct expression_with_arguments : input_expression
253 {
sizekfr::CMT_ARCH_NAME::internal::expression_with_arguments254     KFR_MEM_INTRINSIC constexpr size_t size() const CMT_NOEXCEPT
255     {
256         return size_impl(indicesfor_t<Args...>());
257     }
258 
259     constexpr static size_t count = sizeof...(Args);
260     expression_with_arguments()   = delete;
expression_with_argumentskfr::CMT_ARCH_NAME::internal::expression_with_arguments261     constexpr expression_with_arguments(Args&&... args) CMT_NOEXCEPT : args(std::forward<Args>(args)...) {}
262 
begin_blockkfr::CMT_ARCH_NAME::internal::expression_with_arguments263     KFR_MEM_INTRINSIC void begin_block(cinput_t cinput, size_t size) const
264     {
265         begin_block_impl(cinput, size, indicesfor_t<Args...>());
266     }
end_blockkfr::CMT_ARCH_NAME::internal::expression_with_arguments267     KFR_MEM_INTRINSIC void end_block(cinput_t cinput, size_t size) const
268     {
269         end_block_impl(cinput, size, indicesfor_t<Args...>());
270     }
271 
272     std::tuple<Args...> args;
273 
274 protected:
275     template <size_t... indices>
size_implkfr::CMT_ARCH_NAME::internal::expression_with_arguments276     KFR_MEM_INTRINSIC constexpr size_t size_impl(csizes_t<indices...>) const CMT_NOEXCEPT
277     {
278         return size_min(std::get<indices>(this->args).size()...);
279     }
280 
281     template <typename Fn, typename T, size_t N>
callkfr::CMT_ARCH_NAME::internal::expression_with_arguments282     KFR_MEM_INTRINSIC vec<T, N> call(cinput_t cinput, Fn&& fn, size_t index, vec_shape<T, N> x) const
283     {
284         return call_impl(cinput, std::forward<Fn>(fn), indicesfor_t<Args...>(), index, x);
285     }
286     template <size_t ArgIndex, typename U, size_t N,
287               typename T = value_type_of<typename details::get_nth_type<ArgIndex, Args...>::type>>
argumentkfr::CMT_ARCH_NAME::internal::expression_with_arguments288     KFR_MEM_INTRINSIC vec<U, N> argument(cinput_t cinput, csize_t<ArgIndex>, size_t index,
289                                          vec_shape<U, N>) const
290     {
291         static_assert(ArgIndex < count, "Incorrect ArgIndex");
292         return static_cast<vec<U, N>>(
293             get_elements(std::get<ArgIndex>(this->args), cinput, index, vec_shape<T, N>()));
294     }
295     template <typename U, size_t N,
296               typename T = value_type_of<typename details::get_nth_type<0, Args...>::type>>
argument_firstkfr::CMT_ARCH_NAME::internal::expression_with_arguments297     KFR_MEM_INTRINSIC vec<U, N> argument_first(cinput_t cinput, size_t index, vec_shape<U, N>) const
298     {
299         return static_cast<vec<U, N>>(
300             get_elements(std::get<0>(this->args), cinput, index, vec_shape<T, N>()));
301     }
302 
303 private:
304     template <typename Fn, typename T, size_t N, size_t... indices>
call_implkfr::CMT_ARCH_NAME::internal::expression_with_arguments305     KFR_MEM_INTRINSIC vec<T, N> call_impl(cinput_t cinput, Fn&& fn, csizes_t<indices...>, size_t index,
306                                           vec_shape<T, N>) const
307     {
308         return fn(get_elements(std::get<indices>(this->args), cinput, index,
309                                vec_shape<value_type_of<Args>, N>())...);
310     }
311     template <size_t... indices>
begin_block_implkfr::CMT_ARCH_NAME::internal::expression_with_arguments312     KFR_MEM_INTRINSIC void begin_block_impl(cinput_t cinput, size_t size, csizes_t<indices...>) const
313     {
314         swallow{ (std::get<indices>(args).begin_block(cinput, size), 0)... };
315     }
316     template <size_t... indices>
end_block_implkfr::CMT_ARCH_NAME::internal::expression_with_arguments317     KFR_MEM_INTRINSIC void end_block_impl(cinput_t cinput, size_t size, csizes_t<indices...>) const
318     {
319         swallow{ (std::get<indices>(args).end_block(cinput, size), 0)... };
320     }
321 };
322 
323 template <typename T>
324 struct expression_scalar : input_expression
325 {
326     using value_type    = T;
327     expression_scalar() = delete;
expression_scalarkfr::CMT_ARCH_NAME::internal::expression_scalar328     constexpr expression_scalar(const T& val) CMT_NOEXCEPT : val(val) {}
329     T val;
330 
331     template <size_t N>
get_elements(const expression_scalar & self,cinput_t,size_t,vec_shape<T,N>)332     friend KFR_INTRINSIC vec<T, N> get_elements(const expression_scalar& self, cinput_t, size_t,
333                                                 vec_shape<T, N>)
334     {
335         return broadcast<N>(self.val);
336     }
337 };
338 
339 template <typename T1, typename T2, typename = void>
340 struct arg_impl
341 {
342     using type       = T2;
343     using value_type = typename T1::value_type;
344 };
345 
346 template <typename T1, typename T2>
347 struct arg_impl<T1, T2, void_t<enable_if<is_vec_element<T1>>>>
348 {
349     using type       = expression_scalar<T1>;
350     using value_type = T1;
351 };
352 
353 template <typename T>
354 using arg = typename internal::arg_impl<decay<T>, T>::type;
355 
356 template <typename T>
357 using arg_type = typename internal::arg_impl<decay<T>, T>::value_type;
358 
359 template <typename Fn, typename... Args>
360 struct function_value_type
361 {
362     using type = typename invoke_result<Fn, vec<arg_type<Args>, 1>...>::value_type;
363 };
364 
365 template <typename Fn, typename... Args>
366 struct expression_function : expression_with_arguments<arg<Args>...>
367 {
368     using value_type = typename function_value_type<Fn, Args...>::type;
369     // subtype<decltype(std::declval<Fn>()(std::declval<vec<value_type_of<arg<Args>>, 1>>()...))>;
370     using T = value_type;
371 
expression_functionkfr::CMT_ARCH_NAME::internal::expression_function372     expression_function(Fn&& fn, arg<Args>&&... args) CMT_NOEXCEPT
373         : expression_with_arguments<arg<Args>...>(std::forward<arg<Args>>(args)...),
374           fn(std::forward<Fn>(fn))
375     {
376     }
expression_functionkfr::CMT_ARCH_NAME::internal::expression_function377     expression_function(const Fn& fn, arg<Args>&&... args) CMT_NOEXCEPT
378         : expression_with_arguments<arg<Args>...>(std::forward<arg<Args>>(args)...),
379           fn(fn)
380     {
381     }
382     template <size_t N>
get_elements(const expression_function & self,cinput_t cinput,size_t index,vec_shape<T,N> x)383     friend KFR_INTRINSIC vec<T, N> get_elements(const expression_function& self, cinput_t cinput,
384                                                 size_t index, vec_shape<T, N> x)
385     {
386         return self.call(cinput, self.fn, index, x);
387     }
388 
get_fnkfr::CMT_ARCH_NAME::internal::expression_function389     const Fn& get_fn() const CMT_NOEXCEPT { return fn; }
390 
391 protected:
392     Fn fn;
393 };
394 } // namespace internal
395 
396 template <typename A>
e(A && a)397 CMT_INTRINSIC internal::arg<A> e(A&& a)
398 {
399     return internal::arg<A>(std::forward<A>(a));
400 }
401 
402 template <typename T>
scalar(const T & val)403 CMT_INTRINSIC internal::expression_scalar<T> scalar(const T& val)
404 {
405     return internal::expression_scalar<T>(val);
406 }
407 
408 template <typename Fn, typename... Args>
bind_expression(Fn && fn,Args &&...args)409 CMT_INTRINSIC internal::expression_function<decay<Fn>, Args...> bind_expression(Fn&& fn, Args&&... args)
410 {
411     return internal::expression_function<decay<Fn>, Args...>(std::forward<Fn>(fn),
412                                                              std::forward<Args>(args)...);
413 }
414 /**
415  * @brief Construct a new expression using the same function as in @c e and new arguments
416  * @param e an expression
417  * @param args new arguments for the function
418  */
419 template <typename Fn, typename... OldArgs, typename... NewArgs>
rebind(const internal::expression_function<Fn,OldArgs...> & e,NewArgs &&...args)420 CMT_INTRINSIC internal::expression_function<Fn, NewArgs...> rebind(
421     const internal::expression_function<Fn, OldArgs...>& e, NewArgs&&... args)
422 {
423     return internal::expression_function<Fn, NewArgs...>(e.get_fn(), std::forward<NewArgs>(args)...);
424 }
425 
426 template <size_t width = 0, typename OutputExpr, typename InputExpr, size_t groupsize = 1,
427           typename Tvec = vec<value_type_of<InputExpr>, 1>>
process(OutputExpr && out,const InputExpr & in,size_t start=0,size_t size=infinite_size,coutput_t coutput=nullptr,cinput_t cinput=nullptr,csize_t<groupsize>=csize_t<groupsize> ())428 CMT_INTRINSIC static size_t process(OutputExpr&& out, const InputExpr& in, size_t start = 0,
429                                     size_t size = infinite_size, coutput_t coutput = nullptr,
430                                     cinput_t cinput = nullptr, csize_t<groupsize> = csize_t<groupsize>())
431 {
432     using Tin = value_type_of<InputExpr>;
433     static_assert(is_output_expression<OutputExpr>, "OutFn must be an expression");
434     static_assert(is_input_expression<InputExpr>, "Fn must be an expression");
435 
436     size = size_sub(size_min(out.size(), in.size(), size_add(size, start)), start);
437     if (size == 0 || size == infinite_size)
438         return size;
439     out.begin_block(coutput, size);
440     in.begin_block(cinput, size);
441 
442 #ifdef NDEBUG
443     constexpr size_t w = width == 0 ? maximum_vector_size<Tin> : width;
444 #else
445     constexpr size_t w = width == 0 ? vector_width<Tin> : width;
446 #endif
447 
448     static_assert(w > 0 && is_poweroftwo(w), "");
449 
450     size_t i = start;
451 
452     CMT_LOOP_NOUNROLL
453     for (; i < start + size / w * w; i += w)
454         out(coutput, i, get_elements(in, cinput, i, vec_shape<Tin, w>()));
455     CMT_LOOP_NOUNROLL
456     for (; i < start + size / groupsize * groupsize; i += groupsize)
457         out(coutput, i, get_elements(in, cinput, i, vec_shape<Tin, groupsize>()));
458 
459     in.end_block(cinput, size);
460     out.end_block(coutput, size);
461     return size;
462 }
463 
464 template <typename T>
465 struct input_expression_base : input_expression
466 {
~input_expression_basekfr::CMT_ARCH_NAME::input_expression_base467     virtual ~input_expression_base() {}
468     virtual T input(size_t index) const = 0;
469     template <typename U, size_t N>
get_elements(const input_expression_base & self,cinput_t,size_t index,vec_shape<U,N>)470     friend KFR_INTRINSIC vec<U, N> get_elements(const input_expression_base& self, cinput_t, size_t index,
471                                                 vec_shape<U, N>)
472     {
473         vec<U, N> out;
474         for (size_t i = 0; i < N; i++)
475             out[i] = static_cast<U>(self.input(index + i));
476         return out;
477     }
478 };
479 
480 template <typename T>
481 struct output_expression_base : output_expression
482 {
~output_expression_basekfr::CMT_ARCH_NAME::output_expression_base483     virtual ~output_expression_base() {}
484     virtual void output(size_t index, const T& value) = 0;
485 
486     template <typename U, size_t N>
operator ()kfr::CMT_ARCH_NAME::output_expression_base487     KFR_MEM_INTRINSIC void operator()(coutput_t, size_t index, const vec<U, N>& value)
488     {
489         for (size_t i = 0; i < N; i++)
490             output(index + i, static_cast<T>(value[i]));
491     }
492 };
493 
494 template <typename E1, typename E2, KFR_ENABLE_IF(is_input_expressions<E1, E2>)>
interleave(E1 && x,E2 && y)495 CMT_INTRINSIC internal::expression_function<fn::interleave, E1, E2> interleave(E1&& x, E2&& y)
496 {
497     return { fn::interleave(), std::forward<E1>(x), std::forward<E2>(y) };
498 }
499 } // namespace CMT_ARCH_NAME
500 } // namespace kfr
501 
502 CMT_PRAGMA_GNU(GCC diagnostic pop)
503