1 // Copyright 2020, 2021 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com)
2 //
3 // This file is part of the heyoka library.
4 //
5 // This Source Code Form is subject to the terms of the Mozilla
6 // Public License v. 2.0. If a copy of the MPL was not distributed
7 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
8
9 #include <heyoka/config.hpp>
10
11 #include <algorithm>
12 #include <cassert>
13 #include <cstddef>
14 #include <cstdint>
15 #include <functional>
16 #include <initializer_list>
17 #include <ostream>
18 #include <stdexcept>
19 #include <string>
20 #include <type_traits>
21 #include <unordered_map>
22 #include <utility>
23 #include <variant>
24 #include <vector>
25
26 #include <fmt/format.h>
27
28 #include <llvm/IR/BasicBlock.h>
29 #include <llvm/IR/DerivedTypes.h>
30 #include <llvm/IR/Function.h>
31 #include <llvm/IR/IRBuilder.h>
32 #include <llvm/IR/LLVMContext.h>
33 #include <llvm/IR/Module.h>
34 #include <llvm/IR/Type.h>
35 #include <llvm/IR/Value.h>
36
37 #if defined(HEYOKA_HAVE_REAL128)
38
39 #include <mp++/real128.hpp>
40
41 #endif
42
43 #include <heyoka/detail/fwd_decl.hpp>
44 #include <heyoka/detail/llvm_helpers.hpp>
45 #include <heyoka/detail/string_conv.hpp>
46 #include <heyoka/expression.hpp>
47 #include <heyoka/func.hpp>
48 #include <heyoka/llvm_state.hpp>
49 #include <heyoka/math/binary_op.hpp>
50 #include <heyoka/number.hpp>
51 #include <heyoka/s11n.hpp>
52 #include <heyoka/taylor.hpp>
53 #include <heyoka/variable.hpp>
54
55 #if defined(_MSC_VER) && !defined(__clang__)
56
57 // NOTE: MSVC has issues with the other "using"
58 // statement form.
59 using namespace fmt::literals;
60
61 #else
62
63 using fmt::literals::operator""_format;
64
65 #endif
66
67 namespace heyoka
68 {
69
70 namespace detail
71 {
72
binary_op()73 binary_op::binary_op() : binary_op(type::add, 0_dbl, 0_dbl) {}
74
binary_op(type t,expression a,expression b)75 binary_op::binary_op(type t, expression a, expression b)
76 : func_base("binary_op", std::vector{std::move(a), std::move(b)}), m_type(t)
77 {
78 assert(m_type >= type::add && m_type <= type::div);
79 }
80
extra_equal_to(const func & f) const81 bool binary_op::extra_equal_to(const func &f) const
82 {
83 // NOTE: this should be ensured by the
84 // implementation of func's equality operator.
85 assert(f.extract<binary_op>() == f.get_ptr());
86
87 return static_cast<const binary_op *>(f.get_ptr())->m_type == m_type;
88 }
89
extra_hash() const90 std::size_t binary_op::extra_hash() const
91 {
92 return std::hash<type>{}(m_type);
93 }
94
to_stream(std::ostream & os) const95 void binary_op::to_stream(std::ostream &os) const
96 {
97 assert(args().size() == 2u);
98 assert(m_type >= type::add && m_type <= type::div);
99
100 os << '(' << lhs() << ' ';
101
102 switch (m_type) {
103 case type::add:
104 os << '+';
105 break;
106 case type::sub:
107 os << '-';
108 break;
109 case type::mul:
110 os << '*';
111 break;
112 default:
113 os << '/';
114 break;
115 }
116
117 os << ' ' << rhs() << ')';
118 }
119
op() const120 binary_op::type binary_op::op() const
121 {
122 return m_type;
123 }
124
lhs() const125 const expression &binary_op::lhs() const
126 {
127 assert(args().size() == 2u);
128 return args()[0];
129 }
130
rhs() const131 const expression &binary_op::rhs() const
132 {
133 assert(args().size() == 2u);
134 return args()[1];
135 }
136
137 template <typename T>
diff_impl(std::unordered_map<const void *,expression> & func_map,const T & x) const138 expression binary_op::diff_impl(std::unordered_map<const void *, expression> &func_map, const T &x) const
139 {
140 assert(args().size() == 2u);
141 assert(m_type >= type::add && m_type <= type::div);
142
143 switch (m_type) {
144 case type::add:
145 return detail::diff(func_map, lhs(), x) + detail::diff(func_map, rhs(), x);
146 case type::sub:
147 return detail::diff(func_map, lhs(), x) - detail::diff(func_map, rhs(), x);
148 case type::mul:
149 return detail::diff(func_map, lhs(), x) * rhs() + lhs() * detail::diff(func_map, rhs(), x);
150 default:
151 return (detail::diff(func_map, lhs(), x) * rhs() - lhs() * detail::diff(func_map, rhs(), x))
152 / (rhs() * rhs());
153 }
154 }
155
diff(std::unordered_map<const void *,expression> & func_map,const std::string & s) const156 expression binary_op::diff(std::unordered_map<const void *, expression> &func_map, const std::string &s) const
157 {
158 return diff_impl(func_map, s);
159 }
160
diff(std::unordered_map<const void *,expression> & func_map,const param & p) const161 expression binary_op::diff(std::unordered_map<const void *, expression> &func_map, const param &p) const
162 {
163 return diff_impl(func_map, p);
164 }
165
166 namespace
167 {
168
169 template <class T>
eval_bo_impl(const binary_op & bo,const std::unordered_map<std::string,T> & map,const std::vector<T> & pars)170 T eval_bo_impl(const binary_op &bo, const std::unordered_map<std::string, T> &map, const std::vector<T> &pars)
171 {
172 assert(bo.args().size() == 2u);
173 assert(bo.op() >= binary_op::type::add && bo.op() <= binary_op::type::div);
174
175 switch (bo.op()) {
176 case binary_op::type::add:
177 return eval<T>(bo.lhs(), map, pars) + eval<T>(bo.rhs(), map, pars);
178 case binary_op::type::sub:
179 return eval<T>(bo.lhs(), map, pars) - eval<T>(bo.rhs(), map, pars);
180 case binary_op::type::mul:
181 return eval<T>(bo.lhs(), map, pars) * eval<T>(bo.rhs(), map, pars);
182 default:
183 return eval<T>(bo.lhs(), map, pars) / eval<T>(bo.rhs(), map, pars);
184 }
185 }
186
187 } // namespace
188
eval_dbl(const std::unordered_map<std::string,double> & map,const std::vector<double> & pars) const189 double binary_op::eval_dbl(const std::unordered_map<std::string, double> &map, const std::vector<double> &pars) const
190 {
191 return eval_bo_impl<double>(*this, map, pars);
192 }
193
eval_ldbl(const std::unordered_map<std::string,long double> & map,const std::vector<long double> & pars) const194 long double binary_op::eval_ldbl(const std::unordered_map<std::string, long double> &map,
195 const std::vector<long double> &pars) const
196 {
197 return eval_bo_impl<long double>(*this, map, pars);
198 }
199
200 #if defined(HEYOKA_HAVE_REAL128)
201
eval_f128(const std::unordered_map<std::string,mppp::real128> & map,const std::vector<mppp::real128> & pars) const202 mppp::real128 binary_op::eval_f128(const std::unordered_map<std::string, mppp::real128> &map,
203 const std::vector<mppp::real128> &pars) const
204 {
205 return eval_bo_impl<mppp::real128>(*this, map, pars);
206 }
207
208 #endif
209
eval_batch_dbl(std::vector<double> & out_values,const std::unordered_map<std::string,std::vector<double>> & map,const std::vector<double> & pars) const210 void binary_op::eval_batch_dbl(std::vector<double> &out_values,
211 const std::unordered_map<std::string, std::vector<double>> &map,
212 const std::vector<double> &pars) const
213 {
214 assert(args().size() == 2u);
215 assert(m_type >= type::add && m_type <= type::div);
216
217 auto tmp = out_values;
218 heyoka::eval_batch_dbl(out_values, lhs(), map, pars);
219 heyoka::eval_batch_dbl(tmp, rhs(), map, pars);
220 switch (m_type) {
221 case type::add:
222 std::transform(out_values.begin(), out_values.end(), tmp.begin(), out_values.begin(), std::plus<>());
223 break;
224 case type::sub:
225 std::transform(out_values.begin(), out_values.end(), tmp.begin(), out_values.begin(), std::minus<>());
226 break;
227 case type::mul:
228 std::transform(out_values.begin(), out_values.end(), tmp.begin(), out_values.begin(), std::multiplies<>());
229 break;
230 default:
231 std::transform(out_values.begin(), out_values.end(), tmp.begin(), out_values.begin(), std::divides<>());
232 break;
233 }
234 }
235
236 namespace
237 {
238
239 // Derivative of number +- number.
240 template <bool AddOrSub, typename T, typename U, typename V,
241 std::enable_if_t<std::conjunction_v<is_num_param<U>, is_num_param<V>>, int> = 0>
bo_taylor_diff_addsub_impl(llvm_state & s,const U & num0,const V & num1,const std::vector<llvm::Value * > &,llvm::Value * par_ptr,std::uint32_t,std::uint32_t order,std::uint32_t,std::uint32_t batch_size)242 llvm::Value *bo_taylor_diff_addsub_impl(llvm_state &s, const U &num0, const V &num1, const std::vector<llvm::Value *> &,
243 llvm::Value *par_ptr, std::uint32_t, std::uint32_t order, std::uint32_t,
244 std::uint32_t batch_size)
245 {
246 if (order == 0u) {
247 auto n0 = taylor_codegen_numparam<T>(s, num0, par_ptr, batch_size);
248 auto n1 = taylor_codegen_numparam<T>(s, num1, par_ptr, batch_size);
249
250 return AddOrSub ? s.builder().CreateFAdd(n0, n1) : s.builder().CreateFSub(n0, n1);
251 } else {
252 return vector_splat(s.builder(), codegen<T>(s, number{0.}), batch_size);
253 }
254 }
255
256 // Derivative of number +- var.
257 template <bool AddOrSub, typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
bo_taylor_diff_addsub_impl(llvm_state & s,const U & num,const variable & var,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t,std::uint32_t batch_size)258 llvm::Value *bo_taylor_diff_addsub_impl(llvm_state &s, const U &num, const variable &var,
259 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr,
260 std::uint32_t n_uvars, std::uint32_t order, std::uint32_t,
261 std::uint32_t batch_size)
262 {
263 auto &builder = s.builder();
264
265 auto ret = taylor_fetch_diff(arr, uname_to_index(var.name()), order, n_uvars);
266
267 if (order == 0u) {
268 auto n = taylor_codegen_numparam<T>(s, num, par_ptr, batch_size);
269
270 return AddOrSub ? builder.CreateFAdd(n, ret) : builder.CreateFSub(n, ret);
271 } else {
272 if constexpr (AddOrSub) {
273 return ret;
274 } else {
275 // Negate if we are doing a subtraction.
276 return builder.CreateFNeg(ret);
277 }
278 }
279 }
280
281 // Derivative of var +- number.
282 template <bool AddOrSub, typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
bo_taylor_diff_addsub_impl(llvm_state & s,const variable & var,const U & num,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t,std::uint32_t batch_size)283 llvm::Value *bo_taylor_diff_addsub_impl(llvm_state &s, const variable &var, const U &num,
284 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr,
285 std::uint32_t n_uvars, std::uint32_t order, std::uint32_t,
286 std::uint32_t batch_size)
287 {
288 auto ret = taylor_fetch_diff(arr, uname_to_index(var.name()), order, n_uvars);
289
290 if (order == 0u) {
291 auto &builder = s.builder();
292
293 auto n = taylor_codegen_numparam<T>(s, num, par_ptr, batch_size);
294
295 return AddOrSub ? builder.CreateFAdd(ret, n) : builder.CreateFSub(ret, n);
296 } else {
297 return ret;
298 }
299 }
300
301 // Derivative of var +- var.
302 template <bool AddOrSub, typename T>
bo_taylor_diff_addsub_impl(llvm_state & s,const variable & var0,const variable & var1,const std::vector<llvm::Value * > & arr,llvm::Value *,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t,std::uint32_t)303 llvm::Value *bo_taylor_diff_addsub_impl(llvm_state &s, const variable &var0, const variable &var1,
304 const std::vector<llvm::Value *> &arr, llvm::Value *, std::uint32_t n_uvars,
305 std::uint32_t order, std::uint32_t, std::uint32_t)
306 {
307 auto v0 = taylor_fetch_diff(arr, uname_to_index(var0.name()), order, n_uvars);
308 auto v1 = taylor_fetch_diff(arr, uname_to_index(var1.name()), order, n_uvars);
309
310 if constexpr (AddOrSub) {
311 return s.builder().CreateFAdd(v0, v1);
312 } else {
313 return s.builder().CreateFSub(v0, v1);
314 }
315 }
316
317 // All the other cases.
318 // LCOV_EXCL_START
319 template <bool, typename, typename V1, typename V2,
320 std::enable_if_t<!std::conjunction_v<is_num_param<V1>, is_num_param<V2>>, int> = 0>
bo_taylor_diff_addsub_impl(llvm_state &,const V1 &,const V2 &,const std::vector<llvm::Value * > &,llvm::Value *,std::uint32_t,std::uint32_t,std::uint32_t,std::uint32_t)321 llvm::Value *bo_taylor_diff_addsub_impl(llvm_state &, const V1 &, const V2 &, const std::vector<llvm::Value *> &,
322 llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t, std::uint32_t)
323 {
324 throw std::invalid_argument(
325 "An invalid argument type was encountered while trying to build the Taylor derivative of add()/sub()");
326 }
327 // LCOV_EXCL_STOP
328
329 template <typename T>
bo_taylor_diff_add(llvm_state & s,const binary_op & bo,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size)330 llvm::Value *bo_taylor_diff_add(llvm_state &s, const binary_op &bo, const std::vector<llvm::Value *> &arr,
331 llvm::Value *par_ptr, std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
332 std::uint32_t batch_size)
333 {
334 return std::visit(
335 [&](const auto &v1, const auto &v2) {
336 return bo_taylor_diff_addsub_impl<true, T>(s, v1, v2, arr, par_ptr, n_uvars, order, idx, batch_size);
337 },
338 bo.lhs().value(), bo.rhs().value());
339 }
340
341 template <typename T>
bo_taylor_diff_sub(llvm_state & s,const binary_op & bo,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size)342 llvm::Value *bo_taylor_diff_sub(llvm_state &s, const binary_op &bo, const std::vector<llvm::Value *> &arr,
343 llvm::Value *par_ptr, std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
344 std::uint32_t batch_size)
345 {
346 return std::visit(
347 [&](const auto &v1, const auto &v2) {
348 return bo_taylor_diff_addsub_impl<false, T>(s, v1, v2, arr, par_ptr, n_uvars, order, idx, batch_size);
349 },
350 bo.lhs().value(), bo.rhs().value());
351 }
352
353 // Derivative of number * number.
354 template <typename T, typename U, typename V,
355 std::enable_if_t<std::conjunction_v<is_num_param<U>, is_num_param<V>>, int> = 0>
bo_taylor_diff_mul_impl(llvm_state & s,const U & num0,const V & num1,const std::vector<llvm::Value * > &,llvm::Value * par_ptr,std::uint32_t,std::uint32_t order,std::uint32_t,std::uint32_t batch_size)356 llvm::Value *bo_taylor_diff_mul_impl(llvm_state &s, const U &num0, const V &num1, const std::vector<llvm::Value *> &,
357 llvm::Value *par_ptr, std::uint32_t, std::uint32_t order, std::uint32_t,
358 std::uint32_t batch_size)
359 {
360 if (order == 0u) {
361 auto n0 = taylor_codegen_numparam<T>(s, num0, par_ptr, batch_size);
362 auto n1 = taylor_codegen_numparam<T>(s, num1, par_ptr, batch_size);
363
364 return s.builder().CreateFMul(n0, n1);
365 } else {
366 return vector_splat(s.builder(), codegen<T>(s, number{0.}), batch_size);
367 }
368 }
369
370 // Derivative of var * number.
371 template <typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
bo_taylor_diff_mul_impl(llvm_state & s,const variable & var,const U & num,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t,std::uint32_t batch_size)372 llvm::Value *bo_taylor_diff_mul_impl(llvm_state &s, const variable &var, const U &num,
373 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, std::uint32_t n_uvars,
374 std::uint32_t order, std::uint32_t, std::uint32_t batch_size)
375 {
376 auto &builder = s.builder();
377
378 auto ret = taylor_fetch_diff(arr, uname_to_index(var.name()), order, n_uvars);
379 auto mul = taylor_codegen_numparam<T>(s, num, par_ptr, batch_size);
380
381 return builder.CreateFMul(mul, ret);
382 }
383
384 // Derivative of number * var.
385 template <typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
bo_taylor_diff_mul_impl(llvm_state & s,const U & num,const variable & var,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size)386 llvm::Value *bo_taylor_diff_mul_impl(llvm_state &s, const U &num, const variable &var,
387 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, std::uint32_t n_uvars,
388 std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size)
389 {
390 // Return the derivative of var * number.
391 return bo_taylor_diff_mul_impl<T>(s, var, num, arr, par_ptr, n_uvars, order, idx, batch_size);
392 }
393
394 // Derivative of var * var.
395 template <typename T>
bo_taylor_diff_mul_impl(llvm_state & s,const variable & var0,const variable & var1,const std::vector<llvm::Value * > & arr,llvm::Value *,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t,std::uint32_t)396 llvm::Value *bo_taylor_diff_mul_impl(llvm_state &s, const variable &var0, const variable &var1,
397 const std::vector<llvm::Value *> &arr, llvm::Value *, std::uint32_t n_uvars,
398 std::uint32_t order, std::uint32_t, std::uint32_t)
399 {
400 // Fetch the indices of the u variables.
401 const auto u_idx0 = uname_to_index(var0.name());
402 const auto u_idx1 = uname_to_index(var1.name());
403
404 // NOTE: iteration in the [0, order] range
405 // (i.e., order inclusive).
406 std::vector<llvm::Value *> sum;
407 auto &builder = s.builder();
408 for (std::uint32_t j = 0; j <= order; ++j) {
409 auto v0 = taylor_fetch_diff(arr, u_idx0, order - j, n_uvars);
410 auto v1 = taylor_fetch_diff(arr, u_idx1, j, n_uvars);
411
412 // Add v0*v1 to the sum.
413 sum.push_back(builder.CreateFMul(v0, v1));
414 }
415
416 return pairwise_sum(builder, sum);
417 }
418
419 // All the other cases.
420 // LCOV_EXCL_START
421 template <typename, typename V1, typename V2,
422 std::enable_if_t<!std::conjunction_v<is_num_param<V1>, is_num_param<V2>>, int> = 0>
bo_taylor_diff_mul_impl(llvm_state &,const V1 &,const V2 &,const std::vector<llvm::Value * > &,llvm::Value *,std::uint32_t,std::uint32_t,std::uint32_t,std::uint32_t)423 llvm::Value *bo_taylor_diff_mul_impl(llvm_state &, const V1 &, const V2 &, const std::vector<llvm::Value *> &,
424 llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t, std::uint32_t)
425 {
426 throw std::invalid_argument(
427 "An invalid argument type was encountered while trying to build the Taylor derivative of mul()");
428 }
429 // LCOV_EXCL_STOP
430
431 template <typename T>
bo_taylor_diff_mul(llvm_state & s,const binary_op & bo,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size)432 llvm::Value *bo_taylor_diff_mul(llvm_state &s, const binary_op &bo, const std::vector<llvm::Value *> &arr,
433 llvm::Value *par_ptr, std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
434 std::uint32_t batch_size)
435 {
436 return std::visit(
437 [&](const auto &v1, const auto &v2) {
438 return bo_taylor_diff_mul_impl<T>(s, v1, v2, arr, par_ptr, n_uvars, order, idx, batch_size);
439 },
440 bo.lhs().value(), bo.rhs().value());
441 }
442
443 // Derivative of number / number.
444 template <typename T, typename U, typename V,
445 std::enable_if_t<std::conjunction_v<is_num_param<U>, is_num_param<V>>, int> = 0>
bo_taylor_diff_div_impl(llvm_state & s,const U & num0,const V & num1,const std::vector<llvm::Value * > &,llvm::Value * par_ptr,std::uint32_t,std::uint32_t order,std::uint32_t,std::uint32_t batch_size)446 llvm::Value *bo_taylor_diff_div_impl(llvm_state &s, const U &num0, const V &num1, const std::vector<llvm::Value *> &,
447 llvm::Value *par_ptr, std::uint32_t, std::uint32_t order, std::uint32_t,
448 std::uint32_t batch_size)
449 {
450 if (order == 0u) {
451 auto n0 = taylor_codegen_numparam<T>(s, num0, par_ptr, batch_size);
452 auto n1 = taylor_codegen_numparam<T>(s, num1, par_ptr, batch_size);
453
454 return s.builder().CreateFDiv(n0, n1);
455 } else {
456 return vector_splat(s.builder(), codegen<T>(s, number{0.}), batch_size);
457 }
458 }
459
460 // Derivative of variable / variable or number / variable. These two cases
461 // are quite similar, so we handle them together.
462 template <typename T, typename U,
463 std::enable_if_t<
464 std::disjunction_v<std::is_same<U, number>, std::is_same<U, variable>, std::is_same<U, param>>, int> = 0>
bo_taylor_diff_div_impl(llvm_state & s,const U & nv,const variable & var1,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size)465 llvm::Value *bo_taylor_diff_div_impl(llvm_state &s, const U &nv, const variable &var1,
466 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, std::uint32_t n_uvars,
467 std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size)
468 {
469 auto &builder = s.builder();
470
471 // Fetch the index of var1.
472 const auto u_idx1 = uname_to_index(var1.name());
473
474 if (order == 0u) {
475 // Special casing for zero order.
476 auto numerator = [&]() -> llvm::Value * {
477 if constexpr (std::is_same_v<U, number> || std::is_same_v<U, param>) {
478 return taylor_codegen_numparam<T>(s, nv, par_ptr, batch_size);
479 } else {
480 return taylor_fetch_diff(arr, uname_to_index(nv.name()), 0, n_uvars);
481 }
482 }();
483
484 return builder.CreateFDiv(numerator, taylor_fetch_diff(arr, u_idx1, 0, n_uvars));
485 }
486
487 // NOTE: iteration in the [1, order] range
488 // (i.e., order inclusive).
489 std::vector<llvm::Value *> sum;
490 for (std::uint32_t j = 1; j <= order; ++j) {
491 auto v0 = taylor_fetch_diff(arr, idx, order - j, n_uvars);
492 auto v1 = taylor_fetch_diff(arr, u_idx1, j, n_uvars);
493
494 // Add v0*v1 to the sum.
495 sum.push_back(builder.CreateFMul(v0, v1));
496 }
497
498 // Init the return value as the result of the sum.
499 auto ret_acc = pairwise_sum(builder, sum);
500
501 // Load the divisor for the quotient formula.
502 // This is the zero-th order derivative of var1.
503 auto div = taylor_fetch_diff(arr, u_idx1, 0, n_uvars);
504
505 if constexpr (std::is_same_v<U, number> || std::is_same_v<U, param>) {
506 // nv is a number/param. Negate the accumulator
507 // and divide it by the divisor.
508 return builder.CreateFDiv(builder.CreateFNeg(ret_acc), div);
509 } else {
510 // nv is a variable. We need to fetch its
511 // derivative of order 'order' from the array of derivatives.
512 auto diff_nv_v = taylor_fetch_diff(arr, uname_to_index(nv.name()), order, n_uvars);
513
514 // Produce the result: (diff_nv_v - ret_acc) / div.
515 return builder.CreateFDiv(builder.CreateFSub(diff_nv_v, ret_acc), div);
516 }
517 }
518
519 // Derivative of variable / number.
520 template <typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
bo_taylor_diff_div_impl(llvm_state & s,const variable & var,const U & num,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t,std::uint32_t batch_size)521 llvm::Value *bo_taylor_diff_div_impl(llvm_state &s, const variable &var, const U &num,
522 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, std::uint32_t n_uvars,
523 std::uint32_t order, std::uint32_t, std::uint32_t batch_size)
524 {
525 auto &builder = s.builder();
526
527 auto ret = taylor_fetch_diff(arr, uname_to_index(var.name()), order, n_uvars);
528 auto div = taylor_codegen_numparam<T>(s, num, par_ptr, batch_size);
529
530 return builder.CreateFDiv(ret, div);
531 }
532
533 // All the other cases.
534 // LCOV_EXCL_START
535 template <typename, typename V1, typename V2,
536 std::enable_if_t<!std::conjunction_v<is_num_param<V1>, is_num_param<V2>>, int> = 0>
bo_taylor_diff_div_impl(llvm_state &,const V1 &,const V2 &,const std::vector<llvm::Value * > &,llvm::Value *,std::uint32_t,std::uint32_t,std::uint32_t,std::uint32_t)537 llvm::Value *bo_taylor_diff_div_impl(llvm_state &, const V1 &, const V2 &, const std::vector<llvm::Value *> &,
538 llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t, std::uint32_t)
539 {
540 throw std::invalid_argument(
541 "An invalid argument type was encountered while trying to build the Taylor derivative of div()");
542 }
543 // LCOV_EXCL_STOP
544
545 template <typename T>
bo_taylor_diff_div(llvm_state & s,const binary_op & bo,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size)546 llvm::Value *bo_taylor_diff_div(llvm_state &s, const binary_op &bo, const std::vector<llvm::Value *> &arr,
547 llvm::Value *par_ptr, std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
548 std::uint32_t batch_size)
549 {
550 return std::visit(
551 [&](const auto &v1, const auto &v2) {
552 return bo_taylor_diff_div_impl<T>(s, v1, v2, arr, par_ptr, n_uvars, order, idx, batch_size);
553 },
554 bo.lhs().value(), bo.rhs().value());
555 }
556
557 template <typename T>
taylor_diff_bo_impl(llvm_state & s,const binary_op & bo,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size)558 llvm::Value *taylor_diff_bo_impl(llvm_state &s, const binary_op &bo, const std::vector<std::uint32_t> &deps,
559 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, std::uint32_t n_uvars,
560 std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size)
561 {
562 assert(bo.args().size() == 2u);
563 assert(bo.op() >= binary_op::type::add && bo.op() <= binary_op::type::div);
564
565 if (!deps.empty()) {
566 throw std::invalid_argument("The vector of hidden dependencies in the Taylor diff for a binary operator "
567 "should be empty, but instead it has a size of {}"_format(deps.size()));
568 }
569
570 switch (bo.op()) {
571 case binary_op::type::add:
572 return bo_taylor_diff_add<T>(s, bo, arr, par_ptr, n_uvars, order, idx, batch_size);
573 case binary_op::type::sub:
574 return bo_taylor_diff_sub<T>(s, bo, arr, par_ptr, n_uvars, order, idx, batch_size);
575 case binary_op::type::mul:
576 return bo_taylor_diff_mul<T>(s, bo, arr, par_ptr, n_uvars, order, idx, batch_size);
577 default:
578 return bo_taylor_diff_div<T>(s, bo, arr, par_ptr, n_uvars, order, idx, batch_size);
579 }
580 }
581
582 } // namespace
583
taylor_diff_dbl(llvm_state & s,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value *,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool) const584 llvm::Value *binary_op::taylor_diff_dbl(llvm_state &s, const std::vector<std::uint32_t> &deps,
585 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *,
586 std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
587 std::uint32_t batch_size, bool) const
588 {
589
590 return taylor_diff_bo_impl<double>(s, *this, deps, arr, par_ptr, n_uvars, order, idx, batch_size);
591 }
592
taylor_diff_ldbl(llvm_state & s,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value *,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool) const593 llvm::Value *binary_op::taylor_diff_ldbl(llvm_state &s, const std::vector<std::uint32_t> &deps,
594 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *,
595 std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
596 std::uint32_t batch_size, bool) const
597 {
598 return taylor_diff_bo_impl<long double>(s, *this, deps, arr, par_ptr, n_uvars, order, idx, batch_size);
599 }
600
601 #if defined(HEYOKA_HAVE_REAL128)
602
taylor_diff_f128(llvm_state & s,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value *,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool) const603 llvm::Value *binary_op::taylor_diff_f128(llvm_state &s, const std::vector<std::uint32_t> &deps,
604 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *,
605 std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
606 std::uint32_t batch_size, bool) const
607 {
608 return taylor_diff_bo_impl<mppp::real128>(s, *this, deps, arr, par_ptr, n_uvars, order, idx, batch_size);
609 }
610
611 #endif
612
613 namespace
614 {
615
616 // Helper to implement the function for the differentiation of
617 // 'number/param op number/param' in compact mode. The function will always return zero,
618 // unless the order is 0 (in which case it will return the result of the codegen).
619 template <typename T, typename U, typename V>
bo_taylor_c_diff_func_num_num(llvm_state & s,const binary_op & bo,const U & n0,const V & n1,std::uint32_t n_uvars,std::uint32_t batch_size,const std::string & op_name)620 llvm::Function *bo_taylor_c_diff_func_num_num(llvm_state &s, const binary_op &bo, const U &n0, const V &n1,
621 std::uint32_t n_uvars, std::uint32_t batch_size,
622 const std::string &op_name)
623 {
624 auto &module = s.module();
625 auto &builder = s.builder();
626 auto &context = s.context();
627
628 // Fetch the floating-point type.
629 auto val_t = to_llvm_vector_type<T>(context, batch_size);
630
631 // Fetch the function name and arguments.
632 const auto na_pair = taylor_c_diff_func_name_args<T>(context, op_name, n_uvars, batch_size, {n0, n1});
633 const auto &fname = na_pair.first;
634 const auto &fargs = na_pair.second;
635
636 // Try to see if we already created the function.
637 auto f = module.getFunction(fname);
638
639 if (f == nullptr) {
640 // The function was not created before, do it now.
641
642 // Fetch the current insertion block.
643 auto orig_bb = builder.GetInsertBlock();
644
645 // The return type is val_t.
646 auto *ft = llvm::FunctionType::get(val_t, fargs, false);
647 // Create the function
648 f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module);
649 assert(f != nullptr);
650
651 // Fetch the necessary function arguments.
652 auto ord = f->args().begin();
653 auto par_ptr = f->args().begin() + 3;
654 auto num0 = f->args().begin() + 5;
655 auto num1 = f->args().begin() + 6;
656
657 // Create a new basic block to start insertion into.
658 builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
659
660 // Create the return value.
661 auto retval = builder.CreateAlloca(val_t);
662
663 llvm_if_then_else(
664 s, builder.CreateICmpEQ(ord, builder.getInt32(0)),
665 [&]() {
666 // If the order is zero, run the codegen.
667 auto vnum0 = taylor_c_diff_numparam_codegen(s, n0, num0, par_ptr, batch_size);
668 auto vnum1 = taylor_c_diff_numparam_codegen(s, n1, num1, par_ptr, batch_size);
669
670 switch (bo.op()) {
671 case binary_op::type::add:
672 builder.CreateStore(builder.CreateFAdd(vnum0, vnum1), retval);
673 break;
674 case binary_op::type::sub:
675 builder.CreateStore(builder.CreateFSub(vnum0, vnum1), retval);
676 break;
677 case binary_op::type::mul:
678 builder.CreateStore(builder.CreateFMul(vnum0, vnum1), retval);
679 break;
680 default:
681 builder.CreateStore(builder.CreateFDiv(vnum0, vnum1), retval);
682 }
683 },
684 [&]() {
685 // Otherwise, return zero.
686 builder.CreateStore(vector_splat(builder, codegen<T>(s, number{0.}), batch_size), retval);
687 });
688
689 // Return the result.
690 builder.CreateRet(builder.CreateLoad(retval));
691
692 // Verify.
693 s.verify_function(f);
694
695 // Restore the original insertion block.
696 builder.SetInsertPoint(orig_bb);
697 } else {
698 // The function was created before. Check if the signatures match.
699 // NOTE: there could be a mismatch if the derivative function was created
700 // and then optimised - optimisation might remove arguments which are compile-time
701 // constants.
702 if (!compare_function_signature(f, val_t, fargs)) {
703 throw std::invalid_argument("Inconsistent function signature for the Taylor derivative of {}() "
704 "in compact mode detected"_format(op_name));
705 }
706 }
707
708 return f;
709 }
710
711 // Derivative of number/param +- number/param.
712 template <bool AddOrSub, typename T, typename U, typename V,
713 std::enable_if_t<std::conjunction_v<is_num_param<U>, is_num_param<V>>, int> = 0>
bo_taylor_c_diff_func_addsub_impl(llvm_state & s,const binary_op & bo,const U & num0,const V & num1,std::uint32_t n_uvars,std::uint32_t batch_size)714 llvm::Function *bo_taylor_c_diff_func_addsub_impl(llvm_state &s, const binary_op &bo, const U &num0, const V &num1,
715 std::uint32_t n_uvars, std::uint32_t batch_size)
716 {
717 return bo_taylor_c_diff_func_num_num<T>(s, bo, num0, num1, n_uvars, batch_size, AddOrSub ? "add" : "sub");
718 }
719
720 // Derivative of number +- var.
721 template <bool AddOrSub, typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
bo_taylor_c_diff_func_addsub_impl(llvm_state & s,const binary_op &,const U & n,const variable & var,std::uint32_t n_uvars,std::uint32_t batch_size)722 llvm::Function *bo_taylor_c_diff_func_addsub_impl(llvm_state &s, const binary_op &, const U &n, const variable &var,
723 std::uint32_t n_uvars, std::uint32_t batch_size)
724 {
725 auto &module = s.module();
726 auto &builder = s.builder();
727 auto &context = s.context();
728
729 // Fetch the floating-point type.
730 auto val_t = to_llvm_vector_type<T>(context, batch_size);
731
732 // Fetch the function name and arguments.
733 const auto na_pair
734 = taylor_c_diff_func_name_args<T>(context, AddOrSub ? "add" : "sub", n_uvars, batch_size, {n, var});
735 const auto &fname = na_pair.first;
736 const auto &fargs = na_pair.second;
737
738 // Try to see if we already created the function.
739 auto f = module.getFunction(fname);
740
741 if (f == nullptr) {
742 // The function was not created before, do it now.
743
744 // Fetch the current insertion block.
745 auto orig_bb = builder.GetInsertBlock();
746
747 // The return type is val_t.
748 auto *ft = llvm::FunctionType::get(val_t, fargs, false);
749 // Create the function
750 f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module);
751 assert(f != nullptr);
752
753 // Fetch the necessary function arguments.
754 auto order = f->args().begin();
755 auto diff_arr = f->args().begin() + 2;
756 auto par_ptr = f->args().begin() + 3;
757 auto num = f->args().begin() + 5;
758 auto var_idx = f->args().begin() + 6;
759
760 // Create a new basic block to start insertion into.
761 builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
762
763 // Create the return value.
764 auto retval = builder.CreateAlloca(val_t);
765
766 llvm_if_then_else(
767 s, builder.CreateICmpEQ(order, builder.getInt32(0)),
768 [&]() {
769 // For order zero, run the codegen.
770 auto num_vec = taylor_c_diff_numparam_codegen(s, n, num, par_ptr, batch_size);
771 auto ret = taylor_c_load_diff(s, diff_arr, n_uvars, builder.getInt32(0), var_idx);
772
773 builder.CreateStore(AddOrSub ? builder.CreateFAdd(num_vec, ret) : builder.CreateFSub(num_vec, ret),
774 retval);
775 },
776 [&]() {
777 // Load the derivative.
778 auto ret = taylor_c_load_diff(s, diff_arr, n_uvars, order, var_idx);
779
780 if constexpr (!AddOrSub) {
781 ret = builder.CreateFNeg(ret);
782 }
783
784 // Create the return value.
785 builder.CreateStore(ret, retval);
786 });
787
788 // Return the result.
789 builder.CreateRet(builder.CreateLoad(retval));
790
791 // Verify.
792 s.verify_function(f);
793
794 // Restore the original insertion block.
795 builder.SetInsertPoint(orig_bb);
796 } else {
797 // The function was created before. Check if the signatures match.
798 // NOTE: there could be a mismatch if the derivative function was created
799 // and then optimised - optimisation might remove arguments which are compile-time
800 // constants.
801 if (!compare_function_signature(f, val_t, fargs)) {
802 throw std::invalid_argument(
803 "Inconsistent function signature for the Taylor derivative of addition in compact mode detected");
804 }
805 }
806
807 return f;
808 }
809
810 // Derivative of var +- number.
811 template <bool AddOrSub, typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
bo_taylor_c_diff_func_addsub_impl(llvm_state & s,const binary_op &,const variable & var,const U & n,std::uint32_t n_uvars,std::uint32_t batch_size)812 llvm::Function *bo_taylor_c_diff_func_addsub_impl(llvm_state &s, const binary_op &, const variable &var, const U &n,
813 std::uint32_t n_uvars, std::uint32_t batch_size)
814 {
815 auto &module = s.module();
816 auto &builder = s.builder();
817 auto &context = s.context();
818
819 // Fetch the floating-point type.
820 auto val_t = to_llvm_vector_type<T>(context, batch_size);
821
822 // Fetch the function name and arguments.
823 const auto na_pair
824 = taylor_c_diff_func_name_args<T>(context, AddOrSub ? "add" : "sub", n_uvars, batch_size, {var, n});
825 const auto &fname = na_pair.first;
826 const auto &fargs = na_pair.second;
827
828 // Try to see if we already created the function.
829 auto f = module.getFunction(fname);
830
831 if (f == nullptr) {
832 // The function was not created before, do it now.
833
834 // Fetch the current insertion block.
835 auto orig_bb = builder.GetInsertBlock();
836
837 // The return type is val_t.
838 auto *ft = llvm::FunctionType::get(val_t, fargs, false);
839 // Create the function
840 f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module);
841 assert(f != nullptr);
842
843 // Fetch the necessary arguments.
844 auto order = f->args().begin();
845 auto diff_arr = f->args().begin() + 2;
846 auto par_ptr = f->args().begin() + 3;
847 auto var_idx = f->args().begin() + 5;
848 auto num = f->args().begin() + 6;
849
850 // Create a new basic block to start insertion into.
851 builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
852
853 // Create the return value.
854 auto retval = builder.CreateAlloca(val_t);
855
856 llvm_if_then_else(
857 s, builder.CreateICmpEQ(order, builder.getInt32(0)),
858 [&]() {
859 // For order zero, run the codegen.
860 auto ret = taylor_c_load_diff(s, diff_arr, n_uvars, builder.getInt32(0), var_idx);
861 auto num_vec = taylor_c_diff_numparam_codegen(s, n, num, par_ptr, batch_size);
862
863 builder.CreateStore(AddOrSub ? builder.CreateFAdd(ret, num_vec) : builder.CreateFSub(ret, num_vec),
864 retval);
865 },
866 [&]() {
867 // Create the return value.
868 builder.CreateStore(taylor_c_load_diff(s, diff_arr, n_uvars, order, var_idx), retval);
869 });
870
871 // Return the result.
872 builder.CreateRet(builder.CreateLoad(retval));
873
874 // Verify.
875 s.verify_function(f);
876
877 // Restore the original insertion block.
878 builder.SetInsertPoint(orig_bb);
879 } else {
880 // The function was created before. Check if the signatures match.
881 // NOTE: there could be a mismatch if the derivative function was created
882 // and then optimised - optimisation might remove arguments which are compile-time
883 // constants.
884 if (!compare_function_signature(f, val_t, fargs)) {
885 throw std::invalid_argument(
886 "Inconsistent function signature for the Taylor derivative of addition in compact mode detected");
887 }
888 }
889
890 return f;
891 }
892
893 // Derivative of var +- var.
894 template <bool AddOrSub, typename T>
bo_taylor_c_diff_func_addsub_impl(llvm_state & s,const binary_op &,const variable & var0,const variable & var1,std::uint32_t n_uvars,std::uint32_t batch_size)895 llvm::Function *bo_taylor_c_diff_func_addsub_impl(llvm_state &s, const binary_op &, const variable &var0,
896 const variable &var1, std::uint32_t n_uvars, std::uint32_t batch_size)
897 {
898 auto &module = s.module();
899 auto &builder = s.builder();
900 auto &context = s.context();
901
902 // Fetch the floating-point type.
903 auto val_t = to_llvm_vector_type<T>(context, batch_size);
904
905 // Fetch the function name and arguments.
906 const auto na_pair
907 = taylor_c_diff_func_name_args<T>(context, AddOrSub ? "add" : "sub", n_uvars, batch_size, {var0, var1});
908 const auto &fname = na_pair.first;
909 const auto &fargs = na_pair.second;
910
911 // Try to see if we already created the function.
912 auto f = module.getFunction(fname);
913
914 if (f == nullptr) {
915 // The function was not created before, do it now.
916
917 // Fetch the current insertion block.
918 auto orig_bb = builder.GetInsertBlock();
919
920 // The return type is val_t.
921 auto *ft = llvm::FunctionType::get(val_t, fargs, false);
922 // Create the function
923 f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module);
924 assert(f != nullptr);
925
926 // Fetch the necessary function arguments.
927 auto order = f->args().begin();
928 auto diff_arr = f->args().begin() + 2;
929 auto var_idx0 = f->args().begin() + 5;
930 auto var_idx1 = f->args().begin() + 6;
931
932 // Create a new basic block to start insertion into.
933 builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
934
935 auto v0 = taylor_c_load_diff(s, diff_arr, n_uvars, order, var_idx0);
936 auto v1 = taylor_c_load_diff(s, diff_arr, n_uvars, order, var_idx1);
937
938 // Create the return value.
939 if constexpr (AddOrSub) {
940 builder.CreateRet(builder.CreateFAdd(v0, v1));
941 } else {
942 builder.CreateRet(builder.CreateFSub(v0, v1));
943 }
944
945 // Verify.
946 s.verify_function(f);
947
948 // Restore the original insertion block.
949 builder.SetInsertPoint(orig_bb);
950 } else {
951 // The function was created before. Check if the signatures match.
952 // NOTE: there could be a mismatch if the derivative function was created
953 // and then optimised - optimisation might remove arguments which are compile-time
954 // constants.
955 if (!compare_function_signature(f, val_t, fargs)) {
956 throw std::invalid_argument(
957 "Inconsistent function signature for the Taylor derivative of addition in compact mode detected");
958 }
959 }
960
961 return f;
962 }
963
964 // All the other cases.
965 // LCOV_EXCL_START
966 template <bool, typename, typename V1, typename V2,
967 std::enable_if_t<!std::conjunction_v<is_num_param<V1>, is_num_param<V2>>, int> = 0>
bo_taylor_c_diff_func_addsub_impl(llvm_state &,const binary_op &,const V1 &,const V2 &,std::uint32_t,std::uint32_t)968 llvm::Function *bo_taylor_c_diff_func_addsub_impl(llvm_state &, const binary_op &, const V1 &, const V2 &,
969 std::uint32_t, std::uint32_t)
970 {
971 throw std::invalid_argument("An invalid argument type was encountered while trying to build the Taylor derivative "
972 "of add()/sub() in compact mode");
973 }
974 // LCOV_EXCL_STOP
975
976 template <typename T>
bo_taylor_c_diff_func_add(llvm_state & s,const binary_op & bo,std::uint32_t n_uvars,std::uint32_t batch_size)977 llvm::Function *bo_taylor_c_diff_func_add(llvm_state &s, const binary_op &bo, std::uint32_t n_uvars,
978 std::uint32_t batch_size)
979 {
980 return std::visit(
981 [&](const auto &v1, const auto &v2) {
982 return bo_taylor_c_diff_func_addsub_impl<true, T>(s, bo, v1, v2, n_uvars, batch_size);
983 },
984 bo.lhs().value(), bo.rhs().value());
985 }
986
987 template <typename T>
bo_taylor_c_diff_func_sub(llvm_state & s,const binary_op & bo,std::uint32_t n_uvars,std::uint32_t batch_size)988 llvm::Function *bo_taylor_c_diff_func_sub(llvm_state &s, const binary_op &bo, std::uint32_t n_uvars,
989 std::uint32_t batch_size)
990 {
991 return std::visit(
992 [&](const auto &v1, const auto &v2) {
993 return bo_taylor_c_diff_func_addsub_impl<false, T>(s, bo, v1, v2, n_uvars, batch_size);
994 },
995 bo.lhs().value(), bo.rhs().value());
996 }
997
998 // Derivative of number/param * number/param.
999 template <typename T, typename U, typename V,
1000 std::enable_if_t<std::conjunction_v<is_num_param<U>, is_num_param<V>>, int> = 0>
bo_taylor_c_diff_func_mul_impl(llvm_state & s,const binary_op & bo,const U & num0,const V & num1,std::uint32_t n_uvars,std::uint32_t batch_size)1001 llvm::Function *bo_taylor_c_diff_func_mul_impl(llvm_state &s, const binary_op &bo, const U &num0, const V &num1,
1002 std::uint32_t n_uvars, std::uint32_t batch_size)
1003 {
1004 return bo_taylor_c_diff_func_num_num<T>(s, bo, num0, num1, n_uvars, batch_size, "mul");
1005 }
1006
1007 // Derivative of var * number.
1008 template <typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
bo_taylor_c_diff_func_mul_impl(llvm_state & s,const binary_op &,const variable & var,const U & n,std::uint32_t n_uvars,std::uint32_t batch_size)1009 llvm::Function *bo_taylor_c_diff_func_mul_impl(llvm_state &s, const binary_op &, const variable &var, const U &n,
1010 std::uint32_t n_uvars, std::uint32_t batch_size)
1011 {
1012 auto &module = s.module();
1013 auto &builder = s.builder();
1014 auto &context = s.context();
1015
1016 // Fetch the floating-point type.
1017 auto val_t = to_llvm_vector_type<T>(context, batch_size);
1018
1019 // Fetch the function name and arguments.
1020 const auto na_pair = taylor_c_diff_func_name_args<T>(context, "mul", n_uvars, batch_size, {var, n});
1021 const auto &fname = na_pair.first;
1022 const auto &fargs = na_pair.second;
1023
1024 // Try to see if we already created the function.
1025 auto f = module.getFunction(fname);
1026
1027 if (f == nullptr) {
1028 // The function was not created before, do it now.
1029
1030 // Fetch the current insertion block.
1031 auto orig_bb = builder.GetInsertBlock();
1032
1033 // The return type is val_t.
1034 auto *ft = llvm::FunctionType::get(val_t, fargs, false);
1035 // Create the function
1036 f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module);
1037 assert(f != nullptr);
1038
1039 // Fetch the necessary function arguments.
1040 auto order = f->args().begin();
1041 auto diff_arr = f->args().begin() + 2;
1042 auto par_ptr = f->args().begin() + 3;
1043 auto var_idx = f->args().begin() + 5;
1044 auto num = f->args().begin() + 6;
1045
1046 // Create a new basic block to start insertion into.
1047 builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
1048
1049 // Load the derivative.
1050 auto ret = taylor_c_load_diff(s, diff_arr, n_uvars, order, var_idx);
1051
1052 // Create the return value.
1053 builder.CreateRet(builder.CreateFMul(ret, taylor_c_diff_numparam_codegen(s, n, num, par_ptr, batch_size)));
1054
1055 // Verify.
1056 s.verify_function(f);
1057
1058 // Restore the original insertion block.
1059 builder.SetInsertPoint(orig_bb);
1060 } else {
1061 // The function was created before. Check if the signatures match.
1062 // NOTE: there could be a mismatch if the derivative function was created
1063 // and then optimised - optimisation might remove arguments which are compile-time
1064 // constants.
1065 if (!compare_function_signature(f, val_t, fargs)) {
1066 throw std::invalid_argument(
1067 "Inconsistent function signature for the Taylor derivative of multiplication in compact mode detected");
1068 }
1069 }
1070
1071 return f;
1072 }
1073
1074 // Derivative of number * var.
1075 template <typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
bo_taylor_c_diff_func_mul_impl(llvm_state & s,const binary_op &,const U & n,const variable & var,std::uint32_t n_uvars,std::uint32_t batch_size)1076 llvm::Function *bo_taylor_c_diff_func_mul_impl(llvm_state &s, const binary_op &, const U &n, const variable &var,
1077 std::uint32_t n_uvars, std::uint32_t batch_size)
1078 {
1079 auto &module = s.module();
1080 auto &builder = s.builder();
1081 auto &context = s.context();
1082
1083 // Fetch the floating-point type.
1084 auto val_t = to_llvm_vector_type<T>(context, batch_size);
1085
1086 // Fetch the function name and arguments.
1087 const auto na_pair = taylor_c_diff_func_name_args<T>(context, "mul", n_uvars, batch_size, {n, var});
1088 const auto &fname = na_pair.first;
1089 const auto &fargs = na_pair.second;
1090
1091 // Try to see if we already created the function.
1092 auto f = module.getFunction(fname);
1093
1094 if (f == nullptr) {
1095 // The function was not created before, do it now.
1096
1097 // Fetch the current insertion block.
1098 auto orig_bb = builder.GetInsertBlock();
1099
1100 // The return type is val_t.
1101 auto *ft = llvm::FunctionType::get(val_t, fargs, false);
1102 // Create the function
1103 f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module);
1104 assert(f != nullptr);
1105
1106 // Fetch the necessary function arguments.
1107 auto order = f->args().begin();
1108 auto diff_arr = f->args().begin() + 2;
1109 auto par_ptr = f->args().begin() + 3;
1110 auto num = f->args().begin() + 5;
1111 auto var_idx = f->args().begin() + 6;
1112
1113 // Create a new basic block to start insertion into.
1114 builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
1115
1116 // Load the derivative.
1117 auto ret = taylor_c_load_diff(s, diff_arr, n_uvars, order, var_idx);
1118
1119 // Create the return value.
1120 builder.CreateRet(builder.CreateFMul(ret, taylor_c_diff_numparam_codegen(s, n, num, par_ptr, batch_size)));
1121
1122 // Verify.
1123 s.verify_function(f);
1124
1125 // Restore the original insertion block.
1126 builder.SetInsertPoint(orig_bb);
1127 } else {
1128 // The function was created before. Check if the signatures match.
1129 // NOTE: there could be a mismatch if the derivative function was created
1130 // and then optimised - optimisation might remove arguments which are compile-time
1131 // constants.
1132 if (!compare_function_signature(f, val_t, fargs)) {
1133 throw std::invalid_argument(
1134 "Inconsistent function signature for the Taylor derivative of multiplication in compact mode detected");
1135 }
1136 }
1137
1138 return f;
1139 }
1140
1141 // Derivative of var * var.
1142 template <typename T>
bo_taylor_c_diff_func_mul_impl(llvm_state & s,const binary_op &,const variable & var0,const variable & var1,std::uint32_t n_uvars,std::uint32_t batch_size)1143 llvm::Function *bo_taylor_c_diff_func_mul_impl(llvm_state &s, const binary_op &, const variable &var0,
1144 const variable &var1, std::uint32_t n_uvars, std::uint32_t batch_size)
1145 {
1146 auto &module = s.module();
1147 auto &builder = s.builder();
1148 auto &context = s.context();
1149
1150 // Fetch the floating-point type.
1151 auto val_t = to_llvm_vector_type<T>(context, batch_size);
1152
1153 // Fetch the function name and arguments.
1154 const auto na_pair = taylor_c_diff_func_name_args<T>(context, "mul", n_uvars, batch_size, {var0, var1});
1155 const auto &fname = na_pair.first;
1156 const auto &fargs = na_pair.second;
1157
1158 // Try to see if we already created the function.
1159 auto f = module.getFunction(fname);
1160
1161 if (f == nullptr) {
1162 // The function was not created before, do it now.
1163
1164 // Fetch the current insertion block.
1165 auto orig_bb = builder.GetInsertBlock();
1166
1167 // The return type is val_t.
1168 auto *ft = llvm::FunctionType::get(val_t, fargs, false);
1169 // Create the function
1170 f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module);
1171 assert(f != nullptr);
1172
1173 // Fetch the necessary function arguments.
1174 auto ord = f->args().begin();
1175 auto diff_ptr = f->args().begin() + 2;
1176 auto idx0 = f->args().begin() + 5;
1177 auto idx1 = f->args().begin() + 6;
1178
1179 // Create a new basic block to start insertion into.
1180 builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
1181
1182 // Create the accumulator.
1183 auto acc = builder.CreateAlloca(val_t);
1184 builder.CreateStore(vector_splat(builder, codegen<T>(s, number{0.}), batch_size), acc);
1185
1186 // Run the loop.
1187 llvm_loop_u32(s, builder.getInt32(0), builder.CreateAdd(ord, builder.getInt32(1)), [&](llvm::Value *j) {
1188 auto b_nj = taylor_c_load_diff(s, diff_ptr, n_uvars, builder.CreateSub(ord, j), idx0);
1189 auto cj = taylor_c_load_diff(s, diff_ptr, n_uvars, j, idx1);
1190 builder.CreateStore(builder.CreateFAdd(builder.CreateLoad(acc), builder.CreateFMul(b_nj, cj)), acc);
1191 });
1192
1193 // Create the return value.
1194 builder.CreateRet(builder.CreateLoad(acc));
1195
1196 // Verify.
1197 s.verify_function(f);
1198
1199 // Restore the original insertion block.
1200 builder.SetInsertPoint(orig_bb);
1201 } else {
1202 // The function was created before. Check if the signatures match.
1203 // NOTE: there could be a mismatch if the derivative function was created
1204 // and then optimised - optimisation might remove arguments which are compile-time
1205 // constants.
1206 if (!compare_function_signature(f, val_t, fargs)) {
1207 throw std::invalid_argument(
1208 "Inconsistent function signature for the Taylor derivative of multiplication in compact mode detected");
1209 }
1210 }
1211
1212 return f;
1213 }
1214
1215 // All the other cases.
1216 // LCOV_EXCL_START
1217 template <typename, typename V1, typename V2,
1218 std::enable_if_t<!std::conjunction_v<is_num_param<V1>, is_num_param<V2>>, int> = 0>
bo_taylor_c_diff_func_mul_impl(llvm_state &,const binary_op &,const V1 &,const V2 &,std::uint32_t,std::uint32_t)1219 llvm::Function *bo_taylor_c_diff_func_mul_impl(llvm_state &, const binary_op &, const V1 &, const V2 &, std::uint32_t,
1220 std::uint32_t)
1221 {
1222 throw std::invalid_argument("An invalid argument type was encountered while trying to build the Taylor derivative "
1223 "of mul() in compact mode");
1224 }
1225 // LCOV_EXCL_STOP
1226
1227 template <typename T>
bo_taylor_c_diff_func_mul(llvm_state & s,const binary_op & bo,std::uint32_t n_uvars,std::uint32_t batch_size)1228 llvm::Function *bo_taylor_c_diff_func_mul(llvm_state &s, const binary_op &bo, std::uint32_t n_uvars,
1229 std::uint32_t batch_size)
1230 {
1231 return std::visit(
1232 [&](const auto &v1, const auto &v2) {
1233 return bo_taylor_c_diff_func_mul_impl<T>(s, bo, v1, v2, n_uvars, batch_size);
1234 },
1235 bo.lhs().value(), bo.rhs().value());
1236 }
1237
1238 // Derivative of number/param / number/param.
1239 template <typename T, typename U, typename V,
1240 std::enable_if_t<std::conjunction_v<is_num_param<U>, is_num_param<V>>, int> = 0>
bo_taylor_c_diff_func_div_impl(llvm_state & s,const binary_op & bo,const U & num0,const V & num1,std::uint32_t n_uvars,std::uint32_t batch_size)1241 llvm::Function *bo_taylor_c_diff_func_div_impl(llvm_state &s, const binary_op &bo, const U &num0, const V &num1,
1242 std::uint32_t n_uvars, std::uint32_t batch_size)
1243 {
1244 return bo_taylor_c_diff_func_num_num<T>(s, bo, num0, num1, n_uvars, batch_size, "div");
1245 }
1246
1247 // Derivative of var / number.
1248 template <typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
bo_taylor_c_diff_func_div_impl(llvm_state & s,const binary_op &,const variable & var,const U & n,std::uint32_t n_uvars,std::uint32_t batch_size)1249 llvm::Function *bo_taylor_c_diff_func_div_impl(llvm_state &s, const binary_op &, const variable &var, const U &n,
1250 std::uint32_t n_uvars, std::uint32_t batch_size)
1251 {
1252 auto &module = s.module();
1253 auto &builder = s.builder();
1254 auto &context = s.context();
1255
1256 // Fetch the floating-point type.
1257 auto val_t = to_llvm_vector_type<T>(context, batch_size);
1258
1259 // Fetch the function name and arguments.
1260 const auto na_pair = taylor_c_diff_func_name_args<T>(context, "div", n_uvars, batch_size, {var, n});
1261 const auto &fname = na_pair.first;
1262 const auto &fargs = na_pair.second;
1263
1264 // Try to see if we already created the function.
1265 auto f = module.getFunction(fname);
1266
1267 if (f == nullptr) {
1268 // The function was not created before, do it now.
1269
1270 // Fetch the current insertion block.
1271 auto orig_bb = builder.GetInsertBlock();
1272
1273 // The return type is val_t.
1274 auto *ft = llvm::FunctionType::get(val_t, fargs, false);
1275 // Create the function
1276 f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module);
1277 assert(f != nullptr);
1278
1279 // Fetch the necessary function arguments.
1280 auto order = f->args().begin();
1281 auto diff_arr = f->args().begin() + 2;
1282 auto par_ptr = f->args().begin() + 3;
1283 auto var_idx = f->args().begin() + 5;
1284 auto num = f->args().begin() + 6;
1285
1286 // Create a new basic block to start insertion into.
1287 builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
1288
1289 // Load the derivative.
1290 auto ret = taylor_c_load_diff(s, diff_arr, n_uvars, order, var_idx);
1291
1292 // Create the return value.
1293 builder.CreateRet(builder.CreateFDiv(ret, taylor_c_diff_numparam_codegen(s, n, num, par_ptr, batch_size)));
1294
1295 // Verify.
1296 s.verify_function(f);
1297
1298 // Restore the original insertion block.
1299 builder.SetInsertPoint(orig_bb);
1300 } else {
1301 // The function was created before. Check if the signatures match.
1302 // NOTE: there could be a mismatch if the derivative function was created
1303 // and then optimised - optimisation might remove arguments which are compile-time
1304 // constants.
1305 if (!compare_function_signature(f, val_t, fargs)) {
1306 throw std::invalid_argument(
1307 "Inconsistent function signature for the Taylor derivative of division in compact mode detected");
1308 }
1309 }
1310
1311 return f;
1312 }
1313
1314 // Derivative of number / var.
1315 template <typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
bo_taylor_c_diff_func_div_impl(llvm_state & s,const binary_op &,const U & n,const variable & var,std::uint32_t n_uvars,std::uint32_t batch_size)1316 llvm::Function *bo_taylor_c_diff_func_div_impl(llvm_state &s, const binary_op &, const U &n, const variable &var,
1317 std::uint32_t n_uvars, std::uint32_t batch_size)
1318 {
1319 auto &module = s.module();
1320 auto &builder = s.builder();
1321 auto &context = s.context();
1322
1323 // Fetch the floating-point type.
1324 auto val_t = to_llvm_vector_type<T>(context, batch_size);
1325
1326 // Fetch the function name and arguments.
1327 const auto na_pair = taylor_c_diff_func_name_args<T>(context, "div", n_uvars, batch_size, {n, var});
1328 const auto &fname = na_pair.first;
1329 const auto &fargs = na_pair.second;
1330
1331 // Try to see if we already created the function.
1332 auto f = module.getFunction(fname);
1333
1334 if (f == nullptr) {
1335 // The function was not created before, do it now.
1336
1337 // Fetch the current insertion block.
1338 auto orig_bb = builder.GetInsertBlock();
1339
1340 // The return type is val_t.
1341 auto *ft = llvm::FunctionType::get(val_t, fargs, false);
1342 // Create the function
1343 f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module);
1344 assert(f != nullptr);
1345
1346 // Fetch the necessary function arguments.
1347 // NOTE: we don't need the number argument because
1348 // we only need its derivative of order n >= 1,
1349 // which is always zero.
1350 auto ord = f->args().begin();
1351 auto u_idx = f->args().begin() + 1;
1352 auto diff_ptr = f->args().begin() + 2;
1353 auto par_ptr = f->args().begin() + 3;
1354 auto num = f->args().begin() + 5;
1355 auto var_idx = f->args().begin() + 6;
1356
1357 // Create a new basic block to start insertion into.
1358 builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
1359
1360 // Create the return value.
1361 auto retval = builder.CreateAlloca(val_t);
1362
1363 // Create the accumulator.
1364 auto acc = builder.CreateAlloca(val_t);
1365
1366 llvm_if_then_else(
1367 s, builder.CreateICmpEQ(ord, builder.getInt32(0)),
1368 [&]() {
1369 // For order zero, run the codegen.
1370 auto num_vec = taylor_c_diff_numparam_codegen(s, n, num, par_ptr, batch_size);
1371 auto ret = taylor_c_load_diff(s, diff_ptr, n_uvars, builder.getInt32(0), var_idx);
1372
1373 builder.CreateStore(builder.CreateFDiv(num_vec, ret), retval);
1374 },
1375 [&]() {
1376 // Init the accumulator.
1377 builder.CreateStore(vector_splat(builder, codegen<T>(s, number{0.}), batch_size), acc);
1378
1379 // Run the loop.
1380 llvm_loop_u32(s, builder.getInt32(1), builder.CreateAdd(ord, builder.getInt32(1)), [&](llvm::Value *j) {
1381 auto cj = taylor_c_load_diff(s, diff_ptr, n_uvars, j, var_idx);
1382 auto a_nj = taylor_c_load_diff(s, diff_ptr, n_uvars, builder.CreateSub(ord, j), u_idx);
1383 builder.CreateStore(builder.CreateFAdd(builder.CreateLoad(acc), builder.CreateFMul(cj, a_nj)), acc);
1384 });
1385
1386 // Negate the loop summation.
1387 auto ret = builder.CreateFNeg(builder.CreateLoad(acc));
1388
1389 // Divide and return.
1390 builder.CreateStore(
1391 builder.CreateFDiv(ret, taylor_c_load_diff(s, diff_ptr, n_uvars, builder.getInt32(0), var_idx)),
1392 retval);
1393 });
1394
1395 // Return the result.
1396 builder.CreateRet(builder.CreateLoad(retval));
1397
1398 // Verify.
1399 s.verify_function(f);
1400
1401 // Restore the original insertion block.
1402 builder.SetInsertPoint(orig_bb);
1403 } else {
1404 // The function was created before. Check if the signatures match.
1405 // NOTE: there could be a mismatch if the derivative function was created
1406 // and then optimised - optimisation might remove arguments which are compile-time
1407 // constants.
1408 if (!compare_function_signature(f, val_t, fargs)) {
1409 throw std::invalid_argument(
1410 "Inconsistent function signature for the Taylor derivative of division in compact mode detected");
1411 }
1412 }
1413
1414 return f;
1415 }
1416
1417 // Derivative of var / var.
1418 template <typename T>
bo_taylor_c_diff_func_div_impl(llvm_state & s,const binary_op &,const variable & var0,const variable & var1,std::uint32_t n_uvars,std::uint32_t batch_size)1419 llvm::Function *bo_taylor_c_diff_func_div_impl(llvm_state &s, const binary_op &, const variable &var0,
1420 const variable &var1, std::uint32_t n_uvars, std::uint32_t batch_size)
1421 {
1422 auto &module = s.module();
1423 auto &builder = s.builder();
1424 auto &context = s.context();
1425
1426 // Fetch the floating-point type.
1427 auto val_t = to_llvm_vector_type<T>(context, batch_size);
1428
1429 // Fetch the function name and arguments.
1430 const auto na_pair = taylor_c_diff_func_name_args<T>(context, "div", n_uvars, batch_size, {var0, var1});
1431 const auto &fname = na_pair.first;
1432 const auto &fargs = na_pair.second;
1433
1434 // Try to see if we already created the function.
1435 auto f = module.getFunction(fname);
1436
1437 if (f == nullptr) {
1438 // The function was not created before, do it now.
1439
1440 // Fetch the current insertion block.
1441 auto orig_bb = builder.GetInsertBlock();
1442
1443 // The return type is val_t.
1444 auto *ft = llvm::FunctionType::get(val_t, fargs, false);
1445 // Create the function
1446 f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module);
1447 assert(f != nullptr);
1448
1449 // Fetch the necessary function arguments.
1450 auto ord = f->args().begin();
1451 auto u_idx = f->args().begin() + 1;
1452 auto diff_ptr = f->args().begin() + 2;
1453 auto var_idx0 = f->args().begin() + 5;
1454 auto var_idx1 = f->args().begin() + 6;
1455
1456 // Create a new basic block to start insertion into.
1457 builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
1458
1459 // Create the accumulator.
1460 auto acc = builder.CreateAlloca(val_t);
1461 builder.CreateStore(vector_splat(builder, codegen<T>(s, number{0.}), batch_size), acc);
1462
1463 // Run the loop.
1464 llvm_loop_u32(s, builder.getInt32(1), builder.CreateAdd(ord, builder.getInt32(1)), [&](llvm::Value *j) {
1465 auto cj = taylor_c_load_diff(s, diff_ptr, n_uvars, j, var_idx1);
1466 auto a_nj = taylor_c_load_diff(s, diff_ptr, n_uvars, builder.CreateSub(ord, j), u_idx);
1467 builder.CreateStore(builder.CreateFAdd(builder.CreateLoad(acc), builder.CreateFMul(cj, a_nj)), acc);
1468 });
1469
1470 auto ret = builder.CreateFSub(taylor_c_load_diff(s, diff_ptr, n_uvars, ord, var_idx0), builder.CreateLoad(acc));
1471
1472 // Divide and return.
1473 builder.CreateRet(
1474 builder.CreateFDiv(ret, taylor_c_load_diff(s, diff_ptr, n_uvars, builder.getInt32(0), var_idx1)));
1475
1476 // Verify.
1477 s.verify_function(f);
1478
1479 // Restore the original insertion block.
1480 builder.SetInsertPoint(orig_bb);
1481 } else {
1482 // The function was created before. Check if the signatures match.
1483 // NOTE: there could be a mismatch if the derivative function was created
1484 // and then optimised - optimisation might remove arguments which are compile-time
1485 // constants.
1486 if (!compare_function_signature(f, val_t, fargs)) {
1487 throw std::invalid_argument(
1488 "Inconsistent function signature for the Taylor derivative of division in compact mode detected");
1489 }
1490 }
1491
1492 return f;
1493 }
1494
1495 // All the other cases.
1496 // LCOV_EXCL_START
1497 template <typename, typename V1, typename V2,
1498 std::enable_if_t<!std::conjunction_v<is_num_param<V1>, is_num_param<V2>>, int> = 0>
bo_taylor_c_diff_func_div_impl(llvm_state &,const binary_op &,const V1 &,const V2 &,std::uint32_t,std::uint32_t)1499 llvm::Function *bo_taylor_c_diff_func_div_impl(llvm_state &, const binary_op &, const V1 &, const V2 &, std::uint32_t,
1500 std::uint32_t)
1501 {
1502 throw std::invalid_argument("An invalid argument type was encountered while trying to build the Taylor derivative "
1503 "of div() in compact mode");
1504 }
1505 // LCOV_EXCL_STOP
1506
1507 template <typename T>
bo_taylor_c_diff_func_div(llvm_state & s,const binary_op & bo,std::uint32_t n_uvars,std::uint32_t batch_size)1508 llvm::Function *bo_taylor_c_diff_func_div(llvm_state &s, const binary_op &bo, std::uint32_t n_uvars,
1509 std::uint32_t batch_size)
1510 {
1511 return std::visit(
1512 [&](const auto &v1, const auto &v2) {
1513 return bo_taylor_c_diff_func_div_impl<T>(s, bo, v1, v2, n_uvars, batch_size);
1514 },
1515 bo.lhs().value(), bo.rhs().value());
1516 }
1517
1518 template <typename T>
taylor_c_diff_func_bo_impl(llvm_state & s,const binary_op & bo,std::uint32_t n_uvars,std::uint32_t batch_size)1519 llvm::Function *taylor_c_diff_func_bo_impl(llvm_state &s, const binary_op &bo, std::uint32_t n_uvars,
1520 std::uint32_t batch_size)
1521 {
1522 switch (bo.op()) {
1523 case binary_op::type::add:
1524 return bo_taylor_c_diff_func_add<T>(s, bo, n_uvars, batch_size);
1525 case binary_op::type::sub:
1526 return bo_taylor_c_diff_func_sub<T>(s, bo, n_uvars, batch_size);
1527 case binary_op::type::mul:
1528 return bo_taylor_c_diff_func_mul<T>(s, bo, n_uvars, batch_size);
1529 default:
1530 return bo_taylor_c_diff_func_div<T>(s, bo, n_uvars, batch_size);
1531 }
1532 }
1533
1534 } // namespace
1535
taylor_c_diff_func_dbl(llvm_state & s,std::uint32_t n_uvars,std::uint32_t batch_size,bool) const1536 llvm::Function *binary_op::taylor_c_diff_func_dbl(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
1537 bool) const
1538 {
1539 return taylor_c_diff_func_bo_impl<double>(s, *this, n_uvars, batch_size);
1540 }
1541
taylor_c_diff_func_ldbl(llvm_state & s,std::uint32_t n_uvars,std::uint32_t batch_size,bool) const1542 llvm::Function *binary_op::taylor_c_diff_func_ldbl(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
1543 bool) const
1544 {
1545 return taylor_c_diff_func_bo_impl<long double>(s, *this, n_uvars, batch_size);
1546 }
1547
1548 #if defined(HEYOKA_HAVE_REAL128)
1549
taylor_c_diff_func_f128(llvm_state & s,std::uint32_t n_uvars,std::uint32_t batch_size,bool) const1550 llvm::Function *binary_op::taylor_c_diff_func_f128(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
1551 bool) const
1552 {
1553 return taylor_c_diff_func_bo_impl<mppp::real128>(s, *this, n_uvars, batch_size);
1554 }
1555
1556 #endif
1557
1558 } // namespace detail
1559
add(expression x,expression y)1560 expression add(expression x, expression y)
1561 {
1562 return expression{func{detail::binary_op(detail::binary_op::type::add, std::move(x), std::move(y))}};
1563 }
1564
sub(expression x,expression y)1565 expression sub(expression x, expression y)
1566 {
1567 return expression{func{detail::binary_op(detail::binary_op::type::sub, std::move(x), std::move(y))}};
1568 }
1569
mul(expression x,expression y)1570 expression mul(expression x, expression y)
1571 {
1572 return expression{func{detail::binary_op(detail::binary_op::type::mul, std::move(x), std::move(y))}};
1573 }
1574
div(expression x,expression y)1575 expression div(expression x, expression y)
1576 {
1577 return expression{func{detail::binary_op(detail::binary_op::type::div, std::move(x), std::move(y))}};
1578 }
1579
1580 } // namespace heyoka
1581
1582 HEYOKA_S11N_FUNC_EXPORT_IMPLEMENT(heyoka::detail::binary_op)
1583