1 #ifndef STAN_MATH_PRIM_META_OPERANDS_AND_PARTIALS_HPP 2 #define STAN_MATH_PRIM_META_OPERANDS_AND_PARTIALS_HPP 3 4 #include <stan/math/prim/fun/Eigen.hpp> 5 #include <stan/math/prim/meta/require_generics.hpp> 6 #include <stan/math/prim/meta/return_type.hpp> 7 #include <stan/math/prim/functor/broadcast_array.hpp> 8 #include <vector> 9 #include <type_traits> 10 #include <tuple> 11 12 namespace stan { 13 namespace math { 14 template <typename Op1 = double, typename Op2 = double, typename Op3 = double, 15 typename Op4 = double, typename Op5 = double, 16 typename T_return_type = return_type_t<Op1, Op2, Op3, Op4, Op5>> 17 class operands_and_partials; // Forward declaration 18 19 namespace internal { 20 template <typename ViewElt, typename Op, typename = void> 21 struct ops_partials_edge; 22 /** 23 * Class representing an edge with an inner type of double. This class 24 * should never be used by the program and only exists so that 25 * developer can write functions using `operands_and_partials` that works for 26 * double, vars, and fvar types. 27 * @tparam ViewElt One of `double`, `var`, `fvar`. 28 * @tparam Op The type of the input operand. It's scalar type 29 * for this specialization must be an `Arithmetic` 30 */ 31 template <typename ViewElt, typename Op> 32 struct ops_partials_edge<ViewElt, Op, require_st_arithmetic<Op>> { 33 using inner_op = std::conditional_t<is_eigen<value_type_t<Op>>::value, 34 value_type_t<Op>, Op>; 35 using partials_t = empty_broadcast_array<ViewElt, inner_op>; 36 /** 37 * The `partials_` are always called in `if` statements that will be 38 * removed by the dead code elimination pass of the compiler. So if we ever 39 * move up to C++17 these can be made into `constexpr if` and 40 * this can be deleted. 41 */ 42 partials_t partials_; 43 empty_broadcast_array<partials_t, inner_op> partials_vec_; 44 static constexpr double operands_{0}; ops_partials_edgestan::math::internal::ops_partials_edge45 ops_partials_edge() {} 46 template <typename T> ops_partials_edgestan::math::internal::ops_partials_edge47 explicit ops_partials_edge(T&& /* op */) noexcept {} 48 49 /** 50 * Get the operand for the edge. For doubles this is a compile time 51 * expression returning zero. 52 */ operandstan::math::internal::ops_partials_edge53 static constexpr double operand() noexcept { return 0.0; } 54 55 /** 56 * Get the partial for the edge. For doubles this is a compile time 57 * expression returning zero. 58 */ partialstan::math::internal::ops_partials_edge59 static constexpr double partial() noexcept { return 0.0; } 60 /** 61 * Return the tangent for the edge. For doubles this is a compile time 62 * expression returning zero. 63 */ dxstan::math::internal::ops_partials_edge64 static constexpr double dx() noexcept { return 0.0; } 65 /** 66 * Return the size of the operand for the edge. For doubles this is a compile 67 * time expression returning zero. 68 */ sizestan::math::internal::ops_partials_edge69 static constexpr int size() noexcept { return 0; } // reverse mode 70 71 private: 72 template <typename, typename, typename, typename, typename, typename> 73 friend class stan::math::operands_and_partials; 74 }; 75 template <typename ViewElt, typename Op> 76 constexpr double 77 ops_partials_edge<ViewElt, Op, require_st_arithmetic<Op>>::operands_; 78 } // namespace internal 79 80 /** \ingroup type_trait 81 * \callergraph 82 * This template builds partial derivatives with respect to a 83 * set of 84 * operands. There are two reason for the generality of this 85 * class. The first is to handle vector and scalar arguments 86 * without needing to write additional code. The second is to use 87 * this class for writing probability distributions that handle 88 * primitives, reverse mode, and forward mode variables 89 * seamlessly. 90 * 91 * Conceptually, this class is used when we want to manually calculate 92 * the derivative of a function and store this manual result on the 93 * autodiff stack in a sort of "compressed" form. Think of it like an 94 * easy-to-use interface to rev/core/precomputed_gradients. 95 * 96 * This class supports nested container ("multivariate") use-cases 97 * as well by exposing a partials_vec_ member on edges of the 98 * appropriate type. 99 * 100 * This base template is instantiated when all operands are 101 * primitives and we don't want to calculate derivatives at 102 * all. So all Op1 - Op5 must be arithmetic primitives 103 * like int or double. This is controlled with the 104 * T_return_type type parameter. 105 * 106 * @tparam Op1 type of the first operand 107 * @tparam Op2 type of the second operand 108 * @tparam Op3 type of the third operand 109 * @tparam Op4 type of the fourth operand 110 * @tparam Op5 type of the fifth operand 111 * @tparam T_return_type return type of the expression. This defaults 112 * to calling a template metaprogram that calculates the scalar 113 * promotion of Op1..Op4 114 */ 115 template <typename Op1, typename Op2, typename Op3, typename Op4, typename Op5, 116 typename T_return_type> 117 class operands_and_partials { 118 public: operands_and_partials(const Op1 &)119 explicit operands_and_partials(const Op1& /* op1 */) noexcept {} operands_and_partials(const Op1 &,const Op2 &)120 operands_and_partials(const Op1& /* op1 */, const Op2& /* op2 */) noexcept {} operands_and_partials(const Op1 &,const Op2 &,const Op3 &)121 operands_and_partials(const Op1& /* op1 */, const Op2& /* op2 */, 122 const Op3& /* op3 */) noexcept {} operands_and_partials(const Op1 &,const Op2 &,const Op3 &,const Op4 &)123 operands_and_partials(const Op1& /* op1 */, const Op2& /* op2 */, 124 const Op3& /* op3 */, const Op4& /* op4 */) noexcept {} operands_and_partials(const Op1 &,const Op2 &,const Op3 &,const Op4 &,const Op5 &)125 operands_and_partials(const Op1& /* op1 */, const Op2& /* op2 */, 126 const Op3& /* op3 */, const Op4& /* op4 */, 127 const Op5& /* op5 */) noexcept {} 128 129 /** \ingroup type_trait 130 * Build the node to be stored on the autodiff graph. 131 * This should contain both the value and the tangent. 132 * 133 * For scalars (this implementation), we don't calculate any derivatives. 134 * For reverse mode, we end up returning a type of var that will calculate 135 * the appropriate adjoint using the stored operands and partials. 136 * Forward mode just calculates the tangent on the spot and returns it in 137 * a vanilla fvar. 138 * 139 * @param value the return value of the function we are compressing 140 * @return the value with its derivative 141 */ build(double value) const142 inline double build(double value) const noexcept { return value; } 143 144 // These will always be 0 size base template instantiations (above). 145 internal::ops_partials_edge<double, std::decay_t<Op1>> edge1_; 146 internal::ops_partials_edge<double, std::decay_t<Op2>> edge2_; 147 internal::ops_partials_edge<double, std::decay_t<Op3>> edge3_; 148 internal::ops_partials_edge<double, std::decay_t<Op4>> edge4_; 149 internal::ops_partials_edge<double, std::decay_t<Op5>> edge5_; 150 }; 151 152 } // namespace math 153 } // namespace stan 154 #endif 155