1 #ifndef STAN_MATH_PRIM_FUN_ACCUMULATOR_HPP 2 #define STAN_MATH_PRIM_FUN_ACCUMULATOR_HPP 3 4 #include <stan/math/prim/fun/Eigen.hpp> 5 #include <stan/math/prim/meta.hpp> 6 #include <stan/math/prim/fun/sum.hpp> 7 #include <vector> 8 #include <type_traits> 9 10 namespace stan { 11 namespace math { 12 13 /** 14 * Class to accumulate values and eventually return their sum. If 15 * no values are ever added, the return value is 0. 16 * 17 * This class is useful for speeding up autodiff of long sums 18 * because it uses the <code>sum()</code> operation (either from 19 * <code>stan::math</code> or one defined by argument-dependent lookup. 20 * 21 * @tparam T Type of scalar added 22 */ 23 template <typename T, typename = void> 24 class accumulator { 25 private: 26 std::vector<T> buf_; 27 28 public: 29 /** 30 * Add the specified arithmetic type value to the buffer after 31 * static casting it to the class type <code>T</code>. 32 * 33 * <p>See the std library doc for <code>std::is_arithmetic</code> 34 * for information on what counts as an arithmetic type. 35 * 36 * @tparam S Type of argument 37 * @param x Value to add 38 */ 39 template <typename S, typename = require_stan_scalar_t<S>> add(S x)40 inline void add(S x) { 41 buf_.push_back(x); 42 } 43 44 /** 45 * Add each entry in the specified matrix, vector, or row vector 46 * of values to the buffer. 47 * 48 * @tparam S type of the matrix 49 * @param m Matrix of values to add 50 */ 51 template <typename S, require_matrix_t<S>* = nullptr> add(const S & m)52 inline void add(const S& m) { 53 buf_.push_back(stan::math::sum(m)); 54 } 55 56 /** 57 * Recursively add each entry in the specified standard vector 58 * to the buffer. This will allow vectors of primitives, 59 * autodiff variables to be added; if the vector entries 60 * are collections, their elements are recursively added. 61 * 62 * @tparam S Type of value to recursively add. 63 * @param xs Vector of entries to add 64 */ 65 template <typename S> add(const std::vector<S> & xs)66 inline void add(const std::vector<S>& xs) { 67 for (size_t i = 0; i < xs.size(); ++i) { 68 this->add(xs[i]); 69 } 70 } 71 72 #ifdef STAN_OPENCL 73 74 /** 75 * Sum each entry and then push to the buffer. 76 * @tparam S A Type inheriting from `matrix_cl_base` 77 * @param x An OpenCL matrix 78 */ 79 template <typename S, 80 require_all_kernel_expressions_and_none_scalar_t<S>* = nullptr> add(const S & xs)81 inline void add(const S& xs) { 82 buf_.push_back(stan::math::sum(xs)); 83 } 84 85 #endif 86 87 /** 88 * Return the sum of the accumulated values. 89 * 90 * @return Sum of accumulated values. 91 */ sum() const92 inline T sum() const { return stan::math::sum(buf_); } 93 }; 94 95 } // namespace math 96 } // namespace stan 97 98 #endif 99