1 #ifndef STAN_MATH_REV_FUN_SVD_V_HPP
2 #define STAN_MATH_REV_FUN_SVD_V_HPP
3
4 #include <stan/math/prim/fun/Eigen.hpp>
5 #include <stan/math/prim/err/check_nonzero_size.hpp>
6 #include <stan/math/prim/fun/typedefs.hpp>
7 #include <stan/math/rev/meta.hpp>
8 #include <stan/math/rev/core.hpp>
9 #include <stan/math/rev/fun/value_of.hpp>
10
11 namespace stan {
12 namespace math {
13
14 /**
15 * Given input matrix m, return matrix V where `m = UDV^{T}`
16 *
17 * Adjoint update equation comes from Equation (4) in Differentiable Programming
18 * Tensor Networks(H. Liao, J. Liu, et al., arXiv:1903.09650).
19 *
20 * @tparam EigMat type of input matrix
21 * @param m MxN input matrix
22 * @return Orthogonal matrix V
23 */
24 template <typename EigMat, require_rev_matrix_t<EigMat>* = nullptr>
svd_V(const EigMat & m)25 inline auto svd_V(const EigMat& m) {
26 using ret_type = return_var_matrix_t<Eigen::MatrixXd, EigMat>;
27 check_nonzero_size("svd_V", "m", m);
28
29 const int M = std::min(m.rows(), m.cols());
30 auto arena_m = to_arena(m);
31
32 Eigen::JacobiSVD<Eigen::MatrixXd> svd(
33 arena_m.val(), Eigen::ComputeThinU | Eigen::ComputeThinV);
34
35 auto arena_D = to_arena(svd.singularValues());
36
37 arena_t<Eigen::MatrixXd> arena_Fm(M, M);
38
39 for (int i = 0; i < M; i++) {
40 for (int j = 0; j < M; j++) {
41 if (j == i) {
42 arena_Fm(i, j) = 0.0;
43 } else {
44 arena_Fm(i, j)
45 = 1.0 / (arena_D[j] - arena_D[i]) - 1.0 / (arena_D[i] + arena_D[j]);
46 }
47 }
48 }
49
50 auto arena_U = to_arena(svd.matrixU());
51 arena_t<ret_type> arena_V = svd.matrixV();
52
53 reverse_pass_callback([arena_m, arena_U, arena_D, arena_V, arena_Fm,
54 M]() mutable {
55 Eigen::MatrixXd VTVadj = arena_V.val_op().transpose() * arena_V.adj_op();
56 arena_m.adj()
57 += 0.5 * arena_U
58 * (arena_Fm.array() * (VTVadj - VTVadj.transpose()).array())
59 .matrix()
60 * arena_V.val_op().transpose()
61 + arena_U * arena_D.asDiagonal().inverse()
62 * arena_V.adj_op().transpose()
63 * (Eigen::MatrixXd::Identity(arena_m.cols(), arena_m.cols())
64 - arena_V.val_op() * arena_V.val_op().transpose());
65 });
66
67 return ret_type(arena_V);
68 }
69
70 } // namespace math
71 } // namespace stan
72
73 #endif
74