1 /* ************************************************************************
2 * Copyright 2013 Advanced Micro Devices, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * ************************************************************************/
16
17 //#define DEBUG_TBMV
18
19 #include <stdio.h>
20 #include <stdlib.h>
21 #include <string.h>
22 #include <clBLAS.h>
23
24 #include <devinfo.h>
25 #include "clblas-internal.h"
26 #include "solution_seq.h"
27
28 clblasStatus
doTbmv(CLBlasKargs * kargs,clblasOrder order,clblasUplo uplo,clblasTranspose trans,clblasDiag diag,size_t N,size_t K,const cl_mem A,size_t offa,size_t lda,cl_mem x,size_t offx,int incx,cl_mem y,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)29 doTbmv(
30 CLBlasKargs *kargs,
31 clblasOrder order,
32 clblasUplo uplo,
33 clblasTranspose trans,
34 clblasDiag diag,
35 size_t N,
36 size_t K,
37 const cl_mem A,
38 size_t offa,
39 size_t lda,
40 cl_mem x,
41 size_t offx,
42 int incx,
43 cl_mem y, // Scratch Buffer
44 cl_uint numCommandQueues,
45 cl_command_queue *commandQueues,
46 cl_uint numEventsInWaitList,
47 const cl_event *eventWaitList,
48 cl_event *events)
49 {
50 cl_int err;
51 ListHead seq;
52 size_t sizeOfVector;
53 cl_event *newEventWaitList;
54 clblasStatus retCode = clblasSuccess;
55
56 if (!clblasInitialized) {
57 return clblasNotInitialized;
58 }
59
60 /* Validate arguments */
61
62 if ((retCode = checkMemObjects(A, x, y, true, A_MAT_ERRSET, X_VEC_ERRSET, Y_VEC_ERRSET))) {
63 #ifdef DEBUG_TBMV
64 printf("Invalid mem object..\n");
65 #endif
66 return retCode;
67 }
68
69 if ((retCode = checkBandedMatrixSizes(kargs->dtype, order, trans, N, N, K, 0, A, offa, lda, A_MAT_ERRSET))) {
70 #ifdef DEBUG_TBMV
71 printf("Invalid Size for A\n");
72 #endif
73 return retCode;
74 }
75 if ((retCode = checkVectorSizes(kargs->dtype, N, x, offx, incx, X_VEC_ERRSET))) {
76 #ifdef DEBUG_TBMV
77 printf("Invalid Size for X\n");
78 #endif
79 return retCode;
80 }
81 if ((retCode = checkVectorSizes(kargs->dtype, N, y, 0, incx, Y_VEC_ERRSET))) {
82 #ifdef DEBUG_TBMV
83 printf("Invalid Size for scratch vector\n");
84 #endif
85 return retCode;
86 }
87
88 #ifdef DEBUG_TBMV
89 printf("DoTbmv being called...\n");
90 #endif
91
92 if ((commandQueues == NULL) || (numCommandQueues == 0))
93 {
94 return clblasInvalidValue;
95 }
96 numCommandQueues = 1;
97
98 if ((numEventsInWaitList !=0) && (eventWaitList == NULL))
99 {
100 return clblasInvalidEventWaitList;
101 }
102
103 newEventWaitList = malloc((numEventsInWaitList+1) * sizeof(cl_event));
104 if (newEventWaitList == NULL)
105 {
106 return clblasOutOfHostMemory;
107 }
108 if (numEventsInWaitList != 0 )
109 {
110 memcpy(newEventWaitList, eventWaitList, numEventsInWaitList*sizeof(cl_event));
111 }
112
113 /*
114 * ASSUMPTION:
115 * doTBMV assumes "commandQueue" of 0. The same is reflected in
116 * "makeSolutionSeq" as well. If either of them changes in future,
117 * this code needs to be revisited.
118 */
119 sizeOfVector = (1 + (N-1)*abs(incx)) * dtypeSize(kargs->dtype);
120 err = clEnqueueCopyBuffer(commandQueues[0], x, y, offx*dtypeSize(kargs->dtype), 0, sizeOfVector,
121 numEventsInWaitList, eventWaitList, &newEventWaitList[numEventsInWaitList]);
122 if (err != CL_SUCCESS)
123 {
124 free(newEventWaitList);
125 return err;
126 }
127
128 kargs->order = order;
129 kargs->uplo = uplo;
130 kargs->transA = trans;
131 kargs->diag = diag;
132 kargs->M = N;
133 kargs->N = N;
134 if( uplo == clblasUpper )
135 {
136 kargs->KL = 0;
137 kargs->KU = K;
138 }
139 else {
140 kargs->KL = K;
141 kargs->KU = 0;
142 }
143 kargs->A = A;
144 kargs->lda.matrix = lda;
145 kargs->B = y; // Now it becomes x = A * y
146 kargs->ldb.vector = incx;
147 kargs->C = x;
148 kargs->ldc.vector = incx;
149 kargs->offBX = 0; // Not used by assignKargs(); Just for clarity
150 kargs->offCY = offx;
151 kargs->offa = offa;
152 kargs->offA = offa;
153
154 #ifdef DEBUG_TBMV
155 printf("Calling makeSolutionSeq : TBMV\n");
156 #endif
157
158 listInitHead(&seq);
159 err = makeSolutionSeq(CLBLAS_GBMV, kargs, numCommandQueues, commandQueues,
160 numEventsInWaitList+1, newEventWaitList, events, &seq);
161 if (err == CL_SUCCESS) {
162 err = executeSolutionSeq(&seq);
163 }
164
165 freeSolutionSeq(&seq);
166 free(newEventWaitList);
167 return (clblasStatus)err;
168 }
169
170 clblasStatus
clblasStbmv(clblasOrder order,clblasUplo uplo,clblasTranspose trans,clblasDiag diag,size_t N,size_t K,const cl_mem A,size_t offa,size_t lda,cl_mem X,size_t offx,int incx,cl_mem scratchBuff,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)171 clblasStbmv(
172 clblasOrder order,
173 clblasUplo uplo,
174 clblasTranspose trans,
175 clblasDiag diag,
176 size_t N,
177 size_t K,
178 const cl_mem A,
179 size_t offa,
180 size_t lda,
181 cl_mem X,
182 size_t offx,
183 int incx,
184 cl_mem scratchBuff,
185 cl_uint numCommandQueues,
186 cl_command_queue *commandQueues,
187 cl_uint numEventsInWaitList,
188 const cl_event *eventWaitList,
189 cl_event *events)
190 {
191 CLBlasKargs kargs;
192 #ifdef DEBUG_TBMV
193 printf("STBMV Called\n");
194 #endif
195
196 memset(&kargs, 0, sizeof(kargs));
197 kargs.dtype = TYPE_FLOAT;
198 kargs.pigFuncID = CLBLAS_TBMV;
199
200 return doTbmv(&kargs, order, uplo, trans, diag, N, K, A, offa, lda, X, offx, incx, scratchBuff, numCommandQueues, commandQueues,
201 numEventsInWaitList, eventWaitList, events);
202 }
203
204 clblasStatus
clblasDtbmv(clblasOrder order,clblasUplo uplo,clblasTranspose trans,clblasDiag diag,size_t N,size_t K,const cl_mem A,size_t offa,size_t lda,cl_mem X,size_t offx,int incx,cl_mem scratchBuff,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)205 clblasDtbmv(
206 clblasOrder order,
207 clblasUplo uplo,
208 clblasTranspose trans,
209 clblasDiag diag,
210 size_t N,
211 size_t K,
212 const cl_mem A,
213 size_t offa,
214 size_t lda,
215 cl_mem X,
216 size_t offx,
217 int incx,
218 cl_mem scratchBuff,
219 cl_uint numCommandQueues,
220 cl_command_queue *commandQueues,
221 cl_uint numEventsInWaitList,
222 const cl_event *eventWaitList,
223 cl_event *events)
224 {
225 CLBlasKargs kargs;
226 #ifdef DEBUG_TBMV
227 printf("DTBMV called\n");
228 #endif
229
230 memset(&kargs, 0, sizeof(kargs));
231 kargs.dtype = TYPE_DOUBLE;
232 kargs.pigFuncID = CLBLAS_TBMV;
233
234 return doTbmv(&kargs, order, uplo, trans, diag, N, K, A, offa, lda, X, offx, incx, scratchBuff, numCommandQueues, commandQueues,
235 numEventsInWaitList, eventWaitList, events);
236 }
237
238
239 clblasStatus
clblasCtbmv(clblasOrder order,clblasUplo uplo,clblasTranspose trans,clblasDiag diag,size_t N,size_t K,const cl_mem A,size_t offa,size_t lda,cl_mem X,size_t offx,int incx,cl_mem scratchBuff,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)240 clblasCtbmv(
241 clblasOrder order,
242 clblasUplo uplo,
243 clblasTranspose trans,
244 clblasDiag diag,
245 size_t N,
246 size_t K,
247 const cl_mem A,
248 size_t offa,
249 size_t lda,
250 cl_mem X,
251 size_t offx,
252 int incx,
253 cl_mem scratchBuff,
254 cl_uint numCommandQueues,
255 cl_command_queue *commandQueues,
256 cl_uint numEventsInWaitList,
257 const cl_event *eventWaitList,
258 cl_event *events)
259 {
260 CLBlasKargs kargs;
261 #ifdef DEBUG_TBMV
262 printf("CTBMV called\n");
263 #endif
264
265 memset(&kargs, 0, sizeof(kargs));
266 kargs.dtype = TYPE_COMPLEX_FLOAT;
267 kargs.pigFuncID = CLBLAS_TBMV;
268
269 return doTbmv(&kargs, order, uplo, trans, diag, N, K, A, offa, lda, X, offx, incx, scratchBuff, numCommandQueues, commandQueues,
270 numEventsInWaitList, eventWaitList, events);
271 }
272
273 clblasStatus
clblasZtbmv(clblasOrder order,clblasUplo uplo,clblasTranspose trans,clblasDiag diag,size_t N,size_t K,const cl_mem A,size_t offa,size_t lda,cl_mem X,size_t offx,int incx,cl_mem scratchBuff,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)274 clblasZtbmv(
275 clblasOrder order,
276 clblasUplo uplo,
277 clblasTranspose trans,
278 clblasDiag diag,
279 size_t N,
280 size_t K,
281 const cl_mem A,
282 size_t offa,
283 size_t lda,
284 cl_mem X,
285 size_t offx,
286 int incx,
287 cl_mem scratchBuff,
288 cl_uint numCommandQueues,
289 cl_command_queue *commandQueues,
290 cl_uint numEventsInWaitList,
291 const cl_event *eventWaitList,
292 cl_event *events)
293 {
294 CLBlasKargs kargs;
295 #ifdef DEBUG_TBMV
296 printf("ZTBMV called\n");
297 #endif
298
299 memset(&kargs, 0, sizeof(kargs));
300 kargs.dtype = TYPE_COMPLEX_DOUBLE;
301 kargs.pigFuncID = CLBLAS_TBMV;
302
303 return doTbmv(&kargs, order, uplo, trans, diag, N, K, A, offa, lda, X, offx, incx, scratchBuff, numCommandQueues, commandQueues,
304 numEventsInWaitList, eventWaitList, events);
305 }
306