1 // Copyright 2020, 2021 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com)
2 //
3 // This file is part of the heyoka library.
4 //
5 // This Source Code Form is subject to the terms of the Mozilla
6 // Public License v. 2.0. If a copy of the MPL was not distributed
7 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #include <heyoka/config.hpp>
10 
11 #include <cassert>
12 #include <cmath>
13 #include <cstdint>
14 #include <initializer_list>
15 #include <stdexcept>
16 #include <string>
17 #include <type_traits>
18 #include <unordered_map>
19 #include <utility>
20 #include <variant>
21 #include <vector>
22 
23 #include <boost/numeric/conversion/cast.hpp>
24 
25 #include <fmt/format.h>
26 
27 #include <llvm/IR/Attributes.h>
28 #include <llvm/IR/BasicBlock.h>
29 #include <llvm/IR/DerivedTypes.h>
30 #include <llvm/IR/Function.h>
31 #include <llvm/IR/IRBuilder.h>
32 #include <llvm/IR/LLVMContext.h>
33 #include <llvm/IR/Module.h>
34 #include <llvm/IR/Type.h>
35 #include <llvm/IR/Value.h>
36 #include <llvm/Support/Casting.h>
37 
38 #if defined(HEYOKA_HAVE_REAL128)
39 
40 #include <mp++/real128.hpp>
41 
42 #endif
43 
44 #include <heyoka/detail/llvm_helpers.hpp>
45 #include <heyoka/detail/llvm_vector_type.hpp>
46 #include <heyoka/detail/sleef.hpp>
47 #include <heyoka/detail/string_conv.hpp>
48 #include <heyoka/detail/taylor_common.hpp>
49 #include <heyoka/expression.hpp>
50 #include <heyoka/func.hpp>
51 #include <heyoka/llvm_state.hpp>
52 #include <heyoka/math/sigmoid.hpp>
53 #include <heyoka/math/square.hpp>
54 #include <heyoka/number.hpp>
55 #include <heyoka/s11n.hpp>
56 #include <heyoka/taylor.hpp>
57 #include <heyoka/variable.hpp>
58 
59 #if defined(_MSC_VER) && !defined(__clang__)
60 
61 // NOTE: MSVC has issues with the other "using"
62 // statement form.
63 using namespace fmt::literals;
64 
65 #else
66 
67 using fmt::literals::operator""_format;
68 
69 #endif
70 
71 // The sigmoid is not in the standard library and thus needs wrappers
72 // for its double, long double and 128 versions.
73 
heyoka_sigmoid(double x)74 extern "C" HEYOKA_DLL_PUBLIC double heyoka_sigmoid(double x) noexcept
75 {
76     return 1. / (1. + std::exp(-x));
77 }
78 
heyoka_sigmoidl(long double x)79 extern "C" HEYOKA_DLL_PUBLIC long double heyoka_sigmoidl(long double x) noexcept
80 {
81     return 1. / (1. + std::exp(-x));
82 }
83 
84 #if defined(HEYOKA_HAVE_REAL128)
85 
heyoka_sigmoid128(__float128 x)86 extern "C" HEYOKA_DLL_PUBLIC __float128 heyoka_sigmoid128(__float128 x) noexcept
87 {
88     return (1. / (1. + mppp::exp(-mppp::real128{x}))).m_value;
89 }
90 
91 #endif
92 
93 namespace heyoka
94 {
95 
96 namespace detail
97 {
98 
sigmoid_impl(expression e)99 sigmoid_impl::sigmoid_impl(expression e) : func_base("sigmoid", std::vector{std::move(e)}) {}
100 
sigmoid_impl()101 sigmoid_impl::sigmoid_impl() : sigmoid_impl(0_dbl) {}
102 
codegen_dbl(llvm_state & s,const std::vector<llvm::Value * > & args) const103 llvm::Value *sigmoid_impl::codegen_dbl(llvm_state &s, const std::vector<llvm::Value *> &args) const
104 {
105     assert(args.size() == 1u);
106     assert(args[0] != nullptr);
107 
108     if (auto vec_t = llvm::dyn_cast<llvm_vector_type>(args[0]->getType())) {
109         const auto batch_size = boost::numeric_cast<std::uint32_t>(vec_t->getNumElements());
110 
111         if (const auto sfn = sleef_function_name(s.context(), "exp", vec_t->getElementType(), batch_size);
112             !sfn.empty()) {
113             auto &builder = s.builder();
114 
115             // Compute -arg.
116             auto m_arg = builder.CreateFNeg(args[0]);
117 
118             // Compute e^(-arg).
119             auto e_m_arg = llvm_invoke_external(
120                 s, sfn, vec_t, {m_arg},
121                 // NOTE: in theory we may add ReadNone here as well,
122                 // but for some reason, at least up to LLVM 10,
123                 // this causes strange codegen issues. Revisit
124                 // in the future.
125                 {llvm::Attribute::NoUnwind, llvm::Attribute::Speculatable, llvm::Attribute::WillReturn});
126 
127             // Return 1 / (1 + e_m_arg).
128             auto one_fp = vector_splat(builder, codegen<double>(s, number{1.}), batch_size);
129             return builder.CreateFDiv(one_fp, builder.CreateFAdd(one_fp, e_m_arg));
130         }
131     }
132 
133     return call_extern_vec(s, args[0], "heyoka_sigmoid");
134 }
135 
codegen_ldbl(llvm_state & s,const std::vector<llvm::Value * > & args) const136 llvm::Value *sigmoid_impl::codegen_ldbl(llvm_state &s, const std::vector<llvm::Value *> &args) const
137 {
138     assert(args.size() == 1u);
139     assert(args[0] != nullptr);
140 
141     return call_extern_vec(s, args[0], "heyoka_sigmoidl");
142 }
143 
144 #if defined(HEYOKA_HAVE_REAL128)
145 
codegen_f128(llvm_state & s,const std::vector<llvm::Value * > & args) const146 llvm::Value *sigmoid_impl::codegen_f128(llvm_state &s, const std::vector<llvm::Value *> &args) const
147 {
148     assert(args.size() == 1u);
149     assert(args[0] != nullptr);
150 
151     return call_extern_vec(s, args[0], "heyoka_sigmoid128");
152 }
153 
154 #endif
155 
eval_dbl(const std::unordered_map<std::string,double> & map,const std::vector<double> & pars) const156 double sigmoid_impl::eval_dbl(const std::unordered_map<std::string, double> &map, const std::vector<double> &pars) const
157 {
158     assert(args().size() == 1u);
159 
160     return heyoka_sigmoid(heyoka::eval_dbl(args()[0], map, pars));
161 }
162 
eval_ldbl(const std::unordered_map<std::string,long double> & map,const std::vector<long double> & pars) const163 long double sigmoid_impl::eval_ldbl(const std::unordered_map<std::string, long double> &map,
164                                     const std::vector<long double> &pars) const
165 {
166     assert(args().size() == 1u);
167 
168     return heyoka_sigmoidl(heyoka::eval_ldbl(args()[0], map, pars));
169 }
170 
171 #if defined(HEYOKA_HAVE_REAL128)
eval_f128(const std::unordered_map<std::string,mppp::real128> & map,const std::vector<mppp::real128> & pars) const172 mppp::real128 sigmoid_impl::eval_f128(const std::unordered_map<std::string, mppp::real128> &map,
173                                       const std::vector<mppp::real128> &pars) const
174 {
175     assert(args().size() == 1u);
176 
177     return mppp::real128(heyoka_sigmoid128(heyoka::eval_f128(args()[0], map, pars).m_value));
178 }
179 #endif
180 
eval_batch_dbl(std::vector<double> & out,const std::unordered_map<std::string,std::vector<double>> & map,const std::vector<double> & pars) const181 void sigmoid_impl::eval_batch_dbl(std::vector<double> &out,
182                                   const std::unordered_map<std::string, std::vector<double>> &map,
183                                   const std::vector<double> &pars) const
184 {
185     assert(args().size() == 1u);
186 
187     heyoka::eval_batch_dbl(out, args()[0], map, pars);
188     for (auto &el : out) {
189         el = heyoka_sigmoid(el);
190     }
191 }
192 
eval_num_dbl(const std::vector<double> & a) const193 double sigmoid_impl::eval_num_dbl(const std::vector<double> &a) const
194 {
195     if (a.size() != 1u) {
196         throw std::invalid_argument(
197             "Inconsistent number of arguments when computing the numerical value of the "
198             "sigmoid over doubles (1 argument was expected, but {} arguments were provided"_format(a.size()));
199     }
200 
201     return heyoka_sigmoid(a[0]);
202 }
203 
deval_num_dbl(const std::vector<double> & a,std::vector<double>::size_type i) const204 double sigmoid_impl::deval_num_dbl(const std::vector<double> &a, std::vector<double>::size_type i) const
205 {
206     if (a.size() != 1u || i != 0u) {
207         throw std::invalid_argument("Inconsistent number of arguments or derivative requested when computing the "
208                                     "numerical derivative of the sigmoid");
209     }
210     auto sigma = heyoka_sigmoid(a[0]);
211     return sigma * (1 - sigma);
212 }
213 
taylor_decompose(taylor_dc_t & u_vars_defs)214 taylor_dc_t::size_type sigmoid_impl::taylor_decompose(taylor_dc_t &u_vars_defs) &&
215 {
216     assert(args().size() == 1u);
217 
218     // Append the sigmoid decomposition.
219     u_vars_defs.emplace_back(func{std::move(*this)}, std::vector<std::uint32_t>{});
220 
221     // Append the auxiliary function sigmoid(arg) * sigmoid(arg).
222     u_vars_defs.emplace_back(square(expression{"u_{}"_format(u_vars_defs.size() - 1u)}), std::vector<std::uint32_t>{});
223 
224     // Add the hidden dep.
225     (u_vars_defs.end() - 2)->second.push_back(boost::numeric_cast<std::uint32_t>(u_vars_defs.size() - 1u));
226 
227     return u_vars_defs.size() - 2u;
228 }
229 
230 namespace
231 {
232 
233 // Derivative of sigmoid(number).
234 template <typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
taylor_diff_sigmoid_impl(llvm_state & s,const sigmoid_impl & f,const std::vector<std::uint32_t> &,const U & num,const std::vector<llvm::Value * > &,llvm::Value * par_ptr,std::uint32_t,std::uint32_t order,std::uint32_t,std::uint32_t batch_size)235 llvm::Value *taylor_diff_sigmoid_impl(llvm_state &s, const sigmoid_impl &f, const std::vector<std::uint32_t> &,
236                                       const U &num, const std::vector<llvm::Value *> &, llvm::Value *par_ptr,
237                                       std::uint32_t, std::uint32_t order, std::uint32_t, std::uint32_t batch_size)
238 {
239     if (order == 0u) {
240         return codegen_from_values<T>(s, f, {taylor_codegen_numparam<T>(s, num, par_ptr, batch_size)});
241     } else {
242         return vector_splat(s.builder(), codegen<T>(s, number{0.}), batch_size);
243     }
244 }
245 
246 // Derivative of sigmoid(variable).
247 template <typename T>
taylor_diff_sigmoid_impl(llvm_state & s,const sigmoid_impl & f,const std::vector<std::uint32_t> & deps,const variable & var,const std::vector<llvm::Value * > & arr,llvm::Value *,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t a_idx,std::uint32_t batch_size)248 llvm::Value *taylor_diff_sigmoid_impl(llvm_state &s, const sigmoid_impl &f, const std::vector<std::uint32_t> &deps,
249                                       const variable &var, const std::vector<llvm::Value *> &arr, llvm::Value *,
250                                       std::uint32_t n_uvars, std::uint32_t order, std::uint32_t a_idx,
251                                       std::uint32_t batch_size)
252 {
253     auto &builder = s.builder();
254 
255     // Fetch the index of the variable.
256     const auto u_idx = uname_to_index(var.name());
257 
258     if (order == 0u) {
259         return codegen_from_values<T>(s, f, {taylor_fetch_diff(arr, u_idx, 0, n_uvars)});
260     }
261 
262     // NOTE: iteration in the [1, order] range.
263     std::vector<llvm::Value *> sum;
264     for (std::uint32_t j = 1; j <= order; ++j) {
265         // NOTE: the only hidden dependency contains the index of the
266         // u variable whose definition is sigmoid(var) * sigmoid(var).
267         auto anj = taylor_fetch_diff(arr, a_idx, order - j, n_uvars);
268         auto bj = taylor_fetch_diff(arr, u_idx, j, n_uvars);
269         auto cnj = taylor_fetch_diff(arr, deps[0], order - j, n_uvars);
270 
271         auto fac = vector_splat(builder, codegen<T>(s, number(static_cast<T>(j))), batch_size);
272 
273         // Add j*(anj-cnj)*bj to the sum.
274         auto tmp1 = builder.CreateFSub(anj, cnj);
275         auto tmp2 = builder.CreateFMul(tmp1, bj);
276         auto tmp3 = builder.CreateFMul(tmp2, fac);
277 
278         sum.push_back(tmp3);
279     }
280 
281     // Init the return value as the result of the sum.
282     auto ret_acc = pairwise_sum(builder, sum);
283 
284     // Finalise the return value: ret_acc / n.
285     auto div = vector_splat(builder, codegen<T>(s, number(static_cast<T>(order))), batch_size);
286 
287     return builder.CreateFDiv(ret_acc, div);
288 }
289 
290 // All the other cases.
291 template <typename T, typename U, std::enable_if_t<!is_num_param_v<U>, int> = 0>
taylor_diff_sigmoid_impl(llvm_state &,const sigmoid_impl &,const std::vector<std::uint32_t> &,const U &,const std::vector<llvm::Value * > &,llvm::Value *,std::uint32_t,std::uint32_t,std::uint32_t,std::uint32_t)292 llvm::Value *taylor_diff_sigmoid_impl(llvm_state &, const sigmoid_impl &, const std::vector<std::uint32_t> &, const U &,
293                                       const std::vector<llvm::Value *> &, llvm::Value *, std::uint32_t, std::uint32_t,
294                                       std::uint32_t, std::uint32_t)
295 {
296     throw std::invalid_argument(
297         "An invalid argument type was encountered while trying to build the Taylor derivative of a sigmoid");
298 }
299 
300 template <typename T>
taylor_diff_sigmoid(llvm_state & s,const sigmoid_impl & f,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size)301 llvm::Value *taylor_diff_sigmoid(llvm_state &s, const sigmoid_impl &f, const std::vector<std::uint32_t> &deps,
302                                  const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, std::uint32_t n_uvars,
303                                  std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size)
304 {
305     assert(f.args().size() == 1u);
306 
307     if (deps.size() != 1u) {
308         throw std::invalid_argument(
309             "A hidden dependency vector of size 1 is expected in order to compute the Taylor "
310             "derivative of the sigmoid, but a vector of size {} was passed instead"_format(deps.size()));
311     }
312 
313     return std::visit(
314         [&](const auto &v) {
315             return taylor_diff_sigmoid_impl<T>(s, f, deps, v, arr, par_ptr, n_uvars, order, idx, batch_size);
316         },
317         f.args()[0].value());
318 }
319 
320 } // namespace
321 
taylor_diff_dbl(llvm_state & s,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value *,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool) const322 llvm::Value *sigmoid_impl::taylor_diff_dbl(llvm_state &s, const std::vector<std::uint32_t> &deps,
323                                            const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *,
324                                            std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
325                                            std::uint32_t batch_size, bool) const
326 {
327     return taylor_diff_sigmoid<double>(s, *this, deps, arr, par_ptr, n_uvars, order, idx, batch_size);
328 }
329 
taylor_diff_ldbl(llvm_state & s,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value *,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool) const330 llvm::Value *sigmoid_impl::taylor_diff_ldbl(llvm_state &s, const std::vector<std::uint32_t> &deps,
331                                             const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *,
332                                             std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
333                                             std::uint32_t batch_size, bool) const
334 {
335     return taylor_diff_sigmoid<long double>(s, *this, deps, arr, par_ptr, n_uvars, order, idx, batch_size);
336 }
337 
338 #if defined(HEYOKA_HAVE_REAL128)
339 
taylor_diff_f128(llvm_state & s,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value *,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t idx,std::uint32_t batch_size,bool) const340 llvm::Value *sigmoid_impl::taylor_diff_f128(llvm_state &s, const std::vector<std::uint32_t> &deps,
341                                             const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *,
342                                             std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx,
343                                             std::uint32_t batch_size, bool) const
344 {
345     return taylor_diff_sigmoid<mppp::real128>(s, *this, deps, arr, par_ptr, n_uvars, order, idx, batch_size);
346 }
347 
348 #endif
349 
350 namespace
351 {
352 
353 // Derivative of sigmoid(number).
354 template <typename T, typename U, std::enable_if_t<is_num_param_v<U>, int> = 0>
taylor_c_diff_func_sigmoid_impl(llvm_state & s,const sigmoid_impl & fn,const U & num,std::uint32_t n_uvars,std::uint32_t batch_size)355 llvm::Function *taylor_c_diff_func_sigmoid_impl(llvm_state &s, const sigmoid_impl &fn, const U &num,
356                                                 std::uint32_t n_uvars, std::uint32_t batch_size)
357 {
358     return taylor_c_diff_func_unary_num_det<T>(s, fn, num, n_uvars, batch_size, "sigmoid", 1);
359 }
360 
361 // Derivative of sigmoid(variable).
362 template <typename T>
taylor_c_diff_func_sigmoid_impl(llvm_state & s,const sigmoid_impl & fn,const variable & var,std::uint32_t n_uvars,std::uint32_t batch_size)363 llvm::Function *taylor_c_diff_func_sigmoid_impl(llvm_state &s, const sigmoid_impl &fn, const variable &var,
364                                                 std::uint32_t n_uvars, std::uint32_t batch_size)
365 {
366     auto &module = s.module();
367     auto &builder = s.builder();
368     auto &context = s.context();
369 
370     // Fetch the floating-point type.
371     auto val_t = to_llvm_vector_type<T>(context, batch_size);
372 
373     const auto na_pair = taylor_c_diff_func_name_args<T>(context, "sigmoid", n_uvars, batch_size, {var}, 1);
374     const auto &fname = na_pair.first;
375     const auto &fargs = na_pair.second;
376 
377     // Try to see if we already created the function.
378     auto f = module.getFunction(fname);
379 
380     if (f == nullptr) {
381         // The function was not created before, do it now.
382 
383         // Fetch the current insertion block.
384         auto orig_bb = builder.GetInsertBlock();
385 
386         // The return type is val_t.
387         auto *ft = llvm::FunctionType::get(val_t, fargs, false);
388         // Create the function
389         f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module);
390         assert(f != nullptr);
391 
392         // Fetch the necessary function arguments.
393         auto ord = f->args().begin();
394         auto a_idx = f->args().begin() + 1;
395         auto diff_ptr = f->args().begin() + 2;
396         auto b_idx = f->args().begin() + 5;
397         auto dep_idx = f->args().begin() + 6;
398 
399         // Create a new basic block to start insertion into.
400         builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
401 
402         // Create the return value.
403         auto retval = builder.CreateAlloca(val_t);
404 
405         // Create the accumulator.
406         auto acc = builder.CreateAlloca(val_t);
407 
408         llvm_if_then_else(
409             s, builder.CreateICmpEQ(ord, builder.getInt32(0)),
410             [&]() {
411                 // For order 0, invoke the function on the order 0 of b_idx.
412                 builder.CreateStore(codegen_from_values<T>(
413                                         s, fn, {taylor_c_load_diff(s, diff_ptr, n_uvars, builder.getInt32(0), b_idx)}),
414                                     retval);
415             },
416             [&]() {
417                 // Init the accumulator.
418                 builder.CreateStore(vector_splat(builder, codegen<T>(s, number{0.}), batch_size), acc);
419 
420                 // Run the loop.
421                 llvm_loop_u32(s, builder.getInt32(1), builder.CreateAdd(ord, builder.getInt32(1)), [&](llvm::Value *j) {
422                     auto anj = taylor_c_load_diff(s, diff_ptr, n_uvars, builder.CreateSub(ord, j), a_idx);
423                     auto bj = taylor_c_load_diff(s, diff_ptr, n_uvars, j, b_idx);
424                     auto cnj = taylor_c_load_diff(s, diff_ptr, n_uvars, builder.CreateSub(ord, j), dep_idx);
425 
426                     // Compute the factor j.
427                     auto fac = vector_splat(builder, builder.CreateUIToFP(j, to_llvm_type<T>(context)), batch_size);
428 
429                     // Add  j*(anj-cnj)*bj into the sum.
430                     auto tmp1 = builder.CreateFSub(anj, cnj);
431                     auto tmp2 = builder.CreateFMul(tmp1, bj);
432                     auto tmp3 = builder.CreateFMul(tmp2, fac);
433 
434                     builder.CreateStore(builder.CreateFAdd(builder.CreateLoad(acc), tmp3), acc);
435                 });
436 
437                 // Divide by the order to produce the return value.
438                 auto ord_v = vector_splat(builder, builder.CreateUIToFP(ord, to_llvm_type<T>(context)), batch_size);
439 
440                 builder.CreateStore(builder.CreateFDiv(builder.CreateLoad(acc), ord_v), retval);
441             });
442 
443         // Return the result.
444         builder.CreateRet(builder.CreateLoad(retval));
445 
446         // Verify.
447         s.verify_function(f);
448 
449         // Restore the original insertion block.
450         builder.SetInsertPoint(orig_bb);
451     } else {
452         // The function was created before. Check if the signatures match.
453         // NOTE: there could be a mismatch if the derivative function was created
454         // and then optimised - optimisation might remove arguments which are compile-time
455         // constants.
456         if (!compare_function_signature(f, val_t, fargs)) {
457             throw std::invalid_argument(
458                 "Inconsistent function signature for the Taylor derivative of the sigmpid in compact mode detected");
459         }
460     }
461 
462     return f;
463 }
464 
465 // All the other cases.
466 template <typename T, typename U, std::enable_if_t<!is_num_param_v<U>, int> = 0>
taylor_c_diff_func_sigmoid_impl(llvm_state &,const sigmoid_impl &,const U &,std::uint32_t,std::uint32_t)467 llvm::Function *taylor_c_diff_func_sigmoid_impl(llvm_state &, const sigmoid_impl &, const U &, std::uint32_t,
468                                                 std::uint32_t)
469 {
470     throw std::invalid_argument("An invalid argument type was encountered while trying to build the Taylor derivative "
471                                 "of a sigmpid in compact mode");
472 }
473 
474 template <typename T>
taylor_c_diff_func_sigmoid(llvm_state & s,const sigmoid_impl & fn,std::uint32_t n_uvars,std::uint32_t batch_size)475 llvm::Function *taylor_c_diff_func_sigmoid(llvm_state &s, const sigmoid_impl &fn, std::uint32_t n_uvars,
476                                            std::uint32_t batch_size)
477 {
478     assert(fn.args().size() == 1u);
479 
480     return std::visit([&](const auto &v) { return taylor_c_diff_func_sigmoid_impl<T>(s, fn, v, n_uvars, batch_size); },
481                       fn.args()[0].value());
482 }
483 
484 } // namespace
485 
taylor_c_diff_func_dbl(llvm_state & s,std::uint32_t n_uvars,std::uint32_t batch_size,bool) const486 llvm::Function *sigmoid_impl::taylor_c_diff_func_dbl(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
487                                                      bool) const
488 {
489     return taylor_c_diff_func_sigmoid<double>(s, *this, n_uvars, batch_size);
490 }
491 
taylor_c_diff_func_ldbl(llvm_state & s,std::uint32_t n_uvars,std::uint32_t batch_size,bool) const492 llvm::Function *sigmoid_impl::taylor_c_diff_func_ldbl(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
493                                                       bool) const
494 {
495     return taylor_c_diff_func_sigmoid<long double>(s, *this, n_uvars, batch_size);
496 }
497 
498 #if defined(HEYOKA_HAVE_REAL128)
499 
taylor_c_diff_func_f128(llvm_state & s,std::uint32_t n_uvars,std::uint32_t batch_size,bool) const500 llvm::Function *sigmoid_impl::taylor_c_diff_func_f128(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
501                                                       bool) const
502 {
503     return taylor_c_diff_func_sigmoid<mppp::real128>(s, *this, n_uvars, batch_size);
504 }
505 
506 #endif
507 
gradient() const508 std::vector<expression> sigmoid_impl::gradient() const
509 {
510     assert(args().size() == 1u);
511     // NOTE: if single-precision floats are implemented,
512     // should 1_dbl become 1_flt?
513     return {(1_dbl - sigmoid(args()[0])) * sigmoid(args()[0])};
514 }
515 
516 } // namespace detail
517 
sigmoid(expression e)518 expression sigmoid(expression e)
519 {
520     return expression{func{detail::sigmoid_impl(std::move(e))}};
521 }
522 
523 } // namespace heyoka
524 
525 HEYOKA_S11N_FUNC_EXPORT_IMPLEMENT(heyoka::detail::sigmoid_impl)
526