1 // Copyright 2020, 2021 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com)
2 //
3 // This file is part of the heyoka library.
4 //
5 // This Source Code Form is subject to the terms of the Mozilla
6 // Public License v. 2.0. If a copy of the MPL was not distributed
7 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #include <heyoka/config.hpp>
10 
11 #include <algorithm>
12 #include <cassert>
13 #include <cstddef>
14 #include <cstdint>
15 #include <functional>
16 #include <initializer_list>
17 #include <ostream>
18 #include <stdexcept>
19 #include <string>
20 #include <type_traits>
21 #include <unordered_map>
22 #include <utility>
23 #include <variant>
24 #include <vector>
25 
26 #include <fmt/format.h>
27 
28 #include <llvm/IR/BasicBlock.h>
29 #include <llvm/IR/DerivedTypes.h>
30 #include <llvm/IR/Function.h>
31 #include <llvm/IR/IRBuilder.h>
32 #include <llvm/IR/LLVMContext.h>
33 #include <llvm/IR/Module.h>
34 #include <llvm/IR/Type.h>
35 #include <llvm/IR/Value.h>
36 
37 #if defined(HEYOKA_HAVE_REAL128)
38 
39 #include <mp++/real128.hpp>
40 
41 #endif
42 
43 #include <heyoka/detail/fwd_decl.hpp>
44 #include <heyoka/detail/llvm_helpers.hpp>
45 #include <heyoka/detail/string_conv.hpp>
46 #include <heyoka/expression.hpp>
47 #include <heyoka/func.hpp>
48 #include <heyoka/llvm_state.hpp>
49 #include <heyoka/math/binary_op.hpp>
50 #include <heyoka/number.hpp>
51 #include <heyoka/s11n.hpp>
52 #include <heyoka/taylor.hpp>
53 #include <heyoka/variable.hpp>
54 
55 #if defined(_MSC_VER) && !defined(__clang__)
56 
57 // NOTE: MSVC has issues with the other "using"
58 // statement form.
59 using namespace fmt::literals;
60 
61 #else
62 
63 using fmt::literals::operator""_format;
64 
65 #endif
66 
67 namespace heyoka
68 {
69 
70 namespace detail
71 {
72 
binary_op()73 binary_op::binary_op() : binary_op(type::add, 0_dbl, 0_dbl) {}
74 
binary_op(type t,expression a,expression b)75 binary_op::binary_op(type t, expression a, expression b)
76     : func_base("binary_op", std::vector{std::move(a), std::move(b)}), m_type(t)
77 {
78     assert(m_type >= type::add && m_type <= type::div);
79 }
80 
extra_equal_to(const func & f) const81 bool binary_op::extra_equal_to(const func &f) const
82 {
83     // NOTE: this should be ensured by the
84     // implementation of func's equality operator.
85     assert(f.extract<binary_op>() == f.get_ptr());
86 
87     return static_cast<const binary_op *>(f.get_ptr())->m_type == m_type;
88 }
89 
extra_hash() const90 std::size_t binary_op::extra_hash() const
91 {
92     return std::hash<type>{}(m_type);
93 }
94 
to_stream(std::ostream & os) const95 void binary_op::to_stream(std::ostream &os) const
96 {
97     assert(args().size() == 2u);
98     assert(m_type >= type::add && m_type <= type::div);
99 
100     os << '(' << lhs() << ' ';
101 
102     switch (m_type) {
103         case type::add:
104             os << '+';
105             break;
106         case type::sub:
107             os << '-';
108             break;
109         case type::mul:
110             os << '*';
111             break;
112         default:
113             os << '/';
114             break;
115     }
116 
117     os << ' ' << rhs() << ')';
118 }
119 
op() const120 binary_op::type binary_op::op() const
121 {
122     return m_type;
123 }
124 
lhs() const125 const expression &binary_op::lhs() const
126 {
127     assert(args().size() == 2u);
128     return args()[0];
129 }
130 
rhs() const131 const expression &binary_op::rhs() const
132 {
133     assert(args().size() == 2u);
134     return args()[1];
135 }
136 
137 template <typename T>
diff_impl(std::unordered_map<const void *,expression> & func_map,const T & x) const138 expression binary_op::diff_impl(std::unordered_map<const void *, expression> &func_map, const T &x) const
139 {
140     assert(args().size() == 2u);
141     assert(m_type >= type::add && m_type <= type::div);
142 
143     switch (m_type) {
144         case type::add:
145             return detail::diff(func_map, lhs(), x) + detail::diff(func_map, rhs(), x);
146         case type::sub:
147             return detail::diff(func_map, lhs(), x) - detail::diff(func_map, rhs(), x);
148         case type::mul:
149             return detail::diff(func_map, lhs(), x) * rhs() + lhs() * detail::diff(func_map, rhs(), x);
150         default:
151             return (detail::diff(func_map, lhs(), x) * rhs() - lhs() * detail::diff(func_map, rhs(), x))
152                    / (rhs() * rhs());
153     }
154 }
155 
diff(std::unordered_map<const void *,expression> & func_map,const std::string & s) const156 expression binary_op::diff(std::unordered_map<const void *, expression> &func_map, const std::string &s) const
157 {
158     return diff_impl(func_map, s);
159 }
160 
diff(std::unordered_map<const void *,expression> & func_map,const param & p) const161 expression binary_op::diff(std::unordered_map<const void *, expression> &func_map, const param &p) const
162 {
163     return diff_impl(func_map, p);
164 }
165 
166 namespace
167 {
168 
169 template <class T>
eval_bo_impl(const binary_op & bo,const std::unordered_map<std::string,T> & map,const std::vector<T> & pars)170 T eval_bo_impl(const binary_op &bo, const std::unordered_map<std::string, T> &map, const std::vector<T> &pars)
171 {
172     assert(bo.args().size() == 2u);
173     assert(bo.op() >= binary_op::type::add && bo.op() <= binary_op::type::div);
174 
175     switch (bo.op()) {
176         case binary_op::type::add:
177             return eval<T>(bo.lhs(), map, pars) + eval<T>(bo.rhs(), map, pars);
178         case binary_op::type::sub:
179             return eval<T>(bo.lhs(), map, pars) - eval<T>(bo.rhs(), map, pars);
180         case binary_op::type::mul:
181             return eval<T>(bo.lhs(), map, pars) * eval<T>(bo.rhs(), map, pars);
182         default:
183             return eval<T>(bo.lhs(), map, pars) / eval<T>(bo.rhs(), map, pars);
184     }
185 }
186 
187 } // namespace
188 
eval_dbl(const std::unordered_map<std::string,double> & map,const std::vector<double> & pars) const189 double binary_op::eval_dbl(const std::unordered_map<std::string, double> &map, const std::vector<double> &pars) const
190 {
191     return eval_bo_impl<double>(*this, map, pars);
192 }
193 
eval_ldbl(const std::unordered_map<std::string,long double> & map,const std::vector<long double> & pars) const194 long double binary_op::eval_ldbl(const std::unordered_map<std::string, long double> &map,
195                                  const std::vector<long double> &pars) const
196 {
197     return eval_bo_impl<long double>(*this, map, pars);
198 }
199 
200 #if defined(HEYOKA_HAVE_REAL128)
201 
eval_f128(const std::unordered_map<std::string,mppp::real128> & map,const std::vector<mppp::real128> & pars) const202 mppp::real128 binary_op::eval_f128(const std::unordered_map<std::string, mppp::real128> &map,
203                                    const std::vector<mppp::real128> &pars) const
204 {
205     return eval_bo_impl<mppp::real128>(*this, map, pars);
206 }
207 
208 #endif
209 
eval_batch_dbl(std::vector<double> & out_values,const std::unordered_map<std::string,std::vector<double>> & map,const std::vector<double> & pars) const210 void binary_op::eval_batch_dbl(std::vector<double> &out_values,
211                                const std::unordered_map<std::string, std::vector<double>> &map,
212                                const std::vector<double> &pars) const
213 {
214     assert(args().size() == 2u);
215     assert(m_type >= type::add && m_type <= type::div);
216 
217     auto tmp = out_values;
218     heyoka::eval_batch_dbl(out_values, lhs(), map, pars);
219     heyoka::eval_batch_dbl(tmp, rhs(), map, pars);
220     switch (m_type) {
221         case type::add:
222             std::transform(out_values.begin(), out_values.end(), tmp.begin(), out_values.begin(), std::plus<>());
223             break;
224         case type::sub:
225             std::transform(out_values.begin(), out_values.end(), tmp.begin(), out_values.begin(), std::minus<>());
226             break;
227         case type::mul:
228             std::transform(out_values.begin(), out_values.end(), tmp.begin(), out_values.begin(), std::multiplies<>());
229             break;
230         default:
231             std::transform(out_values.begin(), out_values.end(), tmp.begin(), out_values.begin(), std::divides<>());
232             break;
233     }
234 }
235 
236 namespace
237 {
238 
239 // Derivative of number +- number.
240 template <bool AddOrSub, typename T, typename U, typename V,
241           std::enable_if_t<std::conjunction_v<is_num_param<U>, is_num_param<V>>, int> = 0>
bo_taylor_diff_addsub_impl(llvm_state & s,const U & num0,const V & num1,const std::vector<llvm::Value * > &,llvm::Value * par_ptr,std::uint32_t,std::uint32_t order,std::uint32_t,std::uint32_t batch_size)242 llvm::Value *bo_taylor_diff_addsub_impl(llvm_state &s, const U &num0, const V &num1, const std::vector<llvm::Value *> &,
243                                         llvm::Value *par_ptr, std::uint32_t, std::uint32_t order, std::uint32_t,
244                                         std::uint32_t batch_size)
245 {
246     if (order == 0u) {
247         auto n0 = taylor_codegen_numparam<T>(s, num0, par_ptr, batch_size);
248         auto n1 = taylor_codegen_numparam<T>(s, num1, par_ptr, batch_size);
249 
250         return AddOrSub ? s.builder().CreateFAdd(n0, n1) : s.builder().CreateFSub(n0, n1);
251     } else {
252         return vector_splat(s.builder(), codegen<T>(s, number{0.}), batch_size);
253     }
254 }
255 
256 // Derivative of number +- var.
257 template <bool AddOrSub, typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
bo_taylor_diff_addsub_impl(llvm_state & s,const U & num,const variable & var,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t,std::uint32_t batch_size)258 llvm::Value *bo_taylor_diff_addsub_impl(llvm_state &s, const U &num, const variable &var,
259                                         const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr,
260                                         std::uint32_t n_uvars, std::uint32_t order, std::uint32_t,
261                                         std::uint32_t batch_size)
262 {
263     auto &builder = s.builder();
264 
265     auto ret = taylor_fetch_diff(arr, uname_to_index(var.name()), order, n_uvars);
266 
267     if (order == 0u) {
268         auto n = taylor_codegen_numparam<T>(s, num, par_ptr, batch_size);
269 
270         return AddOrSub ? builder.CreateFAdd(n, ret) : builder.CreateFSub(n, ret);
271     } else {
272         if constexpr (AddOrSub) {
273             return ret;
274         } else {
275             // Negate if we are doing a subtraction.
276             return builder.CreateFNeg(ret);
277         }
278     }
279 }
280 
281 // Derivative of var +- number.
282 template <bool AddOrSub, typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
bo_taylor_diff_addsub_impl(llvm_state & s,const variable & var,const U & num,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t,std::uint32_t batch_size)283 llvm::Value *bo_taylor_diff_addsub_impl(llvm_state &s, const variable &var, const U &num,
284                                         const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr,
285                                         std::uint32_t n_uvars, std::uint32_t order, std::uint32_t,
286                                         std::uint32_t batch_size)
287 {
288     auto ret = taylor_fetch_diff(arr, uname_to_index(var.name()), order, n_uvars);
289 
290     if (order == 0u) {
291         auto &builder = s.builder();
292 
293         auto n = taylor_codegen_numparam<T>(s, num, par_ptr, batch_size);
294 
295         return AddOrSub ? builder.CreateFAdd(ret, n) : builder.CreateFSub(ret, n);
296     } else {
297         return ret;
298     }
299 }
300 
301 // Derivative of var +- var.
302 template <bool AddOrSub, typename T>
bo_taylor_diff_addsub_impl(llvm_state & s,const variable & var0,const variable & var1,const std::vector<llvm::Value * > & arr,llvm::Value *,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t,std::uint32_t)303 llvm::Value *bo_taylor_diff_addsub_impl(llvm_state &s, const variable &var0, const variable &var1,
304                                         const std::vector<llvm::Value *> &arr, llvm::Value *, std::uint32_t n_uvars,
305                                         std::uint32_t order, std::uint32_t, std::uint32_t)
306 {
307     auto v0 = taylor_fetch_diff(arr, uname_to_index(var0.name()), order, n_uvars);
308     auto v1 = taylor_fetch_diff(arr, uname_to_index(var1.name()), order, n_uvars);
309 
310     if constexpr (AddOrSub) {
311         return s.builder().CreateFAdd(v0, v1);
312     } else {
313         return s.builder().CreateFSub(v0, v1);
314     }
315 }
316 
317 // All the other cases.
318 // LCOV_EXCL_START
319 template <bool, typename, typename V1, typename V2,
320           std::enable_if_t<!std::conjunction_v<is_num_param<V1>, is_num_param<V2>>, int> = 0>
bo_taylor_diff_addsub_impl(llvm_state &,const V1 &,const V2 &,const std::vector<llvm::Value * > &,llvm::Value *,std::uint32_t,std::uint32_t,std::uint32_t,std::uint32_t)321 llvm::Value *bo_taylor_diff_addsub_impl(llvm_state &, const V1 &, const V2 &, const std::vector<llvm::Value *> &,
322                                         llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t, std::uint32_t)
323 {
324     throw std::invalid_argument(
325         "An invalid argument type was encountered while trying to build the Taylor derivative of add()/sub()");
326 }
327 // LCOV_EXCL_STOP
328 
329 template <typename T>
bo_taylor_diff_add(llvm_state & s,const binary_op & bo,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size)330 llvm::Value *bo_taylor_diff_add(llvm_state &s, const binary_op &bo, const std::vector<llvm::Value *> &arr,
331                                 llvm::Value *par_ptr, std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
332                                 std::uint32_t batch_size)
333 {
334     return std::visit(
335         [&](const auto &v1, const auto &v2) {
336             return bo_taylor_diff_addsub_impl<true, T>(s, v1, v2, arr, par_ptr, n_uvars, order, idx, batch_size);
337         },
338         bo.lhs().value(), bo.rhs().value());
339 }
340 
341 template <typename T>
bo_taylor_diff_sub(llvm_state & s,const binary_op & bo,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size)342 llvm::Value *bo_taylor_diff_sub(llvm_state &s, const binary_op &bo, const std::vector<llvm::Value *> &arr,
343                                 llvm::Value *par_ptr, std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
344                                 std::uint32_t batch_size)
345 {
346     return std::visit(
347         [&](const auto &v1, const auto &v2) {
348             return bo_taylor_diff_addsub_impl<false, T>(s, v1, v2, arr, par_ptr, n_uvars, order, idx, batch_size);
349         },
350         bo.lhs().value(), bo.rhs().value());
351 }
352 
353 // Derivative of number * number.
354 template <typename T, typename U, typename V,
355           std::enable_if_t<std::conjunction_v<is_num_param<U>, is_num_param<V>>, int> = 0>
bo_taylor_diff_mul_impl(llvm_state & s,const U & num0,const V & num1,const std::vector<llvm::Value * > &,llvm::Value * par_ptr,std::uint32_t,std::uint32_t order,std::uint32_t,std::uint32_t batch_size)356 llvm::Value *bo_taylor_diff_mul_impl(llvm_state &s, const U &num0, const V &num1, const std::vector<llvm::Value *> &,
357                                      llvm::Value *par_ptr, std::uint32_t, std::uint32_t order, std::uint32_t,
358                                      std::uint32_t batch_size)
359 {
360     if (order == 0u) {
361         auto n0 = taylor_codegen_numparam<T>(s, num0, par_ptr, batch_size);
362         auto n1 = taylor_codegen_numparam<T>(s, num1, par_ptr, batch_size);
363 
364         return s.builder().CreateFMul(n0, n1);
365     } else {
366         return vector_splat(s.builder(), codegen<T>(s, number{0.}), batch_size);
367     }
368 }
369 
370 // Derivative of var * number.
371 template <typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
bo_taylor_diff_mul_impl(llvm_state & s,const variable & var,const U & num,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t,std::uint32_t batch_size)372 llvm::Value *bo_taylor_diff_mul_impl(llvm_state &s, const variable &var, const U &num,
373                                      const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, std::uint32_t n_uvars,
374                                      std::uint32_t order, std::uint32_t, std::uint32_t batch_size)
375 {
376     auto &builder = s.builder();
377 
378     auto ret = taylor_fetch_diff(arr, uname_to_index(var.name()), order, n_uvars);
379     auto mul = taylor_codegen_numparam<T>(s, num, par_ptr, batch_size);
380 
381     return builder.CreateFMul(mul, ret);
382 }
383 
384 // Derivative of number * var.
385 template <typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
bo_taylor_diff_mul_impl(llvm_state & s,const U & num,const variable & var,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size)386 llvm::Value *bo_taylor_diff_mul_impl(llvm_state &s, const U &num, const variable &var,
387                                      const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, std::uint32_t n_uvars,
388                                      std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size)
389 {
390     // Return the derivative of var * number.
391     return bo_taylor_diff_mul_impl<T>(s, var, num, arr, par_ptr, n_uvars, order, idx, batch_size);
392 }
393 
394 // Derivative of var * var.
395 template <typename T>
bo_taylor_diff_mul_impl(llvm_state & s,const variable & var0,const variable & var1,const std::vector<llvm::Value * > & arr,llvm::Value *,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t,std::uint32_t)396 llvm::Value *bo_taylor_diff_mul_impl(llvm_state &s, const variable &var0, const variable &var1,
397                                      const std::vector<llvm::Value *> &arr, llvm::Value *, std::uint32_t n_uvars,
398                                      std::uint32_t order, std::uint32_t, std::uint32_t)
399 {
400     // Fetch the indices of the u variables.
401     const auto u_idx0 = uname_to_index(var0.name());
402     const auto u_idx1 = uname_to_index(var1.name());
403 
404     // NOTE: iteration in the [0, order] range
405     // (i.e., order inclusive).
406     std::vector<llvm::Value *> sum;
407     auto &builder = s.builder();
408     for (std::uint32_t j = 0; j <= order; ++j) {
409         auto v0 = taylor_fetch_diff(arr, u_idx0, order - j, n_uvars);
410         auto v1 = taylor_fetch_diff(arr, u_idx1, j, n_uvars);
411 
412         // Add v0*v1 to the sum.
413         sum.push_back(builder.CreateFMul(v0, v1));
414     }
415 
416     return pairwise_sum(builder, sum);
417 }
418 
419 // All the other cases.
420 // LCOV_EXCL_START
421 template <typename, typename V1, typename V2,
422           std::enable_if_t<!std::conjunction_v<is_num_param<V1>, is_num_param<V2>>, int> = 0>
bo_taylor_diff_mul_impl(llvm_state &,const V1 &,const V2 &,const std::vector<llvm::Value * > &,llvm::Value *,std::uint32_t,std::uint32_t,std::uint32_t,std::uint32_t)423 llvm::Value *bo_taylor_diff_mul_impl(llvm_state &, const V1 &, const V2 &, const std::vector<llvm::Value *> &,
424                                      llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t, std::uint32_t)
425 {
426     throw std::invalid_argument(
427         "An invalid argument type was encountered while trying to build the Taylor derivative of mul()");
428 }
429 // LCOV_EXCL_STOP
430 
431 template <typename T>
bo_taylor_diff_mul(llvm_state & s,const binary_op & bo,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size)432 llvm::Value *bo_taylor_diff_mul(llvm_state &s, const binary_op &bo, const std::vector<llvm::Value *> &arr,
433                                 llvm::Value *par_ptr, std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
434                                 std::uint32_t batch_size)
435 {
436     return std::visit(
437         [&](const auto &v1, const auto &v2) {
438             return bo_taylor_diff_mul_impl<T>(s, v1, v2, arr, par_ptr, n_uvars, order, idx, batch_size);
439         },
440         bo.lhs().value(), bo.rhs().value());
441 }
442 
443 // Derivative of number / number.
444 template <typename T, typename U, typename V,
445           std::enable_if_t<std::conjunction_v<is_num_param<U>, is_num_param<V>>, int> = 0>
bo_taylor_diff_div_impl(llvm_state & s,const U & num0,const V & num1,const std::vector<llvm::Value * > &,llvm::Value * par_ptr,std::uint32_t,std::uint32_t order,std::uint32_t,std::uint32_t batch_size)446 llvm::Value *bo_taylor_diff_div_impl(llvm_state &s, const U &num0, const V &num1, const std::vector<llvm::Value *> &,
447                                      llvm::Value *par_ptr, std::uint32_t, std::uint32_t order, std::uint32_t,
448                                      std::uint32_t batch_size)
449 {
450     if (order == 0u) {
451         auto n0 = taylor_codegen_numparam<T>(s, num0, par_ptr, batch_size);
452         auto n1 = taylor_codegen_numparam<T>(s, num1, par_ptr, batch_size);
453 
454         return s.builder().CreateFDiv(n0, n1);
455     } else {
456         return vector_splat(s.builder(), codegen<T>(s, number{0.}), batch_size);
457     }
458 }
459 
460 // Derivative of variable / variable or number / variable. These two cases
461 // are quite similar, so we handle them together.
462 template <typename T, typename U,
463           std::enable_if_t<
464               std::disjunction_v<std::is_same<U, number>, std::is_same<U, variable>, std::is_same<U, param>>, int> = 0>
bo_taylor_diff_div_impl(llvm_state & s,const U & nv,const variable & var1,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size)465 llvm::Value *bo_taylor_diff_div_impl(llvm_state &s, const U &nv, const variable &var1,
466                                      const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, std::uint32_t n_uvars,
467                                      std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size)
468 {
469     auto &builder = s.builder();
470 
471     // Fetch the index of var1.
472     const auto u_idx1 = uname_to_index(var1.name());
473 
474     if (order == 0u) {
475         // Special casing for zero order.
476         auto numerator = [&]() -> llvm::Value * {
477             if constexpr (std::is_same_v<U, number> || std::is_same_v<U, param>) {
478                 return taylor_codegen_numparam<T>(s, nv, par_ptr, batch_size);
479             } else {
480                 return taylor_fetch_diff(arr, uname_to_index(nv.name()), 0, n_uvars);
481             }
482         }();
483 
484         return builder.CreateFDiv(numerator, taylor_fetch_diff(arr, u_idx1, 0, n_uvars));
485     }
486 
487     // NOTE: iteration in the [1, order] range
488     // (i.e., order inclusive).
489     std::vector<llvm::Value *> sum;
490     for (std::uint32_t j = 1; j <= order; ++j) {
491         auto v0 = taylor_fetch_diff(arr, idx, order - j, n_uvars);
492         auto v1 = taylor_fetch_diff(arr, u_idx1, j, n_uvars);
493 
494         // Add v0*v1 to the sum.
495         sum.push_back(builder.CreateFMul(v0, v1));
496     }
497 
498     // Init the return value as the result of the sum.
499     auto ret_acc = pairwise_sum(builder, sum);
500 
501     // Load the divisor for the quotient formula.
502     // This is the zero-th order derivative of var1.
503     auto div = taylor_fetch_diff(arr, u_idx1, 0, n_uvars);
504 
505     if constexpr (std::is_same_v<U, number> || std::is_same_v<U, param>) {
506         // nv is a number/param. Negate the accumulator
507         // and divide it by the divisor.
508         return builder.CreateFDiv(builder.CreateFNeg(ret_acc), div);
509     } else {
510         // nv is a variable. We need to fetch its
511         // derivative of order 'order' from the array of derivatives.
512         auto diff_nv_v = taylor_fetch_diff(arr, uname_to_index(nv.name()), order, n_uvars);
513 
514         // Produce the result: (diff_nv_v - ret_acc) / div.
515         return builder.CreateFDiv(builder.CreateFSub(diff_nv_v, ret_acc), div);
516     }
517 }
518 
519 // Derivative of variable / number.
520 template <typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
bo_taylor_diff_div_impl(llvm_state & s,const variable & var,const U & num,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t,std::uint32_t batch_size)521 llvm::Value *bo_taylor_diff_div_impl(llvm_state &s, const variable &var, const U &num,
522                                      const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, std::uint32_t n_uvars,
523                                      std::uint32_t order, std::uint32_t, std::uint32_t batch_size)
524 {
525     auto &builder = s.builder();
526 
527     auto ret = taylor_fetch_diff(arr, uname_to_index(var.name()), order, n_uvars);
528     auto div = taylor_codegen_numparam<T>(s, num, par_ptr, batch_size);
529 
530     return builder.CreateFDiv(ret, div);
531 }
532 
533 // All the other cases.
534 // LCOV_EXCL_START
535 template <typename, typename V1, typename V2,
536           std::enable_if_t<!std::conjunction_v<is_num_param<V1>, is_num_param<V2>>, int> = 0>
bo_taylor_diff_div_impl(llvm_state &,const V1 &,const V2 &,const std::vector<llvm::Value * > &,llvm::Value *,std::uint32_t,std::uint32_t,std::uint32_t,std::uint32_t)537 llvm::Value *bo_taylor_diff_div_impl(llvm_state &, const V1 &, const V2 &, const std::vector<llvm::Value *> &,
538                                      llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t, std::uint32_t)
539 {
540     throw std::invalid_argument(
541         "An invalid argument type was encountered while trying to build the Taylor derivative of div()");
542 }
543 // LCOV_EXCL_STOP
544 
545 template <typename T>
bo_taylor_diff_div(llvm_state & s,const binary_op & bo,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size)546 llvm::Value *bo_taylor_diff_div(llvm_state &s, const binary_op &bo, const std::vector<llvm::Value *> &arr,
547                                 llvm::Value *par_ptr, std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
548                                 std::uint32_t batch_size)
549 {
550     return std::visit(
551         [&](const auto &v1, const auto &v2) {
552             return bo_taylor_diff_div_impl<T>(s, v1, v2, arr, par_ptr, n_uvars, order, idx, batch_size);
553         },
554         bo.lhs().value(), bo.rhs().value());
555 }
556 
557 template <typename T>
taylor_diff_bo_impl(llvm_state & s,const binary_op & bo,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size)558 llvm::Value *taylor_diff_bo_impl(llvm_state &s, const binary_op &bo, const std::vector<std::uint32_t> &deps,
559                                  const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, std::uint32_t n_uvars,
560                                  std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size)
561 {
562     assert(bo.args().size() == 2u);
563     assert(bo.op() >= binary_op::type::add && bo.op() <= binary_op::type::div);
564 
565     if (!deps.empty()) {
566         throw std::invalid_argument("The vector of hidden dependencies in the Taylor diff for a binary operator "
567                                     "should be empty, but instead it has a size of {}"_format(deps.size()));
568     }
569 
570     switch (bo.op()) {
571         case binary_op::type::add:
572             return bo_taylor_diff_add<T>(s, bo, arr, par_ptr, n_uvars, order, idx, batch_size);
573         case binary_op::type::sub:
574             return bo_taylor_diff_sub<T>(s, bo, arr, par_ptr, n_uvars, order, idx, batch_size);
575         case binary_op::type::mul:
576             return bo_taylor_diff_mul<T>(s, bo, arr, par_ptr, n_uvars, order, idx, batch_size);
577         default:
578             return bo_taylor_diff_div<T>(s, bo, arr, par_ptr, n_uvars, order, idx, batch_size);
579     }
580 }
581 
582 } // namespace
583 
taylor_diff_dbl(llvm_state & s,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value *,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool) const584 llvm::Value *binary_op::taylor_diff_dbl(llvm_state &s, const std::vector<std::uint32_t> &deps,
585                                         const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *,
586                                         std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
587                                         std::uint32_t batch_size, bool) const
588 {
589 
590     return taylor_diff_bo_impl<double>(s, *this, deps, arr, par_ptr, n_uvars, order, idx, batch_size);
591 }
592 
taylor_diff_ldbl(llvm_state & s,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value *,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool) const593 llvm::Value *binary_op::taylor_diff_ldbl(llvm_state &s, const std::vector<std::uint32_t> &deps,
594                                          const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *,
595                                          std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
596                                          std::uint32_t batch_size, bool) const
597 {
598     return taylor_diff_bo_impl<long double>(s, *this, deps, arr, par_ptr, n_uvars, order, idx, batch_size);
599 }
600 
601 #if defined(HEYOKA_HAVE_REAL128)
602 
taylor_diff_f128(llvm_state & s,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value *,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool) const603 llvm::Value *binary_op::taylor_diff_f128(llvm_state &s, const std::vector<std::uint32_t> &deps,
604                                          const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *,
605                                          std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
606                                          std::uint32_t batch_size, bool) const
607 {
608     return taylor_diff_bo_impl<mppp::real128>(s, *this, deps, arr, par_ptr, n_uvars, order, idx, batch_size);
609 }
610 
611 #endif
612 
613 namespace
614 {
615 
616 // Helper to implement the function for the differentiation of
617 // 'number/param op number/param' in compact mode. The function will always return zero,
618 // unless the order is 0 (in which case it will return the result of the codegen).
619 template <typename T, typename U, typename V>
bo_taylor_c_diff_func_num_num(llvm_state & s,const binary_op & bo,const U & n0,const V & n1,std::uint32_t n_uvars,std::uint32_t batch_size,const std::string & op_name)620 llvm::Function *bo_taylor_c_diff_func_num_num(llvm_state &s, const binary_op &bo, const U &n0, const V &n1,
621                                               std::uint32_t n_uvars, std::uint32_t batch_size,
622                                               const std::string &op_name)
623 {
624     auto &module = s.module();
625     auto &builder = s.builder();
626     auto &context = s.context();
627 
628     // Fetch the floating-point type.
629     auto val_t = to_llvm_vector_type<T>(context, batch_size);
630 
631     // Fetch the function name and arguments.
632     const auto na_pair = taylor_c_diff_func_name_args<T>(context, op_name, n_uvars, batch_size, {n0, n1});
633     const auto &fname = na_pair.first;
634     const auto &fargs = na_pair.second;
635 
636     // Try to see if we already created the function.
637     auto f = module.getFunction(fname);
638 
639     if (f == nullptr) {
640         // The function was not created before, do it now.
641 
642         // Fetch the current insertion block.
643         auto orig_bb = builder.GetInsertBlock();
644 
645         // The return type is val_t.
646         auto *ft = llvm::FunctionType::get(val_t, fargs, false);
647         // Create the function
648         f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module);
649         assert(f != nullptr);
650 
651         // Fetch the necessary function arguments.
652         auto ord = f->args().begin();
653         auto par_ptr = f->args().begin() + 3;
654         auto num0 = f->args().begin() + 5;
655         auto num1 = f->args().begin() + 6;
656 
657         // Create a new basic block to start insertion into.
658         builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
659 
660         // Create the return value.
661         auto retval = builder.CreateAlloca(val_t);
662 
663         llvm_if_then_else(
664             s, builder.CreateICmpEQ(ord, builder.getInt32(0)),
665             [&]() {
666                 // If the order is zero, run the codegen.
667                 auto vnum0 = taylor_c_diff_numparam_codegen(s, n0, num0, par_ptr, batch_size);
668                 auto vnum1 = taylor_c_diff_numparam_codegen(s, n1, num1, par_ptr, batch_size);
669 
670                 switch (bo.op()) {
671                     case binary_op::type::add:
672                         builder.CreateStore(builder.CreateFAdd(vnum0, vnum1), retval);
673                         break;
674                     case binary_op::type::sub:
675                         builder.CreateStore(builder.CreateFSub(vnum0, vnum1), retval);
676                         break;
677                     case binary_op::type::mul:
678                         builder.CreateStore(builder.CreateFMul(vnum0, vnum1), retval);
679                         break;
680                     default:
681                         builder.CreateStore(builder.CreateFDiv(vnum0, vnum1), retval);
682                 }
683             },
684             [&]() {
685                 // Otherwise, return zero.
686                 builder.CreateStore(vector_splat(builder, codegen<T>(s, number{0.}), batch_size), retval);
687             });
688 
689         // Return the result.
690         builder.CreateRet(builder.CreateLoad(retval));
691 
692         // Verify.
693         s.verify_function(f);
694 
695         // Restore the original insertion block.
696         builder.SetInsertPoint(orig_bb);
697     } else {
698         // The function was created before. Check if the signatures match.
699         // NOTE: there could be a mismatch if the derivative function was created
700         // and then optimised - optimisation might remove arguments which are compile-time
701         // constants.
702         if (!compare_function_signature(f, val_t, fargs)) {
703             throw std::invalid_argument("Inconsistent function signature for the Taylor derivative of {}() "
704                                         "in compact mode detected"_format(op_name));
705         }
706     }
707 
708     return f;
709 }
710 
711 // Derivative of number/param +- number/param.
712 template <bool AddOrSub, typename T, typename U, typename V,
713           std::enable_if_t<std::conjunction_v<is_num_param<U>, is_num_param<V>>, int> = 0>
bo_taylor_c_diff_func_addsub_impl(llvm_state & s,const binary_op & bo,const U & num0,const V & num1,std::uint32_t n_uvars,std::uint32_t batch_size)714 llvm::Function *bo_taylor_c_diff_func_addsub_impl(llvm_state &s, const binary_op &bo, const U &num0, const V &num1,
715                                                   std::uint32_t n_uvars, std::uint32_t batch_size)
716 {
717     return bo_taylor_c_diff_func_num_num<T>(s, bo, num0, num1, n_uvars, batch_size, AddOrSub ? "add" : "sub");
718 }
719 
720 // Derivative of number +- var.
721 template <bool AddOrSub, typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
bo_taylor_c_diff_func_addsub_impl(llvm_state & s,const binary_op &,const U & n,const variable & var,std::uint32_t n_uvars,std::uint32_t batch_size)722 llvm::Function *bo_taylor_c_diff_func_addsub_impl(llvm_state &s, const binary_op &, const U &n, const variable &var,
723                                                   std::uint32_t n_uvars, std::uint32_t batch_size)
724 {
725     auto &module = s.module();
726     auto &builder = s.builder();
727     auto &context = s.context();
728 
729     // Fetch the floating-point type.
730     auto val_t = to_llvm_vector_type<T>(context, batch_size);
731 
732     // Fetch the function name and arguments.
733     const auto na_pair
734         = taylor_c_diff_func_name_args<T>(context, AddOrSub ? "add" : "sub", n_uvars, batch_size, {n, var});
735     const auto &fname = na_pair.first;
736     const auto &fargs = na_pair.second;
737 
738     // Try to see if we already created the function.
739     auto f = module.getFunction(fname);
740 
741     if (f == nullptr) {
742         // The function was not created before, do it now.
743 
744         // Fetch the current insertion block.
745         auto orig_bb = builder.GetInsertBlock();
746 
747         // The return type is val_t.
748         auto *ft = llvm::FunctionType::get(val_t, fargs, false);
749         // Create the function
750         f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module);
751         assert(f != nullptr);
752 
753         // Fetch the necessary function arguments.
754         auto order = f->args().begin();
755         auto diff_arr = f->args().begin() + 2;
756         auto par_ptr = f->args().begin() + 3;
757         auto num = f->args().begin() + 5;
758         auto var_idx = f->args().begin() + 6;
759 
760         // Create a new basic block to start insertion into.
761         builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
762 
763         // Create the return value.
764         auto retval = builder.CreateAlloca(val_t);
765 
766         llvm_if_then_else(
767             s, builder.CreateICmpEQ(order, builder.getInt32(0)),
768             [&]() {
769                 // For order zero, run the codegen.
770                 auto num_vec = taylor_c_diff_numparam_codegen(s, n, num, par_ptr, batch_size);
771                 auto ret = taylor_c_load_diff(s, diff_arr, n_uvars, builder.getInt32(0), var_idx);
772 
773                 builder.CreateStore(AddOrSub ? builder.CreateFAdd(num_vec, ret) : builder.CreateFSub(num_vec, ret),
774                                     retval);
775             },
776             [&]() {
777                 // Load the derivative.
778                 auto ret = taylor_c_load_diff(s, diff_arr, n_uvars, order, var_idx);
779 
780                 if constexpr (!AddOrSub) {
781                     ret = builder.CreateFNeg(ret);
782                 }
783 
784                 // Create the return value.
785                 builder.CreateStore(ret, retval);
786             });
787 
788         // Return the result.
789         builder.CreateRet(builder.CreateLoad(retval));
790 
791         // Verify.
792         s.verify_function(f);
793 
794         // Restore the original insertion block.
795         builder.SetInsertPoint(orig_bb);
796     } else {
797         // The function was created before. Check if the signatures match.
798         // NOTE: there could be a mismatch if the derivative function was created
799         // and then optimised - optimisation might remove arguments which are compile-time
800         // constants.
801         if (!compare_function_signature(f, val_t, fargs)) {
802             throw std::invalid_argument(
803                 "Inconsistent function signature for the Taylor derivative of addition in compact mode detected");
804         }
805     }
806 
807     return f;
808 }
809 
810 // Derivative of var +- number.
811 template <bool AddOrSub, typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
bo_taylor_c_diff_func_addsub_impl(llvm_state & s,const binary_op &,const variable & var,const U & n,std::uint32_t n_uvars,std::uint32_t batch_size)812 llvm::Function *bo_taylor_c_diff_func_addsub_impl(llvm_state &s, const binary_op &, const variable &var, const U &n,
813                                                   std::uint32_t n_uvars, std::uint32_t batch_size)
814 {
815     auto &module = s.module();
816     auto &builder = s.builder();
817     auto &context = s.context();
818 
819     // Fetch the floating-point type.
820     auto val_t = to_llvm_vector_type<T>(context, batch_size);
821 
822     // Fetch the function name and arguments.
823     const auto na_pair
824         = taylor_c_diff_func_name_args<T>(context, AddOrSub ? "add" : "sub", n_uvars, batch_size, {var, n});
825     const auto &fname = na_pair.first;
826     const auto &fargs = na_pair.second;
827 
828     // Try to see if we already created the function.
829     auto f = module.getFunction(fname);
830 
831     if (f == nullptr) {
832         // The function was not created before, do it now.
833 
834         // Fetch the current insertion block.
835         auto orig_bb = builder.GetInsertBlock();
836 
837         // The return type is val_t.
838         auto *ft = llvm::FunctionType::get(val_t, fargs, false);
839         // Create the function
840         f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module);
841         assert(f != nullptr);
842 
843         // Fetch the necessary arguments.
844         auto order = f->args().begin();
845         auto diff_arr = f->args().begin() + 2;
846         auto par_ptr = f->args().begin() + 3;
847         auto var_idx = f->args().begin() + 5;
848         auto num = f->args().begin() + 6;
849 
850         // Create a new basic block to start insertion into.
851         builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
852 
853         // Create the return value.
854         auto retval = builder.CreateAlloca(val_t);
855 
856         llvm_if_then_else(
857             s, builder.CreateICmpEQ(order, builder.getInt32(0)),
858             [&]() {
859                 // For order zero, run the codegen.
860                 auto ret = taylor_c_load_diff(s, diff_arr, n_uvars, builder.getInt32(0), var_idx);
861                 auto num_vec = taylor_c_diff_numparam_codegen(s, n, num, par_ptr, batch_size);
862 
863                 builder.CreateStore(AddOrSub ? builder.CreateFAdd(ret, num_vec) : builder.CreateFSub(ret, num_vec),
864                                     retval);
865             },
866             [&]() {
867                 // Create the return value.
868                 builder.CreateStore(taylor_c_load_diff(s, diff_arr, n_uvars, order, var_idx), retval);
869             });
870 
871         // Return the result.
872         builder.CreateRet(builder.CreateLoad(retval));
873 
874         // Verify.
875         s.verify_function(f);
876 
877         // Restore the original insertion block.
878         builder.SetInsertPoint(orig_bb);
879     } else {
880         // The function was created before. Check if the signatures match.
881         // NOTE: there could be a mismatch if the derivative function was created
882         // and then optimised - optimisation might remove arguments which are compile-time
883         // constants.
884         if (!compare_function_signature(f, val_t, fargs)) {
885             throw std::invalid_argument(
886                 "Inconsistent function signature for the Taylor derivative of addition in compact mode detected");
887         }
888     }
889 
890     return f;
891 }
892 
893 // Derivative of var +- var.
894 template <bool AddOrSub, typename T>
bo_taylor_c_diff_func_addsub_impl(llvm_state & s,const binary_op &,const variable & var0,const variable & var1,std::uint32_t n_uvars,std::uint32_t batch_size)895 llvm::Function *bo_taylor_c_diff_func_addsub_impl(llvm_state &s, const binary_op &, const variable &var0,
896                                                   const variable &var1, std::uint32_t n_uvars, std::uint32_t batch_size)
897 {
898     auto &module = s.module();
899     auto &builder = s.builder();
900     auto &context = s.context();
901 
902     // Fetch the floating-point type.
903     auto val_t = to_llvm_vector_type<T>(context, batch_size);
904 
905     // Fetch the function name and arguments.
906     const auto na_pair
907         = taylor_c_diff_func_name_args<T>(context, AddOrSub ? "add" : "sub", n_uvars, batch_size, {var0, var1});
908     const auto &fname = na_pair.first;
909     const auto &fargs = na_pair.second;
910 
911     // Try to see if we already created the function.
912     auto f = module.getFunction(fname);
913 
914     if (f == nullptr) {
915         // The function was not created before, do it now.
916 
917         // Fetch the current insertion block.
918         auto orig_bb = builder.GetInsertBlock();
919 
920         // The return type is val_t.
921         auto *ft = llvm::FunctionType::get(val_t, fargs, false);
922         // Create the function
923         f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module);
924         assert(f != nullptr);
925 
926         // Fetch the necessary function arguments.
927         auto order = f->args().begin();
928         auto diff_arr = f->args().begin() + 2;
929         auto var_idx0 = f->args().begin() + 5;
930         auto var_idx1 = f->args().begin() + 6;
931 
932         // Create a new basic block to start insertion into.
933         builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
934 
935         auto v0 = taylor_c_load_diff(s, diff_arr, n_uvars, order, var_idx0);
936         auto v1 = taylor_c_load_diff(s, diff_arr, n_uvars, order, var_idx1);
937 
938         // Create the return value.
939         if constexpr (AddOrSub) {
940             builder.CreateRet(builder.CreateFAdd(v0, v1));
941         } else {
942             builder.CreateRet(builder.CreateFSub(v0, v1));
943         }
944 
945         // Verify.
946         s.verify_function(f);
947 
948         // Restore the original insertion block.
949         builder.SetInsertPoint(orig_bb);
950     } else {
951         // The function was created before. Check if the signatures match.
952         // NOTE: there could be a mismatch if the derivative function was created
953         // and then optimised - optimisation might remove arguments which are compile-time
954         // constants.
955         if (!compare_function_signature(f, val_t, fargs)) {
956             throw std::invalid_argument(
957                 "Inconsistent function signature for the Taylor derivative of addition in compact mode detected");
958         }
959     }
960 
961     return f;
962 }
963 
964 // All the other cases.
965 // LCOV_EXCL_START
966 template <bool, typename, typename V1, typename V2,
967           std::enable_if_t<!std::conjunction_v<is_num_param<V1>, is_num_param<V2>>, int> = 0>
bo_taylor_c_diff_func_addsub_impl(llvm_state &,const binary_op &,const V1 &,const V2 &,std::uint32_t,std::uint32_t)968 llvm::Function *bo_taylor_c_diff_func_addsub_impl(llvm_state &, const binary_op &, const V1 &, const V2 &,
969                                                   std::uint32_t, std::uint32_t)
970 {
971     throw std::invalid_argument("An invalid argument type was encountered while trying to build the Taylor derivative "
972                                 "of add()/sub() in compact mode");
973 }
974 // LCOV_EXCL_STOP
975 
976 template <typename T>
bo_taylor_c_diff_func_add(llvm_state & s,const binary_op & bo,std::uint32_t n_uvars,std::uint32_t batch_size)977 llvm::Function *bo_taylor_c_diff_func_add(llvm_state &s, const binary_op &bo, std::uint32_t n_uvars,
978                                           std::uint32_t batch_size)
979 {
980     return std::visit(
981         [&](const auto &v1, const auto &v2) {
982             return bo_taylor_c_diff_func_addsub_impl<true, T>(s, bo, v1, v2, n_uvars, batch_size);
983         },
984         bo.lhs().value(), bo.rhs().value());
985 }
986 
987 template <typename T>
bo_taylor_c_diff_func_sub(llvm_state & s,const binary_op & bo,std::uint32_t n_uvars,std::uint32_t batch_size)988 llvm::Function *bo_taylor_c_diff_func_sub(llvm_state &s, const binary_op &bo, std::uint32_t n_uvars,
989                                           std::uint32_t batch_size)
990 {
991     return std::visit(
992         [&](const auto &v1, const auto &v2) {
993             return bo_taylor_c_diff_func_addsub_impl<false, T>(s, bo, v1, v2, n_uvars, batch_size);
994         },
995         bo.lhs().value(), bo.rhs().value());
996 }
997 
998 // Derivative of number/param * number/param.
999 template <typename T, typename U, typename V,
1000           std::enable_if_t<std::conjunction_v<is_num_param<U>, is_num_param<V>>, int> = 0>
bo_taylor_c_diff_func_mul_impl(llvm_state & s,const binary_op & bo,const U & num0,const V & num1,std::uint32_t n_uvars,std::uint32_t batch_size)1001 llvm::Function *bo_taylor_c_diff_func_mul_impl(llvm_state &s, const binary_op &bo, const U &num0, const V &num1,
1002                                                std::uint32_t n_uvars, std::uint32_t batch_size)
1003 {
1004     return bo_taylor_c_diff_func_num_num<T>(s, bo, num0, num1, n_uvars, batch_size, "mul");
1005 }
1006 
1007 // Derivative of var * number.
1008 template <typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
bo_taylor_c_diff_func_mul_impl(llvm_state & s,const binary_op &,const variable & var,const U & n,std::uint32_t n_uvars,std::uint32_t batch_size)1009 llvm::Function *bo_taylor_c_diff_func_mul_impl(llvm_state &s, const binary_op &, const variable &var, const U &n,
1010                                                std::uint32_t n_uvars, std::uint32_t batch_size)
1011 {
1012     auto &module = s.module();
1013     auto &builder = s.builder();
1014     auto &context = s.context();
1015 
1016     // Fetch the floating-point type.
1017     auto val_t = to_llvm_vector_type<T>(context, batch_size);
1018 
1019     // Fetch the function name and arguments.
1020     const auto na_pair = taylor_c_diff_func_name_args<T>(context, "mul", n_uvars, batch_size, {var, n});
1021     const auto &fname = na_pair.first;
1022     const auto &fargs = na_pair.second;
1023 
1024     // Try to see if we already created the function.
1025     auto f = module.getFunction(fname);
1026 
1027     if (f == nullptr) {
1028         // The function was not created before, do it now.
1029 
1030         // Fetch the current insertion block.
1031         auto orig_bb = builder.GetInsertBlock();
1032 
1033         // The return type is val_t.
1034         auto *ft = llvm::FunctionType::get(val_t, fargs, false);
1035         // Create the function
1036         f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module);
1037         assert(f != nullptr);
1038 
1039         // Fetch the necessary function arguments.
1040         auto order = f->args().begin();
1041         auto diff_arr = f->args().begin() + 2;
1042         auto par_ptr = f->args().begin() + 3;
1043         auto var_idx = f->args().begin() + 5;
1044         auto num = f->args().begin() + 6;
1045 
1046         // Create a new basic block to start insertion into.
1047         builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
1048 
1049         // Load the derivative.
1050         auto ret = taylor_c_load_diff(s, diff_arr, n_uvars, order, var_idx);
1051 
1052         // Create the return value.
1053         builder.CreateRet(builder.CreateFMul(ret, taylor_c_diff_numparam_codegen(s, n, num, par_ptr, batch_size)));
1054 
1055         // Verify.
1056         s.verify_function(f);
1057 
1058         // Restore the original insertion block.
1059         builder.SetInsertPoint(orig_bb);
1060     } else {
1061         // The function was created before. Check if the signatures match.
1062         // NOTE: there could be a mismatch if the derivative function was created
1063         // and then optimised - optimisation might remove arguments which are compile-time
1064         // constants.
1065         if (!compare_function_signature(f, val_t, fargs)) {
1066             throw std::invalid_argument(
1067                 "Inconsistent function signature for the Taylor derivative of multiplication in compact mode detected");
1068         }
1069     }
1070 
1071     return f;
1072 }
1073 
1074 // Derivative of number * var.
1075 template <typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
bo_taylor_c_diff_func_mul_impl(llvm_state & s,const binary_op &,const U & n,const variable & var,std::uint32_t n_uvars,std::uint32_t batch_size)1076 llvm::Function *bo_taylor_c_diff_func_mul_impl(llvm_state &s, const binary_op &, const U &n, const variable &var,
1077                                                std::uint32_t n_uvars, std::uint32_t batch_size)
1078 {
1079     auto &module = s.module();
1080     auto &builder = s.builder();
1081     auto &context = s.context();
1082 
1083     // Fetch the floating-point type.
1084     auto val_t = to_llvm_vector_type<T>(context, batch_size);
1085 
1086     // Fetch the function name and arguments.
1087     const auto na_pair = taylor_c_diff_func_name_args<T>(context, "mul", n_uvars, batch_size, {n, var});
1088     const auto &fname = na_pair.first;
1089     const auto &fargs = na_pair.second;
1090 
1091     // Try to see if we already created the function.
1092     auto f = module.getFunction(fname);
1093 
1094     if (f == nullptr) {
1095         // The function was not created before, do it now.
1096 
1097         // Fetch the current insertion block.
1098         auto orig_bb = builder.GetInsertBlock();
1099 
1100         // The return type is val_t.
1101         auto *ft = llvm::FunctionType::get(val_t, fargs, false);
1102         // Create the function
1103         f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module);
1104         assert(f != nullptr);
1105 
1106         // Fetch the necessary function arguments.
1107         auto order = f->args().begin();
1108         auto diff_arr = f->args().begin() + 2;
1109         auto par_ptr = f->args().begin() + 3;
1110         auto num = f->args().begin() + 5;
1111         auto var_idx = f->args().begin() + 6;
1112 
1113         // Create a new basic block to start insertion into.
1114         builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
1115 
1116         // Load the derivative.
1117         auto ret = taylor_c_load_diff(s, diff_arr, n_uvars, order, var_idx);
1118 
1119         // Create the return value.
1120         builder.CreateRet(builder.CreateFMul(ret, taylor_c_diff_numparam_codegen(s, n, num, par_ptr, batch_size)));
1121 
1122         // Verify.
1123         s.verify_function(f);
1124 
1125         // Restore the original insertion block.
1126         builder.SetInsertPoint(orig_bb);
1127     } else {
1128         // The function was created before. Check if the signatures match.
1129         // NOTE: there could be a mismatch if the derivative function was created
1130         // and then optimised - optimisation might remove arguments which are compile-time
1131         // constants.
1132         if (!compare_function_signature(f, val_t, fargs)) {
1133             throw std::invalid_argument(
1134                 "Inconsistent function signature for the Taylor derivative of multiplication in compact mode detected");
1135         }
1136     }
1137 
1138     return f;
1139 }
1140 
1141 // Derivative of var * var.
1142 template <typename T>
bo_taylor_c_diff_func_mul_impl(llvm_state & s,const binary_op &,const variable & var0,const variable & var1,std::uint32_t n_uvars,std::uint32_t batch_size)1143 llvm::Function *bo_taylor_c_diff_func_mul_impl(llvm_state &s, const binary_op &, const variable &var0,
1144                                                const variable &var1, std::uint32_t n_uvars, std::uint32_t batch_size)
1145 {
1146     auto &module = s.module();
1147     auto &builder = s.builder();
1148     auto &context = s.context();
1149 
1150     // Fetch the floating-point type.
1151     auto val_t = to_llvm_vector_type<T>(context, batch_size);
1152 
1153     // Fetch the function name and arguments.
1154     const auto na_pair = taylor_c_diff_func_name_args<T>(context, "mul", n_uvars, batch_size, {var0, var1});
1155     const auto &fname = na_pair.first;
1156     const auto &fargs = na_pair.second;
1157 
1158     // Try to see if we already created the function.
1159     auto f = module.getFunction(fname);
1160 
1161     if (f == nullptr) {
1162         // The function was not created before, do it now.
1163 
1164         // Fetch the current insertion block.
1165         auto orig_bb = builder.GetInsertBlock();
1166 
1167         // The return type is val_t.
1168         auto *ft = llvm::FunctionType::get(val_t, fargs, false);
1169         // Create the function
1170         f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module);
1171         assert(f != nullptr);
1172 
1173         // Fetch the necessary function arguments.
1174         auto ord = f->args().begin();
1175         auto diff_ptr = f->args().begin() + 2;
1176         auto idx0 = f->args().begin() + 5;
1177         auto idx1 = f->args().begin() + 6;
1178 
1179         // Create a new basic block to start insertion into.
1180         builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
1181 
1182         // Create the accumulator.
1183         auto acc = builder.CreateAlloca(val_t);
1184         builder.CreateStore(vector_splat(builder, codegen<T>(s, number{0.}), batch_size), acc);
1185 
1186         // Run the loop.
1187         llvm_loop_u32(s, builder.getInt32(0), builder.CreateAdd(ord, builder.getInt32(1)), [&](llvm::Value *j) {
1188             auto b_nj = taylor_c_load_diff(s, diff_ptr, n_uvars, builder.CreateSub(ord, j), idx0);
1189             auto cj = taylor_c_load_diff(s, diff_ptr, n_uvars, j, idx1);
1190             builder.CreateStore(builder.CreateFAdd(builder.CreateLoad(acc), builder.CreateFMul(b_nj, cj)), acc);
1191         });
1192 
1193         // Create the return value.
1194         builder.CreateRet(builder.CreateLoad(acc));
1195 
1196         // Verify.
1197         s.verify_function(f);
1198 
1199         // Restore the original insertion block.
1200         builder.SetInsertPoint(orig_bb);
1201     } else {
1202         // The function was created before. Check if the signatures match.
1203         // NOTE: there could be a mismatch if the derivative function was created
1204         // and then optimised - optimisation might remove arguments which are compile-time
1205         // constants.
1206         if (!compare_function_signature(f, val_t, fargs)) {
1207             throw std::invalid_argument(
1208                 "Inconsistent function signature for the Taylor derivative of multiplication in compact mode detected");
1209         }
1210     }
1211 
1212     return f;
1213 }
1214 
1215 // All the other cases.
1216 // LCOV_EXCL_START
1217 template <typename, typename V1, typename V2,
1218           std::enable_if_t<!std::conjunction_v<is_num_param<V1>, is_num_param<V2>>, int> = 0>
bo_taylor_c_diff_func_mul_impl(llvm_state &,const binary_op &,const V1 &,const V2 &,std::uint32_t,std::uint32_t)1219 llvm::Function *bo_taylor_c_diff_func_mul_impl(llvm_state &, const binary_op &, const V1 &, const V2 &, std::uint32_t,
1220                                                std::uint32_t)
1221 {
1222     throw std::invalid_argument("An invalid argument type was encountered while trying to build the Taylor derivative "
1223                                 "of mul() in compact mode");
1224 }
1225 // LCOV_EXCL_STOP
1226 
1227 template <typename T>
bo_taylor_c_diff_func_mul(llvm_state & s,const binary_op & bo,std::uint32_t n_uvars,std::uint32_t batch_size)1228 llvm::Function *bo_taylor_c_diff_func_mul(llvm_state &s, const binary_op &bo, std::uint32_t n_uvars,
1229                                           std::uint32_t batch_size)
1230 {
1231     return std::visit(
1232         [&](const auto &v1, const auto &v2) {
1233             return bo_taylor_c_diff_func_mul_impl<T>(s, bo, v1, v2, n_uvars, batch_size);
1234         },
1235         bo.lhs().value(), bo.rhs().value());
1236 }
1237 
1238 // Derivative of number/param / number/param.
1239 template <typename T, typename U, typename V,
1240           std::enable_if_t<std::conjunction_v<is_num_param<U>, is_num_param<V>>, int> = 0>
bo_taylor_c_diff_func_div_impl(llvm_state & s,const binary_op & bo,const U & num0,const V & num1,std::uint32_t n_uvars,std::uint32_t batch_size)1241 llvm::Function *bo_taylor_c_diff_func_div_impl(llvm_state &s, const binary_op &bo, const U &num0, const V &num1,
1242                                                std::uint32_t n_uvars, std::uint32_t batch_size)
1243 {
1244     return bo_taylor_c_diff_func_num_num<T>(s, bo, num0, num1, n_uvars, batch_size, "div");
1245 }
1246 
1247 // Derivative of var / number.
1248 template <typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
bo_taylor_c_diff_func_div_impl(llvm_state & s,const binary_op &,const variable & var,const U & n,std::uint32_t n_uvars,std::uint32_t batch_size)1249 llvm::Function *bo_taylor_c_diff_func_div_impl(llvm_state &s, const binary_op &, const variable &var, const U &n,
1250                                                std::uint32_t n_uvars, std::uint32_t batch_size)
1251 {
1252     auto &module = s.module();
1253     auto &builder = s.builder();
1254     auto &context = s.context();
1255 
1256     // Fetch the floating-point type.
1257     auto val_t = to_llvm_vector_type<T>(context, batch_size);
1258 
1259     // Fetch the function name and arguments.
1260     const auto na_pair = taylor_c_diff_func_name_args<T>(context, "div", n_uvars, batch_size, {var, n});
1261     const auto &fname = na_pair.first;
1262     const auto &fargs = na_pair.second;
1263 
1264     // Try to see if we already created the function.
1265     auto f = module.getFunction(fname);
1266 
1267     if (f == nullptr) {
1268         // The function was not created before, do it now.
1269 
1270         // Fetch the current insertion block.
1271         auto orig_bb = builder.GetInsertBlock();
1272 
1273         // The return type is val_t.
1274         auto *ft = llvm::FunctionType::get(val_t, fargs, false);
1275         // Create the function
1276         f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module);
1277         assert(f != nullptr);
1278 
1279         // Fetch the necessary function arguments.
1280         auto order = f->args().begin();
1281         auto diff_arr = f->args().begin() + 2;
1282         auto par_ptr = f->args().begin() + 3;
1283         auto var_idx = f->args().begin() + 5;
1284         auto num = f->args().begin() + 6;
1285 
1286         // Create a new basic block to start insertion into.
1287         builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
1288 
1289         // Load the derivative.
1290         auto ret = taylor_c_load_diff(s, diff_arr, n_uvars, order, var_idx);
1291 
1292         // Create the return value.
1293         builder.CreateRet(builder.CreateFDiv(ret, taylor_c_diff_numparam_codegen(s, n, num, par_ptr, batch_size)));
1294 
1295         // Verify.
1296         s.verify_function(f);
1297 
1298         // Restore the original insertion block.
1299         builder.SetInsertPoint(orig_bb);
1300     } else {
1301         // The function was created before. Check if the signatures match.
1302         // NOTE: there could be a mismatch if the derivative function was created
1303         // and then optimised - optimisation might remove arguments which are compile-time
1304         // constants.
1305         if (!compare_function_signature(f, val_t, fargs)) {
1306             throw std::invalid_argument(
1307                 "Inconsistent function signature for the Taylor derivative of division in compact mode detected");
1308         }
1309     }
1310 
1311     return f;
1312 }
1313 
1314 // Derivative of number / var.
1315 template <typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
bo_taylor_c_diff_func_div_impl(llvm_state & s,const binary_op &,const U & n,const variable & var,std::uint32_t n_uvars,std::uint32_t batch_size)1316 llvm::Function *bo_taylor_c_diff_func_div_impl(llvm_state &s, const binary_op &, const U &n, const variable &var,
1317                                                std::uint32_t n_uvars, std::uint32_t batch_size)
1318 {
1319     auto &module = s.module();
1320     auto &builder = s.builder();
1321     auto &context = s.context();
1322 
1323     // Fetch the floating-point type.
1324     auto val_t = to_llvm_vector_type<T>(context, batch_size);
1325 
1326     // Fetch the function name and arguments.
1327     const auto na_pair = taylor_c_diff_func_name_args<T>(context, "div", n_uvars, batch_size, {n, var});
1328     const auto &fname = na_pair.first;
1329     const auto &fargs = na_pair.second;
1330 
1331     // Try to see if we already created the function.
1332     auto f = module.getFunction(fname);
1333 
1334     if (f == nullptr) {
1335         // The function was not created before, do it now.
1336 
1337         // Fetch the current insertion block.
1338         auto orig_bb = builder.GetInsertBlock();
1339 
1340         // The return type is val_t.
1341         auto *ft = llvm::FunctionType::get(val_t, fargs, false);
1342         // Create the function
1343         f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module);
1344         assert(f != nullptr);
1345 
1346         // Fetch the necessary function arguments.
1347         // NOTE: we don't need the number argument because
1348         // we only need its derivative of order n >= 1,
1349         // which is always zero.
1350         auto ord = f->args().begin();
1351         auto u_idx = f->args().begin() + 1;
1352         auto diff_ptr = f->args().begin() + 2;
1353         auto par_ptr = f->args().begin() + 3;
1354         auto num = f->args().begin() + 5;
1355         auto var_idx = f->args().begin() + 6;
1356 
1357         // Create a new basic block to start insertion into.
1358         builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
1359 
1360         // Create the return value.
1361         auto retval = builder.CreateAlloca(val_t);
1362 
1363         // Create the accumulator.
1364         auto acc = builder.CreateAlloca(val_t);
1365 
1366         llvm_if_then_else(
1367             s, builder.CreateICmpEQ(ord, builder.getInt32(0)),
1368             [&]() {
1369                 // For order zero, run the codegen.
1370                 auto num_vec = taylor_c_diff_numparam_codegen(s, n, num, par_ptr, batch_size);
1371                 auto ret = taylor_c_load_diff(s, diff_ptr, n_uvars, builder.getInt32(0), var_idx);
1372 
1373                 builder.CreateStore(builder.CreateFDiv(num_vec, ret), retval);
1374             },
1375             [&]() {
1376                 // Init the accumulator.
1377                 builder.CreateStore(vector_splat(builder, codegen<T>(s, number{0.}), batch_size), acc);
1378 
1379                 // Run the loop.
1380                 llvm_loop_u32(s, builder.getInt32(1), builder.CreateAdd(ord, builder.getInt32(1)), [&](llvm::Value *j) {
1381                     auto cj = taylor_c_load_diff(s, diff_ptr, n_uvars, j, var_idx);
1382                     auto a_nj = taylor_c_load_diff(s, diff_ptr, n_uvars, builder.CreateSub(ord, j), u_idx);
1383                     builder.CreateStore(builder.CreateFAdd(builder.CreateLoad(acc), builder.CreateFMul(cj, a_nj)), acc);
1384                 });
1385 
1386                 // Negate the loop summation.
1387                 auto ret = builder.CreateFNeg(builder.CreateLoad(acc));
1388 
1389                 // Divide and return.
1390                 builder.CreateStore(
1391                     builder.CreateFDiv(ret, taylor_c_load_diff(s, diff_ptr, n_uvars, builder.getInt32(0), var_idx)),
1392                     retval);
1393             });
1394 
1395         // Return the result.
1396         builder.CreateRet(builder.CreateLoad(retval));
1397 
1398         // Verify.
1399         s.verify_function(f);
1400 
1401         // Restore the original insertion block.
1402         builder.SetInsertPoint(orig_bb);
1403     } else {
1404         // The function was created before. Check if the signatures match.
1405         // NOTE: there could be a mismatch if the derivative function was created
1406         // and then optimised - optimisation might remove arguments which are compile-time
1407         // constants.
1408         if (!compare_function_signature(f, val_t, fargs)) {
1409             throw std::invalid_argument(
1410                 "Inconsistent function signature for the Taylor derivative of division in compact mode detected");
1411         }
1412     }
1413 
1414     return f;
1415 }
1416 
1417 // Derivative of var / var.
1418 template <typename T>
bo_taylor_c_diff_func_div_impl(llvm_state & s,const binary_op &,const variable & var0,const variable & var1,std::uint32_t n_uvars,std::uint32_t batch_size)1419 llvm::Function *bo_taylor_c_diff_func_div_impl(llvm_state &s, const binary_op &, const variable &var0,
1420                                                const variable &var1, std::uint32_t n_uvars, std::uint32_t batch_size)
1421 {
1422     auto &module = s.module();
1423     auto &builder = s.builder();
1424     auto &context = s.context();
1425 
1426     // Fetch the floating-point type.
1427     auto val_t = to_llvm_vector_type<T>(context, batch_size);
1428 
1429     // Fetch the function name and arguments.
1430     const auto na_pair = taylor_c_diff_func_name_args<T>(context, "div", n_uvars, batch_size, {var0, var1});
1431     const auto &fname = na_pair.first;
1432     const auto &fargs = na_pair.second;
1433 
1434     // Try to see if we already created the function.
1435     auto f = module.getFunction(fname);
1436 
1437     if (f == nullptr) {
1438         // The function was not created before, do it now.
1439 
1440         // Fetch the current insertion block.
1441         auto orig_bb = builder.GetInsertBlock();
1442 
1443         // The return type is val_t.
1444         auto *ft = llvm::FunctionType::get(val_t, fargs, false);
1445         // Create the function
1446         f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module);
1447         assert(f != nullptr);
1448 
1449         // Fetch the necessary function arguments.
1450         auto ord = f->args().begin();
1451         auto u_idx = f->args().begin() + 1;
1452         auto diff_ptr = f->args().begin() + 2;
1453         auto var_idx0 = f->args().begin() + 5;
1454         auto var_idx1 = f->args().begin() + 6;
1455 
1456         // Create a new basic block to start insertion into.
1457         builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
1458 
1459         // Create the accumulator.
1460         auto acc = builder.CreateAlloca(val_t);
1461         builder.CreateStore(vector_splat(builder, codegen<T>(s, number{0.}), batch_size), acc);
1462 
1463         // Run the loop.
1464         llvm_loop_u32(s, builder.getInt32(1), builder.CreateAdd(ord, builder.getInt32(1)), [&](llvm::Value *j) {
1465             auto cj = taylor_c_load_diff(s, diff_ptr, n_uvars, j, var_idx1);
1466             auto a_nj = taylor_c_load_diff(s, diff_ptr, n_uvars, builder.CreateSub(ord, j), u_idx);
1467             builder.CreateStore(builder.CreateFAdd(builder.CreateLoad(acc), builder.CreateFMul(cj, a_nj)), acc);
1468         });
1469 
1470         auto ret = builder.CreateFSub(taylor_c_load_diff(s, diff_ptr, n_uvars, ord, var_idx0), builder.CreateLoad(acc));
1471 
1472         // Divide and return.
1473         builder.CreateRet(
1474             builder.CreateFDiv(ret, taylor_c_load_diff(s, diff_ptr, n_uvars, builder.getInt32(0), var_idx1)));
1475 
1476         // Verify.
1477         s.verify_function(f);
1478 
1479         // Restore the original insertion block.
1480         builder.SetInsertPoint(orig_bb);
1481     } else {
1482         // The function was created before. Check if the signatures match.
1483         // NOTE: there could be a mismatch if the derivative function was created
1484         // and then optimised - optimisation might remove arguments which are compile-time
1485         // constants.
1486         if (!compare_function_signature(f, val_t, fargs)) {
1487             throw std::invalid_argument(
1488                 "Inconsistent function signature for the Taylor derivative of division in compact mode detected");
1489         }
1490     }
1491 
1492     return f;
1493 }
1494 
1495 // All the other cases.
1496 // LCOV_EXCL_START
1497 template <typename, typename V1, typename V2,
1498           std::enable_if_t<!std::conjunction_v<is_num_param<V1>, is_num_param<V2>>, int> = 0>
bo_taylor_c_diff_func_div_impl(llvm_state &,const binary_op &,const V1 &,const V2 &,std::uint32_t,std::uint32_t)1499 llvm::Function *bo_taylor_c_diff_func_div_impl(llvm_state &, const binary_op &, const V1 &, const V2 &, std::uint32_t,
1500                                                std::uint32_t)
1501 {
1502     throw std::invalid_argument("An invalid argument type was encountered while trying to build the Taylor derivative "
1503                                 "of div() in compact mode");
1504 }
1505 // LCOV_EXCL_STOP
1506 
1507 template <typename T>
bo_taylor_c_diff_func_div(llvm_state & s,const binary_op & bo,std::uint32_t n_uvars,std::uint32_t batch_size)1508 llvm::Function *bo_taylor_c_diff_func_div(llvm_state &s, const binary_op &bo, std::uint32_t n_uvars,
1509                                           std::uint32_t batch_size)
1510 {
1511     return std::visit(
1512         [&](const auto &v1, const auto &v2) {
1513             return bo_taylor_c_diff_func_div_impl<T>(s, bo, v1, v2, n_uvars, batch_size);
1514         },
1515         bo.lhs().value(), bo.rhs().value());
1516 }
1517 
1518 template <typename T>
taylor_c_diff_func_bo_impl(llvm_state & s,const binary_op & bo,std::uint32_t n_uvars,std::uint32_t batch_size)1519 llvm::Function *taylor_c_diff_func_bo_impl(llvm_state &s, const binary_op &bo, std::uint32_t n_uvars,
1520                                            std::uint32_t batch_size)
1521 {
1522     switch (bo.op()) {
1523         case binary_op::type::add:
1524             return bo_taylor_c_diff_func_add<T>(s, bo, n_uvars, batch_size);
1525         case binary_op::type::sub:
1526             return bo_taylor_c_diff_func_sub<T>(s, bo, n_uvars, batch_size);
1527         case binary_op::type::mul:
1528             return bo_taylor_c_diff_func_mul<T>(s, bo, n_uvars, batch_size);
1529         default:
1530             return bo_taylor_c_diff_func_div<T>(s, bo, n_uvars, batch_size);
1531     }
1532 }
1533 
1534 } // namespace
1535 
taylor_c_diff_func_dbl(llvm_state & s,std::uint32_t n_uvars,std::uint32_t batch_size,bool) const1536 llvm::Function *binary_op::taylor_c_diff_func_dbl(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
1537                                                   bool) const
1538 {
1539     return taylor_c_diff_func_bo_impl<double>(s, *this, n_uvars, batch_size);
1540 }
1541 
taylor_c_diff_func_ldbl(llvm_state & s,std::uint32_t n_uvars,std::uint32_t batch_size,bool) const1542 llvm::Function *binary_op::taylor_c_diff_func_ldbl(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
1543                                                    bool) const
1544 {
1545     return taylor_c_diff_func_bo_impl<long double>(s, *this, n_uvars, batch_size);
1546 }
1547 
1548 #if defined(HEYOKA_HAVE_REAL128)
1549 
taylor_c_diff_func_f128(llvm_state & s,std::uint32_t n_uvars,std::uint32_t batch_size,bool) const1550 llvm::Function *binary_op::taylor_c_diff_func_f128(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
1551                                                    bool) const
1552 {
1553     return taylor_c_diff_func_bo_impl<mppp::real128>(s, *this, n_uvars, batch_size);
1554 }
1555 
1556 #endif
1557 
1558 } // namespace detail
1559 
add(expression x,expression y)1560 expression add(expression x, expression y)
1561 {
1562     return expression{func{detail::binary_op(detail::binary_op::type::add, std::move(x), std::move(y))}};
1563 }
1564 
sub(expression x,expression y)1565 expression sub(expression x, expression y)
1566 {
1567     return expression{func{detail::binary_op(detail::binary_op::type::sub, std::move(x), std::move(y))}};
1568 }
1569 
mul(expression x,expression y)1570 expression mul(expression x, expression y)
1571 {
1572     return expression{func{detail::binary_op(detail::binary_op::type::mul, std::move(x), std::move(y))}};
1573 }
1574 
div(expression x,expression y)1575 expression div(expression x, expression y)
1576 {
1577     return expression{func{detail::binary_op(detail::binary_op::type::div, std::move(x), std::move(y))}};
1578 }
1579 
1580 } // namespace heyoka
1581 
1582 HEYOKA_S11N_FUNC_EXPORT_IMPLEMENT(heyoka::detail::binary_op)
1583