1 // Copyright 2020, 2021 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com)
2 //
3 // This file is part of the heyoka library.
4 //
5 // This Source Code Form is subject to the terms of the Mozilla
6 // Public License v. 2.0. If a copy of the MPL was not distributed
7 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #include <heyoka/config.hpp>
10 
11 #include <algorithm>
12 #include <cassert>
13 #include <cstdint>
14 #include <ostream>
15 #include <stdexcept>
16 #include <string>
17 #include <type_traits>
18 #include <unordered_map>
19 #include <utility>
20 #include <variant>
21 #include <vector>
22 
23 #include <boost/numeric/conversion/cast.hpp>
24 
25 #include <fmt/format.h>
26 
27 #include <llvm/IR/Attributes.h>
28 #include <llvm/IR/BasicBlock.h>
29 #include <llvm/IR/DerivedTypes.h>
30 #include <llvm/IR/Function.h>
31 #include <llvm/IR/IRBuilder.h>
32 #include <llvm/IR/LLVMContext.h>
33 #include <llvm/IR/Module.h>
34 #include <llvm/IR/Value.h>
35 
36 #if defined(HEYOKA_HAVE_REAL128)
37 
38 #include <mp++/real128.hpp>
39 
40 #endif
41 
42 #include <heyoka/detail/llvm_helpers.hpp>
43 #include <heyoka/detail/string_conv.hpp>
44 #include <heyoka/detail/type_traits.hpp>
45 #include <heyoka/expression.hpp>
46 #include <heyoka/func.hpp>
47 #include <heyoka/llvm_state.hpp>
48 #include <heyoka/math/square.hpp>
49 #include <heyoka/math/sum.hpp>
50 #include <heyoka/math/sum_sq.hpp>
51 #include <heyoka/number.hpp>
52 #include <heyoka/param.hpp>
53 #include <heyoka/s11n.hpp>
54 #include <heyoka/taylor.hpp>
55 #include <heyoka/variable.hpp>
56 
57 #if defined(_MSC_VER) && !defined(__clang__)
58 
59 // NOTE: MSVC has issues with the other "using"
60 // statement form.
61 using namespace fmt::literals;
62 
63 #else
64 
65 using fmt::literals::operator""_format;
66 
67 #endif
68 
69 namespace heyoka
70 {
71 
72 namespace detail
73 {
74 
sum_sq_impl()75 sum_sq_impl::sum_sq_impl() : sum_sq_impl(std::vector<expression>{}) {}
76 
sum_sq_impl(std::vector<expression> v)77 sum_sq_impl::sum_sq_impl(std::vector<expression> v) : func_base("sum_sq", std::move(v)) {}
78 
to_stream(std::ostream & os) const79 void sum_sq_impl::to_stream(std::ostream &os) const
80 {
81     if (args().size() == 1u) {
82         // NOTE: avoid brackets if there's only 1 argument.
83         os << args()[0] << "**2";
84     } else {
85         os << '(';
86 
87         for (decltype(args().size()) i = 0; i < args().size(); ++i) {
88             os << args()[i] << "**2";
89             if (i != args().size() - 1u) {
90                 os << " + ";
91             }
92         }
93 
94         os << ')';
95     }
96 }
97 
98 template <typename T>
diff_impl(std::unordered_map<const void *,expression> & func_map,const T & x) const99 expression sum_sq_impl::diff_impl(std::unordered_map<const void *, expression> &func_map, const T &x) const
100 {
101     std::vector<expression> terms;
102     terms.reserve(args().size());
103 
104     for (const auto &arg : args()) {
105         terms.push_back(arg * detail::diff(func_map, arg, x));
106     }
107 
108     return 2_dbl * sum(std::move(terms));
109 }
110 
diff(std::unordered_map<const void *,expression> & func_map,const std::string & s) const111 expression sum_sq_impl::diff(std::unordered_map<const void *, expression> &func_map, const std::string &s) const
112 {
113     return diff_impl(func_map, s);
114 }
115 
diff(std::unordered_map<const void *,expression> & func_map,const param & p) const116 expression sum_sq_impl::diff(std::unordered_map<const void *, expression> &func_map, const param &p) const
117 {
118     return diff_impl(func_map, p);
119 }
120 
121 namespace
122 {
123 
124 template <typename T>
sum_sq_taylor_diff_impl(llvm_state & s,const sum_sq_impl & sf,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t batch_size)125 llvm::Value *sum_sq_taylor_diff_impl(llvm_state &s, const sum_sq_impl &sf, const std::vector<std::uint32_t> &deps,
126                                      const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, std::uint32_t n_uvars,
127                                      std::uint32_t order, std::uint32_t batch_size)
128 {
129     // NOTE: this is prevented in the implementation
130     // of the sum_sq() function.
131     assert(!sf.args().empty());
132 
133     if (!deps.empty()) {
134         // LCOV_EXCL_START
135         throw std::invalid_argument("The vector of hidden dependencies in the Taylor diff for a sum of squares "
136                                     "should be empty, but instead it has a size of {}"_format(deps.size()));
137         // LCOV_EXCL_STOP
138     }
139 
140     auto &builder = s.builder();
141 
142     // Each vector in v_sums will contain the terms in the summation in the formula
143     // for the computation of the Taylor derivative of square() for each argument in sf.
144     std::vector<std::vector<llvm::Value *>> v_sums;
145     v_sums.resize(boost::numeric_cast<decltype(v_sums.size())>(sf.args().size()));
146 
147     // This function calculates the j-th term in the summation in the formula for the
148     // Taylor derivative of square() for each k-th argument in sf, and appends the result
149     // to the k-th entry in v_sums.
150     auto looper = [&](std::uint32_t j) {
151         for (decltype(sf.args().size()) k = 0; k < sf.args().size(); ++k) {
152             std::visit(
153                 [&](const auto &v) {
154                     using type = detail::uncvref_t<decltype(v)>;
155 
156                     if constexpr (std::is_same_v<type, variable>) {
157                         // Variable.
158                         const auto u_idx = uname_to_index(v.name());
159 
160                         auto v0 = taylor_fetch_diff(arr, u_idx, order - j, n_uvars);
161                         auto v1 = taylor_fetch_diff(arr, u_idx, j, n_uvars);
162 
163                         v_sums[k].push_back(builder.CreateFMul(v0, v1));
164                     } else if constexpr (is_num_param_v<type>) {
165                         // Number/param.
166 
167                         // NOTE: for number/params, all terms in the summation
168                         // will be zero. Thus, ensure that v_sums[k] just
169                         // contains a single zero.
170                         if (v_sums[k].empty()) {
171                             v_sums[k].push_back(vector_splat(builder, codegen<T>(s, number{0.}), batch_size));
172                         }
173                     } else {
174                         // LCOV_EXCL_START
175                         throw std::invalid_argument(
176                             "An invalid argument type was encountered while trying to build the "
177                             "Taylor derivative of a sum of squares");
178                         // LCOV_EXCL_STOP
179                     }
180                 },
181                 sf.args()[k].value());
182         }
183     };
184 
185     if (order % 2u == 1u) {
186         // Odd order.
187         for (std::uint32_t j = 0; j <= (order - 1u) / 2u; ++j) {
188             looper(j);
189         }
190 
191         // Pairwise sum each item in v_sums.
192         std::vector<llvm::Value *> tmp;
193         tmp.reserve(boost::numeric_cast<decltype(tmp.size())>(v_sums.size()));
194         for (auto &v_sum : v_sums) {
195             tmp.push_back(pairwise_sum(builder, v_sum));
196         }
197 
198         // Sum the sums.
199         pairwise_sum(builder, tmp);
200 
201         // Multiply by 2 and return.
202         return builder.CreateFAdd(tmp[0], tmp[0]);
203     } else {
204         // Even order.
205         for (std::uint32_t j = 0; order > 0u && j <= (order - 2u) / 2u; ++j) {
206             looper(j);
207         }
208 
209         // Pairwise sum each item in v_sums, multiply the result by 2 and add the
210         // term outside the summation.
211         std::vector<llvm::Value *> tmp;
212         tmp.reserve(boost::numeric_cast<decltype(tmp.size())>(v_sums.size()));
213         for (decltype(sf.args().size()) k = 0; k < sf.args().size(); ++k) {
214             // Compute the term outside the summation and store it in tmp.
215             tmp.push_back(std::visit(
216                 [&](const auto &v) -> llvm::Value * {
217                     using type = detail::uncvref_t<decltype(v)>;
218 
219                     if constexpr (std::is_same_v<type, variable>) {
220                         // Variable.
221                         auto val = taylor_fetch_diff(arr, uname_to_index(v.name()), order / 2u, n_uvars);
222                         return builder.CreateFMul(val, val);
223                     } else if constexpr (is_num_param_v<type>) {
224                         // Number/param.
225                         if (order == 0u) {
226                             auto val = taylor_codegen_numparam<T>(s, v, par_ptr, batch_size);
227                             return builder.CreateFMul(val, val);
228                         } else {
229                             return vector_splat(builder, codegen<T>(s, number{0.}), batch_size);
230                         }
231                     } else {
232                         // LCOV_EXCL_START
233                         throw std::invalid_argument(
234                             "An invalid argument type was encountered while trying to build the "
235                             "Taylor derivative of a sum of squares");
236                         // LCOV_EXCL_STOP
237                     }
238                 },
239                 sf.args()[k].value()));
240 
241             // NOTE: avoid doing the pairwise sum if the order is 0, in which case
242             // the items in v_sums are all empty and tmp.back() contains only the term
243             // outside the summation.
244             if (order > 0u) {
245                 auto p_sum = pairwise_sum(builder, v_sums[k]);
246                 // Muliply the pairwise sum by 2.
247                 p_sum = builder.CreateFAdd(p_sum, p_sum);
248                 // Add it to the term outside the sum.
249                 tmp.back() = builder.CreateFAdd(p_sum, tmp.back());
250             }
251         }
252 
253         // Sum the sums and return.
254         return pairwise_sum(builder, tmp);
255     }
256 }
257 
258 } // namespace
259 
taylor_diff_dbl(llvm_state & s,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value *,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t,std::uint32_t batch_size,bool) const260 llvm::Value *sum_sq_impl::taylor_diff_dbl(llvm_state &s, const std::vector<std::uint32_t> &deps,
261                                           const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *,
262                                           std::uint32_t n_uvars, std::uint32_t order, std::uint32_t,
263                                           std::uint32_t batch_size, bool) const
264 {
265     return sum_sq_taylor_diff_impl<double>(s, *this, deps, arr, par_ptr, n_uvars, order, batch_size);
266 }
267 
taylor_diff_ldbl(llvm_state & s,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value *,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t,std::uint32_t batch_size,bool) const268 llvm::Value *sum_sq_impl::taylor_diff_ldbl(llvm_state &s, const std::vector<std::uint32_t> &deps,
269                                            const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *,
270                                            std::uint32_t n_uvars, std::uint32_t order, std::uint32_t,
271                                            std::uint32_t batch_size, bool) const
272 {
273     return sum_sq_taylor_diff_impl<long double>(s, *this, deps, arr, par_ptr, n_uvars, order, batch_size);
274 }
275 
276 #if defined(HEYOKA_HAVE_REAL128)
277 
taylor_diff_f128(llvm_state & s,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value *,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t,std::uint32_t batch_size,bool) const278 llvm::Value *sum_sq_impl::taylor_diff_f128(llvm_state &s, const std::vector<std::uint32_t> &deps,
279                                            const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *,
280                                            std::uint32_t n_uvars, std::uint32_t order, std::uint32_t,
281                                            std::uint32_t batch_size, bool) const
282 {
283     return sum_sq_taylor_diff_impl<mppp::real128>(s, *this, deps, arr, par_ptr, n_uvars, order, batch_size);
284 }
285 
286 #endif
287 
288 namespace
289 {
290 
291 template <typename T>
sum_sq_taylor_c_diff_func_impl(llvm_state & s,const sum_sq_impl & sf,std::uint32_t n_uvars,std::uint32_t batch_size)292 llvm::Function *sum_sq_taylor_c_diff_func_impl(llvm_state &s, const sum_sq_impl &sf, std::uint32_t n_uvars,
293                                                std::uint32_t batch_size)
294 {
295     // NOTE: this is prevented in the implementation
296     // of the sum() function.
297     assert(!sf.args().empty());
298 
299     auto &md = s.module();
300     auto &builder = s.builder();
301     auto &context = s.context();
302 
303     // Fetch the floating-point type.
304     auto val_t = to_llvm_vector_type<T>(context, batch_size);
305 
306     // Build the vector of arguments needed to determine the function name.
307     std::vector<std::variant<variable, number, param>> nm_args;
308     nm_args.reserve(static_cast<decltype(nm_args.size())>(sf.args().size()));
309     for (const auto &arg : sf.args()) {
310         nm_args.push_back(std::visit(
311             [](const auto &v) -> std::variant<variable, number, param> {
312                 using type = detail::uncvref_t<decltype(v)>;
313 
314                 if constexpr (std::is_same_v<type, func>) {
315                     // LCOV_EXCL_START
316                     assert(false);
317                     throw;
318                     // LCOV_EXCL_STOP
319                 } else {
320                     return v;
321                 }
322             },
323             arg.value()));
324     }
325 
326     // Fetch the function name and arguments.
327     const auto na_pair = taylor_c_diff_func_name_args<T>(context, "sum_sq", n_uvars, batch_size, nm_args);
328     const auto &fname = na_pair.first;
329     const auto &fargs = na_pair.second;
330 
331     // Try to see if we already created the function.
332     auto f = md.getFunction(fname);
333 
334     if (f == nullptr) {
335         // The function was not created before, do it now.
336 
337         // Fetch the current insertion block.
338         auto orig_bb = builder.GetInsertBlock();
339 
340         // The return type is val_t.
341         auto *ft = llvm::FunctionType::get(val_t, fargs, false);
342         // Create the function
343         f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &md);
344         assert(f != nullptr);
345         // NOTE: force inline.
346         f->addFnAttr(llvm::Attribute::AlwaysInline);
347 
348         // Fetch the necessary function arguments.
349         auto order = f->args().begin();
350         auto diff_arr = f->args().begin() + 2;
351         auto par_ptr = f->args().begin() + 3;
352         auto terms = f->args().begin() + 5;
353 
354         // Create a new basic block to start insertion into.
355         builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
356 
357         // Create the accumulators for each argument in the summation, and init them to zero.
358         std::vector<llvm::Value *> v_accs;
359         v_accs.resize(boost::numeric_cast<decltype(v_accs.size())>(sf.args().size()));
360         for (auto &acc : v_accs) {
361             acc = builder.CreateAlloca(val_t);
362             builder.CreateStore(vector_splat(builder, codegen<T>(s, number{0.}), batch_size), acc);
363         }
364 
365         // Create the return value.
366         auto retval = builder.CreateAlloca(val_t);
367 
368         // This function calculates the j-th term in the summation in the formula for the
369         // Taylor derivative of square() for each k-th argument in sf, and accumulates the result
370         // into the k-th entry in v_accs.
371         auto looper = [&](llvm::Value *j) {
372             for (decltype(sf.args().size()) k = 0; k < sf.args().size(); ++k) {
373                 std::visit(
374                     [&](const auto &v) {
375                         using type = detail::uncvref_t<decltype(v)>;
376 
377                         if constexpr (std::is_same_v<type, variable>) {
378                             // Variable.
379                             auto v0 = taylor_c_load_diff(s, diff_arr, n_uvars, builder.CreateSub(order, j), terms + k);
380                             auto v1 = taylor_c_load_diff(s, diff_arr, n_uvars, j, terms + k);
381 
382                             // Update the k-th accumulator.
383                             builder.CreateStore(
384                                 builder.CreateFAdd(builder.CreateLoad(v_accs[k]), builder.CreateFMul(v0, v1)),
385                                 v_accs[k]);
386                         } else if constexpr (is_num_param_v<type>) {
387                             // Number/param: nothing to do, leave the accumulator to zero.
388                         } else {
389                             // LCOV_EXCL_START
390                             throw std::invalid_argument(
391                                 "An invalid argument type was encountered while trying to build the "
392                                 "Taylor derivative of a sum of squares in compact mode");
393                             // LCOV_EXCL_STOP
394                         }
395                     },
396                     sf.args()[k].value());
397             }
398         };
399 
400         // Distinguish odd/even cases.
401         const auto odd_or_even
402             = builder.CreateICmpEQ(builder.CreateURem(order, builder.getInt32(2)), builder.getInt32(1));
403 
404         llvm_if_then_else(
405             s, odd_or_even,
406             [&]() {
407                 // Odd order.
408                 const auto loop_end = builder.CreateAdd(
409                     builder.CreateUDiv(builder.CreateSub(order, builder.getInt32(1)), builder.getInt32(2)),
410                     builder.getInt32(1));
411 
412                 llvm_loop_u32(s, builder.getInt32(0), loop_end, [&](llvm::Value *j) { looper(j); });
413 
414                 // Run a pairwise summation on the vector of accumulators.
415                 std::vector<llvm::Value *> tmp;
416                 tmp.reserve(v_accs.size());
417                 for (auto &acc : v_accs) {
418                     tmp.push_back(builder.CreateLoad(acc));
419                 }
420                 auto ret = pairwise_sum(builder, tmp);
421 
422                 // Return 2 * ret.
423                 builder.CreateStore(builder.CreateFAdd(ret, ret), retval);
424             },
425             [&]() {
426                 // Even order.
427                 // NOTE: run the loop only if we are not at order 0.
428                 llvm_if_then_else(
429                     s, builder.CreateICmpEQ(order, builder.getInt32(0)),
430                     []() {
431                         // Order 0, do nothing.
432                     },
433                     [&]() {
434                         // Order 2 or higher.
435                         const auto loop_end = builder.CreateAdd(
436                             builder.CreateUDiv(builder.CreateSub(order, builder.getInt32(2)), builder.getInt32(2)),
437                             builder.getInt32(1));
438 
439                         llvm_loop_u32(s, builder.getInt32(0), loop_end, [&](llvm::Value *j) { looper(j); });
440                     });
441 
442                 // Multiply each accumulator by two and add the term outside the summation.
443                 std::vector<llvm::Value *> tmp;
444                 tmp.reserve(v_accs.size());
445                 for (decltype(sf.args().size()) k = 0; k < sf.args().size(); ++k) {
446                     // Load the current accumulator and multiply it by 2.
447                     auto acc_val = builder.CreateLoad(v_accs[k]);
448                     auto acc2 = builder.CreateFAdd(acc_val, acc_val);
449 
450                     // Load the external term.
451                     auto ex_term = std::visit( // LCOV_EXCL_LINE
452                         [&](const auto &v) -> llvm::Value * {
453                             using type = detail::uncvref_t<decltype(v)>;
454 
455                             if constexpr (std::is_same_v<type, variable>) {
456                                 // Variable.
457                                 auto val = taylor_c_load_diff(
458                                     s, diff_arr, n_uvars, builder.CreateUDiv(order, builder.getInt32(2)), terms + k);
459                                 return builder.CreateFMul(val, val);
460                             } else if constexpr (is_num_param_v<type>) {
461                                 // Number/param.
462                                 auto ret = builder.CreateAlloca(val_t);
463 
464                                 llvm_if_then_else(
465                                     s, builder.CreateICmpEQ(order, builder.getInt32(0)),
466                                     [&]() {
467                                         // Order 0, store the num/param.
468                                         builder.CreateStore(
469                                             taylor_c_diff_numparam_codegen(s, v, terms + k, par_ptr, batch_size), ret);
470                                     },
471                                     [&]() {
472                                         // Order 2 or higher, store zero.
473                                         builder.CreateStore(
474                                             vector_splat(builder, codegen<T>(s, number{0.}), batch_size), ret);
475                                     });
476 
477                                 auto val = builder.CreateLoad(ret);
478 
479                                 return builder.CreateFMul(val, val);
480                             } else {
481                                 // LCOV_EXCL_START
482                                 throw std::invalid_argument(
483                                     "An invalid argument type was encountered while trying to build the "
484                                     "Taylor derivative of a sum of squares in compact mode");
485                                 // LCOV_EXCL_STOP
486                             }
487                         },
488                         sf.args()[k].value());
489 
490                     // Compute the Taylor derivative for the current argument.
491                     tmp.push_back(builder.CreateFAdd(acc2, ex_term));
492                 }
493 
494                 // Return the pairwise sum.
495                 builder.CreateStore(pairwise_sum(builder, tmp), retval);
496             });
497 
498         // Create the return value.
499         builder.CreateRet(builder.CreateLoad(retval));
500 
501         // Verify.
502         s.verify_function(f);
503 
504         // Restore the original insertion block.
505         builder.SetInsertPoint(orig_bb);
506     } else {
507         // The function was created before. Check if the signatures match.
508         // NOTE: there could be a mismatch if the derivative function was created
509         // and then optimised - optimisation might remove arguments which are compile-time
510         // constants.
511         if (!compare_function_signature(f, val_t, fargs)) {
512             // LCOV_EXCL_START
513             throw std::invalid_argument(
514                 "Inconsistent function signature for the Taylor derivative of sum_sq() in compact mode detected");
515             // LCOV_EXCL_STOP
516         }
517     }
518 
519     return f;
520 }
521 
522 } // namespace
523 
taylor_c_diff_func_dbl(llvm_state & s,std::uint32_t n_uvars,std::uint32_t batch_size,bool) const524 llvm::Function *sum_sq_impl::taylor_c_diff_func_dbl(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
525                                                     bool) const
526 {
527     return sum_sq_taylor_c_diff_func_impl<double>(s, *this, n_uvars, batch_size);
528 }
529 
taylor_c_diff_func_ldbl(llvm_state & s,std::uint32_t n_uvars,std::uint32_t batch_size,bool) const530 llvm::Function *sum_sq_impl::taylor_c_diff_func_ldbl(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
531                                                      bool) const
532 {
533     return sum_sq_taylor_c_diff_func_impl<long double>(s, *this, n_uvars, batch_size);
534 }
535 
536 #if defined(HEYOKA_HAVE_REAL128)
537 
taylor_c_diff_func_f128(llvm_state & s,std::uint32_t n_uvars,std::uint32_t batch_size,bool) const538 llvm::Function *sum_sq_impl::taylor_c_diff_func_f128(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
539                                                      bool) const
540 {
541     return sum_sq_taylor_c_diff_func_impl<mppp::real128>(s, *this, n_uvars, batch_size);
542 }
543 
544 #endif
545 
546 } // namespace detail
547 
sum_sq(std::vector<expression> args,std::uint32_t split)548 expression sum_sq(std::vector<expression> args, std::uint32_t split)
549 {
550     if (split < 2u) {
551         throw std::invalid_argument(
552             "The 'split' value for a sum of squares must be at least 2, but it is {} instead"_format(split));
553     }
554 
555     // Partition args so that all zeroes are at the end.
556     const auto n_end_it = std::stable_partition(args.begin(), args.end(), [](const expression &ex) {
557         return !std::holds_alternative<number>(ex.value()) || !is_zero(std::get<number>(ex.value()));
558     });
559 
560     // If we have one or more zeroes, eliminate them
561     args.erase(n_end_it, args.end());
562 
563     // Special cases.
564     if (args.empty()) {
565         return 0_dbl;
566     }
567 
568     if (args.size() == 1u) {
569         return square(std::move(args[0]));
570     }
571 
572     // NOTE: ret_seq will contain a sequence
573     // of sum_sqs each containing 'split' terms.
574     // tmp is a temporary vector
575     // used to accumulate the arguments to each
576     // sum_sq in ret_seq.
577     std::vector<expression> ret_seq, tmp;
578     for (auto &arg : args) {
579         // LCOV_EXCL_START
580 #if !defined(NDEBUG)
581         // NOTE: there cannot be zero numbers here because
582         // we removed them.
583         if (auto nptr = std::get_if<number>(&arg.value()); nptr && is_zero(*nptr)) {
584             assert(false);
585         }
586 #endif
587         // LCOV_EXCL_STOP
588 
589         tmp.push_back(std::move(arg));
590         if (tmp.size() == split) {
591             // NOTE: after the move, tmp is guaranteed to be empty.
592             ret_seq.emplace_back(func{detail::sum_sq_impl{std::move(tmp)}});
593             assert(tmp.empty());
594         }
595     }
596 
597     // NOTE: tmp is not empty if 'split' does not divide
598     // exactly args.size(). In such a case, we need to do the
599     // last iteration manually.
600     if (!tmp.empty()) {
601         // NOTE: contrary to the previous loop, here we could
602         // in principle end up creating a sum_sq_impl with only one
603         // term. In such a case, for consistency with the general
604         // behaviour of sum_sq({arg}), return arg*arg directly.
605         if (tmp.size() == 1u) {
606             ret_seq.emplace_back(square(std::move(tmp[0])));
607         } else {
608             ret_seq.emplace_back(func{detail::sum_sq_impl{std::move(tmp)}});
609         }
610     }
611 
612     // Perform a sum over the sum_sqs.
613     return sum(std::move(ret_seq));
614 }
615 
616 } // namespace heyoka
617 
618 HEYOKA_S11N_FUNC_EXPORT_IMPLEMENT(heyoka::detail::sum_sq_impl)
619