1 #ifndef STAN_MATH_REV_FUN_SVD_U_HPP
2 #define STAN_MATH_REV_FUN_SVD_U_HPP
3 
4 #include <stan/math/prim/fun/Eigen.hpp>
5 #include <stan/math/prim/err/check_nonzero_size.hpp>
6 #include <stan/math/prim/fun/typedefs.hpp>
7 #include <stan/math/rev/meta.hpp>
8 #include <stan/math/rev/core.hpp>
9 #include <stan/math/rev/fun/value_of.hpp>
10 
11 namespace stan {
12 namespace math {
13 
14 /**
15  * Given input matrix m, return matrix U where `m = UDV^{T}`
16  *
17  * Adjoint update equation comes from Equation (4) in Differentiable Programming
18  * Tensor Networks(H. Liao, J. Liu, et al., arXiv:1903.09650).
19  *
20  * @tparam EigMat type of input matrix
21  * @param m MxN input matrix
22  * @return Orthogonal matrix U
23  */
24 template <typename EigMat, require_rev_matrix_t<EigMat>* = nullptr>
svd_U(const EigMat & m)25 inline auto svd_U(const EigMat& m) {
26   using ret_type = return_var_matrix_t<Eigen::MatrixXd, EigMat>;
27   check_nonzero_size("svd_U", "m", m);
28 
29   const int M = std::min(m.rows(), m.cols());
30   auto arena_m = to_arena(m);
31 
32   Eigen::JacobiSVD<Eigen::MatrixXd> svd(
33       arena_m.val(), Eigen::ComputeThinU | Eigen::ComputeThinV);
34 
35   auto arena_D = to_arena(svd.singularValues());
36 
37   arena_t<Eigen::MatrixXd> arena_Fp(M, M);
38 
39   for (int i = 0; i < M; i++) {
40     for (int j = 0; j < M; j++) {
41       if (j == i) {
42         arena_Fp(i, j) = 0.0;
43       } else {
44         arena_Fp(i, j)
45             = 1.0 / (arena_D[j] - arena_D[i]) + 1.0 / (arena_D[i] + arena_D[j]);
46       }
47     }
48   }
49 
50   arena_t<ret_type> arena_U = svd.matrixU();
51   auto arena_V = to_arena(svd.matrixV());
52 
53   reverse_pass_callback([arena_m, arena_U, arena_D, arena_V, arena_Fp,
54                          M]() mutable {
55     Eigen::MatrixXd UUadjT = arena_U.val_op().transpose() * arena_U.adj_op();
56     arena_m.adj()
57         += .5 * arena_U.val_op()
58                * (arena_Fp.array() * (UUadjT - UUadjT.transpose()).array())
59                      .matrix()
60                * arena_V.transpose()
61            + (Eigen::MatrixXd::Identity(arena_m.rows(), arena_m.rows())
62               - arena_U.val_op() * arena_U.val_op().transpose())
63                  * arena_U.adj_op() * arena_D.asDiagonal().inverse()
64                  * arena_V.transpose();
65   });
66 
67   return ret_type(arena_U);
68 }
69 
70 }  // namespace math
71 }  // namespace stan
72 
73 #endif
74