1 #ifndef STAN_MATH_REV_FUN_SVD_U_HPP
2 #define STAN_MATH_REV_FUN_SVD_U_HPP
3
4 #include <stan/math/prim/fun/Eigen.hpp>
5 #include <stan/math/prim/err/check_nonzero_size.hpp>
6 #include <stan/math/prim/fun/typedefs.hpp>
7 #include <stan/math/rev/meta.hpp>
8 #include <stan/math/rev/core.hpp>
9 #include <stan/math/rev/fun/value_of.hpp>
10
11 namespace stan {
12 namespace math {
13
14 /**
15 * Given input matrix m, return matrix U where `m = UDV^{T}`
16 *
17 * Adjoint update equation comes from Equation (4) in Differentiable Programming
18 * Tensor Networks(H. Liao, J. Liu, et al., arXiv:1903.09650).
19 *
20 * @tparam EigMat type of input matrix
21 * @param m MxN input matrix
22 * @return Orthogonal matrix U
23 */
24 template <typename EigMat, require_rev_matrix_t<EigMat>* = nullptr>
svd_U(const EigMat & m)25 inline auto svd_U(const EigMat& m) {
26 using ret_type = return_var_matrix_t<Eigen::MatrixXd, EigMat>;
27 check_nonzero_size("svd_U", "m", m);
28
29 const int M = std::min(m.rows(), m.cols());
30 auto arena_m = to_arena(m);
31
32 Eigen::JacobiSVD<Eigen::MatrixXd> svd(
33 arena_m.val(), Eigen::ComputeThinU | Eigen::ComputeThinV);
34
35 auto arena_D = to_arena(svd.singularValues());
36
37 arena_t<Eigen::MatrixXd> arena_Fp(M, M);
38
39 for (int i = 0; i < M; i++) {
40 for (int j = 0; j < M; j++) {
41 if (j == i) {
42 arena_Fp(i, j) = 0.0;
43 } else {
44 arena_Fp(i, j)
45 = 1.0 / (arena_D[j] - arena_D[i]) + 1.0 / (arena_D[i] + arena_D[j]);
46 }
47 }
48 }
49
50 arena_t<ret_type> arena_U = svd.matrixU();
51 auto arena_V = to_arena(svd.matrixV());
52
53 reverse_pass_callback([arena_m, arena_U, arena_D, arena_V, arena_Fp,
54 M]() mutable {
55 Eigen::MatrixXd UUadjT = arena_U.val_op().transpose() * arena_U.adj_op();
56 arena_m.adj()
57 += .5 * arena_U.val_op()
58 * (arena_Fp.array() * (UUadjT - UUadjT.transpose()).array())
59 .matrix()
60 * arena_V.transpose()
61 + (Eigen::MatrixXd::Identity(arena_m.rows(), arena_m.rows())
62 - arena_U.val_op() * arena_U.val_op().transpose())
63 * arena_U.adj_op() * arena_D.asDiagonal().inverse()
64 * arena_V.transpose();
65 });
66
67 return ret_type(arena_U);
68 }
69
70 } // namespace math
71 } // namespace stan
72
73 #endif
74