1
2#include "FLAME.h"
3
4FLA_Error FLA_Fused_Gerc2_opt_var1( FLA_Obj alpha, FLA_Obj u, FLA_Obj y, FLA_Obj z, FLA_Obj v, FLA_Obj A )
5{
6/*
7   Effective computation:
8   A = A + alpha * ( u * y' + z * v' );
9*/
10  FLA_Datatype datatype;
11  int          m_A, n_A;
12  int          rs_A, cs_A;
13  int          inc_u, inc_y, inc_z, inc_v;
14
15  datatype = FLA_Obj_datatype( A );
16
17  m_A      = FLA_Obj_length( A );
18  n_A      = FLA_Obj_width( A );
19
20  rs_A     = FLA_Obj_row_stride( A );
21  cs_A     = FLA_Obj_col_stride( A );
22
23  inc_u    = FLA_Obj_vector_inc( u );
24  inc_y    = FLA_Obj_vector_inc( y );
25  inc_z    = FLA_Obj_vector_inc( z );
26  inc_v    = FLA_Obj_vector_inc( v );
27
28
29  switch ( datatype )
30  {
31    case FLA_FLOAT:
32    {
33      float* buff_A = FLA_FLOAT_PTR( A );
34      float* buff_u = FLA_FLOAT_PTR( u );
35      float* buff_y = FLA_FLOAT_PTR( y );
36      float* buff_z = FLA_FLOAT_PTR( z );
37      float* buff_v = FLA_FLOAT_PTR( v );
38      float* buff_alpha = FLA_FLOAT_PTR( alpha );
39
40      FLA_Fused_Gerc2_ops_var1( m_A,
41                                n_A,
42                                buff_alpha,
43                                buff_u, inc_u,
44                                buff_y, inc_y,
45                                buff_z, inc_z,
46                                buff_v, inc_v,
47                                buff_A, rs_A, cs_A );
48
49      break;
50    }
51
52    case FLA_DOUBLE:
53    {
54      double* buff_A = FLA_DOUBLE_PTR( A );
55      double* buff_u = FLA_DOUBLE_PTR( u );
56      double* buff_y = FLA_DOUBLE_PTR( y );
57      double* buff_z = FLA_DOUBLE_PTR( z );
58      double* buff_v = FLA_DOUBLE_PTR( v );
59      double* buff_alpha = FLA_DOUBLE_PTR( alpha );
60
61      FLA_Fused_Gerc2_opd_var1( m_A,
62                                n_A,
63                                buff_alpha,
64                                buff_u, inc_u,
65                                buff_y, inc_y,
66                                buff_z, inc_z,
67                                buff_v, inc_v,
68                                buff_A, rs_A, cs_A );
69
70      break;
71    }
72
73    case FLA_COMPLEX:
74    {
75      scomplex* buff_A = FLA_COMPLEX_PTR( A );
76      scomplex* buff_u = FLA_COMPLEX_PTR( u );
77      scomplex* buff_y = FLA_COMPLEX_PTR( y );
78      scomplex* buff_z = FLA_COMPLEX_PTR( z );
79      scomplex* buff_v = FLA_COMPLEX_PTR( v );
80      scomplex* buff_alpha = FLA_COMPLEX_PTR( alpha );
81
82      FLA_Fused_Gerc2_opc_var1( m_A,
83                                n_A,
84                                buff_alpha,
85                                buff_u, inc_u,
86                                buff_y, inc_y,
87                                buff_z, inc_z,
88                                buff_v, inc_v,
89                                buff_A, rs_A, cs_A );
90
91      break;
92    }
93
94    case FLA_DOUBLE_COMPLEX:
95    {
96      dcomplex* buff_A = FLA_DOUBLE_COMPLEX_PTR( A );
97      dcomplex* buff_u = FLA_DOUBLE_COMPLEX_PTR( u );
98      dcomplex* buff_y = FLA_DOUBLE_COMPLEX_PTR( y );
99      dcomplex* buff_z = FLA_DOUBLE_COMPLEX_PTR( z );
100      dcomplex* buff_v = FLA_DOUBLE_COMPLEX_PTR( v );
101      dcomplex* buff_alpha = FLA_DOUBLE_COMPLEX_PTR( alpha );
102
103      FLA_Fused_Gerc2_opz_var1( m_A,
104                                n_A,
105                                buff_alpha,
106                                buff_u, inc_u,
107                                buff_y, inc_y,
108                                buff_z, inc_z,
109                                buff_v, inc_v,
110                                buff_A, rs_A, cs_A );
111
112      break;
113    }
114  }
115
116  return FLA_SUCCESS;
117}
118
119
120
121FLA_Error FLA_Fused_Gerc2_ops_var1( int m_A,
122                                    int n_A,
123                                    float* buff_alpha,
124                                    float* buff_u, int inc_u,
125                                    float* buff_y, int inc_y,
126                                    float* buff_z, int inc_z,
127                                    float* buff_v, int inc_v,
128                                    float* buff_A, int rs_A, int cs_A )
129{
130  int       i;
131
132  for ( i = 0; i < n_A; ++i )
133  {
134    float*    a1       = buff_A + (i  )*cs_A + (0  )*rs_A;
135    float*    u        = buff_u;
136    float*    psi1     = buff_y + (i  )*inc_y;
137    float*    z        = buff_z;
138    float*    nu1      = buff_v + (i  )*inc_v;
139    float*    alpha    = buff_alpha;
140    float     temp1;
141    float     temp2;
142
143    /*------------------------------------------------------------*/
144
145    // bl1_smult3( alpha, psi1, &temp1 );
146    temp1 = *alpha * *psi1;
147
148    // bl1_smult3( alpha, nu1, &temp2 );
149    temp2 = *alpha * *nu1;
150
151    // bl1_saxpyv( BLIS1_NO_CONJUGATE,
152    //             m_A,
153    //             &temp1,
154    //             u,  inc_u,
155    //             a1, rs_A );
156    F77_saxpy( &m_A,
157               &temp1,
158               u,  &inc_u,
159               a1, &rs_A );
160
161    // bl1_saxpyv( BLIS1_NO_CONJUGATE,
162    //             m_A,
163    //             &temp2,
164    //             z,  inc_z,
165    //             a1, rs_A );
166    F77_saxpy( &m_A,
167               &temp2,
168               z,  &inc_z,
169               a1, &rs_A );
170
171    /*------------------------------------------------------------*/
172
173  }
174
175  return FLA_SUCCESS;
176}
177
178
179
180FLA_Error FLA_Fused_Gerc2_opd_var1( int m_A,
181                                    int n_A,
182                                    double* buff_alpha,
183                                    double* buff_u, int inc_u,
184                                    double* buff_y, int inc_y,
185                                    double* buff_z, int inc_z,
186                                    double* buff_v, int inc_v,
187                                    double* buff_A, int rs_A, int cs_A )
188{
189  int       i;
190
191  for ( i = 0; i < n_A; ++i )
192  {
193/*
194   Effective computation:
195   A = A + alpha * ( u * y' + z * v' );
196*/
197    double*   restrict a1       = buff_A + (i  )*cs_A + (0  )*rs_A;
198    double*   restrict u        = buff_u;
199    double*   restrict psi1     = buff_y + (i  )*inc_y;
200    double*   restrict z        = buff_z;
201    double*   restrict nu1      = buff_v + (i  )*inc_v;
202    double*   restrict alpha    = buff_alpha;
203    double    alpha_conj_psi1;
204    double    alpha_conj_nu1;
205
206    /*------------------------------------------------------------*/
207
208    bl1_dmult3( alpha, psi1, &alpha_conj_psi1 );
209
210    bl1_dmult3( alpha, nu1, &alpha_conj_nu1 );
211
212    bl1_daxpyv2b( m_A,
213	              &alpha_conj_psi1,
214	              &alpha_conj_nu1,
215	              u,  inc_u,
216	              z,  inc_z,
217	              a1, rs_A );
218
219    /*------------------------------------------------------------*/
220
221  }
222
223  return FLA_SUCCESS;
224}
225
226
227
228FLA_Error FLA_Fused_Gerc2_opc_var1( int m_A,
229                                    int n_A,
230                                    scomplex* buff_alpha,
231                                    scomplex* buff_u, int inc_u,
232                                    scomplex* buff_y, int inc_y,
233                                    scomplex* buff_z, int inc_z,
234                                    scomplex* buff_v, int inc_v,
235                                    scomplex* buff_A, int rs_A, int cs_A )
236{
237  int       i;
238
239  for ( i = 0; i < n_A; ++i )
240  {
241    scomplex* a1       = buff_A + (i  )*cs_A + (0  )*rs_A;
242    scomplex* u        = buff_u;
243    scomplex* psi1     = buff_y + (i  )*inc_y;
244    scomplex* z        = buff_z;
245    scomplex* nu1      = buff_v + (i  )*inc_v;
246    scomplex* alpha    = buff_alpha;
247    scomplex  psi1_conj;
248    scomplex  nu1_conj;
249    scomplex  temp1;
250    scomplex  temp2;
251
252    /*------------------------------------------------------------*/
253
254    bl1_ccopyconj( psi1, &psi1_conj );
255    bl1_cmult3( alpha, &psi1_conj, &temp1 );
256
257    bl1_ccopyconj( nu1, &nu1_conj );
258    bl1_cmult3( alpha, &nu1_conj, &temp2 );
259
260    // bl1_caxpyv( BLIS1_NO_CONJUGATE,
261    //             m_A,
262    //             &temp1,
263    //             u,  inc_u,
264    //             a1, rs_A );
265    F77_caxpy( &m_A,
266               &temp1,
267               u,  &inc_u,
268               a1, &rs_A );
269
270    // bl1_caxpyv( BLIS1_NO_CONJUGATE,
271    //             m_A,
272    //             &temp2,
273    //             z,  inc_z,
274    //             a1, rs_A );
275    F77_caxpy( &m_A,
276               &temp2,
277               z,  &inc_z,
278               a1, &rs_A );
279
280    /*------------------------------------------------------------*/
281
282  }
283
284  return FLA_SUCCESS;
285}
286
287
288
289FLA_Error FLA_Fused_Gerc2_opz_var1( int m_A,
290                                    int n_A,
291                                    dcomplex* buff_alpha,
292                                    dcomplex* buff_u, int inc_u,
293                                    dcomplex* buff_y, int inc_y,
294                                    dcomplex* buff_z, int inc_z,
295                                    dcomplex* buff_v, int inc_v,
296                                    dcomplex* buff_A, int rs_A, int cs_A )
297{
298  int i;
299
300  for ( i = 0; i < n_A; ++i )
301  {
302    dcomplex* restrict a1       = buff_A + (i  )*cs_A + (0  )*rs_A;
303    dcomplex* restrict u        = buff_u;
304    dcomplex* restrict psi1     = buff_y + (i  )*inc_y;
305    dcomplex* restrict z        = buff_z;
306    dcomplex* restrict nu1      = buff_v + (i  )*inc_v;
307    dcomplex* restrict alpha    = buff_alpha;
308    dcomplex  conj_psi1;
309    dcomplex  conj_nu1;
310    dcomplex  alpha_conj_psi1;
311    dcomplex  alpha_conj_nu1;
312
313    /*------------------------------------------------------------*/
314
315    bl1_zcopyconj( psi1, &conj_psi1 );
316    bl1_zmult3( alpha, &conj_psi1, &alpha_conj_psi1 );
317
318    bl1_zcopyconj( nu1, &conj_nu1 );
319    bl1_zmult3( alpha, &conj_nu1, &alpha_conj_nu1 );
320
321    bl1_zaxpyv2b( m_A,
322	              &alpha_conj_psi1,
323	              &alpha_conj_nu1,
324	              u,  inc_u,
325	              z,  inc_z,
326	              a1, rs_A );
327
328    /*------------------------------------------------------------*/
329
330  }
331
332  return FLA_SUCCESS;
333}
334
335