1 /*********************************************************************/
2 /*                                                                   */
3 /*             Optimized BLAS libraries                              */
4 /*                     By Kazushige Goto <kgoto@tacc.utexas.edu>     */
5 /*                                                                   */
6 /* Copyright (c) The University of Texas, 2009. All rights reserved. */
7 /* UNIVERSITY EXPRESSLY DISCLAIMS ANY AND ALL WARRANTIES CONCERNING  */
8 /* THIS SOFTWARE AND DOCUMENTATION, INCLUDING ANY WARRANTIES OF      */
9 /* MERCHANTABILITY, FITNESS FOR ANY PARTICULAR PURPOSE,              */
10 /* NON-INFRINGEMENT AND WARRANTIES OF PERFORMANCE, AND ANY WARRANTY  */
11 /* THAT MIGHT OTHERWISE ARISE FROM COURSE OF DEALING OR USAGE OF     */
12 /* TRADE. NO WARRANTY IS EITHER EXPRESS OR IMPLIED WITH RESPECT TO   */
13 /* THE USE OF THE SOFTWARE OR DOCUMENTATION.                         */
14 /* Under no circumstances shall University be liable for incidental, */
15 /* special, indirect, direct or consequential damages or loss of     */
16 /* profits, interruption of business, or related expenses which may  */
17 /* arise from use of Software or Documentation, including but not    */
18 /* limited to those resulting from defects in Software and/or        */
19 /* Documentation, or loss or inaccuracy of data of any kind.         */
20 /*********************************************************************/
21 
22 #include <stdio.h>
23 #include <ctype.h>
24 #include "common.h"
25 #ifdef FUNCTION_PROFILE
26 #include "functable.h"
27 #endif
28 
29 #ifndef COMPLEX
30 #ifdef XDOUBLE
31 #define ERROR_NAME "QSYR2K"
32 #elif defined(DOUBLE)
33 #define ERROR_NAME "DSYR2K"
34 #else
35 #define ERROR_NAME "SSYR2K"
36 #endif
37 #else
38 #ifndef HEMM
39 #ifdef XDOUBLE
40 #define ERROR_NAME "XSYR2K"
41 #elif defined(DOUBLE)
42 #define ERROR_NAME "ZSYR2K"
43 #else
44 #define ERROR_NAME "CSYR2K"
45 #endif
46 #else
47 #ifdef XDOUBLE
48 #define ERROR_NAME "XHER2K"
49 #elif defined(DOUBLE)
50 #define ERROR_NAME "ZHER2K"
51 #else
52 #define ERROR_NAME "CHER2K"
53 #endif
54 #endif
55 #endif
56 
57 static int (*syr2k[])(blas_arg_t *, BLASLONG *, BLASLONG *, FLOAT *, FLOAT *, BLASLONG) = {
58 #ifndef HEMM
59   SYR2K_UN, SYR2K_UC, SYR2K_LN, SYR2K_LC,
60 #else
61   HER2K_UN, HER2K_UC, HER2K_LN, HER2K_LC,
62 #endif
63 };
64 
65 #ifndef CBLAS
66 
NAME(char * UPLO,char * TRANS,blasint * N,blasint * K,FLOAT * alpha,FLOAT * a,blasint * ldA,FLOAT * b,blasint * ldB,FLOAT * beta,FLOAT * c,blasint * ldC)67 void NAME(char *UPLO, char *TRANS,
68          blasint *N, blasint *K,
69          FLOAT *alpha, FLOAT *a, blasint *ldA,
70 	               FLOAT *b, blasint *ldB,
71          FLOAT *beta,  FLOAT *c, blasint *ldC){
72 
73   char uplo_arg  = *UPLO;
74   char trans_arg = *TRANS;
75 
76   blas_arg_t args;
77 
78   FLOAT *buffer;
79   FLOAT *sa, *sb;
80 
81 #ifdef SMP
82 #ifndef COMPLEX
83 #ifdef XDOUBLE
84   int mode  =  BLAS_XDOUBLE | BLAS_REAL;
85 #elif defined(DOUBLE)
86   int mode  =  BLAS_DOUBLE  | BLAS_REAL;
87 #else
88   int mode  =  BLAS_SINGLE  | BLAS_REAL;
89 #endif
90 #else
91 #ifdef XDOUBLE
92   int mode  =  BLAS_XDOUBLE | BLAS_COMPLEX;
93 #elif defined(DOUBLE)
94   int mode  =  BLAS_DOUBLE  | BLAS_COMPLEX;
95 #else
96   int mode  =  BLAS_SINGLE  | BLAS_COMPLEX;
97 #endif
98 #endif
99 #endif
100 
101   blasint info;
102   int uplo;
103   int trans;
104   int nrowa;
105 
106   PRINT_DEBUG_NAME;
107 
108   args.n = *N;
109   args.k = *K;
110 
111   args.a = (void *)a;
112   args.b = (void *)b;
113   args.c = (void *)c;
114 
115   args.lda = *ldA;
116   args.ldb = *ldB;
117   args.ldc = *ldC;
118 
119   args.alpha = (void *)alpha;
120   args.beta  = (void *)beta;
121 
122   TOUPPER(uplo_arg);
123   TOUPPER(trans_arg);
124 
125   uplo  = -1;
126   trans = -1;
127 
128   if (uplo_arg  == 'U') uplo  = 0;
129   if (uplo_arg  == 'L') uplo  = 1;
130 
131   if (trans_arg == 'N') trans = 0;
132   if (trans_arg == 'T') trans = 1;
133   if (trans_arg == 'R') trans = 0;
134   if (trans_arg == 'C') trans = 1;
135 
136   nrowa = args.n;
137   if (trans & 1) nrowa = args.k;
138 
139   info = 0;
140 
141   if (args.ldc < MAX(1,args.n)) info = 12;
142   if (args.ldb < MAX(1,nrowa))  info =  9;
143   if (args.lda < MAX(1,nrowa))  info =  7;
144   if (args.k < 0)               info =  4;
145   if (args.n < 0)               info =  3;
146   if (trans < 0)                info =  2;
147   if (uplo  < 0)                info =  1;
148 
149   if (info != 0) {
150     BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
151     return;
152   }
153 
154 #else
155 
156 void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE Trans,
157 	   blasint n, blasint k,
158 #ifndef COMPLEX
159 	   FLOAT alpha,
160 #else
161 	   FLOAT *alpha,
162 #endif
163 	   FLOAT *a, blasint lda,
164 	   FLOAT *b, blasint ldb,
165 #if !defined(COMPLEX) || defined(HEMM)
166 	   FLOAT beta,
167 #else
168 	   FLOAT *beta,
169 #endif
170 	   FLOAT *c, blasint ldc) {
171 
172   blas_arg_t args;
173   int uplo, trans;
174   blasint info, nrowa;
175 
176   FLOAT *buffer;
177   FLOAT *sa, *sb;
178 
179 #ifdef HEMM
180   FLOAT CAlpha[2];
181 #endif
182 
183 #ifdef SMP
184 #ifndef COMPLEX
185 #ifdef XDOUBLE
186   int mode  =  BLAS_XDOUBLE | BLAS_REAL;
187 #elif defined(DOUBLE)
188   int mode  =  BLAS_DOUBLE  | BLAS_REAL;
189 #else
190   int mode  =  BLAS_SINGLE  | BLAS_REAL;
191 #endif
192 #else
193 #ifdef XDOUBLE
194   int mode  =  BLAS_XDOUBLE | BLAS_COMPLEX;
195 #elif defined(DOUBLE)
196   int mode  =  BLAS_DOUBLE  | BLAS_COMPLEX;
197 #else
198   int mode  =  BLAS_SINGLE  | BLAS_COMPLEX;
199 #endif
200 #endif
201 #endif
202 
203   PRINT_DEBUG_CNAME;
204 
205   args.n = n;
206   args.k = k;
207 
208   args.a = (void *)a;
209   args.b = (void *)b;
210   args.c = (void *)c;
211 
212   args.lda = lda;
213   args.ldb = ldb;
214   args.ldc = ldc;
215 
216 #ifndef COMPLEX
217   args.alpha = (void *)&alpha;
218 #else
219   args.alpha = (void *)alpha;
220 #endif
221 
222 #if !defined(COMPLEX) || defined(HEMM)
223   args.beta  = (void *)&beta;
224 #else
225   args.beta  = (void *)beta;
226 #endif
227 
228   trans = -1;
229   uplo  = -1;
230   info  =  0;
231 
232   if (order == CblasColMajor) {
233     if (Uplo == CblasUpper) uplo  = 0;
234     if (Uplo == CblasLower) uplo  = 1;
235 
236     if (Trans == CblasNoTrans)     trans = 0;
237 #ifndef COMPLEX
238     if (Trans == CblasTrans)       trans = 1;
239     if (Trans == CblasConjNoTrans) trans = 0;
240     if (Trans == CblasConjTrans)   trans = 1;
241 #elif !defined(HEMM)
242     if (Trans == CblasTrans)       trans = 1;
243 #else
244     if (Trans == CblasConjTrans)   trans = 1;
245 #endif
246 
247     info = -1;
248 
249     nrowa = args.n;
250     if (trans & 1) nrowa = args.k;
251 
252     if (args.ldc < MAX(1,args.n)) info = 12;
253     if (args.ldb < MAX(1,nrowa))  info =  9;
254     if (args.lda < MAX(1,nrowa))  info =  7;
255     if (args.k < 0)               info =  4;
256     if (args.n < 0)               info =  3;
257     if (trans < 0)                info =  2;
258     if (uplo  < 0)                info =  1;
259   }
260 
261   if (order == CblasRowMajor) {
262 
263 #ifdef HEMM
264     CAlpha[0] =  alpha[0];
265     CAlpha[1] = -alpha[1];
266 
267     args.alpha = (void *)CAlpha;
268 #endif
269 
270     if (Uplo == CblasUpper) uplo  = 1;
271     if (Uplo == CblasLower) uplo  = 0;
272 
273     if (Trans == CblasNoTrans)     trans = 1;
274 #ifndef COMPLEX
275     if (Trans == CblasTrans)       trans = 0;
276     if (Trans == CblasConjNoTrans) trans = 1;
277     if (Trans == CblasConjTrans)   trans = 0;
278 #elif !defined(HEMM)
279     if (Trans == CblasTrans)       trans = 0;
280 #else
281     if (Trans == CblasConjTrans)   trans = 0;
282 #endif
283 
284     info = -1;
285 
286     nrowa = args.n;
287     if (trans & 1) nrowa = args.k;
288 
289     if (args.ldc < MAX(1,args.n)) info = 12;
290     if (args.ldb < MAX(1,nrowa))  info =  9;
291     if (args.lda < MAX(1,nrowa))  info =  7;
292     if (args.k < 0)               info =  4;
293     if (args.n < 0)               info =  3;
294     if (trans < 0)                info =  2;
295     if (uplo  < 0)                info =  1;
296   }
297 
298   if (info >= 0) {
299     BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
300     return;
301   }
302 
303 #endif
304 
305   if (args.n == 0) return;
306 
307   IDEBUG_START;
308 
309   FUNCTION_PROFILE_START();
310 
311   buffer = (FLOAT *)blas_memory_alloc(0);
312 
313   sa = (FLOAT *)((BLASLONG)buffer + GEMM_OFFSET_A);
314   sb = (FLOAT *)(((BLASLONG)sa + ((GEMM_P * GEMM_Q * COMPSIZE * SIZE + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
315 
316 #ifdef SMP
317   if (!trans){
318     mode |= (BLAS_TRANSA_N | BLAS_TRANSB_T);
319   } else {
320     mode |= (BLAS_TRANSA_T | BLAS_TRANSB_N);
321   }
322 
323   mode |= (uplo  << BLAS_UPLO_SHIFT);
324 
325   args.common = NULL;
326   args.nthreads = num_cpu_avail(3);
327 
328   if (args.nthreads == 1) {
329 #endif
330 
331     (syr2k[(uplo << 1) | trans ])(&args, NULL, NULL, sa, sb, 0);
332 
333 #ifdef SMP
334 
335   } else {
336 
337     syrk_thread(mode, &args, NULL, NULL, syr2k[(uplo << 1) | trans ], sa, sb, args.nthreads);
338 
339   }
340 #endif
341 
342   blas_memory_free(buffer);
343 
344   FUNCTION_PROFILE_END(COMPSIZE * COMPSIZE, 2 * args.n * args.k + args.n * args.n, 2 * args.n * args.n * args.k);
345 
346   IDEBUG_END;
347 
348   return;
349 }
350