1 #include "relapack.h"
2 
3 static void RELAPACK_cpotrf_rec(const char *, const blasint *, float *,
4         const blasint *, blasint *);
5 
6 
7 /** CPOTRF computes the Cholesky factorization of a complex Hermitian positive definite matrix A.
8  *
9  * This routine is functionally equivalent to LAPACK's cpotrf.
10  * For details on its interface, see
11  * http://www.netlib.org/lapack/explore-html/dd/dce/cpotrf_8f.html
12  * */
RELAPACK_cpotrf(const char * uplo,const blasint * n,float * A,const blasint * ldA,blasint * info)13 void RELAPACK_cpotrf(
14     const char *uplo, const blasint *n,
15     float *A, const blasint *ldA,
16     blasint *info
17 ) {
18 
19     // Check arguments
20     const blasint lower = LAPACK(lsame)(uplo, "L");
21     const blasint upper = LAPACK(lsame)(uplo, "U");
22     *info = 0;
23     if (!lower && !upper)
24         *info = -1;
25     else if (*n < 0)
26         *info = -2;
27     else if (*ldA < MAX(1, *n))
28         *info = -4;
29     if (*info) {
30         const blasint minfo = -*info;
31         LAPACK(xerbla)("CPOTRF", &minfo, strlen("CPOTRF"));
32         return;
33     }
34 
35     if (*n == 0) return;
36 
37     // Clean char * arguments
38     const char cleanuplo = lower ? 'L' : 'U';
39 
40     // Recursive kernel
41     RELAPACK_cpotrf_rec(&cleanuplo, n, A, ldA, info);
42 }
43 
44 
45 /** cpotrf's recursive compute kernel */
RELAPACK_cpotrf_rec(const char * uplo,const blasint * n,float * A,const blasint * ldA,blasint * info)46 static void RELAPACK_cpotrf_rec(
47     const char *uplo, const blasint *n,
48     float *A, const blasint *ldA,
49     blasint *info
50 ){
51     if (*n == 0) return;
52 
53     if (*n <= MAX(CROSSOVER_CPOTRF, 1)) {
54         // Unblocked
55         LAPACK(cpotf2)(uplo, n, A, ldA, info);
56         return;
57     }
58 
59     // Constants
60     const float ONE[]  = { 1., 0. };
61     const float MONE[] = { -1., 0. };
62 
63     // Splitting
64     const blasint n1 = CREC_SPLIT(*n);
65     const blasint n2 = *n - n1;
66 
67     // A_TL A_TR
68     // A_BL A_BR
69     float *const A_TL = A;
70     float *const A_TR = A + 2 * *ldA * n1;
71     float *const A_BL = A                 + 2 * n1;
72     float *const A_BR = A + 2 * *ldA * n1 + 2 * n1;
73 
74     // recursion(A_TL)
75     RELAPACK_cpotrf_rec(uplo, &n1, A_TL, ldA, info);
76     if (*info)
77         return;
78 
79     if (*uplo == 'L') {
80         // A_BL = A_BL / A_TL'
81         BLAS(ctrsm)("R", "L", "C", "N", &n2, &n1, ONE, A_TL, ldA, A_BL, ldA);
82         // A_BR = A_BR - A_BL * A_BL'
83         BLAS(cherk)("L", "N", &n2, &n1, MONE, A_BL, ldA, ONE, A_BR, ldA);
84     } else {
85         // A_TR = A_TL' \ A_TR
86         BLAS(ctrsm)("L", "U", "C", "N", &n1, &n2, ONE, A_TL, ldA, A_TR, ldA);
87         // A_BR = A_BR - A_TR' * A_TR
88         BLAS(cherk)("U", "C", &n2, &n1, MONE, A_TR, ldA, ONE, A_BR, ldA);
89     }
90 
91     // recursion(A_BR)
92     RELAPACK_cpotrf_rec(uplo, &n2, A_BR, ldA, info);
93     if (*info)
94         *info += n1;
95 }
96