1 #include "relapack.h"
2
3 static void RELAPACK_spotrf_rec(const char *, const blasint *, float *,
4 const blasint *, blasint *);
5
6
7 /** SPOTRF computes the Cholesky factorization of a real symmetric positive definite matrix A.
8 *
9 * This routine is functionally equivalent to LAPACK's spotrf.
10 * For details on its interface, see
11 * http://www.netlib.org/lapack/explore-html/d0/da2/spotrf_8f.html
12 * */
RELAPACK_spotrf(const char * uplo,const blasint * n,float * A,const blasint * ldA,blasint * info)13 void RELAPACK_spotrf(
14 const char *uplo, const blasint *n,
15 float *A, const blasint *ldA,
16 blasint *info
17 ) {
18
19 // Check arguments
20 const blasint lower = LAPACK(lsame)(uplo, "L");
21 const blasint upper = LAPACK(lsame)(uplo, "U");
22 *info = 0;
23 if (!lower && !upper)
24 *info = -1;
25 else if (*n < 0)
26 *info = -2;
27 else if (*ldA < MAX(1, *n))
28 *info = -4;
29 if (*info) {
30 const blasint minfo = -*info;
31 LAPACK(xerbla)("SPOTRF", &minfo, strlen("SPOTRF"));
32 return;
33 }
34
35 // Clean char * arguments
36 const char cleanuplo = lower ? 'L' : 'U';
37
38 // Recursive kernel
39 RELAPACK_spotrf_rec(&cleanuplo, n, A, ldA, info);
40 }
41
42
43 /** spotrf's recursive compute kernel */
RELAPACK_spotrf_rec(const char * uplo,const blasint * n,float * A,const blasint * ldA,blasint * info)44 static void RELAPACK_spotrf_rec(
45 const char *uplo, const blasint *n,
46 float *A, const blasint *ldA,
47 blasint *info
48 ) {
49
50 if (*n <= MAX(CROSSOVER_SPOTRF, 1)) {
51 // Unblocked
52 LAPACK(spotf2)(uplo, n, A, ldA, info);
53 return;
54 }
55
56 // Constants
57 const float ONE[] = { 1. };
58 const float MONE[] = { -1. };
59
60 // Splitting
61 const blasint n1 = SREC_SPLIT(*n);
62 const blasint n2 = *n - n1;
63
64 // A_TL A_TR
65 // A_BL A_BR
66 float *const A_TL = A;
67 float *const A_TR = A + *ldA * n1;
68 float *const A_BL = A + n1;
69 float *const A_BR = A + *ldA * n1 + n1;
70
71 // recursion(A_TL)
72 RELAPACK_spotrf_rec(uplo, &n1, A_TL, ldA, info);
73 if (*info)
74 return;
75
76 if (*uplo == 'L') {
77 // A_BL = A_BL / A_TL'
78 BLAS(strsm)("R", "L", "T", "N", &n2, &n1, ONE, A_TL, ldA, A_BL, ldA);
79 // A_BR = A_BR - A_BL * A_BL'
80 BLAS(ssyrk)("L", "N", &n2, &n1, MONE, A_BL, ldA, ONE, A_BR, ldA);
81 } else {
82 // A_TR = A_TL' \ A_TR
83 BLAS(strsm)("L", "U", "T", "N", &n1, &n2, ONE, A_TL, ldA, A_TR, ldA);
84 // A_BR = A_BR - A_TR' * A_TR
85 BLAS(ssyrk)("U", "T", &n2, &n1, MONE, A_TR, ldA, ONE, A_BR, ldA);
86 }
87
88 // recursion(A_BR)
89 RELAPACK_spotrf_rec(uplo, &n2, A_BR, ldA, info);
90 if (*info)
91 *info += n1;
92 }
93