1 /*
2
3 Copyright (C) 2014, The University of Texas at Austin
4
5 This file is part of libflame and is available under the 3-Clause
6 BSD license, which can be found in the LICENSE file at the top-level
7 directory, or at http://opensource.org/licenses/BSD-3-Clause
8
9 */
10
11 #include "FLAME.h"
12
FLA_Sylv_nn_opt_var1(FLA_Obj isgn,FLA_Obj A,FLA_Obj B,FLA_Obj C,FLA_Obj scale)13 FLA_Error FLA_Sylv_nn_opt_var1( FLA_Obj isgn, FLA_Obj A, FLA_Obj B, FLA_Obj C, FLA_Obj scale )
14 {
15 FLA_Datatype datatype;
16 int m_C, n_C;
17 int rs_A, cs_A;
18 int rs_B, cs_B;
19 int rs_C, cs_C;
20 int info;
21
22 datatype = FLA_Obj_datatype( A );
23
24 rs_A = FLA_Obj_row_stride( A );
25 cs_A = FLA_Obj_col_stride( A );
26
27 rs_B = FLA_Obj_row_stride( B );
28 cs_B = FLA_Obj_col_stride( B );
29
30 m_C = FLA_Obj_length( C );
31 n_C = FLA_Obj_width( C );
32 rs_C = FLA_Obj_row_stride( C );
33 cs_C = FLA_Obj_col_stride( C );
34
35
36 switch ( datatype )
37 {
38 case FLA_FLOAT:
39 {
40 int* buff_isgn = FLA_INT_PTR( isgn );
41 float* buff_A = FLA_FLOAT_PTR( A );
42 float* buff_B = FLA_FLOAT_PTR( B );
43 float* buff_C = FLA_FLOAT_PTR( C );
44 float* buff_scale = FLA_FLOAT_PTR( scale );
45 float sgn = ( float ) *buff_isgn;
46
47 FLA_Sylv_nn_ops_var1( sgn,
48 m_C,
49 n_C,
50 buff_A, rs_A, cs_A,
51 buff_B, rs_B, cs_B,
52 buff_C, rs_C, cs_C,
53 buff_scale,
54 &info );
55
56 break;
57 }
58
59 case FLA_DOUBLE:
60 {
61 int* buff_isgn = FLA_INT_PTR( isgn );
62 double* buff_A = FLA_DOUBLE_PTR( A );
63 double* buff_B = FLA_DOUBLE_PTR( B );
64 double* buff_C = FLA_DOUBLE_PTR( C );
65 double* buff_scale = FLA_DOUBLE_PTR( scale );
66 double sgn = ( double ) *buff_isgn;
67
68 FLA_Sylv_nn_opd_var1( sgn,
69 m_C,
70 n_C,
71 buff_A, rs_A, cs_A,
72 buff_B, rs_B, cs_B,
73 buff_C, rs_C, cs_C,
74 buff_scale,
75 &info );
76
77 break;
78 }
79
80 case FLA_COMPLEX:
81 {
82 int* buff_isgn = FLA_INT_PTR( isgn );
83 scomplex* buff_A = FLA_COMPLEX_PTR( A );
84 scomplex* buff_B = FLA_COMPLEX_PTR( B );
85 scomplex* buff_C = FLA_COMPLEX_PTR( C );
86 scomplex* buff_scale = FLA_COMPLEX_PTR( scale );
87 float sgn = ( float ) *buff_isgn;
88
89 FLA_Sylv_nn_opc_var1( sgn,
90 m_C,
91 n_C,
92 buff_A, rs_A, cs_A,
93 buff_B, rs_B, cs_B,
94 buff_C, rs_C, cs_C,
95 buff_scale,
96 &info );
97
98 break;
99 }
100
101 case FLA_DOUBLE_COMPLEX:
102 {
103 int* buff_isgn = FLA_INT_PTR( isgn );
104 dcomplex* buff_A = FLA_DOUBLE_COMPLEX_PTR( A );
105 dcomplex* buff_B = FLA_DOUBLE_COMPLEX_PTR( B );
106 dcomplex* buff_C = FLA_DOUBLE_COMPLEX_PTR( C );
107 dcomplex* buff_scale = FLA_DOUBLE_COMPLEX_PTR( scale );
108 double sgn = ( double ) *buff_isgn;
109
110 FLA_Sylv_nn_opz_var1( sgn,
111 m_C,
112 n_C,
113 buff_A, rs_A, cs_A,
114 buff_B, rs_B, cs_B,
115 buff_C, rs_C, cs_C,
116 buff_scale,
117 &info );
118
119 break;
120 }
121 }
122
123 return FLA_SUCCESS;
124 }
125
126
127
FLA_Sylv_nn_ops_var1(float sgn,int m_C,int n_C,float * buff_A,int rs_A,int cs_A,float * buff_B,int rs_B,int cs_B,float * buff_C,int rs_C,int cs_C,float * buff_scale,int * info)128 FLA_Error FLA_Sylv_nn_ops_var1( float sgn,
129 int m_C,
130 int n_C,
131 float* buff_A, int rs_A, int cs_A,
132 float* buff_B, int rs_B, int cs_B,
133 float* buff_C, int rs_C, int cs_C,
134 float* buff_scale,
135 int* info )
136 {
137 int l, k;
138
139 for ( l = 0; l < n_C; l++ )
140 {
141 for ( k = m_C - 1; k >= 0; k-- )
142 {
143 float* a12t = buff_A + (k+1)*cs_A + (k )*rs_A;
144 float* b01 = buff_B + (l )*cs_B + (0 )*rs_B;
145 float* c10t = buff_C + (0 )*cs_C + (k )*rs_C;
146 float* c21 = buff_C + (l )*cs_C + (k+1)*rs_C;
147 float* alpha11 = buff_A + (k )*cs_A + (k )*rs_A;
148 float* beta11 = buff_B + (l )*cs_B + (l )*rs_B;
149 float* ckl = buff_C + (l )*cs_C + (k )*rs_C;
150 float suml;
151 float sumr;
152 float vec;
153 float a11;
154 float x11;
155
156 int m_behind = m_C - k - 1;
157 int n_behind = l;
158
159 /*------------------------------------------------------------*/
160
161 bl1_sdot( BLIS1_NO_CONJUGATE,
162 m_behind,
163 a12t, cs_A,
164 c21, rs_C,
165 &suml );
166
167 bl1_sdot( BLIS1_NO_CONJUGATE,
168 n_behind,
169 c10t, cs_C,
170 b01, rs_B,
171 &sumr );
172
173 vec = (*ckl) - ( suml + sgn * sumr );
174
175 a11 = (*alpha11) + sgn * (*beta11);
176
177 bl1_sdiv3( &vec, &a11, &x11 );
178
179 *ckl = x11;
180
181 /*------------------------------------------------------------*/
182
183 }
184 }
185
186 return FLA_SUCCESS;
187 }
188
189
190
FLA_Sylv_nn_opd_var1(double sgn,int m_C,int n_C,double * buff_A,int rs_A,int cs_A,double * buff_B,int rs_B,int cs_B,double * buff_C,int rs_C,int cs_C,double * buff_scale,int * info)191 FLA_Error FLA_Sylv_nn_opd_var1( double sgn,
192 int m_C,
193 int n_C,
194 double* buff_A, int rs_A, int cs_A,
195 double* buff_B, int rs_B, int cs_B,
196 double* buff_C, int rs_C, int cs_C,
197 double* buff_scale,
198 int* info )
199 {
200 int l, k;
201
202 for ( l = 0; l < n_C; l++ )
203 {
204 for ( k = m_C - 1; k >= 0; k-- )
205 {
206 double* a12t = buff_A + (k+1)*cs_A + (k )*rs_A;
207 double* b01 = buff_B + (l )*cs_B + (0 )*rs_B;
208 double* c10t = buff_C + (0 )*cs_C + (k )*rs_C;
209 double* c21 = buff_C + (l )*cs_C + (k+1)*rs_C;
210 double* alpha11 = buff_A + (k )*cs_A + (k )*rs_A;
211 double* beta11 = buff_B + (l )*cs_B + (l )*rs_B;
212 double* ckl = buff_C + (l )*cs_C + (k )*rs_C;
213 double suml;
214 double sumr;
215 double vec;
216 double a11;
217 double x11;
218
219 int m_behind = m_C - k - 1;
220 int n_behind = l;
221
222 /*------------------------------------------------------------*/
223
224 bl1_ddot( BLIS1_NO_CONJUGATE,
225 m_behind,
226 a12t, cs_A,
227 c21, rs_C,
228 &suml );
229
230 bl1_ddot( BLIS1_NO_CONJUGATE,
231 n_behind,
232 c10t, cs_C,
233 b01, rs_B,
234 &sumr );
235
236 vec = (*ckl) - ( suml + sgn * sumr );
237
238 a11 = (*alpha11) + sgn * (*beta11);
239
240 bl1_ddiv3( &vec, &a11, &x11 );
241
242 *ckl = x11;
243
244 /*------------------------------------------------------------*/
245
246 }
247 }
248
249 return FLA_SUCCESS;
250 }
251
252
253
FLA_Sylv_nn_opc_var1(float sgn,int m_C,int n_C,scomplex * buff_A,int rs_A,int cs_A,scomplex * buff_B,int rs_B,int cs_B,scomplex * buff_C,int rs_C,int cs_C,scomplex * buff_scale,int * info)254 FLA_Error FLA_Sylv_nn_opc_var1( float sgn,
255 int m_C,
256 int n_C,
257 scomplex* buff_A, int rs_A, int cs_A,
258 scomplex* buff_B, int rs_B, int cs_B,
259 scomplex* buff_C, int rs_C, int cs_C,
260 scomplex* buff_scale,
261 int* info )
262 {
263 int l, k;
264
265 for ( l = 0; l < n_C; l++ )
266 {
267 for ( k = m_C - 1; k >= 0; k-- )
268 {
269 scomplex* a12t = buff_A + (k+1)*cs_A + (k )*rs_A;
270 scomplex* b01 = buff_B + (l )*cs_B + (0 )*rs_B;
271 scomplex* c10t = buff_C + (0 )*cs_C + (k )*rs_C;
272 scomplex* c21 = buff_C + (l )*cs_C + (k+1)*rs_C;
273 scomplex* alpha11 = buff_A + (k )*cs_A + (k )*rs_A;
274 scomplex* beta11 = buff_B + (l )*cs_B + (l )*rs_B;
275 scomplex* ckl = buff_C + (l )*cs_C + (k )*rs_C;
276 scomplex suml;
277 scomplex sumr;
278 scomplex vec;
279 scomplex a11;
280 scomplex x11;
281
282 int m_behind = m_C - k - 1;
283 int n_behind = l;
284
285 /*------------------------------------------------------------*/
286
287 bl1_cdot( BLIS1_NO_CONJUGATE,
288 m_behind,
289 a12t, cs_A,
290 c21, rs_C,
291 &suml );
292
293 bl1_cdot( BLIS1_NO_CONJUGATE,
294 n_behind,
295 c10t, cs_C,
296 b01, rs_B,
297 &sumr );
298
299 vec.real = ckl->real - ( suml.real + sgn * sumr.real );
300 vec.imag = ckl->imag - ( suml.imag + sgn * sumr.imag );
301
302 a11.real = alpha11->real + sgn * beta11->real;
303 a11.imag = alpha11->imag + sgn * beta11->imag;
304
305 bl1_cdiv3( &vec, &a11, &x11 );
306
307 *ckl = x11;
308
309 /*------------------------------------------------------------*/
310
311 }
312 }
313
314 return FLA_SUCCESS;
315 }
316
317
318
FLA_Sylv_nn_opz_var1(double sgn,int m_C,int n_C,dcomplex * buff_A,int rs_A,int cs_A,dcomplex * buff_B,int rs_B,int cs_B,dcomplex * buff_C,int rs_C,int cs_C,dcomplex * buff_scale,int * info)319 FLA_Error FLA_Sylv_nn_opz_var1( double sgn,
320 int m_C,
321 int n_C,
322 dcomplex* buff_A, int rs_A, int cs_A,
323 dcomplex* buff_B, int rs_B, int cs_B,
324 dcomplex* buff_C, int rs_C, int cs_C,
325 dcomplex* buff_scale,
326 int* info )
327 {
328 int l, k;
329
330 for ( l = 0; l < n_C; l++ )
331 {
332 for ( k = m_C - 1; k >= 0; k-- )
333 {
334 dcomplex* a12t = buff_A + (k+1)*cs_A + (k )*rs_A;
335 dcomplex* b01 = buff_B + (l )*cs_B + (0 )*rs_B;
336 dcomplex* c10t = buff_C + (0 )*cs_C + (k )*rs_C;
337 dcomplex* c21 = buff_C + (l )*cs_C + (k+1)*rs_C;
338 dcomplex* alpha11 = buff_A + (k )*cs_A + (k )*rs_A;
339 dcomplex* beta11 = buff_B + (l )*cs_B + (l )*rs_B;
340 dcomplex* ckl = buff_C + (l )*cs_C + (k )*rs_C;
341 dcomplex suml;
342 dcomplex sumr;
343 dcomplex vec;
344 dcomplex a11;
345 dcomplex x11;
346
347 int m_behind = m_C - k - 1;
348 int n_behind = l;
349
350 /*------------------------------------------------------------*/
351
352 bl1_zdot( BLIS1_NO_CONJUGATE,
353 m_behind,
354 a12t, cs_A,
355 c21, rs_C,
356 &suml );
357
358 bl1_zdot( BLIS1_NO_CONJUGATE,
359 n_behind,
360 c10t, cs_C,
361 b01, rs_B,
362 &sumr );
363
364 vec.real = ckl->real - ( suml.real + sgn * sumr.real );
365 vec.imag = ckl->imag - ( suml.imag + sgn * sumr.imag );
366
367 a11.real = alpha11->real + sgn * beta11->real;
368 a11.imag = alpha11->imag + sgn * beta11->imag;
369
370 bl1_zdiv3( &vec, &a11, &x11 );
371
372 *ckl = x11;
373
374 /*------------------------------------------------------------*/
375
376 }
377 }
378
379 return FLA_SUCCESS;
380 }
381
382