1 /*
2 
3     Copyright (C) 2014, The University of Texas at Austin
4 
5     This file is part of libflame and is available under the 3-Clause
6     BSD license, which can be found in the LICENSE file at the top-level
7     directory, or at http://opensource.org/licenses/BSD-3-Clause
8 
9 */
10 
11 #include "FLAME.h"
12 
13 #ifdef FLA_ENABLE_NON_CRITICAL_CODE
14 
FLA_LU_nopiv_opt_var2(FLA_Obj A)15 FLA_Error FLA_LU_nopiv_opt_var2( FLA_Obj A )
16 {
17   FLA_Datatype datatype;
18   int          m_A, n_A;
19   int          rs_A, cs_A;
20 
21   datatype = FLA_Obj_datatype( A );
22 
23   m_A      = FLA_Obj_length( A );
24   n_A      = FLA_Obj_width( A );
25   rs_A     = FLA_Obj_row_stride( A );
26   cs_A     = FLA_Obj_col_stride( A );
27 
28 
29   switch ( datatype )
30   {
31     case FLA_FLOAT:
32     {
33       float* buff_A = FLA_FLOAT_PTR( A );
34 
35       FLA_LU_nopiv_ops_var2( m_A,
36                              n_A,
37                              buff_A, rs_A, cs_A );
38 
39       break;
40     }
41 
42     case FLA_DOUBLE:
43     {
44       double* buff_A = FLA_DOUBLE_PTR( A );
45 
46       FLA_LU_nopiv_opd_var2( m_A,
47                              n_A,
48                              buff_A, rs_A, cs_A );
49 
50       break;
51     }
52 
53     case FLA_COMPLEX:
54     {
55       scomplex* buff_A = FLA_COMPLEX_PTR( A );
56 
57       FLA_LU_nopiv_opc_var2( m_A,
58                              n_A,
59                              buff_A, rs_A, cs_A );
60 
61       break;
62     }
63 
64     case FLA_DOUBLE_COMPLEX:
65     {
66       dcomplex* buff_A = FLA_DOUBLE_COMPLEX_PTR( A );
67 
68       FLA_LU_nopiv_opz_var2( m_A,
69                              n_A,
70                              buff_A, rs_A, cs_A );
71 
72       break;
73     }
74   }
75 
76   return FLA_SUCCESS;
77 }
78 
79 
80 
FLA_LU_nopiv_ops_var2(int m_A,int n_A,float * buff_A,int rs_A,int cs_A)81 FLA_Error FLA_LU_nopiv_ops_var2( int m_A,
82                                  int n_A,
83                                  float* buff_A, int rs_A, int cs_A )
84 {
85   float*    buff_1  = FLA_FLOAT_PTR( FLA_ONE );
86   float*    buff_m1 = FLA_FLOAT_PTR( FLA_MINUS_ONE );
87   int       min_m_n = min( m_A, n_A );
88   int       i;
89 
90   for ( i = 0; i < min_m_n; ++i )
91   {
92     float*    A00       = buff_A + (0  )*cs_A + (0  )*rs_A;
93     float*    a10t      = buff_A + (0  )*cs_A + (i  )*rs_A;
94     float*    a01       = buff_A + (i  )*cs_A + (0  )*rs_A;
95     float*    alpha11   = buff_A + (i  )*cs_A + (i  )*rs_A;
96     float*    A02       = buff_A + (i+1)*cs_A + (0  )*rs_A;
97     float*    a12t      = buff_A + (i+1)*cs_A + (i  )*rs_A;
98 
99     int       n_ahead   = n_A - i - 1;
100     int       mn_behind = i;
101 
102     /*------------------------------------------------------------*/
103 
104     // FLA_Trsv_external( FLA_UPPER_TRIANGULAR, FLA_TRANSPOSE, FLA_NONUNIT_DIAG, A00, a10t );
105     bl1_strsv( BLIS1_UPPER_TRIANGULAR,
106                BLIS1_TRANSPOSE,
107                BLIS1_NONUNIT_DIAG,
108                mn_behind,
109                A00, rs_A, cs_A,
110                a10t, cs_A );
111 
112     // FLA_Dots_external( FLA_MINUS_ONE, a10t, a01, FLA_ONE, alpha11 );
113     bl1_sdots( BLIS1_NO_CONJUGATE,
114                mn_behind,
115                buff_m1,
116                a10t, cs_A,
117                a01, rs_A,
118                buff_1,
119                alpha11 );
120 
121     // FLA_Gemv_external( FLA_TRANSPOSE, FLA_MINUS_ONE, A02, a10t, FLA_ONE, a12t );
122     bl1_sgemv( BLIS1_TRANSPOSE,
123                BLIS1_NO_CONJUGATE,
124                mn_behind,
125                n_ahead,
126                buff_m1,
127                A02, rs_A, cs_A,
128                a10t, cs_A,
129                buff_1,
130                a12t, cs_A );
131 
132     /*------------------------------------------------------------*/
133 
134   }
135 
136   if ( m_A > n_A )
137   {
138     float*    ATL = buff_A;
139     float*    ABL = buff_A + n_A*rs_A;
140 
141     // FLA_Trsm_external( FLA_RIGHT, FLA_UPPER_TRIANGULAR,
142     //                    FLA_NO_TRANSPOSE, FLA_NONUNIT_DIAG,
143     //                    FLA_ONE, ATL, ABL );
144     bl1_strsm( BLIS1_RIGHT,
145                BLIS1_UPPER_TRIANGULAR,
146                BLIS1_NO_TRANSPOSE,
147                BLIS1_NONUNIT_DIAG,
148                m_A - n_A,
149                n_A,
150                buff_1,
151                ATL, rs_A, cs_A,
152                ABL, rs_A, cs_A );
153   }
154 
155   return FLA_SUCCESS;
156 }
157 
158 
159 
FLA_LU_nopiv_opd_var2(int m_A,int n_A,double * buff_A,int rs_A,int cs_A)160 FLA_Error FLA_LU_nopiv_opd_var2( int m_A,
161                                  int n_A,
162                                  double* buff_A, int rs_A, int cs_A )
163 {
164   double*   buff_1  = FLA_DOUBLE_PTR( FLA_ONE );
165   double*   buff_m1 = FLA_DOUBLE_PTR( FLA_MINUS_ONE );
166   int       min_m_n = min( m_A, n_A );
167   int       i;
168 
169   for ( i = 0; i < min_m_n; ++i )
170   {
171     double*   A00       = buff_A + (0  )*cs_A + (0  )*rs_A;
172     double*   a10t      = buff_A + (0  )*cs_A + (i  )*rs_A;
173     double*   a01       = buff_A + (i  )*cs_A + (0  )*rs_A;
174     double*   alpha11   = buff_A + (i  )*cs_A + (i  )*rs_A;
175     double*   A02       = buff_A + (i+1)*cs_A + (0  )*rs_A;
176     double*   a12t      = buff_A + (i+1)*cs_A + (i  )*rs_A;
177 
178     int       n_ahead   = n_A - i - 1;
179     int       mn_behind = i;
180 
181     /*------------------------------------------------------------*/
182 
183     // FLA_Trsv_external( FLA_UPPER_TRIANGULAR, FLA_TRANSPOSE, FLA_NONUNIT_DIAG, A00, a10t );
184     bl1_dtrsv( BLIS1_UPPER_TRIANGULAR,
185                BLIS1_TRANSPOSE,
186                BLIS1_NONUNIT_DIAG,
187                mn_behind,
188                A00, rs_A, cs_A,
189                a10t, cs_A );
190 
191     // FLA_Dots_external( FLA_MINUS_ONE, a10t, a01, FLA_ONE, alpha11 );
192     bl1_ddots( BLIS1_NO_CONJUGATE,
193                mn_behind,
194                buff_m1,
195                a10t, cs_A,
196                a01, rs_A,
197                buff_1,
198                alpha11 );
199 
200     // FLA_Gemv_external( FLA_TRANSPOSE, FLA_MINUS_ONE, A02, a10t, FLA_ONE, a12t );
201     bl1_dgemv( BLIS1_TRANSPOSE,
202                BLIS1_NO_CONJUGATE,
203                mn_behind,
204                n_ahead,
205                buff_m1,
206                A02, rs_A, cs_A,
207                a10t, cs_A,
208                buff_1,
209                a12t, cs_A );
210 
211     /*------------------------------------------------------------*/
212 
213   }
214 
215   if ( m_A > n_A )
216   {
217     double*   ATL = buff_A;
218     double*   ABL = buff_A + n_A*rs_A;
219 
220     // FLA_Trsm_external( FLA_RIGHT, FLA_UPPER_TRIANGULAR,
221     //                    FLA_NO_TRANSPOSE, FLA_NONUNIT_DIAG,
222     //                    FLA_ONE, ATL, ABL );
223     bl1_dtrsm( BLIS1_RIGHT,
224                BLIS1_UPPER_TRIANGULAR,
225                BLIS1_NO_TRANSPOSE,
226                BLIS1_NONUNIT_DIAG,
227                m_A - n_A,
228                n_A,
229                buff_1,
230                ATL, rs_A, cs_A,
231                ABL, rs_A, cs_A );
232   }
233 
234   return FLA_SUCCESS;
235 }
236 
237 
238 
FLA_LU_nopiv_opc_var2(int m_A,int n_A,scomplex * buff_A,int rs_A,int cs_A)239 FLA_Error FLA_LU_nopiv_opc_var2( int m_A,
240                                  int n_A,
241                                  scomplex* buff_A, int rs_A, int cs_A )
242 {
243   scomplex* buff_1  = FLA_COMPLEX_PTR( FLA_ONE );
244   scomplex* buff_m1 = FLA_COMPLEX_PTR( FLA_MINUS_ONE );
245   int       min_m_n = min( m_A, n_A );
246   int       i;
247 
248   for ( i = 0; i < min_m_n; ++i )
249   {
250     scomplex* A00       = buff_A + (0  )*cs_A + (0  )*rs_A;
251     scomplex* a10t      = buff_A + (0  )*cs_A + (i  )*rs_A;
252     scomplex* a01       = buff_A + (i  )*cs_A + (0  )*rs_A;
253     scomplex* alpha11   = buff_A + (i  )*cs_A + (i  )*rs_A;
254     scomplex* A02       = buff_A + (i+1)*cs_A + (0  )*rs_A;
255     scomplex* a12t      = buff_A + (i+1)*cs_A + (i  )*rs_A;
256 
257     int       n_ahead   = n_A - i - 1;
258     int       mn_behind = i;
259 
260     /*------------------------------------------------------------*/
261 
262     // FLA_Trsv_external( FLA_UPPER_TRIANGULAR, FLA_TRANSPOSE, FLA_NONUNIT_DIAG, A00, a10t );
263     bl1_ctrsv( BLIS1_UPPER_TRIANGULAR,
264                BLIS1_TRANSPOSE,
265                BLIS1_NONUNIT_DIAG,
266                mn_behind,
267                A00, rs_A, cs_A,
268                a10t, cs_A );
269 
270     // FLA_Dots_external( FLA_MINUS_ONE, a10t, a01, FLA_ONE, alpha11 );
271     bl1_cdots( BLIS1_NO_CONJUGATE,
272                mn_behind,
273                buff_m1,
274                a10t, cs_A,
275                a01, rs_A,
276                buff_1,
277                alpha11 );
278 
279     // FLA_Gemv_external( FLA_TRANSPOSE, FLA_MINUS_ONE, A02, a10t, FLA_ONE, a12t );
280     bl1_cgemv( BLIS1_TRANSPOSE,
281                BLIS1_NO_CONJUGATE,
282                mn_behind,
283                n_ahead,
284                buff_m1,
285                A02, rs_A, cs_A,
286                a10t, cs_A,
287                buff_1,
288                a12t, cs_A );
289 
290     /*------------------------------------------------------------*/
291 
292   }
293 
294   if ( m_A > n_A )
295   {
296     scomplex* ATL = buff_A;
297     scomplex* ABL = buff_A + n_A*rs_A;
298 
299     // FLA_Trsm_external( FLA_RIGHT, FLA_UPPER_TRIANGULAR,
300     //                    FLA_NO_TRANSPOSE, FLA_NONUNIT_DIAG,
301     //                    FLA_ONE, ATL, ABL );
302     bl1_ctrsm( BLIS1_RIGHT,
303                BLIS1_UPPER_TRIANGULAR,
304                BLIS1_NO_TRANSPOSE,
305                BLIS1_NONUNIT_DIAG,
306                m_A - n_A,
307                n_A,
308                buff_1,
309                ATL, rs_A, cs_A,
310                ABL, rs_A, cs_A );
311   }
312 
313   return FLA_SUCCESS;
314 }
315 
316 
317 
FLA_LU_nopiv_opz_var2(int m_A,int n_A,dcomplex * buff_A,int rs_A,int cs_A)318 FLA_Error FLA_LU_nopiv_opz_var2( int m_A,
319                                  int n_A,
320                                  dcomplex* buff_A, int rs_A, int cs_A )
321 {
322   dcomplex* buff_1  = FLA_DOUBLE_COMPLEX_PTR( FLA_ONE );
323   dcomplex* buff_m1 = FLA_DOUBLE_COMPLEX_PTR( FLA_MINUS_ONE );
324   int       min_m_n = min( m_A, n_A );
325   int       i;
326 
327   for ( i = 0; i < min_m_n; ++i )
328   {
329     dcomplex* A00       = buff_A + (0  )*cs_A + (0  )*rs_A;
330     dcomplex* a10t      = buff_A + (0  )*cs_A + (i  )*rs_A;
331     dcomplex* a01       = buff_A + (i  )*cs_A + (0  )*rs_A;
332     dcomplex* alpha11   = buff_A + (i  )*cs_A + (i  )*rs_A;
333     dcomplex* A02       = buff_A + (i+1)*cs_A + (0  )*rs_A;
334     dcomplex* a12t      = buff_A + (i+1)*cs_A + (i  )*rs_A;
335 
336     int       n_ahead   = n_A - i - 1;
337     int       mn_behind = i;
338 
339     /*------------------------------------------------------------*/
340 
341     // FLA_Trsv_external( FLA_UPPER_TRIANGULAR, FLA_TRANSPOSE, FLA_NONUNIT_DIAG, A00, a10t );
342     bl1_ztrsv( BLIS1_UPPER_TRIANGULAR,
343                BLIS1_TRANSPOSE,
344                BLIS1_NONUNIT_DIAG,
345                mn_behind,
346                A00, rs_A, cs_A,
347                a10t, cs_A );
348 
349     // FLA_Dots_external( FLA_MINUS_ONE, a10t, a01, FLA_ONE, alpha11 );
350     bl1_zdots( BLIS1_NO_CONJUGATE,
351                mn_behind,
352                buff_m1,
353                a10t, cs_A,
354                a01, rs_A,
355                buff_1,
356                alpha11 );
357 
358     // FLA_Gemv_external( FLA_TRANSPOSE, FLA_MINUS_ONE, A02, a10t, FLA_ONE, a12t );
359     bl1_zgemv( BLIS1_TRANSPOSE,
360                BLIS1_NO_CONJUGATE,
361                mn_behind,
362                n_ahead,
363                buff_m1,
364                A02, rs_A, cs_A,
365                a10t, cs_A,
366                buff_1,
367                a12t, cs_A );
368 
369     /*------------------------------------------------------------*/
370 
371   }
372 
373   if ( m_A > n_A )
374   {
375     dcomplex* ATL = buff_A;
376     dcomplex* ABL = buff_A + n_A*rs_A;
377 
378     // FLA_Trsm_external( FLA_RIGHT, FLA_UPPER_TRIANGULAR,
379     //                    FLA_NO_TRANSPOSE, FLA_NONUNIT_DIAG,
380     //                    FLA_ONE, ATL, ABL );
381     bl1_ztrsm( BLIS1_RIGHT,
382                BLIS1_UPPER_TRIANGULAR,
383                BLIS1_NO_TRANSPOSE,
384                BLIS1_NONUNIT_DIAG,
385                m_A - n_A,
386                n_A,
387                buff_1,
388                ATL, rs_A, cs_A,
389                ABL, rs_A, cs_A );
390   }
391 
392   return FLA_SUCCESS;
393 }
394 
395 #endif
396