1 /* ************************************************************************
2 * Copyright 2013 Advanced Micro Devices, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * ************************************************************************/
16
17
18 #include <stdio.h>
19 #include <string.h>
20 #include <clBLAS.h>
21
22 #include <devinfo.h>
23 #include "clblas-internal.h"
24 #include "solution_seq.h"
25
26 //#define DEBUG_GEMM_2
27
28 int
gemmHasMTail(size_t M,int vecLen,clblasOrder order,clblasTranspose transA,clblasTranspose transB)29 gemmHasMTail(size_t M, int vecLen, clblasOrder order, clblasTranspose transA, clblasTranspose transB)
30 {
31 transB = transB; // Dummy- to remove warning
32 if (order == clblasColumnMajor)
33 {
34 if (transA == clblasNoTrans)
35 {
36 return (M % vecLen);
37 } else {
38 return 0;
39 }
40 } else {
41 printf("gemmHasMTail: Not handling Row Major - FIXME\n");
42 return 0;
43 }
44 }
45
46 int
gemmHasNTail(size_t N,int vecLen,clblasOrder order,clblasTranspose transA,clblasTranspose transB)47 gemmHasNTail(size_t N, int vecLen, clblasOrder order, clblasTranspose transA, clblasTranspose transB)
48 {
49 if (order == clblasColumnMajor)
50 {
51 if (transA == clblasNoTrans)
52 {
53 if (transB == clblasNoTrans)
54 {
55 return 0;
56 } else {
57 return (N % vecLen);
58 }
59 } else {
60 if (transB == clblasNoTrans)
61 {
62 return 0;
63 } else {
64 return (N % vecLen);
65 }
66 }
67 } else {
68 printf("gemmHasNTail: Not handling Row Major - FIXME\n");
69 return 0;
70 }
71 }
72
73 int
gemmHasTails(size_t M,size_t N,size_t K,int vecLen,clblasOrder order,clblasTranspose transA,clblasTranspose transB)74 gemmHasTails(size_t M, size_t N, size_t K, int vecLen, clblasOrder order, clblasTranspose transA, clblasTranspose transB)
75 {
76 K = K; // Dummy- to remove warning
77 if (order == clblasColumnMajor)
78 {
79 if (transA == clblasNoTrans)
80 {
81 if (transB == clblasNoTrans)
82 {
83 return (M % vecLen);
84 } else {
85 return ((M % vecLen) || (N % vecLen));
86 }
87 } else {
88 if (transB == clblasNoTrans)
89 {
90 //
91 // Vectoring on A is on K dimension and we handle tail directly in the kernel
92 //
93 return 0;
94 } else {
95 return (N % vecLen);
96 }
97 }
98 } else {
99 printf("gemmHasTails: Not handling Row Major - FIXME\n");
100 return 0;
101 }
102 }
103
executeGEMM(CLBlasKargs * kargs,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)104 clblasStatus executeGEMM( CLBlasKargs *kargs, cl_uint numCommandQueues, cl_command_queue *commandQueues, cl_uint numEventsInWaitList,
105 const cl_event *eventWaitList, cl_event *events)
106 {
107 cl_int err = CL_SUCCESS;
108 ListHead seq, tailSeq;
109 cl_event nontail;
110 cl_uint gemmVeclen;
111 CLBLASKernExtra *kextra;
112 size_t M, N, K;
113
114 M = kargs->M; N = kargs->N; K = kargs->K;
115 #ifdef DEBUG_GEMM_2
116 printf("executeGEMM Called\n");
117 #endif
118 listInitHead(&seq);
119 err = makeSolutionSeq(CLBLAS_GEMM2, kargs, numCommandQueues, commandQueues,
120 numEventsInWaitList, eventWaitList, &nontail, &seq);
121 if (err == CL_SUCCESS) {
122 ListNode *f = listNodeFirst(&seq);
123 SolutionStep *gemm2;
124 size_t tailStartM, tailStartN;
125 bool processTails;
126
127 gemm2 = container_of(f, node, SolutionStep);
128 kextra = gemm2->kernels[CLBLAS_COMPUTING_KERNEL]->extra;
129 gemmVeclen = kextra->vecLen;
130
131 if (gemmHasTails(M, N, K, gemmVeclen, kargs->order, kargs->transA, kargs->transB) == 0)
132 {
133 #ifdef DEBUG_GEMM_2
134 printf("No M or N Tails to process..\n");
135 #endif
136 processTails = false;
137 gemm2->event = events;
138 } else {
139 processTails = true;
140 if (gemmHasMTail(M, gemmVeclen, kargs->order, kargs->transA, kargs->transB))
141 {
142 tailStartM = M - (M%gemmVeclen);
143 } else {
144 tailStartM = M;
145 }
146
147 if (gemmHasNTail(N, gemmVeclen, kargs->order, kargs->transA, kargs->transB))
148 {
149 tailStartN = N - (N%gemmVeclen);
150 } else {
151 tailStartN = N;
152 }
153 }
154 err = executeSolutionSeq(&seq);
155 if ((err == CL_SUCCESS) && (processTails == true))
156 {
157 CLBlasKargs targs;
158
159 memcpy(&targs, &gemm2->args, sizeof(CLBlasKargs));
160 targs.tailStartM = tailStartM;
161 targs.tailStartN = tailStartN;
162 #ifdef DEBUG_GEMM_2
163 printf("Processing Tails\n");
164 #endif
165 listInitHead(&tailSeq);
166 err = makeSolutionSeq(CLBLAS_GEMM_TAIL, &targs, numCommandQueues, commandQueues,
167 1, &nontail, events, &tailSeq);
168 if (err == CL_SUCCESS)
169 {
170 err = executeSolutionSeq(&tailSeq);
171 }
172 freeSolutionSeq(&tailSeq);
173 }
174 }
175 freeSolutionSeq(&seq);
176 return (clblasStatus) err;
177 }
178
179 static clblasStatus
doGemm(CLBlasKargs * kargs,clblasOrder order,clblasTranspose transA,clblasTranspose transB,size_t M,size_t N,size_t K,const cl_mem A,size_t offA,size_t lda,const cl_mem B,size_t offB,size_t ldb,cl_mem C,size_t offC,size_t ldc,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)180 doGemm(
181 CLBlasKargs *kargs,
182 clblasOrder order,
183 clblasTranspose transA,
184 clblasTranspose transB,
185 size_t M,
186 size_t N,
187 size_t K,
188 const cl_mem A,
189 size_t offA,
190 size_t lda,
191 const cl_mem B,
192 size_t offB,
193 size_t ldb,
194 cl_mem C,
195 size_t offC,
196 size_t ldc,
197 cl_uint numCommandQueues,
198 cl_command_queue *commandQueues,
199 cl_uint numEventsInWaitList,
200 const cl_event *eventWaitList,
201 cl_event *events)
202 {
203 clblasStatus err;
204 clblasStatus retCode = clblasSuccess;
205
206 if (!clblasInitialized) {
207 return clblasNotInitialized;
208 }
209
210 /* Validate arguments */
211
212 if ((retCode = checkMemObjects(A, B, C, true, A_MAT_ERRSET, B_MAT_ERRSET, C_MAT_ERRSET))) {
213 return retCode;
214 }
215 if (K != 0) {
216 if ((retCode = checkMatrixSizes(kargs->dtype, order, transA, M, K, A, offA, lda, A_MAT_ERRSET))) {
217 return retCode;
218 }
219 if ((retCode = checkMatrixSizes(kargs->dtype, order, transB, K, N, B, offB, ldb, B_MAT_ERRSET))) {
220 return retCode;
221 }
222 }
223 if ((retCode = checkMatrixSizes(kargs->dtype, order, clblasNoTrans, M, N, C, offC, ldc, C_MAT_ERRSET))) {
224 return retCode;
225 }
226
227 numCommandQueues = 1;
228 #ifdef DEBUG_2
229 printf("DoGemm being called...\n");
230 #endif
231 kargs->pigFuncID = CLBLAS_GEMM2;
232 kargs->order = order;
233 kargs->transA = transA;
234 kargs->transB = transB;
235 kargs->M = M;
236 kargs->N = N;
237 kargs->K = K;
238 kargs->A = A;
239 kargs->offA = offA;
240 kargs->offa = offA;
241 kargs->lda.matrix = lda;
242 kargs->B = B;
243 kargs->offBX = offB;
244 kargs->ldb.matrix = ldb;
245 kargs->C = C;
246 kargs->offCY = offC;
247 kargs->ldc.matrix = ldc;
248
249 kargs->offsetM = 0;
250 kargs->offsetN = 0;
251 kargs->scimage[0] = 0;
252 kargs->scimage[1] = 0;
253
254 err = executeGEMM(kargs, numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events);
255 return err;
256 }
257
258 /*
259 clblasStatus
260 clblasSgemmV2(
261 clblasOrder order,
262 clblasTranspose transA,
263 clblasTranspose transB,
264 size_t M,
265 size_t N,
266 size_t K,
267 cl_float alpha,
268 const cl_mem A,
269 size_t lda,
270 const cl_mem B,
271 size_t ldb,
272 cl_float beta,
273 cl_mem C,
274 size_t ldc,
275 cl_uint numCommandQueues,
276 cl_command_queue *commandQueues,
277 cl_uint numEventsInWaitList,
278 const cl_event *eventWaitList,
279 cl_event *events)
280 {
281 CLBlasKargs kargs;
282
283 memset(&kargs, 0, sizeof(kargs));
284 kargs.dtype = TYPE_FLOAT;
285 kargs.alpha.argFloat = alpha;
286 kargs.beta.argFloat = beta;
287
288 return doGemm(&kargs, order, transA, transB, M, N, K, A, 0, lda, B, 0, ldb,
289 C, 0, ldc, numCommandQueues, commandQueues,
290 numEventsInWaitList, eventWaitList, events);
291 }
292
293 clblasStatus
294 clblasDgemmV2(
295 clblasOrder order,
296 clblasTranspose transA,
297 clblasTranspose transB,
298 size_t M,
299 size_t N,
300 size_t K,
301 cl_double alpha,
302 const cl_mem A,
303 size_t lda,
304 const cl_mem B,
305 size_t ldb,
306 cl_double beta,
307 cl_mem C,
308 size_t ldc,
309 cl_uint numCommandQueues,
310 cl_command_queue *commandQueues,
311 cl_uint numEventsInWaitList,
312 const cl_event *eventWaitList,
313 cl_event *events)
314 {
315 CLBlasKargs kargs;
316
317 memset(&kargs, 0, sizeof(kargs));
318 kargs.dtype = TYPE_DOUBLE;
319 kargs.alpha.argDouble = alpha;
320 kargs.beta.argDouble = beta;
321
322 return doGemm(&kargs, order, transA, transB, M, N, K, A, 0, lda, B, 0, ldb,
323 C, 0, ldc, numCommandQueues, commandQueues,
324 numEventsInWaitList, eventWaitList, events);
325 }
326
327 clblasStatus
328 clblasCgemmV2(
329 clblasOrder order,
330 clblasTranspose transA,
331 clblasTranspose transB,
332 size_t M,
333 size_t N,
334 size_t K,
335 FloatComplex alpha,
336 const cl_mem A,
337 size_t lda,
338 const cl_mem B,
339 size_t ldb,
340 FloatComplex beta,
341 cl_mem C,
342 size_t ldc,
343 cl_uint numCommandQueues,
344 cl_command_queue *commandQueues,
345 cl_uint numEventsInWaitList,
346 const cl_event *eventWaitList,
347 cl_event *events)
348 {
349 CLBlasKargs kargs;
350
351 memset(&kargs, 0, sizeof(kargs));
352 kargs.dtype = TYPE_COMPLEX_FLOAT;
353 kargs.alpha.argFloatComplex = alpha;
354 kargs.beta.argFloatComplex = beta;
355
356 return doGemm(&kargs, order, transA, transB, M, N, K, A, 0, lda, B, 0, ldb,
357 C, 0, ldc, numCommandQueues, commandQueues,
358 numEventsInWaitList, eventWaitList, events);
359 }
360
361 clblasStatus
362 clblasZgemmV2(
363 clblasOrder order,
364 clblasTranspose transA,
365 clblasTranspose transB,
366 size_t M,
367 size_t N,
368 size_t K,
369 DoubleComplex alpha,
370 const cl_mem A,
371 size_t lda,
372 const cl_mem B,
373 size_t ldb,
374 DoubleComplex beta,
375 cl_mem C,
376 size_t ldc,
377 cl_uint numCommandQueues,
378 cl_command_queue *commandQueues,
379 cl_uint numEventsInWaitList,
380 const cl_event *eventWaitList,
381 cl_event *events)
382 {
383 CLBlasKargs kargs;
384
385 memset(&kargs, 0, sizeof(kargs));
386 kargs.dtype = TYPE_COMPLEX_DOUBLE;
387 kargs.alpha.argDoubleComplex = alpha;
388 kargs.beta.argDoubleComplex = beta;
389
390 return doGemm(&kargs, order, transA, transB, M, N, K, A, 0, lda, B, 0, ldb,
391 C, 0, ldc, numCommandQueues, commandQueues,
392 numEventsInWaitList, eventWaitList, events);
393 }
394
395 clblasStatus
396 clblasSgemmExV2(
397 clblasOrder order,
398 clblasTranspose transA,
399 clblasTranspose transB,
400 size_t M,
401 size_t N,
402 size_t K,
403 cl_float alpha,
404 const cl_mem A,
405 size_t offA,
406 size_t lda,
407 const cl_mem B,
408 size_t offB,
409 size_t ldb,
410 cl_float beta,
411 cl_mem C,
412 size_t offC,
413 size_t ldc,
414 cl_uint numCommandQueues,
415 cl_command_queue *commandQueues,
416 cl_uint numEventsInWaitList,
417 const cl_event *eventWaitList,
418 cl_event *events)
419 {
420 CLBlasKargs kargs;
421
422 memset(&kargs, 0, sizeof(kargs));
423 kargs.dtype = TYPE_FLOAT;
424 kargs.alpha.argFloat = alpha;
425 kargs.beta.argFloat = beta;
426
427 return doGemm(&kargs, order, transA, transB, M, N, K, A, offA, lda, B, offB, ldb,
428 C, offC, ldc, numCommandQueues, commandQueues,
429 numEventsInWaitList, eventWaitList, events);
430 }
431
432 clblasStatus
433 clblasDgemmExV2(
434 clblasOrder order,
435 clblasTranspose transA,
436 clblasTranspose transB,
437 size_t M,
438 size_t N,
439 size_t K,
440 cl_double alpha,
441 const cl_mem A,
442 size_t offA,
443 size_t lda,
444 const cl_mem B,
445 size_t offB,
446 size_t ldb,
447 cl_double beta,
448 cl_mem C,
449 size_t offC,
450 size_t ldc,
451 cl_uint numCommandQueues,
452 cl_command_queue *commandQueues,
453 cl_uint numEventsInWaitList,
454 const cl_event *eventWaitList,
455 cl_event *events)
456 {
457 CLBlasKargs kargs;
458
459 memset(&kargs, 0, sizeof(kargs));
460 kargs.dtype = TYPE_DOUBLE;
461 kargs.alpha.argDouble = alpha;
462 kargs.beta.argDouble = beta;
463
464 return doGemm(&kargs, order, transA, transB, M, N, K, A, offA, lda, B, offB, ldb,
465 C, offC, ldc, numCommandQueues, commandQueues,
466 numEventsInWaitList, eventWaitList, events);
467 }
468
469 clblasStatus
470 clblasCgemmExV2(
471 clblasOrder order,
472 clblasTranspose transA,
473 clblasTranspose transB,
474 size_t M,
475 size_t N,
476 size_t K,
477 FloatComplex alpha,
478 const cl_mem A,
479 size_t offA,
480 size_t lda,
481 const cl_mem B,
482 size_t offB,
483 size_t ldb,
484 FloatComplex beta,
485 cl_mem C,
486 size_t offC,
487 size_t ldc,
488 cl_uint numCommandQueues,
489 cl_command_queue *commandQueues,
490 cl_uint numEventsInWaitList,
491 const cl_event *eventWaitList,
492 cl_event *events)
493 {
494 CLBlasKargs kargs;
495
496 memset(&kargs, 0, sizeof(kargs));
497 kargs.dtype = TYPE_COMPLEX_FLOAT;
498 kargs.alpha.argFloatComplex = alpha;
499 kargs.beta.argFloatComplex = beta;
500
501 return doGemm(&kargs, order, transA, transB, M, N, K, A, offA, lda, B, offB, ldb,
502 C, offC, ldc, numCommandQueues, commandQueues,
503 numEventsInWaitList, eventWaitList, events);
504 }
505
506 clblasStatus
507 clblasZgemmExV2(
508 clblasOrder order,
509 clblasTranspose transA,
510 clblasTranspose transB,
511 size_t M,
512 size_t N,
513 size_t K,
514 DoubleComplex alpha,
515 const cl_mem A,
516 size_t offA,
517 size_t lda,
518 const cl_mem B,
519 size_t offB,
520 size_t ldb,
521 DoubleComplex beta,
522 cl_mem C,
523 size_t offC,
524 size_t ldc,
525 cl_uint numCommandQueues,
526 cl_command_queue *commandQueues,
527 cl_uint numEventsInWaitList,
528 const cl_event *eventWaitList,
529 cl_event *events)
530 {
531 CLBlasKargs kargs;
532
533 memset(&kargs, 0, sizeof(kargs));
534 kargs.dtype = TYPE_COMPLEX_DOUBLE;
535 kargs.alpha.argDoubleComplex = alpha;
536 kargs.beta.argDoubleComplex = beta;
537
538 return doGemm(&kargs, order, transA, transB, M, N, K, A, offA, lda, B, offB, ldb,
539 C, offC, ldc, numCommandQueues, commandQueues,
540 numEventsInWaitList, eventWaitList, events);
541 }
542 */
543