1 /* ************************************************************************
2  * Copyright 2013 Advanced Micro Devices, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  * ************************************************************************/
16 
17 
18 #include <stdio.h>
19 #include <string.h>
20 #include <clBLAS.h>
21 
22 #include <devinfo.h>
23 #include "clblas-internal.h"
24 #include "solution_seq.h"
25 
26 //#define DEBUG_GEMM_2
27 
28 int
gemmHasMTail(size_t M,int vecLen,clblasOrder order,clblasTranspose transA,clblasTranspose transB)29 gemmHasMTail(size_t M,  int vecLen, clblasOrder order, clblasTranspose transA, clblasTranspose transB)
30 {
31 	transB = transB;    // Dummy- to remove warning
32     if (order == clblasColumnMajor)
33 	{
34 		if (transA == clblasNoTrans)
35 		{
36 			return (M % vecLen);
37 		} else {
38 			return 0;
39 		}
40 	} else {
41 		printf("gemmHasMTail: Not handling Row Major - FIXME\n");
42 		return 0;
43 	}
44 }
45 
46 int
gemmHasNTail(size_t N,int vecLen,clblasOrder order,clblasTranspose transA,clblasTranspose transB)47 gemmHasNTail(size_t N, int vecLen, clblasOrder order, clblasTranspose transA, clblasTranspose transB)
48 {
49 	if (order == clblasColumnMajor)
50 	{
51 		if (transA == clblasNoTrans)
52 		{
53 			if (transB == clblasNoTrans)
54 			{
55 				return 0;
56 			} else {
57 				return (N % vecLen);
58 			}
59 		} else {
60 			if (transB == clblasNoTrans)
61 			{
62 				return 0;
63 			} else {
64 				return (N % vecLen);
65 			}
66 		}
67 	} else {
68 		printf("gemmHasNTail: Not handling Row Major - FIXME\n");
69 		return 0;
70 	}
71 }
72 
73 int
gemmHasTails(size_t M,size_t N,size_t K,int vecLen,clblasOrder order,clblasTranspose transA,clblasTranspose transB)74 gemmHasTails(size_t M,  size_t N, size_t K, int vecLen, clblasOrder order, clblasTranspose transA, clblasTranspose transB)
75 {
76 	K = K;  // Dummy- to remove warning
77     if (order == clblasColumnMajor)
78 	{
79 		if (transA == clblasNoTrans)
80 		{
81 			if (transB == clblasNoTrans)
82 			{
83 				return (M % vecLen);
84 			} else {
85 				return ((M % vecLen) || (N % vecLen));
86 			}
87 		} else {
88 			if (transB == clblasNoTrans)
89 			{
90 				//
91 				// Vectoring on A is on K dimension and we handle tail directly in the kernel
92 				//
93 				return 0;
94 			} else {
95 				return (N % vecLen);
96 			}
97 		}
98 	} else {
99 		printf("gemmHasTails: Not handling Row Major - FIXME\n");
100 		return 0;
101 	}
102 }
103 
executeGEMM(CLBlasKargs * kargs,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)104 clblasStatus executeGEMM( CLBlasKargs *kargs, cl_uint numCommandQueues, cl_command_queue *commandQueues, cl_uint numEventsInWaitList,
105                              const cl_event *eventWaitList, cl_event *events)
106 {
107     cl_int err = CL_SUCCESS;
108     ListHead seq, tailSeq;
109 	cl_event nontail;
110 	cl_uint gemmVeclen;
111 	CLBLASKernExtra *kextra;
112     size_t M, N, K;
113 
114     M = kargs->M; N = kargs->N; K = kargs->K;
115     #ifdef DEBUG_GEMM_2
116     printf("executeGEMM Called\n");
117     #endif
118     listInitHead(&seq);
119     err = makeSolutionSeq(CLBLAS_GEMM2, kargs, numCommandQueues, commandQueues,
120         numEventsInWaitList, eventWaitList, &nontail, &seq);
121     if (err == CL_SUCCESS) {
122 	    ListNode *f = listNodeFirst(&seq);
123 		SolutionStep *gemm2;
124 		size_t tailStartM, tailStartN;
125 		bool processTails;
126 
127 		gemm2 = container_of(f, node, SolutionStep);
128 		kextra = gemm2->kernels[CLBLAS_COMPUTING_KERNEL]->extra;
129 		gemmVeclen = kextra->vecLen;
130 
131 		if (gemmHasTails(M, N, K, gemmVeclen, kargs->order, kargs->transA, kargs->transB) == 0)
132 		{
133 			#ifdef DEBUG_GEMM_2
134 			printf("No M or N Tails to process..\n");
135 			#endif
136 			processTails = false;
137 			gemm2->event = events;
138 		} else {
139 			processTails = true;
140 			if (gemmHasMTail(M, gemmVeclen, kargs->order, kargs->transA, kargs->transB))
141 			{
142 				tailStartM = M - (M%gemmVeclen);
143 			} else {
144 				tailStartM = M;
145 			}
146 
147 			if (gemmHasNTail(N, gemmVeclen, kargs->order, kargs->transA, kargs->transB))
148 			{
149 				tailStartN = N - (N%gemmVeclen);
150 			} else {
151 				tailStartN = N;
152             }
153 		}
154         err = executeSolutionSeq(&seq);
155 		if ((err == CL_SUCCESS) && (processTails == true))
156 		{
157 			CLBlasKargs targs;
158 
159 			memcpy(&targs, &gemm2->args, sizeof(CLBlasKargs));
160 			targs.tailStartM = tailStartM;
161 			targs.tailStartN = tailStartN;
162 			#ifdef DEBUG_GEMM_2
163 			printf("Processing Tails\n");
164 			#endif
165     		listInitHead(&tailSeq);
166     		err = makeSolutionSeq(CLBLAS_GEMM_TAIL, &targs, numCommandQueues, commandQueues,
167         						  1, &nontail, events, &tailSeq);
168 			if (err == CL_SUCCESS)
169 			{
170 				err = executeSolutionSeq(&tailSeq);
171 			}
172 			freeSolutionSeq(&tailSeq);
173 		}
174     }
175     freeSolutionSeq(&seq);
176     return (clblasStatus) err;
177 }
178 
179 static clblasStatus
doGemm(CLBlasKargs * kargs,clblasOrder order,clblasTranspose transA,clblasTranspose transB,size_t M,size_t N,size_t K,const cl_mem A,size_t offA,size_t lda,const cl_mem B,size_t offB,size_t ldb,cl_mem C,size_t offC,size_t ldc,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)180 doGemm(
181     CLBlasKargs *kargs,
182     clblasOrder order,
183     clblasTranspose transA,
184     clblasTranspose transB,
185     size_t M,
186     size_t N,
187     size_t K,
188     const cl_mem A,
189     size_t offA,
190     size_t lda,
191     const cl_mem B,
192     size_t offB,
193     size_t ldb,
194     cl_mem C,
195     size_t offC,
196     size_t ldc,
197     cl_uint numCommandQueues,
198     cl_command_queue *commandQueues,
199     cl_uint numEventsInWaitList,
200     const cl_event *eventWaitList,
201     cl_event *events)
202 {
203     clblasStatus err;
204     clblasStatus retCode = clblasSuccess;
205 
206     if (!clblasInitialized) {
207         return clblasNotInitialized;
208     }
209 
210     /* Validate arguments */
211 
212     if ((retCode = checkMemObjects(A, B, C, true, A_MAT_ERRSET, B_MAT_ERRSET, C_MAT_ERRSET))) {
213         return retCode;
214     }
215     if (K != 0) {
216         if ((retCode = checkMatrixSizes(kargs->dtype, order, transA, M, K, A, offA, lda, A_MAT_ERRSET))) {
217             return retCode;
218         }
219         if ((retCode = checkMatrixSizes(kargs->dtype, order, transB, K, N, B, offB, ldb, B_MAT_ERRSET))) {
220             return retCode;
221         }
222     }
223     if ((retCode = checkMatrixSizes(kargs->dtype, order, clblasNoTrans, M, N, C, offC, ldc, C_MAT_ERRSET))) {
224             return retCode;
225     }
226 
227 	numCommandQueues = 1;
228 	#ifdef DEBUG_2
229 	printf("DoGemm being called...\n");
230 	#endif
231     kargs->pigFuncID = CLBLAS_GEMM2;
232     kargs->order = order;
233     kargs->transA = transA;
234     kargs->transB = transB;
235     kargs->M = M;
236     kargs->N = N;
237     kargs->K = K;
238     kargs->A = A;
239     kargs->offA = offA;
240     kargs->offa = offA;
241     kargs->lda.matrix = lda;
242     kargs->B = B;
243     kargs->offBX = offB;
244     kargs->ldb.matrix = ldb;
245     kargs->C = C;
246     kargs->offCY = offC;
247     kargs->ldc.matrix = ldc;
248 
249     kargs->offsetM = 0;
250     kargs->offsetN = 0;
251     kargs->scimage[0] = 0;
252     kargs->scimage[1] = 0;
253 
254     err = executeGEMM(kargs, numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events);
255     return err;
256 			}
257 
258 /*
259 clblasStatus
260 clblasSgemmV2(
261     clblasOrder order,
262     clblasTranspose transA,
263     clblasTranspose transB,
264     size_t M,
265     size_t N,
266     size_t K,
267     cl_float alpha,
268     const cl_mem A,
269     size_t lda,
270     const cl_mem B,
271     size_t ldb,
272     cl_float beta,
273     cl_mem C,
274     size_t ldc,
275     cl_uint numCommandQueues,
276     cl_command_queue *commandQueues,
277     cl_uint numEventsInWaitList,
278     const cl_event *eventWaitList,
279     cl_event *events)
280 {
281     CLBlasKargs kargs;
282 
283     memset(&kargs, 0, sizeof(kargs));
284     kargs.dtype = TYPE_FLOAT;
285     kargs.alpha.argFloat = alpha;
286     kargs.beta.argFloat = beta;
287 
288     return doGemm(&kargs, order, transA, transB, M, N, K, A, 0, lda, B, 0, ldb,
289                   C, 0, ldc, numCommandQueues, commandQueues,
290                   numEventsInWaitList, eventWaitList, events);
291 }
292 
293 clblasStatus
294 clblasDgemmV2(
295     clblasOrder order,
296     clblasTranspose transA,
297     clblasTranspose transB,
298     size_t M,
299     size_t N,
300     size_t K,
301     cl_double alpha,
302     const cl_mem A,
303     size_t lda,
304     const cl_mem B,
305     size_t ldb,
306     cl_double beta,
307     cl_mem C,
308     size_t ldc,
309     cl_uint numCommandQueues,
310     cl_command_queue *commandQueues,
311     cl_uint numEventsInWaitList,
312     const cl_event *eventWaitList,
313     cl_event *events)
314 {
315     CLBlasKargs kargs;
316 
317     memset(&kargs, 0, sizeof(kargs));
318     kargs.dtype = TYPE_DOUBLE;
319     kargs.alpha.argDouble = alpha;
320     kargs.beta.argDouble = beta;
321 
322     return doGemm(&kargs, order, transA, transB, M, N, K, A, 0, lda, B, 0, ldb,
323                   C, 0, ldc, numCommandQueues, commandQueues,
324                   numEventsInWaitList, eventWaitList, events);
325 }
326 
327 clblasStatus
328 clblasCgemmV2(
329     clblasOrder order,
330     clblasTranspose transA,
331     clblasTranspose transB,
332     size_t M,
333     size_t N,
334     size_t K,
335     FloatComplex alpha,
336     const cl_mem A,
337     size_t lda,
338     const cl_mem B,
339     size_t ldb,
340     FloatComplex beta,
341     cl_mem C,
342     size_t ldc,
343     cl_uint numCommandQueues,
344     cl_command_queue *commandQueues,
345     cl_uint numEventsInWaitList,
346     const cl_event *eventWaitList,
347     cl_event *events)
348 {
349     CLBlasKargs kargs;
350 
351     memset(&kargs, 0, sizeof(kargs));
352     kargs.dtype = TYPE_COMPLEX_FLOAT;
353     kargs.alpha.argFloatComplex = alpha;
354     kargs.beta.argFloatComplex = beta;
355 
356     return doGemm(&kargs, order, transA, transB, M, N, K, A, 0, lda, B, 0, ldb,
357                   C, 0, ldc, numCommandQueues, commandQueues,
358                   numEventsInWaitList, eventWaitList, events);
359 }
360 
361 clblasStatus
362 clblasZgemmV2(
363     clblasOrder order,
364     clblasTranspose transA,
365     clblasTranspose transB,
366     size_t M,
367     size_t N,
368     size_t K,
369     DoubleComplex alpha,
370     const cl_mem A,
371     size_t lda,
372     const cl_mem B,
373     size_t ldb,
374     DoubleComplex beta,
375     cl_mem C,
376     size_t ldc,
377     cl_uint numCommandQueues,
378     cl_command_queue *commandQueues,
379     cl_uint numEventsInWaitList,
380     const cl_event *eventWaitList,
381     cl_event *events)
382 {
383     CLBlasKargs kargs;
384 
385     memset(&kargs, 0, sizeof(kargs));
386     kargs.dtype = TYPE_COMPLEX_DOUBLE;
387     kargs.alpha.argDoubleComplex = alpha;
388     kargs.beta.argDoubleComplex = beta;
389 
390     return doGemm(&kargs, order, transA, transB, M, N, K, A, 0, lda, B, 0, ldb,
391                   C, 0, ldc, numCommandQueues, commandQueues,
392                   numEventsInWaitList, eventWaitList, events);
393 }
394 
395 clblasStatus
396 clblasSgemmExV2(
397     clblasOrder order,
398     clblasTranspose transA,
399     clblasTranspose transB,
400     size_t M,
401     size_t N,
402     size_t K,
403     cl_float alpha,
404     const cl_mem A,
405 	size_t offA,
406     size_t lda,
407     const cl_mem B,
408 	size_t offB,
409     size_t ldb,
410     cl_float beta,
411     cl_mem C,
412 	size_t offC,
413     size_t ldc,
414     cl_uint numCommandQueues,
415     cl_command_queue *commandQueues,
416     cl_uint numEventsInWaitList,
417     const cl_event *eventWaitList,
418     cl_event *events)
419 {
420     CLBlasKargs kargs;
421 
422     memset(&kargs, 0, sizeof(kargs));
423     kargs.dtype = TYPE_FLOAT;
424     kargs.alpha.argFloat = alpha;
425     kargs.beta.argFloat = beta;
426 
427     return doGemm(&kargs, order, transA, transB, M, N, K, A, offA, lda, B, offB, ldb,
428                   C, offC, ldc, numCommandQueues, commandQueues,
429                   numEventsInWaitList, eventWaitList, events);
430 }
431 
432 clblasStatus
433 clblasDgemmExV2(
434     clblasOrder order,
435     clblasTranspose transA,
436     clblasTranspose transB,
437     size_t M,
438     size_t N,
439     size_t K,
440     cl_double alpha,
441     const cl_mem A,
442 	size_t offA,
443     size_t lda,
444     const cl_mem B,
445 	size_t offB,
446     size_t ldb,
447     cl_double beta,
448     cl_mem C,
449 	size_t offC,
450     size_t ldc,
451     cl_uint numCommandQueues,
452     cl_command_queue *commandQueues,
453     cl_uint numEventsInWaitList,
454     const cl_event *eventWaitList,
455     cl_event *events)
456 {
457     CLBlasKargs kargs;
458 
459     memset(&kargs, 0, sizeof(kargs));
460     kargs.dtype = TYPE_DOUBLE;
461     kargs.alpha.argDouble = alpha;
462     kargs.beta.argDouble = beta;
463 
464     return doGemm(&kargs, order, transA, transB, M, N, K, A, offA, lda, B, offB, ldb,
465                   C, offC, ldc, numCommandQueues, commandQueues,
466                   numEventsInWaitList, eventWaitList, events);
467 }
468 
469 clblasStatus
470 clblasCgemmExV2(
471     clblasOrder order,
472     clblasTranspose transA,
473     clblasTranspose transB,
474     size_t M,
475     size_t N,
476     size_t K,
477     FloatComplex alpha,
478     const cl_mem A,
479 	size_t offA,
480     size_t lda,
481     const cl_mem B,
482 	size_t offB,
483     size_t ldb,
484     FloatComplex beta,
485     cl_mem C,
486 	size_t offC,
487     size_t ldc,
488     cl_uint numCommandQueues,
489     cl_command_queue *commandQueues,
490     cl_uint numEventsInWaitList,
491     const cl_event *eventWaitList,
492     cl_event *events)
493 {
494     CLBlasKargs kargs;
495 
496     memset(&kargs, 0, sizeof(kargs));
497     kargs.dtype = TYPE_COMPLEX_FLOAT;
498     kargs.alpha.argFloatComplex = alpha;
499     kargs.beta.argFloatComplex = beta;
500 
501     return doGemm(&kargs, order, transA, transB, M, N, K, A, offA, lda, B, offB, ldb,
502                   C, offC, ldc, numCommandQueues, commandQueues,
503                   numEventsInWaitList, eventWaitList, events);
504 }
505 
506 clblasStatus
507 clblasZgemmExV2(
508     clblasOrder order,
509     clblasTranspose transA,
510     clblasTranspose transB,
511     size_t M,
512     size_t N,
513     size_t K,
514     DoubleComplex alpha,
515     const cl_mem A,
516 	size_t offA,
517     size_t lda,
518     const cl_mem B,
519 	size_t offB,
520     size_t ldb,
521     DoubleComplex beta,
522     cl_mem C,
523 	size_t offC,
524     size_t ldc,
525     cl_uint numCommandQueues,
526     cl_command_queue *commandQueues,
527     cl_uint numEventsInWaitList,
528     const cl_event *eventWaitList,
529     cl_event *events)
530 {
531     CLBlasKargs kargs;
532 
533     memset(&kargs, 0, sizeof(kargs));
534     kargs.dtype = TYPE_COMPLEX_DOUBLE;
535     kargs.alpha.argDoubleComplex = alpha;
536     kargs.beta.argDoubleComplex = beta;
537 
538     return doGemm(&kargs, order, transA, transB, M, N, K, A, offA, lda, B, offB, ldb,
539                   C, offC, ldc, numCommandQueues, commandQueues,
540                   numEventsInWaitList, eventWaitList, events);
541 }
542 */
543