1 /* ************************************************************************
2  * Copyright 2013 Advanced Micro Devices, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  * ************************************************************************/
16 
17 //#define DEBUG_TBMV
18 
19 #include <stdio.h>
20 #include <stdlib.h>
21 #include <string.h>
22 #include <clBLAS.h>
23 
24 #include <devinfo.h>
25 #include "clblas-internal.h"
26 #include "solution_seq.h"
27 
28 clblasStatus
doTbmv(CLBlasKargs * kargs,clblasOrder order,clblasUplo uplo,clblasTranspose trans,clblasDiag diag,size_t N,size_t K,const cl_mem A,size_t offa,size_t lda,cl_mem x,size_t offx,int incx,cl_mem y,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)29 doTbmv(
30 	CLBlasKargs *kargs,
31     clblasOrder order,
32     clblasUplo uplo,
33     clblasTranspose trans,
34     clblasDiag diag,
35     size_t N,
36     size_t K,
37     const cl_mem A,
38     size_t offa,
39     size_t lda,
40     cl_mem x,
41     size_t offx,
42     int incx,
43 	cl_mem y, // Scratch Buffer
44     cl_uint numCommandQueues,
45     cl_command_queue *commandQueues,
46     cl_uint numEventsInWaitList,
47     const cl_event *eventWaitList,
48     cl_event *events)
49 {
50     cl_int err;
51     ListHead seq;
52 	size_t sizeOfVector;
53 	cl_event *newEventWaitList;
54     clblasStatus retCode = clblasSuccess;
55 
56     if (!clblasInitialized) {
57         return clblasNotInitialized;
58     }
59 
60     /* Validate arguments */
61 
62     if ((retCode = checkMemObjects(A, x, y, true, A_MAT_ERRSET, X_VEC_ERRSET, Y_VEC_ERRSET))) {
63 	    #ifdef DEBUG_TBMV
64         printf("Invalid mem object..\n");
65         #endif
66         return retCode;
67     }
68 
69     if ((retCode = checkBandedMatrixSizes(kargs->dtype, order, trans, N, N, K, 0, A, offa, lda, A_MAT_ERRSET))) {
70 		#ifdef DEBUG_TBMV
71         printf("Invalid Size for A\n");
72         #endif
73         return retCode;
74     }
75     if ((retCode = checkVectorSizes(kargs->dtype, N, x, offx, incx, X_VEC_ERRSET))) {
76 		#ifdef DEBUG_TBMV
77         printf("Invalid Size for X\n");
78         #endif
79         return retCode;
80     }
81     if ((retCode = checkVectorSizes(kargs->dtype, N, y, 0, incx, Y_VEC_ERRSET))) {
82 		#ifdef DEBUG_TBMV
83         printf("Invalid Size for scratch vector\n");
84         #endif
85         return retCode;
86     }
87 
88 	#ifdef DEBUG_TBMV
89 	printf("DoTbmv being called...\n");
90 	#endif
91 
92 	if ((commandQueues == NULL) || (numCommandQueues == 0))
93 	{
94 		return clblasInvalidValue;
95 	}
96     numCommandQueues = 1;
97 
98 	if ((numEventsInWaitList !=0) && (eventWaitList == NULL))
99 	{
100 		return clblasInvalidEventWaitList;
101 	}
102 
103 	newEventWaitList = malloc((numEventsInWaitList+1) * sizeof(cl_event));
104 	if (newEventWaitList == NULL)
105 	{
106 		return clblasOutOfHostMemory;
107 	}
108 	if (numEventsInWaitList != 0 )
109 	{
110 		memcpy(newEventWaitList, eventWaitList, numEventsInWaitList*sizeof(cl_event));
111 	}
112 
113 	/*
114  	 * ASSUMPTION:
115  	 * doTBMV assumes "commandQueue" of 0. The same is reflected in
116 	 * "makeSolutionSeq" as well. If either of them changes in future,
117 	 * this code needs to be revisited.
118   	 */
119 	sizeOfVector = (1 + (N-1)*abs(incx)) * dtypeSize(kargs->dtype);
120 	err = clEnqueueCopyBuffer(commandQueues[0], x, y, offx*dtypeSize(kargs->dtype), 0, sizeOfVector,
121 							  numEventsInWaitList, eventWaitList, &newEventWaitList[numEventsInWaitList]);
122 	if (err != CL_SUCCESS)
123 	{
124 		free(newEventWaitList);
125 		return err;
126 	}
127 
128     kargs->order = order;
129     kargs->uplo = uplo;
130     kargs->transA = trans;
131 	kargs->diag = diag;
132 	kargs->M = N;
133     kargs->N = N;
134     if( uplo == clblasUpper )
135     {
136         kargs->KL = 0;
137         kargs->KU = K;
138     }
139     else    {
140         kargs->KL = K;
141         kargs->KU = 0;
142     }
143     kargs->A = A;
144     kargs->lda.matrix = lda;
145     kargs->B = y;       // Now it becomes x = A * y
146     kargs->ldb.vector = incx;
147     kargs->C = x;
148     kargs->ldc.vector = incx;
149     kargs->offBX = 0;           // Not used by assignKargs(); Just for clarity
150     kargs->offCY = offx;
151 	kargs->offa = offa;
152 	kargs->offA = offa;
153 
154 	#ifdef DEBUG_TBMV
155 	printf("Calling makeSolutionSeq : TBMV\n");
156 	#endif
157 
158     listInitHead(&seq);
159     err = makeSolutionSeq(CLBLAS_GBMV, kargs, numCommandQueues, commandQueues,
160         				  numEventsInWaitList+1, newEventWaitList, events, &seq);
161     if (err == CL_SUCCESS) {
162        	err = executeSolutionSeq(&seq);
163     }
164 
165     freeSolutionSeq(&seq);
166 	free(newEventWaitList);
167     return (clblasStatus)err;
168 }
169 
170 clblasStatus
clblasStbmv(clblasOrder order,clblasUplo uplo,clblasTranspose trans,clblasDiag diag,size_t N,size_t K,const cl_mem A,size_t offa,size_t lda,cl_mem X,size_t offx,int incx,cl_mem scratchBuff,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)171 clblasStbmv(
172     clblasOrder order,
173     clblasUplo uplo,
174     clblasTranspose trans,
175     clblasDiag diag,
176     size_t N,
177     size_t K,
178     const cl_mem A,
179     size_t offa,
180     size_t lda,
181     cl_mem X,
182     size_t offx,
183     int incx,
184 	cl_mem scratchBuff,
185     cl_uint numCommandQueues,
186     cl_command_queue *commandQueues,
187     cl_uint numEventsInWaitList,
188     const cl_event *eventWaitList,
189     cl_event *events)
190 {
191     CLBlasKargs kargs;
192 	#ifdef DEBUG_TBMV
193 	printf("STBMV Called\n");
194 	#endif
195 
196     memset(&kargs, 0, sizeof(kargs));
197     kargs.dtype = TYPE_FLOAT;
198     kargs.pigFuncID = CLBLAS_TBMV;
199 
200     return doTbmv(&kargs, order, uplo, trans, diag, N, K, A, offa, lda, X, offx, incx, scratchBuff, numCommandQueues, commandQueues,
201                    numEventsInWaitList, eventWaitList, events);
202 }
203 
204 clblasStatus
clblasDtbmv(clblasOrder order,clblasUplo uplo,clblasTranspose trans,clblasDiag diag,size_t N,size_t K,const cl_mem A,size_t offa,size_t lda,cl_mem X,size_t offx,int incx,cl_mem scratchBuff,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)205 clblasDtbmv(
206     clblasOrder order,
207     clblasUplo uplo,
208     clblasTranspose trans,
209     clblasDiag diag,
210     size_t N,
211     size_t K,
212     const cl_mem A,
213     size_t offa,
214     size_t lda,
215     cl_mem X,
216     size_t offx,
217     int incx,
218 	cl_mem scratchBuff,
219     cl_uint numCommandQueues,
220     cl_command_queue *commandQueues,
221     cl_uint numEventsInWaitList,
222     const cl_event *eventWaitList,
223     cl_event *events)
224 {
225     CLBlasKargs kargs;
226 	#ifdef DEBUG_TBMV
227 	printf("DTBMV called\n");
228 	#endif
229 
230     memset(&kargs, 0, sizeof(kargs));
231     kargs.dtype = TYPE_DOUBLE;
232     kargs.pigFuncID = CLBLAS_TBMV;
233 
234     return doTbmv(&kargs, order, uplo, trans, diag, N, K, A, offa, lda, X, offx, incx, scratchBuff, numCommandQueues, commandQueues,
235                    numEventsInWaitList, eventWaitList, events);
236 }
237 
238 
239 clblasStatus
clblasCtbmv(clblasOrder order,clblasUplo uplo,clblasTranspose trans,clblasDiag diag,size_t N,size_t K,const cl_mem A,size_t offa,size_t lda,cl_mem X,size_t offx,int incx,cl_mem scratchBuff,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)240 clblasCtbmv(
241     clblasOrder order,
242     clblasUplo uplo,
243     clblasTranspose trans,
244     clblasDiag diag,
245     size_t N,
246     size_t K,
247     const cl_mem A,
248     size_t offa,
249     size_t lda,
250     cl_mem X,
251     size_t offx,
252     int incx,
253 	cl_mem scratchBuff,
254     cl_uint numCommandQueues,
255     cl_command_queue *commandQueues,
256     cl_uint numEventsInWaitList,
257     const cl_event *eventWaitList,
258     cl_event *events)
259 {
260     CLBlasKargs kargs;
261 	#ifdef DEBUG_TBMV
262 	printf("CTBMV called\n");
263 	#endif
264 
265     memset(&kargs, 0, sizeof(kargs));
266     kargs.dtype = TYPE_COMPLEX_FLOAT;
267     kargs.pigFuncID = CLBLAS_TBMV;
268 
269     return doTbmv(&kargs, order, uplo, trans, diag, N, K, A, offa, lda, X, offx, incx, scratchBuff, numCommandQueues, commandQueues,
270                    numEventsInWaitList, eventWaitList, events);
271 }
272 
273 clblasStatus
clblasZtbmv(clblasOrder order,clblasUplo uplo,clblasTranspose trans,clblasDiag diag,size_t N,size_t K,const cl_mem A,size_t offa,size_t lda,cl_mem X,size_t offx,int incx,cl_mem scratchBuff,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)274 clblasZtbmv(
275     clblasOrder order,
276     clblasUplo uplo,
277     clblasTranspose trans,
278     clblasDiag diag,
279     size_t N,
280     size_t K,
281     const cl_mem A,
282     size_t offa,
283     size_t lda,
284     cl_mem X,
285     size_t offx,
286     int incx,
287 	cl_mem scratchBuff,
288     cl_uint numCommandQueues,
289     cl_command_queue *commandQueues,
290     cl_uint numEventsInWaitList,
291     const cl_event *eventWaitList,
292     cl_event *events)
293 {
294     CLBlasKargs kargs;
295 	#ifdef DEBUG_TBMV
296 	printf("ZTBMV called\n");
297 	#endif
298 
299     memset(&kargs, 0, sizeof(kargs));
300     kargs.dtype = TYPE_COMPLEX_DOUBLE;
301     kargs.pigFuncID = CLBLAS_TBMV;
302 
303     return doTbmv(&kargs, order, uplo, trans, diag, N, K, A, offa, lda, X, offx, incx, scratchBuff, numCommandQueues, commandQueues,
304                    numEventsInWaitList, eventWaitList, events);
305 }
306