1 #ifndef STAN_MATH_REV_FUN_APPEND_COL_HPP
2 #define STAN_MATH_REV_FUN_APPEND_COL_HPP
3 
4 #include <stan/math/prim/fun/Eigen.hpp>
5 #include <stan/math/prim/fun/append_col.hpp>
6 #include <stan/math/rev/core.hpp>
7 #include <stan/math/prim/meta.hpp>
8 #include <stan/math/prim/err.hpp>
9 #include <vector>
10 
11 namespace stan {
12 namespace math {
13 
14 /**
15  * Return the result of appending the second argument matrix after the
16  * first argument matrix, that is, putting them side by side, with
17  * the first matrix followed by the second matrix.
18  *
19  * Given input types result in following outputs:
20  * (matrix, matrix) -> matrix,
21  * (matrix, vector) -> matrix,
22  * (vector, matrix) -> matrix,
23  * (vector, vector) -> matrix,
24  * (row vector, row vector) -> row_vector.
25  *
26  * @tparam T1 A `var_value` with inner matrix type.
27  * @tparam T1 A `var_value` with inner matrix type.
28  *
29  * @param A First matrix.
30  * @param B Second matrix.
31  * @return Result of appending the first matrix followed by the
32  * second matrix side by side.
33  */
34 template <typename T1, typename T2, require_any_var_matrix_t<T1, T2>* = nullptr>
append_col(const T1 & A,const T2 & B)35 inline auto append_col(const T1& A, const T2& B) {
36   check_size_match("append_col", "columns of A", A.rows(), "columns of B",
37                    B.rows());
38   if (!is_constant<T1>::value && !is_constant<T2>::value) {
39     arena_t<promote_scalar_t<var, T1>> arena_A = A;
40     arena_t<promote_scalar_t<var, T2>> arena_B = B;
41     return make_callback_var(
42         append_col(value_of(arena_A), value_of(arena_B)),
43         [arena_A, arena_B](auto& vi) mutable {
44           arena_A.adj() += vi.adj().leftCols(arena_A.cols());
45           arena_B.adj() += vi.adj().rightCols(arena_B.cols());
46         });
47   } else if (!is_constant<T1>::value) {
48     arena_t<promote_scalar_t<var, T1>> arena_A = A;
49     return make_callback_var(append_col(value_of(arena_A), value_of(B)),
50                              [arena_A](auto& vi) mutable {
51                                arena_A.adj()
52                                    += vi.adj().leftCols(arena_A.cols());
53                              });
54   } else {
55     arena_t<promote_scalar_t<var, T2>> arena_B = B;
56     return make_callback_var(append_col(value_of(A), value_of(arena_B)),
57                              [arena_B](auto& vi) mutable {
58                                arena_B.adj()
59                                    += vi.adj().rightCols(arena_B.cols());
60                              });
61   }
62 }
63 
64 /**
65  * Return the result of stacking an scalar on top of the
66  * a row vector, with the result being a row vector.
67  *
68  * This function applies to (scalar, row vector) and returns a
69  * row vector.
70  *
71  * @tparam Scal type of the scalar
72  * @tparam RowVec A `var_value` with an inner type of row vector.
73  *
74  * @param A scalar.
75  * @param B row vector.
76  * @return Result of stacking the scalar on top of the row vector.
77  */
78 template <typename Scal, typename RowVec,
79           require_stan_scalar_t<Scal>* = nullptr,
80           require_t<is_eigen_row_vector<RowVec>>* = nullptr>
append_col(const Scal & A,const var_value<RowVec> & B)81 inline auto append_col(const Scal& A, const var_value<RowVec>& B) {
82   if (!is_constant<Scal>::value && !is_constant<RowVec>::value) {
83     var arena_A = A;
84     arena_t<promote_scalar_t<var, RowVec>> arena_B = B;
85     return make_callback_var(append_col(value_of(arena_A), value_of(arena_B)),
86                              [arena_A, arena_B](auto& vi) mutable {
87                                arena_A.adj() += vi.adj().coeff(0);
88                                arena_B.adj() += vi.adj().tail(arena_B.size());
89                              });
90   } else if (!is_constant<Scal>::value) {
91     var arena_A = A;
92     return make_callback_var(
93         append_col(value_of(arena_A), value_of(B)),
94         [arena_A](auto& vi) mutable { arena_A.adj() += vi.adj().coeff(0); });
95   } else {
96     arena_t<promote_scalar_t<var, RowVec>> arena_B = B;
97     return make_callback_var(append_col(value_of(A), value_of(arena_B)),
98                              [arena_B](auto& vi) mutable {
99                                arena_B.adj() += vi.adj().tail(arena_B.size());
100                              });
101   }
102 }
103 
104 /**
105  * Return the result of stacking a row vector on top of the
106  * an scalar, with the result being a row vector.
107  *
108  * This function applies to (row vector, scalar) and returns a
109  * row vector.
110  *
111  * @tparam RowVec A `var_value` with an inner type of row vector.
112  * @tparam Scal type of the scalar
113  *
114  * @param A row vector.
115  * @param B scalar.
116  * @return Result of stacking the row vector on top of the scalar.
117  */
118 template <typename RowVec, typename Scal,
119           require_t<is_eigen_row_vector<RowVec>>* = nullptr,
120           require_stan_scalar_t<Scal>* = nullptr>
append_col(const var_value<RowVec> & A,const Scal & B)121 inline auto append_col(const var_value<RowVec>& A, const Scal& B) {
122   if (!is_constant<RowVec>::value && !is_constant<Scal>::value) {
123     arena_t<promote_scalar_t<var, RowVec>> arena_A = A;
124     var arena_B = B;
125     return make_callback_var(append_col(value_of(arena_A), value_of(arena_B)),
126                              [arena_A, arena_B](auto& vi) mutable {
127                                arena_A.adj() += vi.adj().head(arena_A.size());
128                                arena_B.adj()
129                                    += vi.adj().coeff(vi.adj().size() - 1);
130                              });
131   } else if (!is_constant<RowVec>::value) {
132     arena_t<promote_scalar_t<var, RowVec>> arena_A = A;
133     return make_callback_var(append_col(value_of(arena_A), value_of(B)),
134                              [arena_A](auto& vi) mutable {
135                                arena_A.adj() += vi.adj().head(arena_A.size());
136                              });
137   } else {
138     var arena_B = B;
139     return make_callback_var(append_col(value_of(A), value_of(arena_B)),
140                              [arena_B](auto& vi) mutable {
141                                arena_B.adj()
142                                    += vi.adj().coeff(vi.adj().size() - 1);
143                              });
144   }
145 }
146 
147 }  // namespace math
148 }  // namespace stan
149 
150 #endif
151