1 #ifndef STAN_MATH_PRIM_FUN_ACCUMULATOR_HPP
2 #define STAN_MATH_PRIM_FUN_ACCUMULATOR_HPP
3 
4 #include <stan/math/prim/fun/Eigen.hpp>
5 #include <stan/math/prim/meta.hpp>
6 #include <stan/math/prim/fun/sum.hpp>
7 #include <vector>
8 #include <type_traits>
9 
10 namespace stan {
11 namespace math {
12 
13 /**
14  * Class to accumulate values and eventually return their sum.  If
15  * no values are ever added, the return value is 0.
16  *
17  * This class is useful for speeding up autodiff of long sums
18  * because it uses the <code>sum()</code> operation (either from
19  * <code>stan::math</code> or one defined by argument-dependent lookup.
20  *
21  * @tparam T Type of scalar added
22  */
23 template <typename T, typename = void>
24 class accumulator {
25  private:
26   std::vector<T> buf_;
27 
28  public:
29   /**
30    * Add the specified arithmetic type value to the buffer after
31    * static casting it to the class type <code>T</code>.
32    *
33    * <p>See the std library doc for <code>std::is_arithmetic</code>
34    * for information on what counts as an arithmetic type.
35    *
36    * @tparam S Type of argument
37    * @param x Value to add
38    */
39   template <typename S, typename = require_stan_scalar_t<S>>
add(S x)40   inline void add(S x) {
41     buf_.push_back(x);
42   }
43 
44   /**
45    * Add each entry in the specified matrix, vector, or row vector
46    * of values to the buffer.
47    *
48    * @tparam S type of the matrix
49    * @param m Matrix of values to add
50    */
51   template <typename S, require_matrix_t<S>* = nullptr>
add(const S & m)52   inline void add(const S& m) {
53     buf_.push_back(stan::math::sum(m));
54   }
55 
56   /**
57    * Recursively add each entry in the specified standard vector
58    * to the buffer.  This will allow vectors of primitives,
59    * autodiff variables to be added; if the vector entries
60    * are collections, their elements are recursively added.
61    *
62    * @tparam S Type of value to recursively add.
63    * @param xs Vector of entries to add
64    */
65   template <typename S>
add(const std::vector<S> & xs)66   inline void add(const std::vector<S>& xs) {
67     for (size_t i = 0; i < xs.size(); ++i) {
68       this->add(xs[i]);
69     }
70   }
71 
72 #ifdef STAN_OPENCL
73 
74   /**
75    * Sum each entry and then push to the buffer.
76    * @tparam S A Type inheriting from `matrix_cl_base`
77    * @param x An OpenCL matrix
78    */
79   template <typename S,
80             require_all_kernel_expressions_and_none_scalar_t<S>* = nullptr>
add(const S & xs)81   inline void add(const S& xs) {
82     buf_.push_back(stan::math::sum(xs));
83   }
84 
85 #endif
86 
87   /**
88    * Return the sum of the accumulated values.
89    *
90    * @return Sum of accumulated values.
91    */
sum() const92   inline T sum() const { return stan::math::sum(buf_); }
93 };
94 
95 }  // namespace math
96 }  // namespace stan
97 
98 #endif
99