1 #include "relapack.h"
2 #include "stdlib.h"
3 
4 static void RELAPACK_zpbtrf_rec(const char *, const blasint *, const blasint *,
5     double *, const blasint *, double *, const blasint *, blasint *);
6 
7 
8 /** ZPBTRF computes the Cholesky factorization of a complex Hermitian positive definite band matrix A.
9  *
10  * This routine is functionally equivalent to LAPACK's zpbtrf.
11  * For details on its interface, see
12  * http://www.netlib.org/lapack/explore-html/db/da9/zpbtrf_8f.html
13  * */
RELAPACK_zpbtrf(const char * uplo,const blasint * n,const blasint * kd,double * Ab,const blasint * ldAb,blasint * info)14 void RELAPACK_zpbtrf(
15     const char *uplo, const blasint *n, const blasint *kd,
16     double *Ab, const blasint *ldAb,
17     blasint *info
18 ) {
19 
20     // Check arguments
21     const blasint lower = LAPACK(lsame)(uplo, "L");
22     const blasint upper = LAPACK(lsame)(uplo, "U");
23     *info = 0;
24     if (!lower && !upper)
25         *info = -1;
26     else if (*n < 0)
27         *info = -2;
28     else if (*kd < 0)
29         *info = -3;
30     else if (*ldAb < *kd + 1)
31         *info = -5;
32     if (*info) {
33         const blasint minfo = -*info;
34         LAPACK(xerbla)("ZPBTRF", &minfo, strlen("ZPBTRF"));
35         return;
36     }
37 
38     if (*n == 0) return;
39 
40     // Clean char * arguments
41     const char cleanuplo = lower ? 'L' : 'U';
42 
43     // Constant
44     const double ZERO[] = { 0., 0. };
45 
46     // Allocate work space
47     const blasint n1 = ZREC_SPLIT(*n);
48     const blasint mWork = abs((*kd > n1) ? (lower ? *n - *kd : n1) : *kd);
49     const blasint nWork = abs((*kd > n1) ? (lower ? n1 : *n - *kd) : *kd);
50     double *Work = malloc(mWork * nWork * 2 * sizeof(double));
51 
52     LAPACK(zlaset)(uplo, &mWork, &nWork, ZERO, ZERO, Work, &mWork);
53 
54     // Recursive kernel
55     RELAPACK_zpbtrf_rec(&cleanuplo, n, kd, Ab, ldAb, Work, &mWork, info);
56 
57     // Free work space
58     free(Work);
59 }
60 
61 
62 /** zpbtrf's recursive compute kernel */
RELAPACK_zpbtrf_rec(const char * uplo,const blasint * n,const blasint * kd,double * Ab,const blasint * ldAb,double * Work,const blasint * ldWork,blasint * info)63 static void RELAPACK_zpbtrf_rec(
64     const char *uplo, const blasint *n, const blasint *kd,
65     double *Ab, const blasint *ldAb,
66     double *Work, const blasint *ldWork,
67     blasint *info
68 ){
69 
70     if (*n <= MAX(CROSSOVER_ZPBTRF, 1) || *ldAb == 1) {
71         // Unblocked
72         LAPACK(zpbtf2)(uplo, n, kd, Ab, ldAb, info);
73         return;
74     }
75 
76     // Constants
77     const double ONE[]  = { 1., 0. };
78     const double MONE[] = { -1., 0. };
79 
80     // Unskew A
81     const blasint ldA[] = { *ldAb - 1 };
82     double *const A = Ab + 2 * ((*uplo == 'L') ? 0 : *kd);
83 
84     // Splitting
85     const blasint n1 = MIN(ZREC_SPLIT(*n), *kd);
86     const blasint n2 = *n - n1;
87 
88     // * *
89     // * Ab_BR
90     double *const Ab_BR = Ab + 2 * *ldAb * n1;
91 
92     // A_TL A_TR
93     // A_BL A_BR
94     double *const A_TL = A;
95     double *const A_TR = A + 2 * *ldA * n1;
96     double *const A_BL = A                 + 2 * n1;
97     double *const A_BR = A + 2 * *ldA * n1 + 2 * n1;
98 
99     // recursion(A_TL)
100     RELAPACK_zpotrf(uplo, &n1, A_TL, ldA, info);
101     if (*info)
102         return;
103 
104     // Banded splitting
105     const blasint n21 = MIN(n2, *kd - n1);
106     const blasint n22 = MIN(n2 - n21, *kd);
107 
108     //     n1    n21    n22
109     // n1  *     A_TRl  A_TRr
110     // n21 A_BLt A_BRtl A_BRtr
111     // n22 A_BLb A_BRbl A_BRbr
112     double *const A_TRl  = A_TR;
113     double *const A_TRr  = A_TR + 2 * *ldA * n21;
114     double *const A_BLt  = A_BL;
115     double *const A_BLb  = A_BL                   + 2 * n21;
116     double *const A_BRtl = A_BR;
117     double *const A_BRtr = A_BR + 2 * *ldA * n21;
118     double *const A_BRbl = A_BR                   + 2 * n21;
119     double *const A_BRbr = A_BR + 2 * *ldA * n21  + 2 * n21;
120 
121     if (*uplo == 'L') {
122         // A_BLt = ABLt / A_TL'
123         BLAS(ztrsm)("R", "L", "C", "N", &n21, &n1, ONE, A_TL, ldA, A_BLt, ldA);
124         // A_BRtl = A_BRtl - A_BLt * A_BLt'
125         BLAS(zherk)("L", "N", &n21, &n1, MONE, A_BLt, ldA, ONE, A_BRtl, ldA);
126         // Work = A_BLb
127         LAPACK(zlacpy)("U", &n22, &n1, A_BLb, ldA, Work, ldWork);
128         // Work = Work / A_TL'
129         BLAS(ztrsm)("R", "L", "C", "N", &n22, &n1, ONE, A_TL, ldA, Work, ldWork);
130         // A_BRbl = A_BRbl - Work * A_BLt'
131         BLAS(zgemm)("N", "C", &n22, &n21, &n1, MONE, Work, ldWork, A_BLt, ldA, ONE, A_BRbl, ldA);
132         // A_BRbr = A_BRbr - Work * Work'
133         BLAS(zherk)("L", "N", &n22, &n1, MONE, Work, ldWork, ONE, A_BRbr, ldA);
134         // A_BLb = Work
135         LAPACK(zlacpy)("U", &n22, &n1, Work, ldWork, A_BLb, ldA);
136     } else {
137         // A_TRl = A_TL' \ A_TRl
138         BLAS(ztrsm)("L", "U", "C", "N", &n1, &n21, ONE, A_TL, ldA, A_TRl, ldA);
139         // A_BRtl = A_BRtl - A_TRl' * A_TRl
140         BLAS(zherk)("U", "C", &n21, &n1, MONE, A_TRl, ldA, ONE, A_BRtl, ldA);
141         // Work = A_TRr
142         LAPACK(zlacpy)("L", &n1, &n22, A_TRr, ldA, Work, ldWork);
143         // Work = A_TL' \ Work
144         BLAS(ztrsm)("L", "U", "C", "N", &n1, &n22, ONE, A_TL, ldA, Work, ldWork);
145         // A_BRtr = A_BRtr - A_TRl' * Work
146         BLAS(zgemm)("C", "N", &n21, &n22, &n1, MONE, A_TRl, ldA, Work, ldWork, ONE, A_BRtr, ldA);
147         // A_BRbr = A_BRbr - Work' * Work
148         BLAS(zherk)("U", "C", &n22, &n1, MONE, Work, ldWork, ONE, A_BRbr, ldA);
149         // A_TRr = Work
150         LAPACK(zlacpy)("L", &n1, &n22, Work, ldWork, A_TRr, ldA);
151     }
152 
153     // recursion(A_BR)
154     if (*kd > n1 && ldA != 0)
155         RELAPACK_zpotrf(uplo, &n2, A_BR, ldA, info);
156     else
157         RELAPACK_zpbtrf_rec(uplo, &n2, kd, Ab_BR, ldAb, Work, ldWork, info);
158     if (*info)
159         *info += n1;
160 }
161