1 #include "relapack.h"
2 
3 static void RELAPACK_dpotrf_rec(const char *, const blasint *, double *,
4         const blasint *, blasint *);
5 
6 
7 /** DPOTRF computes the Cholesky factorization of a real symmetric positive definite matrix A.
8  *
9  * This routine is functionally equivalent to LAPACK's dpotrf.
10  * For details on its interface, see
11  * http://www.netlib.org/lapack/explore-html/d0/d8a/dpotrf_8f.html
12  * */
RELAPACK_dpotrf(const char * uplo,const blasint * n,double * A,const blasint * ldA,blasint * info)13 void RELAPACK_dpotrf(
14     const char *uplo, const blasint *n,
15     double *A, const blasint *ldA,
16     blasint *info
17 ) {
18 
19     // Check arguments
20     const blasint lower = LAPACK(lsame)(uplo, "L");
21     const blasint upper = LAPACK(lsame)(uplo, "U");
22     *info = 0;
23     if (!lower && !upper)
24         *info = -1;
25     else if (*n < 0)
26         *info = -2;
27     else if (*ldA < MAX(1, *n))
28         *info = -4;
29     if (*info) {
30         const blasint minfo = -*info;
31         LAPACK(xerbla)("DPOTRF", &minfo, strlen("DPOTRF"));
32         return;
33     }
34 
35     // Clean char * arguments
36     const char cleanuplo = lower ? 'L' : 'U';
37 
38     // Recursive kernel
39     RELAPACK_dpotrf_rec(&cleanuplo, n, A, ldA, info);
40 }
41 
42 
43 /** dpotrf's recursive compute kernel */
RELAPACK_dpotrf_rec(const char * uplo,const blasint * n,double * A,const blasint * ldA,blasint * info)44 static void RELAPACK_dpotrf_rec(
45     const char *uplo, const blasint *n,
46     double *A, const blasint *ldA,
47     blasint *info
48 ){
49 
50     if (*n <= MAX(CROSSOVER_DPOTRF, 1)) {
51         // Unblocked
52         LAPACK(dpotf2)(uplo, n, A, ldA, info);
53         return;
54     }
55 
56     // Constants
57     const double ONE[]  = { 1. };
58     const double MONE[] = { -1. };
59 
60     // Splitting
61     const blasint n1 = DREC_SPLIT(*n);
62     const blasint n2 = *n - n1;
63 
64     // A_TL A_TR
65     // A_BL A_BR
66     double *const A_TL = A;
67     double *const A_TR = A + *ldA * n1;
68     double *const A_BL = A             + n1;
69     double *const A_BR = A + *ldA * n1 + n1;
70 
71     // recursion(A_TL)
72     RELAPACK_dpotrf_rec(uplo, &n1, A_TL, ldA, info);
73     if (*info)
74         return;
75 
76     if (*uplo == 'L') {
77         // A_BL = A_BL / A_TL'
78         BLAS(dtrsm)("R", "L", "T", "N", &n2, &n1, ONE, A_TL, ldA, A_BL, ldA);
79         // A_BR = A_BR - A_BL * A_BL'
80         BLAS(dsyrk)("L", "N", &n2, &n1, MONE, A_BL, ldA, ONE, A_BR, ldA);
81     } else {
82         // A_TR = A_TL' \ A_TR
83         BLAS(dtrsm)("L", "U", "T", "N", &n1, &n2, ONE, A_TL, ldA, A_TR, ldA);
84         // A_BR = A_BR - A_TR' * A_TR
85         BLAS(dsyrk)("U", "T", &n2, &n1, MONE, A_TR, ldA, ONE, A_BR, ldA);
86     }
87 
88     // recursion(A_BR)
89     RELAPACK_dpotrf_rec(uplo, &n2, A_BR, ldA, info);
90     if (*info)
91         *info += n1;
92 }
93