1 // Copyright 2020, 2021 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com)
2 //
3 // This file is part of the heyoka library.
4 //
5 // This Source Code Form is subject to the terms of the Mozilla
6 // Public License v. 2.0. If a copy of the MPL was not distributed
7 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
8
9 #include <heyoka/config.hpp>
10
11 #include <algorithm>
12 #include <cassert>
13 #include <cstdint>
14 #include <ostream>
15 #include <stdexcept>
16 #include <string>
17 #include <type_traits>
18 #include <unordered_map>
19 #include <utility>
20 #include <variant>
21 #include <vector>
22
23 #include <boost/numeric/conversion/cast.hpp>
24
25 #include <fmt/format.h>
26
27 #include <llvm/IR/Attributes.h>
28 #include <llvm/IR/BasicBlock.h>
29 #include <llvm/IR/DerivedTypes.h>
30 #include <llvm/IR/Function.h>
31 #include <llvm/IR/IRBuilder.h>
32 #include <llvm/IR/LLVMContext.h>
33 #include <llvm/IR/Module.h>
34 #include <llvm/IR/Value.h>
35
36 #if defined(HEYOKA_HAVE_REAL128)
37
38 #include <mp++/real128.hpp>
39
40 #endif
41
42 #include <heyoka/detail/llvm_helpers.hpp>
43 #include <heyoka/detail/string_conv.hpp>
44 #include <heyoka/detail/type_traits.hpp>
45 #include <heyoka/expression.hpp>
46 #include <heyoka/func.hpp>
47 #include <heyoka/llvm_state.hpp>
48 #include <heyoka/math/square.hpp>
49 #include <heyoka/math/sum.hpp>
50 #include <heyoka/math/sum_sq.hpp>
51 #include <heyoka/number.hpp>
52 #include <heyoka/param.hpp>
53 #include <heyoka/s11n.hpp>
54 #include <heyoka/taylor.hpp>
55 #include <heyoka/variable.hpp>
56
57 #if defined(_MSC_VER) && !defined(__clang__)
58
59 // NOTE: MSVC has issues with the other "using"
60 // statement form.
61 using namespace fmt::literals;
62
63 #else
64
65 using fmt::literals::operator""_format;
66
67 #endif
68
69 namespace heyoka
70 {
71
72 namespace detail
73 {
74
sum_sq_impl()75 sum_sq_impl::sum_sq_impl() : sum_sq_impl(std::vector<expression>{}) {}
76
sum_sq_impl(std::vector<expression> v)77 sum_sq_impl::sum_sq_impl(std::vector<expression> v) : func_base("sum_sq", std::move(v)) {}
78
to_stream(std::ostream & os) const79 void sum_sq_impl::to_stream(std::ostream &os) const
80 {
81 if (args().size() == 1u) {
82 // NOTE: avoid brackets if there's only 1 argument.
83 os << args()[0] << "**2";
84 } else {
85 os << '(';
86
87 for (decltype(args().size()) i = 0; i < args().size(); ++i) {
88 os << args()[i] << "**2";
89 if (i != args().size() - 1u) {
90 os << " + ";
91 }
92 }
93
94 os << ')';
95 }
96 }
97
98 template <typename T>
diff_impl(std::unordered_map<const void *,expression> & func_map,const T & x) const99 expression sum_sq_impl::diff_impl(std::unordered_map<const void *, expression> &func_map, const T &x) const
100 {
101 std::vector<expression> terms;
102 terms.reserve(args().size());
103
104 for (const auto &arg : args()) {
105 terms.push_back(arg * detail::diff(func_map, arg, x));
106 }
107
108 return 2_dbl * sum(std::move(terms));
109 }
110
diff(std::unordered_map<const void *,expression> & func_map,const std::string & s) const111 expression sum_sq_impl::diff(std::unordered_map<const void *, expression> &func_map, const std::string &s) const
112 {
113 return diff_impl(func_map, s);
114 }
115
diff(std::unordered_map<const void *,expression> & func_map,const param & p) const116 expression sum_sq_impl::diff(std::unordered_map<const void *, expression> &func_map, const param &p) const
117 {
118 return diff_impl(func_map, p);
119 }
120
121 namespace
122 {
123
124 template <typename T>
sum_sq_taylor_diff_impl(llvm_state & s,const sum_sq_impl & sf,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t batch_size)125 llvm::Value *sum_sq_taylor_diff_impl(llvm_state &s, const sum_sq_impl &sf, const std::vector<std::uint32_t> &deps,
126 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, std::uint32_t n_uvars,
127 std::uint32_t order, std::uint32_t batch_size)
128 {
129 // NOTE: this is prevented in the implementation
130 // of the sum_sq() function.
131 assert(!sf.args().empty());
132
133 if (!deps.empty()) {
134 // LCOV_EXCL_START
135 throw std::invalid_argument("The vector of hidden dependencies in the Taylor diff for a sum of squares "
136 "should be empty, but instead it has a size of {}"_format(deps.size()));
137 // LCOV_EXCL_STOP
138 }
139
140 auto &builder = s.builder();
141
142 // Each vector in v_sums will contain the terms in the summation in the formula
143 // for the computation of the Taylor derivative of square() for each argument in sf.
144 std::vector<std::vector<llvm::Value *>> v_sums;
145 v_sums.resize(boost::numeric_cast<decltype(v_sums.size())>(sf.args().size()));
146
147 // This function calculates the j-th term in the summation in the formula for the
148 // Taylor derivative of square() for each k-th argument in sf, and appends the result
149 // to the k-th entry in v_sums.
150 auto looper = [&](std::uint32_t j) {
151 for (decltype(sf.args().size()) k = 0; k < sf.args().size(); ++k) {
152 std::visit(
153 [&](const auto &v) {
154 using type = detail::uncvref_t<decltype(v)>;
155
156 if constexpr (std::is_same_v<type, variable>) {
157 // Variable.
158 const auto u_idx = uname_to_index(v.name());
159
160 auto v0 = taylor_fetch_diff(arr, u_idx, order - j, n_uvars);
161 auto v1 = taylor_fetch_diff(arr, u_idx, j, n_uvars);
162
163 v_sums[k].push_back(builder.CreateFMul(v0, v1));
164 } else if constexpr (is_num_param_v<type>) {
165 // Number/param.
166
167 // NOTE: for number/params, all terms in the summation
168 // will be zero. Thus, ensure that v_sums[k] just
169 // contains a single zero.
170 if (v_sums[k].empty()) {
171 v_sums[k].push_back(vector_splat(builder, codegen<T>(s, number{0.}), batch_size));
172 }
173 } else {
174 // LCOV_EXCL_START
175 throw std::invalid_argument(
176 "An invalid argument type was encountered while trying to build the "
177 "Taylor derivative of a sum of squares");
178 // LCOV_EXCL_STOP
179 }
180 },
181 sf.args()[k].value());
182 }
183 };
184
185 if (order % 2u == 1u) {
186 // Odd order.
187 for (std::uint32_t j = 0; j <= (order - 1u) / 2u; ++j) {
188 looper(j);
189 }
190
191 // Pairwise sum each item in v_sums.
192 std::vector<llvm::Value *> tmp;
193 tmp.reserve(boost::numeric_cast<decltype(tmp.size())>(v_sums.size()));
194 for (auto &v_sum : v_sums) {
195 tmp.push_back(pairwise_sum(builder, v_sum));
196 }
197
198 // Sum the sums.
199 pairwise_sum(builder, tmp);
200
201 // Multiply by 2 and return.
202 return builder.CreateFAdd(tmp[0], tmp[0]);
203 } else {
204 // Even order.
205 for (std::uint32_t j = 0; order > 0u && j <= (order - 2u) / 2u; ++j) {
206 looper(j);
207 }
208
209 // Pairwise sum each item in v_sums, multiply the result by 2 and add the
210 // term outside the summation.
211 std::vector<llvm::Value *> tmp;
212 tmp.reserve(boost::numeric_cast<decltype(tmp.size())>(v_sums.size()));
213 for (decltype(sf.args().size()) k = 0; k < sf.args().size(); ++k) {
214 // Compute the term outside the summation and store it in tmp.
215 tmp.push_back(std::visit(
216 [&](const auto &v) -> llvm::Value * {
217 using type = detail::uncvref_t<decltype(v)>;
218
219 if constexpr (std::is_same_v<type, variable>) {
220 // Variable.
221 auto val = taylor_fetch_diff(arr, uname_to_index(v.name()), order / 2u, n_uvars);
222 return builder.CreateFMul(val, val);
223 } else if constexpr (is_num_param_v<type>) {
224 // Number/param.
225 if (order == 0u) {
226 auto val = taylor_codegen_numparam<T>(s, v, par_ptr, batch_size);
227 return builder.CreateFMul(val, val);
228 } else {
229 return vector_splat(builder, codegen<T>(s, number{0.}), batch_size);
230 }
231 } else {
232 // LCOV_EXCL_START
233 throw std::invalid_argument(
234 "An invalid argument type was encountered while trying to build the "
235 "Taylor derivative of a sum of squares");
236 // LCOV_EXCL_STOP
237 }
238 },
239 sf.args()[k].value()));
240
241 // NOTE: avoid doing the pairwise sum if the order is 0, in which case
242 // the items in v_sums are all empty and tmp.back() contains only the term
243 // outside the summation.
244 if (order > 0u) {
245 auto p_sum = pairwise_sum(builder, v_sums[k]);
246 // Muliply the pairwise sum by 2.
247 p_sum = builder.CreateFAdd(p_sum, p_sum);
248 // Add it to the term outside the sum.
249 tmp.back() = builder.CreateFAdd(p_sum, tmp.back());
250 }
251 }
252
253 // Sum the sums and return.
254 return pairwise_sum(builder, tmp);
255 }
256 }
257
258 } // namespace
259
taylor_diff_dbl(llvm_state & s,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value *,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t,std::uint32_t batch_size,bool) const260 llvm::Value *sum_sq_impl::taylor_diff_dbl(llvm_state &s, const std::vector<std::uint32_t> &deps,
261 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *,
262 std::uint32_t n_uvars, std::uint32_t order, std::uint32_t,
263 std::uint32_t batch_size, bool) const
264 {
265 return sum_sq_taylor_diff_impl<double>(s, *this, deps, arr, par_ptr, n_uvars, order, batch_size);
266 }
267
taylor_diff_ldbl(llvm_state & s,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value *,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t,std::uint32_t batch_size,bool) const268 llvm::Value *sum_sq_impl::taylor_diff_ldbl(llvm_state &s, const std::vector<std::uint32_t> &deps,
269 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *,
270 std::uint32_t n_uvars, std::uint32_t order, std::uint32_t,
271 std::uint32_t batch_size, bool) const
272 {
273 return sum_sq_taylor_diff_impl<long double>(s, *this, deps, arr, par_ptr, n_uvars, order, batch_size);
274 }
275
276 #if defined(HEYOKA_HAVE_REAL128)
277
taylor_diff_f128(llvm_state & s,const std::vector<std::uint32_t> & deps,const std::vector<llvm::Value * > & arr,llvm::Value * par_ptr,llvm::Value *,std::uint32_t n_uvars,std::uint32_t order,std::uint32_t,std::uint32_t batch_size,bool) const278 llvm::Value *sum_sq_impl::taylor_diff_f128(llvm_state &s, const std::vector<std::uint32_t> &deps,
279 const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, llvm::Value *,
280 std::uint32_t n_uvars, std::uint32_t order, std::uint32_t,
281 std::uint32_t batch_size, bool) const
282 {
283 return sum_sq_taylor_diff_impl<mppp::real128>(s, *this, deps, arr, par_ptr, n_uvars, order, batch_size);
284 }
285
286 #endif
287
288 namespace
289 {
290
291 template <typename T>
sum_sq_taylor_c_diff_func_impl(llvm_state & s,const sum_sq_impl & sf,std::uint32_t n_uvars,std::uint32_t batch_size)292 llvm::Function *sum_sq_taylor_c_diff_func_impl(llvm_state &s, const sum_sq_impl &sf, std::uint32_t n_uvars,
293 std::uint32_t batch_size)
294 {
295 // NOTE: this is prevented in the implementation
296 // of the sum() function.
297 assert(!sf.args().empty());
298
299 auto &md = s.module();
300 auto &builder = s.builder();
301 auto &context = s.context();
302
303 // Fetch the floating-point type.
304 auto val_t = to_llvm_vector_type<T>(context, batch_size);
305
306 // Build the vector of arguments needed to determine the function name.
307 std::vector<std::variant<variable, number, param>> nm_args;
308 nm_args.reserve(static_cast<decltype(nm_args.size())>(sf.args().size()));
309 for (const auto &arg : sf.args()) {
310 nm_args.push_back(std::visit(
311 [](const auto &v) -> std::variant<variable, number, param> {
312 using type = detail::uncvref_t<decltype(v)>;
313
314 if constexpr (std::is_same_v<type, func>) {
315 // LCOV_EXCL_START
316 assert(false);
317 throw;
318 // LCOV_EXCL_STOP
319 } else {
320 return v;
321 }
322 },
323 arg.value()));
324 }
325
326 // Fetch the function name and arguments.
327 const auto na_pair = taylor_c_diff_func_name_args<T>(context, "sum_sq", n_uvars, batch_size, nm_args);
328 const auto &fname = na_pair.first;
329 const auto &fargs = na_pair.second;
330
331 // Try to see if we already created the function.
332 auto f = md.getFunction(fname);
333
334 if (f == nullptr) {
335 // The function was not created before, do it now.
336
337 // Fetch the current insertion block.
338 auto orig_bb = builder.GetInsertBlock();
339
340 // The return type is val_t.
341 auto *ft = llvm::FunctionType::get(val_t, fargs, false);
342 // Create the function
343 f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &md);
344 assert(f != nullptr);
345 // NOTE: force inline.
346 f->addFnAttr(llvm::Attribute::AlwaysInline);
347
348 // Fetch the necessary function arguments.
349 auto order = f->args().begin();
350 auto diff_arr = f->args().begin() + 2;
351 auto par_ptr = f->args().begin() + 3;
352 auto terms = f->args().begin() + 5;
353
354 // Create a new basic block to start insertion into.
355 builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
356
357 // Create the accumulators for each argument in the summation, and init them to zero.
358 std::vector<llvm::Value *> v_accs;
359 v_accs.resize(boost::numeric_cast<decltype(v_accs.size())>(sf.args().size()));
360 for (auto &acc : v_accs) {
361 acc = builder.CreateAlloca(val_t);
362 builder.CreateStore(vector_splat(builder, codegen<T>(s, number{0.}), batch_size), acc);
363 }
364
365 // Create the return value.
366 auto retval = builder.CreateAlloca(val_t);
367
368 // This function calculates the j-th term in the summation in the formula for the
369 // Taylor derivative of square() for each k-th argument in sf, and accumulates the result
370 // into the k-th entry in v_accs.
371 auto looper = [&](llvm::Value *j) {
372 for (decltype(sf.args().size()) k = 0; k < sf.args().size(); ++k) {
373 std::visit(
374 [&](const auto &v) {
375 using type = detail::uncvref_t<decltype(v)>;
376
377 if constexpr (std::is_same_v<type, variable>) {
378 // Variable.
379 auto v0 = taylor_c_load_diff(s, diff_arr, n_uvars, builder.CreateSub(order, j), terms + k);
380 auto v1 = taylor_c_load_diff(s, diff_arr, n_uvars, j, terms + k);
381
382 // Update the k-th accumulator.
383 builder.CreateStore(
384 builder.CreateFAdd(builder.CreateLoad(v_accs[k]), builder.CreateFMul(v0, v1)),
385 v_accs[k]);
386 } else if constexpr (is_num_param_v<type>) {
387 // Number/param: nothing to do, leave the accumulator to zero.
388 } else {
389 // LCOV_EXCL_START
390 throw std::invalid_argument(
391 "An invalid argument type was encountered while trying to build the "
392 "Taylor derivative of a sum of squares in compact mode");
393 // LCOV_EXCL_STOP
394 }
395 },
396 sf.args()[k].value());
397 }
398 };
399
400 // Distinguish odd/even cases.
401 const auto odd_or_even
402 = builder.CreateICmpEQ(builder.CreateURem(order, builder.getInt32(2)), builder.getInt32(1));
403
404 llvm_if_then_else(
405 s, odd_or_even,
406 [&]() {
407 // Odd order.
408 const auto loop_end = builder.CreateAdd(
409 builder.CreateUDiv(builder.CreateSub(order, builder.getInt32(1)), builder.getInt32(2)),
410 builder.getInt32(1));
411
412 llvm_loop_u32(s, builder.getInt32(0), loop_end, [&](llvm::Value *j) { looper(j); });
413
414 // Run a pairwise summation on the vector of accumulators.
415 std::vector<llvm::Value *> tmp;
416 tmp.reserve(v_accs.size());
417 for (auto &acc : v_accs) {
418 tmp.push_back(builder.CreateLoad(acc));
419 }
420 auto ret = pairwise_sum(builder, tmp);
421
422 // Return 2 * ret.
423 builder.CreateStore(builder.CreateFAdd(ret, ret), retval);
424 },
425 [&]() {
426 // Even order.
427 // NOTE: run the loop only if we are not at order 0.
428 llvm_if_then_else(
429 s, builder.CreateICmpEQ(order, builder.getInt32(0)),
430 []() {
431 // Order 0, do nothing.
432 },
433 [&]() {
434 // Order 2 or higher.
435 const auto loop_end = builder.CreateAdd(
436 builder.CreateUDiv(builder.CreateSub(order, builder.getInt32(2)), builder.getInt32(2)),
437 builder.getInt32(1));
438
439 llvm_loop_u32(s, builder.getInt32(0), loop_end, [&](llvm::Value *j) { looper(j); });
440 });
441
442 // Multiply each accumulator by two and add the term outside the summation.
443 std::vector<llvm::Value *> tmp;
444 tmp.reserve(v_accs.size());
445 for (decltype(sf.args().size()) k = 0; k < sf.args().size(); ++k) {
446 // Load the current accumulator and multiply it by 2.
447 auto acc_val = builder.CreateLoad(v_accs[k]);
448 auto acc2 = builder.CreateFAdd(acc_val, acc_val);
449
450 // Load the external term.
451 auto ex_term = std::visit( // LCOV_EXCL_LINE
452 [&](const auto &v) -> llvm::Value * {
453 using type = detail::uncvref_t<decltype(v)>;
454
455 if constexpr (std::is_same_v<type, variable>) {
456 // Variable.
457 auto val = taylor_c_load_diff(
458 s, diff_arr, n_uvars, builder.CreateUDiv(order, builder.getInt32(2)), terms + k);
459 return builder.CreateFMul(val, val);
460 } else if constexpr (is_num_param_v<type>) {
461 // Number/param.
462 auto ret = builder.CreateAlloca(val_t);
463
464 llvm_if_then_else(
465 s, builder.CreateICmpEQ(order, builder.getInt32(0)),
466 [&]() {
467 // Order 0, store the num/param.
468 builder.CreateStore(
469 taylor_c_diff_numparam_codegen(s, v, terms + k, par_ptr, batch_size), ret);
470 },
471 [&]() {
472 // Order 2 or higher, store zero.
473 builder.CreateStore(
474 vector_splat(builder, codegen<T>(s, number{0.}), batch_size), ret);
475 });
476
477 auto val = builder.CreateLoad(ret);
478
479 return builder.CreateFMul(val, val);
480 } else {
481 // LCOV_EXCL_START
482 throw std::invalid_argument(
483 "An invalid argument type was encountered while trying to build the "
484 "Taylor derivative of a sum of squares in compact mode");
485 // LCOV_EXCL_STOP
486 }
487 },
488 sf.args()[k].value());
489
490 // Compute the Taylor derivative for the current argument.
491 tmp.push_back(builder.CreateFAdd(acc2, ex_term));
492 }
493
494 // Return the pairwise sum.
495 builder.CreateStore(pairwise_sum(builder, tmp), retval);
496 });
497
498 // Create the return value.
499 builder.CreateRet(builder.CreateLoad(retval));
500
501 // Verify.
502 s.verify_function(f);
503
504 // Restore the original insertion block.
505 builder.SetInsertPoint(orig_bb);
506 } else {
507 // The function was created before. Check if the signatures match.
508 // NOTE: there could be a mismatch if the derivative function was created
509 // and then optimised - optimisation might remove arguments which are compile-time
510 // constants.
511 if (!compare_function_signature(f, val_t, fargs)) {
512 // LCOV_EXCL_START
513 throw std::invalid_argument(
514 "Inconsistent function signature for the Taylor derivative of sum_sq() in compact mode detected");
515 // LCOV_EXCL_STOP
516 }
517 }
518
519 return f;
520 }
521
522 } // namespace
523
taylor_c_diff_func_dbl(llvm_state & s,std::uint32_t n_uvars,std::uint32_t batch_size,bool) const524 llvm::Function *sum_sq_impl::taylor_c_diff_func_dbl(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
525 bool) const
526 {
527 return sum_sq_taylor_c_diff_func_impl<double>(s, *this, n_uvars, batch_size);
528 }
529
taylor_c_diff_func_ldbl(llvm_state & s,std::uint32_t n_uvars,std::uint32_t batch_size,bool) const530 llvm::Function *sum_sq_impl::taylor_c_diff_func_ldbl(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
531 bool) const
532 {
533 return sum_sq_taylor_c_diff_func_impl<long double>(s, *this, n_uvars, batch_size);
534 }
535
536 #if defined(HEYOKA_HAVE_REAL128)
537
taylor_c_diff_func_f128(llvm_state & s,std::uint32_t n_uvars,std::uint32_t batch_size,bool) const538 llvm::Function *sum_sq_impl::taylor_c_diff_func_f128(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
539 bool) const
540 {
541 return sum_sq_taylor_c_diff_func_impl<mppp::real128>(s, *this, n_uvars, batch_size);
542 }
543
544 #endif
545
546 } // namespace detail
547
sum_sq(std::vector<expression> args,std::uint32_t split)548 expression sum_sq(std::vector<expression> args, std::uint32_t split)
549 {
550 if (split < 2u) {
551 throw std::invalid_argument(
552 "The 'split' value for a sum of squares must be at least 2, but it is {} instead"_format(split));
553 }
554
555 // Partition args so that all zeroes are at the end.
556 const auto n_end_it = std::stable_partition(args.begin(), args.end(), [](const expression &ex) {
557 return !std::holds_alternative<number>(ex.value()) || !is_zero(std::get<number>(ex.value()));
558 });
559
560 // If we have one or more zeroes, eliminate them
561 args.erase(n_end_it, args.end());
562
563 // Special cases.
564 if (args.empty()) {
565 return 0_dbl;
566 }
567
568 if (args.size() == 1u) {
569 return square(std::move(args[0]));
570 }
571
572 // NOTE: ret_seq will contain a sequence
573 // of sum_sqs each containing 'split' terms.
574 // tmp is a temporary vector
575 // used to accumulate the arguments to each
576 // sum_sq in ret_seq.
577 std::vector<expression> ret_seq, tmp;
578 for (auto &arg : args) {
579 // LCOV_EXCL_START
580 #if !defined(NDEBUG)
581 // NOTE: there cannot be zero numbers here because
582 // we removed them.
583 if (auto nptr = std::get_if<number>(&arg.value()); nptr && is_zero(*nptr)) {
584 assert(false);
585 }
586 #endif
587 // LCOV_EXCL_STOP
588
589 tmp.push_back(std::move(arg));
590 if (tmp.size() == split) {
591 // NOTE: after the move, tmp is guaranteed to be empty.
592 ret_seq.emplace_back(func{detail::sum_sq_impl{std::move(tmp)}});
593 assert(tmp.empty());
594 }
595 }
596
597 // NOTE: tmp is not empty if 'split' does not divide
598 // exactly args.size(). In such a case, we need to do the
599 // last iteration manually.
600 if (!tmp.empty()) {
601 // NOTE: contrary to the previous loop, here we could
602 // in principle end up creating a sum_sq_impl with only one
603 // term. In such a case, for consistency with the general
604 // behaviour of sum_sq({arg}), return arg*arg directly.
605 if (tmp.size() == 1u) {
606 ret_seq.emplace_back(square(std::move(tmp[0])));
607 } else {
608 ret_seq.emplace_back(func{detail::sum_sq_impl{std::move(tmp)}});
609 }
610 }
611
612 // Perform a sum over the sum_sqs.
613 return sum(std::move(ret_seq));
614 }
615
616 } // namespace heyoka
617
618 HEYOKA_S11N_FUNC_EXPORT_IMPLEMENT(heyoka::detail::sum_sq_impl)
619