1 #ifndef STAN_MATH_PRIM_META_OPERANDS_AND_PARTIALS_HPP
2 #define STAN_MATH_PRIM_META_OPERANDS_AND_PARTIALS_HPP
3 
4 #include <stan/math/prim/fun/Eigen.hpp>
5 #include <stan/math/prim/meta/require_generics.hpp>
6 #include <stan/math/prim/meta/return_type.hpp>
7 #include <stan/math/prim/functor/broadcast_array.hpp>
8 #include <vector>
9 #include <type_traits>
10 #include <tuple>
11 
12 namespace stan {
13 namespace math {
14 template <typename Op1 = double, typename Op2 = double, typename Op3 = double,
15           typename Op4 = double, typename Op5 = double,
16           typename T_return_type = return_type_t<Op1, Op2, Op3, Op4, Op5>>
17 class operands_and_partials;  // Forward declaration
18 
19 namespace internal {
20 template <typename ViewElt, typename Op, typename = void>
21 struct ops_partials_edge;
22 /**
23  * Class representing an edge with an inner type of double. This class
24  *  should never be used by the program and only exists so that
25  *  developer can write functions using `operands_and_partials` that works for
26  *  double, vars, and fvar types.
27  * @tparam ViewElt One of `double`, `var`, `fvar`.
28  * @tparam Op The type of the input operand. It's scalar type
29  *  for this specialization must be an `Arithmetic`
30  */
31 template <typename ViewElt, typename Op>
32 struct ops_partials_edge<ViewElt, Op, require_st_arithmetic<Op>> {
33   using inner_op = std::conditional_t<is_eigen<value_type_t<Op>>::value,
34                                       value_type_t<Op>, Op>;
35   using partials_t = empty_broadcast_array<ViewElt, inner_op>;
36   /**
37    * The `partials_` are always called in `if` statements that will be
38    *  removed by the dead code elimination pass of the compiler. So if we ever
39    *  move up to C++17 these can be made into `constexpr if` and
40    *  this can be deleted.
41    */
42   partials_t partials_;
43   empty_broadcast_array<partials_t, inner_op> partials_vec_;
44   static constexpr double operands_{0};
ops_partials_edgestan::math::internal::ops_partials_edge45   ops_partials_edge() {}
46   template <typename T>
ops_partials_edgestan::math::internal::ops_partials_edge47   explicit ops_partials_edge(T&& /* op */) noexcept {}
48 
49   /**
50    * Get the operand for the edge. For doubles this is a compile time
51    * expression returning zero.
52    */
operandstan::math::internal::ops_partials_edge53   static constexpr double operand() noexcept { return 0.0; }
54 
55   /**
56    * Get the partial for the edge. For doubles this is a compile time
57    * expression returning zero.
58    */
partialstan::math::internal::ops_partials_edge59   static constexpr double partial() noexcept { return 0.0; }
60   /**
61    * Return the tangent for the edge. For doubles this is a compile time
62    * expression returning zero.
63    */
dxstan::math::internal::ops_partials_edge64   static constexpr double dx() noexcept { return 0.0; }
65   /**
66    * Return the size of the operand for the edge. For doubles this is a compile
67    * time expression returning zero.
68    */
sizestan::math::internal::ops_partials_edge69   static constexpr int size() noexcept { return 0; }  // reverse mode
70 
71  private:
72   template <typename, typename, typename, typename, typename, typename>
73   friend class stan::math::operands_and_partials;
74 };
75 template <typename ViewElt, typename Op>
76 constexpr double
77     ops_partials_edge<ViewElt, Op, require_st_arithmetic<Op>>::operands_;
78 }  // namespace internal
79 
80 /** \ingroup type_trait
81  * \callergraph
82  * This template builds partial derivatives with respect to a
83  * set of
84  * operands. There are two reason for the generality of this
85  * class. The first is to handle vector and scalar arguments
86  * without needing to write additional code. The second is to use
87  * this class for writing probability distributions that handle
88  * primitives, reverse mode, and forward mode variables
89  * seamlessly.
90  *
91  * Conceptually, this class is used when we want to manually calculate
92  * the derivative of a function and store this manual result on the
93  * autodiff stack in a sort of "compressed" form. Think of it like an
94  * easy-to-use interface to rev/core/precomputed_gradients.
95  *
96  * This class supports nested container ("multivariate") use-cases
97  * as well by exposing a partials_vec_ member on edges of the
98  * appropriate type.
99  *
100  * This base template is instantiated when all operands are
101  * primitives and we don't want to calculate derivatives at
102  * all. So all Op1 - Op5 must be arithmetic primitives
103  * like int or double. This is controlled with the
104  * T_return_type type parameter.
105  *
106  * @tparam Op1 type of the first operand
107  * @tparam Op2 type of the second operand
108  * @tparam Op3 type of the third operand
109  * @tparam Op4 type of the fourth operand
110  * @tparam Op5 type of the fifth operand
111  * @tparam T_return_type return type of the expression. This defaults
112  *   to calling a template metaprogram that calculates the scalar
113  *   promotion of Op1..Op4
114  */
115 template <typename Op1, typename Op2, typename Op3, typename Op4, typename Op5,
116           typename T_return_type>
117 class operands_and_partials {
118  public:
operands_and_partials(const Op1 &)119   explicit operands_and_partials(const Op1& /* op1 */) noexcept {}
operands_and_partials(const Op1 &,const Op2 &)120   operands_and_partials(const Op1& /* op1 */, const Op2& /* op2 */) noexcept {}
operands_and_partials(const Op1 &,const Op2 &,const Op3 &)121   operands_and_partials(const Op1& /* op1 */, const Op2& /* op2 */,
122                         const Op3& /* op3 */) noexcept {}
operands_and_partials(const Op1 &,const Op2 &,const Op3 &,const Op4 &)123   operands_and_partials(const Op1& /* op1 */, const Op2& /* op2 */,
124                         const Op3& /* op3 */, const Op4& /* op4 */) noexcept {}
operands_and_partials(const Op1 &,const Op2 &,const Op3 &,const Op4 &,const Op5 &)125   operands_and_partials(const Op1& /* op1 */, const Op2& /* op2 */,
126                         const Op3& /* op3 */, const Op4& /* op4 */,
127                         const Op5& /* op5 */) noexcept {}
128 
129   /** \ingroup type_trait
130    * Build the node to be stored on the autodiff graph.
131    * This should contain both the value and the tangent.
132    *
133    * For scalars (this implementation), we don't calculate any derivatives.
134    * For reverse mode, we end up returning a type of var that will calculate
135    * the appropriate adjoint using the stored operands and partials.
136    * Forward mode just calculates the tangent on the spot and returns it in
137    * a vanilla fvar.
138    *
139    * @param value the return value of the function we are compressing
140    * @return the value with its derivative
141    */
build(double value) const142   inline double build(double value) const noexcept { return value; }
143 
144   // These will always be 0 size base template instantiations (above).
145   internal::ops_partials_edge<double, std::decay_t<Op1>> edge1_;
146   internal::ops_partials_edge<double, std::decay_t<Op2>> edge2_;
147   internal::ops_partials_edge<double, std::decay_t<Op3>> edge3_;
148   internal::ops_partials_edge<double, std::decay_t<Op4>> edge4_;
149   internal::ops_partials_edge<double, std::decay_t<Op5>> edge5_;
150 };
151 
152 }  // namespace math
153 }  // namespace stan
154 #endif
155