1 #include "relapack.h"
2 
3 static void RELAPACK_strtri_rec(const char *, const char *, const blasint *,
4     float *, const blasint *, blasint *);
5 
6 
7 /** CTRTRI computes the inverse of a real upper or lower triangular matrix A.
8  *
9  * This routine is functionally equivalent to LAPACK's strtri.
10  * For details on its interface, see
11  * http://www.netlib.org/lapack/explore-html/de/d76/strtri_8f.html
12  * */
RELAPACK_strtri(const char * uplo,const char * diag,const blasint * n,float * A,const blasint * ldA,blasint * info)13 void RELAPACK_strtri(
14     const char *uplo, const char *diag, const blasint *n,
15     float *A, const blasint *ldA,
16     blasint *info
17 ) {
18 
19     // Check arguments
20     const blasint lower = LAPACK(lsame)(uplo, "L");
21     const blasint upper = LAPACK(lsame)(uplo, "U");
22     const blasint nounit = LAPACK(lsame)(diag, "N");
23     const blasint unit = LAPACK(lsame)(diag, "U");
24     *info = 0;
25     if (!lower && !upper)
26         *info = -1;
27     else if (!nounit && !unit)
28         *info = -2;
29     else if (*n < 0)
30         *info = -3;
31     else if (*ldA < MAX(1, *n))
32         *info = -5;
33     if (*info) {
34         const blasint minfo = -*info;
35         LAPACK(xerbla)("STRTRI", &minfo, strlen("STRTRI"));
36         return;
37     }
38 
39     // Clean char * arguments
40     const char cleanuplo = lower  ? 'L' : 'U';
41     const char cleandiag = nounit ? 'N' : 'U';
42 
43     // check for singularity
44     if (nounit) {
45         blasint i;
46         for (i = 0; i < *n; i++)
47             if (A[i + *ldA * i] == 0) {
48                 *info = i;
49                 return;
50             }
51     }
52 
53     // Recursive kernel
54     RELAPACK_strtri_rec(&cleanuplo, &cleandiag, n, A, ldA, info);
55 }
56 
57 
58 /** strtri's recursive compute kernel */
RELAPACK_strtri_rec(const char * uplo,const char * diag,const blasint * n,float * A,const blasint * ldA,blasint * info)59 static void RELAPACK_strtri_rec(
60     const char *uplo, const char *diag, const blasint *n,
61     float *A, const blasint *ldA,
62     blasint *info
63 ){
64 
65     if (*n <= MAX(CROSSOVER_STRTRI, 1)) {
66         // Unblocked
67         LAPACK(strti2)(uplo, diag, n, A, ldA, info);
68         return;
69     }
70 
71     // Constants
72     const float ONE[]  = { 1. };
73     const float MONE[] = { -1. };
74 
75     // Splitting
76     const blasint n1 = SREC_SPLIT(*n);
77     const blasint n2 = *n - n1;
78 
79     // A_TL A_TR
80     // A_BL A_BR
81     float *const A_TL = A;
82     float *const A_TR = A + *ldA * n1;
83     float *const A_BL = A             + n1;
84     float *const A_BR = A + *ldA * n1 + n1;
85 
86     // recursion(A_TL)
87     RELAPACK_strtri_rec(uplo, diag, &n1, A_TL, ldA, info);
88     if (*info)
89         return;
90 
91     if (*uplo == 'L') {
92         // A_BL = - A_BL * A_TL
93         BLAS(strmm)("R", "L", "N", diag, &n2, &n1, MONE, A_TL, ldA, A_BL, ldA);
94         // A_BL = A_BR \ A_BL
95         BLAS(strsm)("L", "L", "N", diag, &n2, &n1, ONE, A_BR, ldA, A_BL, ldA);
96     } else {
97         // A_TR = - A_TL * A_TR
98         BLAS(strmm)("L", "U", "N", diag, &n1, &n2, MONE, A_TL, ldA, A_TR, ldA);
99         // A_TR = A_TR / A_BR
100         BLAS(strsm)("R", "U", "N", diag, &n1, &n2, ONE, A_BR, ldA, A_TR, ldA);
101     }
102 
103     // recursion(A_BR)
104     RELAPACK_strtri_rec(uplo, diag, &n2, A_BR, ldA, info);
105     if (*info)
106         *info += n1;
107 }
108