1 /*
2 
3     Copyright (C) 2014, The University of Texas at Austin
4 
5     This file is part of libflame and is available under the 3-Clause
6     BSD license, which can be found in the LICENSE file at the top-level
7     directory, or at http://opensource.org/licenses/BSD-3-Clause
8 
9 */
10 
11 #include "FLAME.h"
12 
13 #ifdef FLA_ENABLE_NON_CRITICAL_CODE
14 
FLA_Chol_l_opt_var1(FLA_Obj A)15 FLA_Error FLA_Chol_l_opt_var1( FLA_Obj A )
16 {
17   FLA_Error    r_val = FLA_SUCCESS;
18   FLA_Datatype datatype;
19   int          mn_A;
20   int          rs_A, cs_A;
21 
22   datatype = FLA_Obj_datatype( A );
23 
24   mn_A     = FLA_Obj_length( A );
25   rs_A     = FLA_Obj_row_stride( A );
26   cs_A     = FLA_Obj_col_stride( A );
27 
28 
29   switch ( datatype )
30   {
31     case FLA_FLOAT:
32     {
33       float* buff_A = FLA_FLOAT_PTR( A );
34 
35       r_val = FLA_Chol_l_ops_var1( mn_A,
36                                    buff_A, rs_A, cs_A );
37 
38       break;
39     }
40 
41     case FLA_DOUBLE:
42     {
43       double* buff_A = FLA_DOUBLE_PTR( A );
44 
45       r_val = FLA_Chol_l_opd_var1( mn_A,
46                                    buff_A, rs_A, cs_A );
47 
48       break;
49     }
50 
51     case FLA_COMPLEX:
52     {
53       scomplex* buff_A = FLA_COMPLEX_PTR( A );
54 
55       r_val = FLA_Chol_l_opc_var1( mn_A,
56                                    buff_A, rs_A, cs_A );
57 
58       break;
59     }
60 
61     case FLA_DOUBLE_COMPLEX:
62     {
63       dcomplex* buff_A = FLA_DOUBLE_COMPLEX_PTR( A );
64 
65       r_val = FLA_Chol_l_opz_var1( mn_A,
66                                    buff_A, rs_A, cs_A );
67 
68       break;
69     }
70   }
71 
72   return r_val;
73 }
74 
75 
76 
FLA_Chol_l_ops_var1(int mn_A,float * buff_A,int rs_A,int cs_A)77 FLA_Error FLA_Chol_l_ops_var1( int mn_A,
78                                float* buff_A, int rs_A, int cs_A )
79 {
80   float*    buff_1  = FLA_FLOAT_PTR( FLA_ONE );
81   float*    buff_m1 = FLA_FLOAT_PTR( FLA_MINUS_ONE );
82   int       i;
83   FLA_Error e_val;
84 
85   for ( i = 0; i < mn_A; ++i )
86   {
87     float*    A00       = buff_A + (0  )*cs_A + (0  )*rs_A;
88     float*    a10t      = buff_A + (0  )*cs_A + (i  )*rs_A;
89     float*    alpha11   = buff_A + (i  )*cs_A + (i  )*rs_A;
90 
91     int       mn_behind = i;
92 
93     /*------------------------------------------------------------*/
94 
95     // FLA_Trsv_external( FLA_LOWER_TRIANGULAR, FLA_CONJ_NO_TRANSPOSE, FLA_NONUNIT_DIAG, A00, a10t );
96     bl1_strsv( BLIS1_LOWER_TRIANGULAR,
97                BLIS1_CONJ_NO_TRANSPOSE,
98                BLIS1_NONUNIT_DIAG,
99                mn_behind,
100                A00, rs_A, cs_A,
101                a10t, cs_A );
102 
103     // FLA_Dotcs_external( FLA_CONJUGATE, FLA_MINUS_ONE, a10t, a10t, FLA_ONE, alpha11 );
104     bl1_sdots( BLIS1_CONJUGATE,
105                mn_behind,
106                buff_m1,
107                a10t, cs_A,
108                a10t, cs_A,
109                buff_1,
110                alpha11 );
111 
112     // r_val = FLA_Sqrt( alpha11 );
113     // if ( r_val != FLA_SUCCESS )
114     //   return ( FLA_Obj_length( A00 ) + 1 );
115     bl1_ssqrte( alpha11, &e_val );
116     if ( e_val != FLA_SUCCESS ) return mn_behind;
117 
118     /*------------------------------------------------------------*/
119 
120   }
121 
122   return FLA_SUCCESS;
123 }
124 
125 
126 
FLA_Chol_l_opd_var1(int mn_A,double * buff_A,int rs_A,int cs_A)127 FLA_Error FLA_Chol_l_opd_var1( int mn_A,
128                                double* buff_A, int rs_A, int cs_A )
129 {
130   double*   buff_1  = FLA_DOUBLE_PTR( FLA_ONE );
131   double*   buff_m1 = FLA_DOUBLE_PTR( FLA_MINUS_ONE );
132   int       i;
133   FLA_Error e_val;
134 
135   for ( i = 0; i < mn_A; ++i )
136   {
137     double*   A00       = buff_A + (0  )*cs_A + (0  )*rs_A;
138     double*   a10t      = buff_A + (0  )*cs_A + (i  )*rs_A;
139     double*   alpha11   = buff_A + (i  )*cs_A + (i  )*rs_A;
140 
141     int       mn_behind = i;
142 
143     /*------------------------------------------------------------*/
144 
145     // FLA_Trsv_external( FLA_LOWER_TRIANGULAR, FLA_CONJ_NO_TRANSPOSE, FLA_NONUNIT_DIAG, A00, a10t );
146     bl1_dtrsv( BLIS1_LOWER_TRIANGULAR,
147                BLIS1_CONJ_NO_TRANSPOSE,
148                BLIS1_NONUNIT_DIAG,
149                mn_behind,
150                A00, rs_A, cs_A,
151                a10t, cs_A );
152 
153     // FLA_Dotcs_external( FLA_CONJUGATE, FLA_MINUS_ONE, a10t, a10t, FLA_ONE, alpha11 );
154     bl1_ddots( BLIS1_CONJUGATE,
155                mn_behind,
156                buff_m1,
157                a10t, cs_A,
158                a10t, cs_A,
159                buff_1,
160                alpha11 );
161 
162     // r_val = FLA_Sqrt( alpha11 );
163     // if ( r_val != FLA_SUCCESS )
164     //   return ( FLA_Obj_length( A00 ) + 1 );
165     bl1_dsqrte( alpha11, &e_val );
166     if ( e_val != FLA_SUCCESS ) return mn_behind;
167 
168     /*------------------------------------------------------------*/
169 
170   }
171 
172   return FLA_SUCCESS;
173 }
174 
175 
176 
FLA_Chol_l_opc_var1(int mn_A,scomplex * buff_A,int rs_A,int cs_A)177 FLA_Error FLA_Chol_l_opc_var1( int mn_A,
178                                scomplex* buff_A, int rs_A, int cs_A )
179 {
180   scomplex* buff_1  = FLA_COMPLEX_PTR( FLA_ONE );
181   scomplex* buff_m1 = FLA_COMPLEX_PTR( FLA_MINUS_ONE );
182   int       i;
183   FLA_Error e_val;
184 
185   for ( i = 0; i < mn_A; ++i )
186   {
187     scomplex* A00       = buff_A + (0  )*cs_A + (0  )*rs_A;
188     scomplex* a10t      = buff_A + (0  )*cs_A + (i  )*rs_A;
189     scomplex* alpha11   = buff_A + (i  )*cs_A + (i  )*rs_A;
190 
191     int       mn_behind = i;
192 
193     /*------------------------------------------------------------*/
194 
195     // FLA_Trsv_external( FLA_LOWER_TRIANGULAR, FLA_CONJ_NO_TRANSPOSE, FLA_NONUNIT_DIAG, A00, a10t );
196     bl1_ctrsv( BLIS1_LOWER_TRIANGULAR,
197                BLIS1_CONJ_NO_TRANSPOSE,
198                BLIS1_NONUNIT_DIAG,
199                mn_behind,
200                A00, rs_A, cs_A,
201                a10t, cs_A );
202 
203     // FLA_Dotcs_external( FLA_CONJUGATE, FLA_MINUS_ONE, a10t, a10t, FLA_ONE, alpha11 );
204     bl1_cdots( BLIS1_CONJUGATE,
205                mn_behind,
206                buff_m1,
207                a10t, cs_A,
208                a10t, cs_A,
209                buff_1,
210                alpha11 );
211 
212     // r_val = FLA_Sqrt( alpha11 );
213     // if ( r_val != FLA_SUCCESS )
214     //   return ( FLA_Obj_length( A00 ) + 1 );
215     bl1_csqrte( alpha11, &e_val );
216     if ( e_val != FLA_SUCCESS ) return mn_behind;
217 
218     /*------------------------------------------------------------*/
219 
220   }
221 
222   return FLA_SUCCESS;
223 }
224 
225 
226 
FLA_Chol_l_opz_var1(int mn_A,dcomplex * buff_A,int rs_A,int cs_A)227 FLA_Error FLA_Chol_l_opz_var1( int mn_A,
228                                dcomplex* buff_A, int rs_A, int cs_A )
229 {
230   dcomplex* buff_1  = FLA_DOUBLE_COMPLEX_PTR( FLA_ONE );
231   dcomplex* buff_m1 = FLA_DOUBLE_COMPLEX_PTR( FLA_MINUS_ONE );
232   int       i;
233   FLA_Error e_val;
234 
235   for ( i = 0; i < mn_A; ++i )
236   {
237     dcomplex* A00       = buff_A + (0  )*cs_A + (0  )*rs_A;
238     dcomplex* a10t      = buff_A + (0  )*cs_A + (i  )*rs_A;
239     dcomplex* alpha11   = buff_A + (i  )*cs_A + (i  )*rs_A;
240 
241     int       mn_behind = i;
242 
243     /*------------------------------------------------------------*/
244 
245     // FLA_Trsv_external( FLA_LOWER_TRIANGULAR, FLA_CONJ_NO_TRANSPOSE, FLA_NONUNIT_DIAG, A00, a10t );
246     bl1_ztrsv( BLIS1_LOWER_TRIANGULAR,
247                BLIS1_CONJ_NO_TRANSPOSE,
248                BLIS1_NONUNIT_DIAG,
249                mn_behind,
250                A00, rs_A, cs_A,
251                a10t, cs_A );
252 
253     // FLA_Dotcs_external( FLA_CONJUGATE, FLA_MINUS_ONE, a10t, a10t, FLA_ONE, alpha11 );
254     bl1_zdots( BLIS1_CONJUGATE,
255                mn_behind,
256                buff_m1,
257                a10t, cs_A,
258                a10t, cs_A,
259                buff_1,
260                alpha11 );
261 
262     // r_val = FLA_Sqrt( alpha11 );
263     // if ( r_val != FLA_SUCCESS )
264     //   return ( FLA_Obj_length( A00 ) + 1 );
265     bl1_zsqrte( alpha11, &e_val );
266     if ( e_val != FLA_SUCCESS ) return mn_behind;
267 
268     /*------------------------------------------------------------*/
269 
270   }
271 
272   return FLA_SUCCESS;
273 }
274 
275 #endif
276