1 /*
2 
3     Copyright (C) 2014, The University of Texas at Austin
4 
5     This file is part of libflame and is available under the 3-Clause
6     BSD license, which can be found in the LICENSE file at the top-level
7     directory, or at http://opensource.org/licenses/BSD-3-Clause
8 
9 */
10 
11 #include "FLAME.h"
12 
FLA_Sylv_nn_opt_var1(FLA_Obj isgn,FLA_Obj A,FLA_Obj B,FLA_Obj C,FLA_Obj scale)13 FLA_Error FLA_Sylv_nn_opt_var1( FLA_Obj isgn, FLA_Obj A, FLA_Obj B, FLA_Obj C, FLA_Obj scale )
14 {
15   FLA_Datatype datatype;
16   int          m_C, n_C;
17   int          rs_A, cs_A;
18   int          rs_B, cs_B;
19   int          rs_C, cs_C;
20   int          info;
21 
22   datatype = FLA_Obj_datatype( A );
23 
24   rs_A     = FLA_Obj_row_stride( A );
25   cs_A     = FLA_Obj_col_stride( A );
26 
27   rs_B     = FLA_Obj_row_stride( B );
28   cs_B     = FLA_Obj_col_stride( B );
29 
30   m_C      = FLA_Obj_length( C );
31   n_C      = FLA_Obj_width( C );
32   rs_C     = FLA_Obj_row_stride( C );
33   cs_C     = FLA_Obj_col_stride( C );
34 
35 
36   switch ( datatype )
37   {
38     case FLA_FLOAT:
39     {
40       int*   buff_isgn  = FLA_INT_PTR( isgn );
41       float* buff_A     = FLA_FLOAT_PTR( A );
42       float* buff_B     = FLA_FLOAT_PTR( B );
43       float* buff_C     = FLA_FLOAT_PTR( C );
44       float* buff_scale = FLA_FLOAT_PTR( scale );
45       float  sgn        = ( float ) *buff_isgn;
46 
47       FLA_Sylv_nn_ops_var1( sgn,
48                             m_C,
49                             n_C,
50                             buff_A, rs_A, cs_A,
51                             buff_B, rs_B, cs_B,
52                             buff_C, rs_C, cs_C,
53                             buff_scale,
54                             &info );
55 
56       break;
57     }
58 
59     case FLA_DOUBLE:
60     {
61       int*    buff_isgn  = FLA_INT_PTR( isgn );
62       double* buff_A     = FLA_DOUBLE_PTR( A );
63       double* buff_B     = FLA_DOUBLE_PTR( B );
64       double* buff_C     = FLA_DOUBLE_PTR( C );
65       double* buff_scale = FLA_DOUBLE_PTR( scale );
66       double  sgn        = ( double ) *buff_isgn;
67 
68       FLA_Sylv_nn_opd_var1( sgn,
69                             m_C,
70                             n_C,
71                             buff_A, rs_A, cs_A,
72                             buff_B, rs_B, cs_B,
73                             buff_C, rs_C, cs_C,
74                             buff_scale,
75                             &info );
76 
77       break;
78     }
79 
80     case FLA_COMPLEX:
81     {
82       int*      buff_isgn  = FLA_INT_PTR( isgn );
83       scomplex* buff_A     = FLA_COMPLEX_PTR( A );
84       scomplex* buff_B     = FLA_COMPLEX_PTR( B );
85       scomplex* buff_C     = FLA_COMPLEX_PTR( C );
86       scomplex* buff_scale = FLA_COMPLEX_PTR( scale );
87       float     sgn        = ( float ) *buff_isgn;
88 
89       FLA_Sylv_nn_opc_var1( sgn,
90                             m_C,
91                             n_C,
92                             buff_A, rs_A, cs_A,
93                             buff_B, rs_B, cs_B,
94                             buff_C, rs_C, cs_C,
95                             buff_scale,
96                             &info );
97 
98       break;
99     }
100 
101     case FLA_DOUBLE_COMPLEX:
102     {
103       int*      buff_isgn  = FLA_INT_PTR( isgn );
104       dcomplex* buff_A     = FLA_DOUBLE_COMPLEX_PTR( A );
105       dcomplex* buff_B     = FLA_DOUBLE_COMPLEX_PTR( B );
106       dcomplex* buff_C     = FLA_DOUBLE_COMPLEX_PTR( C );
107       dcomplex* buff_scale = FLA_DOUBLE_COMPLEX_PTR( scale );
108       double    sgn        = ( double ) *buff_isgn;
109 
110       FLA_Sylv_nn_opz_var1( sgn,
111                             m_C,
112                             n_C,
113                             buff_A, rs_A, cs_A,
114                             buff_B, rs_B, cs_B,
115                             buff_C, rs_C, cs_C,
116                             buff_scale,
117                             &info );
118 
119       break;
120     }
121   }
122 
123   return FLA_SUCCESS;
124 }
125 
126 
127 
FLA_Sylv_nn_ops_var1(float sgn,int m_C,int n_C,float * buff_A,int rs_A,int cs_A,float * buff_B,int rs_B,int cs_B,float * buff_C,int rs_C,int cs_C,float * buff_scale,int * info)128 FLA_Error FLA_Sylv_nn_ops_var1( float sgn,
129                                 int m_C,
130                                 int n_C,
131                                 float* buff_A, int rs_A, int cs_A,
132                                 float* buff_B, int rs_B, int cs_B,
133                                 float* buff_C, int rs_C, int cs_C,
134                                 float* buff_scale,
135                                 int* info )
136 {
137   int l, k;
138 
139   for ( l = 0; l < n_C; l++ )
140   {
141     for ( k = m_C - 1; k >= 0; k-- )
142     {
143       float*    a12t     = buff_A + (k+1)*cs_A + (k  )*rs_A;
144       float*    b01      = buff_B + (l  )*cs_B + (0  )*rs_B;
145       float*    c10t     = buff_C + (0  )*cs_C + (k  )*rs_C;
146       float*    c21      = buff_C + (l  )*cs_C + (k+1)*rs_C;
147       float*    alpha11  = buff_A + (k  )*cs_A + (k  )*rs_A;
148       float*    beta11   = buff_B + (l  )*cs_B + (l  )*rs_B;
149       float*    ckl      = buff_C + (l  )*cs_C + (k  )*rs_C;
150       float     suml;
151       float     sumr;
152       float     vec;
153       float     a11;
154       float     x11;
155 
156       int       m_behind = m_C - k - 1;
157       int       n_behind = l;
158 
159       /*------------------------------------------------------------*/
160 
161       bl1_sdot( BLIS1_NO_CONJUGATE,
162                 m_behind,
163                 a12t, cs_A,
164                 c21, rs_C,
165                 &suml );
166 
167       bl1_sdot( BLIS1_NO_CONJUGATE,
168                 n_behind,
169                 c10t, cs_C,
170                 b01, rs_B,
171                 &sumr );
172 
173       vec = (*ckl) - ( suml + sgn * sumr );
174 
175       a11 = (*alpha11) + sgn * (*beta11);
176 
177       bl1_sdiv3( &vec, &a11, &x11 );
178 
179       *ckl = x11;
180 
181       /*------------------------------------------------------------*/
182 
183     }
184   }
185 
186   return FLA_SUCCESS;
187 }
188 
189 
190 
FLA_Sylv_nn_opd_var1(double sgn,int m_C,int n_C,double * buff_A,int rs_A,int cs_A,double * buff_B,int rs_B,int cs_B,double * buff_C,int rs_C,int cs_C,double * buff_scale,int * info)191 FLA_Error FLA_Sylv_nn_opd_var1( double sgn,
192                                 int m_C,
193                                 int n_C,
194                                 double* buff_A, int rs_A, int cs_A,
195                                 double* buff_B, int rs_B, int cs_B,
196                                 double* buff_C, int rs_C, int cs_C,
197                                 double* buff_scale,
198                                 int* info )
199 {
200   int l, k;
201 
202   for ( l = 0; l < n_C; l++ )
203   {
204     for ( k = m_C - 1; k >= 0; k-- )
205     {
206       double*   a12t     = buff_A + (k+1)*cs_A + (k  )*rs_A;
207       double*   b01      = buff_B + (l  )*cs_B + (0  )*rs_B;
208       double*   c10t     = buff_C + (0  )*cs_C + (k  )*rs_C;
209       double*   c21      = buff_C + (l  )*cs_C + (k+1)*rs_C;
210       double*   alpha11  = buff_A + (k  )*cs_A + (k  )*rs_A;
211       double*   beta11   = buff_B + (l  )*cs_B + (l  )*rs_B;
212       double*   ckl      = buff_C + (l  )*cs_C + (k  )*rs_C;
213       double    suml;
214       double    sumr;
215       double    vec;
216       double    a11;
217       double    x11;
218 
219       int       m_behind = m_C - k - 1;
220       int       n_behind = l;
221 
222       /*------------------------------------------------------------*/
223 
224       bl1_ddot( BLIS1_NO_CONJUGATE,
225                 m_behind,
226                 a12t, cs_A,
227                 c21, rs_C,
228                 &suml );
229 
230       bl1_ddot( BLIS1_NO_CONJUGATE,
231                 n_behind,
232                 c10t, cs_C,
233                 b01, rs_B,
234                 &sumr );
235 
236       vec = (*ckl) - ( suml + sgn * sumr );
237 
238       a11 = (*alpha11) + sgn * (*beta11);
239 
240       bl1_ddiv3( &vec, &a11, &x11 );
241 
242       *ckl = x11;
243 
244       /*------------------------------------------------------------*/
245 
246     }
247   }
248 
249   return FLA_SUCCESS;
250 }
251 
252 
253 
FLA_Sylv_nn_opc_var1(float sgn,int m_C,int n_C,scomplex * buff_A,int rs_A,int cs_A,scomplex * buff_B,int rs_B,int cs_B,scomplex * buff_C,int rs_C,int cs_C,scomplex * buff_scale,int * info)254 FLA_Error FLA_Sylv_nn_opc_var1( float sgn,
255                                 int m_C,
256                                 int n_C,
257                                 scomplex* buff_A, int rs_A, int cs_A,
258                                 scomplex* buff_B, int rs_B, int cs_B,
259                                 scomplex* buff_C, int rs_C, int cs_C,
260                                 scomplex* buff_scale,
261                                 int* info )
262 {
263   int l, k;
264 
265   for ( l = 0; l < n_C; l++ )
266   {
267     for ( k = m_C - 1; k >= 0; k-- )
268     {
269       scomplex* a12t     = buff_A + (k+1)*cs_A + (k  )*rs_A;
270       scomplex* b01      = buff_B + (l  )*cs_B + (0  )*rs_B;
271       scomplex* c10t     = buff_C + (0  )*cs_C + (k  )*rs_C;
272       scomplex* c21      = buff_C + (l  )*cs_C + (k+1)*rs_C;
273       scomplex* alpha11  = buff_A + (k  )*cs_A + (k  )*rs_A;
274       scomplex* beta11   = buff_B + (l  )*cs_B + (l  )*rs_B;
275       scomplex* ckl      = buff_C + (l  )*cs_C + (k  )*rs_C;
276       scomplex  suml;
277       scomplex  sumr;
278       scomplex  vec;
279       scomplex  a11;
280       scomplex  x11;
281 
282       int       m_behind = m_C - k - 1;
283       int       n_behind = l;
284 
285       /*------------------------------------------------------------*/
286 
287       bl1_cdot( BLIS1_NO_CONJUGATE,
288                 m_behind,
289                 a12t, cs_A,
290                 c21, rs_C,
291                 &suml );
292 
293       bl1_cdot( BLIS1_NO_CONJUGATE,
294                 n_behind,
295                 c10t, cs_C,
296                 b01, rs_B,
297                 &sumr );
298 
299       vec.real = ckl->real - ( suml.real + sgn * sumr.real );
300       vec.imag = ckl->imag - ( suml.imag + sgn * sumr.imag );
301 
302       a11.real = alpha11->real + sgn * beta11->real;
303       a11.imag = alpha11->imag + sgn * beta11->imag;
304 
305       bl1_cdiv3( &vec, &a11, &x11 );
306 
307       *ckl = x11;
308 
309       /*------------------------------------------------------------*/
310 
311     }
312   }
313 
314   return FLA_SUCCESS;
315 }
316 
317 
318 
FLA_Sylv_nn_opz_var1(double sgn,int m_C,int n_C,dcomplex * buff_A,int rs_A,int cs_A,dcomplex * buff_B,int rs_B,int cs_B,dcomplex * buff_C,int rs_C,int cs_C,dcomplex * buff_scale,int * info)319 FLA_Error FLA_Sylv_nn_opz_var1( double sgn,
320                                 int m_C,
321                                 int n_C,
322                                 dcomplex* buff_A, int rs_A, int cs_A,
323                                 dcomplex* buff_B, int rs_B, int cs_B,
324                                 dcomplex* buff_C, int rs_C, int cs_C,
325                                 dcomplex* buff_scale,
326                                 int* info )
327 {
328   int l, k;
329 
330   for ( l = 0; l < n_C; l++ )
331   {
332     for ( k = m_C - 1; k >= 0; k-- )
333     {
334       dcomplex* a12t     = buff_A + (k+1)*cs_A + (k  )*rs_A;
335       dcomplex* b01      = buff_B + (l  )*cs_B + (0  )*rs_B;
336       dcomplex* c10t     = buff_C + (0  )*cs_C + (k  )*rs_C;
337       dcomplex* c21      = buff_C + (l  )*cs_C + (k+1)*rs_C;
338       dcomplex* alpha11  = buff_A + (k  )*cs_A + (k  )*rs_A;
339       dcomplex* beta11   = buff_B + (l  )*cs_B + (l  )*rs_B;
340       dcomplex* ckl      = buff_C + (l  )*cs_C + (k  )*rs_C;
341       dcomplex  suml;
342       dcomplex  sumr;
343       dcomplex  vec;
344       dcomplex  a11;
345       dcomplex  x11;
346 
347       int       m_behind = m_C - k - 1;
348       int       n_behind = l;
349 
350       /*------------------------------------------------------------*/
351 
352       bl1_zdot( BLIS1_NO_CONJUGATE,
353                 m_behind,
354                 a12t, cs_A,
355                 c21, rs_C,
356                 &suml );
357 
358       bl1_zdot( BLIS1_NO_CONJUGATE,
359                 n_behind,
360                 c10t, cs_C,
361                 b01, rs_B,
362                 &sumr );
363 
364       vec.real = ckl->real - ( suml.real + sgn * sumr.real );
365       vec.imag = ckl->imag - ( suml.imag + sgn * sumr.imag );
366 
367       a11.real = alpha11->real + sgn * beta11->real;
368       a11.imag = alpha11->imag + sgn * beta11->imag;
369 
370       bl1_zdiv3( &vec, &a11, &x11 );
371 
372       *ckl = x11;
373 
374       /*------------------------------------------------------------*/
375 
376     }
377   }
378 
379   return FLA_SUCCESS;
380 }
381 
382