1 /* ************************************************************************
2 * Copyright 2013 Advanced Micro Devices, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * ************************************************************************/
16
17
18 #include <string.h>
19 #include <clBLAS.h>
20 #include <devinfo.h>
21
22 #include "clblas-internal.h"
23 #include "solution_seq.h"
24
25 static clblasStatus
doGbmv(CLBlasKargs * kargs,clblasOrder order,clblasTranspose transA,size_t M,size_t N,size_t KL,size_t KU,const cl_mem A,size_t offa,size_t lda,const cl_mem x,size_t offx,int incx,cl_mem y,size_t offy,int incy,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)26 doGbmv(
27 CLBlasKargs *kargs,
28 clblasOrder order,
29 clblasTranspose transA,
30 size_t M,
31 size_t N,
32 size_t KL,
33 size_t KU,
34 const cl_mem A,
35 size_t offa,
36 size_t lda,
37 const cl_mem x,
38 size_t offx,
39 int incx,
40 cl_mem y,
41 size_t offy,
42 int incy,
43 cl_uint numCommandQueues,
44 cl_command_queue *commandQueues,
45 cl_uint numEventsInWaitList,
46 const cl_event *eventWaitList,
47 cl_event *events)
48 {
49 cl_int err;
50 ListHead seq;
51 size_t sizev;
52 clblasStatus retCode = clblasSuccess;
53
54 if (!clblasInitialized) {
55 return clblasNotInitialized;
56 }
57 if ((commandQueues == NULL) || (numCommandQueues == 0))
58 {
59 return clblasInvalidValue;
60 }
61
62 if (commandQueues[0] == NULL)
63 {
64 return clblasInvalidCommandQueue;
65 }
66
67 if ((numEventsInWaitList !=0) && (eventWaitList == NULL))
68 {
69 return clblasInvalidEventWaitList;
70 }
71 /* Validate arguments */
72
73 if ((retCode = checkMemObjects(A, x, y, true, A_MAT_ERRSET, X_VEC_ERRSET, Y_VEC_ERRSET )))
74 {
75 return retCode;
76 }
77 if ((retCode = checkBandedMatrixSizes(kargs->dtype, order, clblasNoTrans,
78 M, N, KL, KU, A, offa, lda, A_MAT_ERRSET ))) {
79 return retCode;
80 }
81 sizev = (transA == clblasNoTrans) ? N : M;
82 if ((retCode = checkVectorSizes(kargs->dtype, sizev, x, offx, incx, X_VEC_ERRSET ))) {
83 return retCode;
84 }
85 sizev = (transA == clblasNoTrans) ? M : N;
86 if ((retCode = checkVectorSizes(kargs->dtype, sizev, y, offy, incy, Y_VEC_ERRSET ))) {
87 return retCode;
88 }
89
90 /* numCommandQueues will be hardcoded to 1 as of now. No multi-gpu support */
91 numCommandQueues = 1;
92
93 kargs->order = order;
94 kargs->transA = transA;
95 kargs->M = M;
96 kargs->N = N;
97 kargs->KL = KL;
98 kargs->KU = KU;
99 kargs->A = A;
100 kargs->offA = offa;
101 kargs->offa = offa;
102 kargs->lda.matrix = lda;
103 kargs->B = x;
104 kargs->offBX = offx;
105 kargs->ldb.vector = incx;
106 kargs->C = y;
107 kargs->offCY = offy;
108 kargs->ldc.vector = incy;
109
110 listInitHead(&seq);
111 err = makeSolutionSeq(CLBLAS_GBMV, kargs, numCommandQueues, commandQueues,
112 numEventsInWaitList, eventWaitList, events, &seq);
113
114 if (err == CL_SUCCESS) {
115 err = executeSolutionSeq(&seq);
116 }
117
118 freeSolutionSeq(&seq);
119
120 return (clblasStatus)err;
121 }
122
123 clblasStatus
clblasSgbmv(clblasOrder order,clblasTranspose transA,size_t M,size_t N,size_t KL,size_t KU,cl_float alpha,const cl_mem A,size_t offa,size_t lda,const cl_mem x,size_t offx,int incx,cl_float beta,cl_mem y,size_t offy,int incy,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)124 clblasSgbmv(
125 clblasOrder order,
126 clblasTranspose transA,
127 size_t M,
128 size_t N,
129 size_t KL,
130 size_t KU,
131 cl_float alpha,
132 const cl_mem A,
133 size_t offa,
134 size_t lda,
135 const cl_mem x,
136 size_t offx,
137 int incx,
138 cl_float beta,
139 cl_mem y,
140 size_t offy,
141 int incy,
142 cl_uint numCommandQueues,
143 cl_command_queue *commandQueues,
144 cl_uint numEventsInWaitList,
145 const cl_event *eventWaitList,
146 cl_event *events)
147 {
148 CLBlasKargs kargs;
149
150 memset(&kargs, 0, sizeof(kargs));
151 kargs.dtype = TYPE_FLOAT;
152 kargs.pigFuncID = CLBLAS_GBMV;
153 kargs.alpha.argFloat = alpha;
154 kargs.beta.argFloat = beta;
155
156 return doGbmv(&kargs, order, transA, M, N, KL, KU, A, offa, lda, x, offx, incx,
157 y, offy, incy, numCommandQueues, commandQueues,
158 numEventsInWaitList, eventWaitList, events);
159 }
160
161 clblasStatus
clblasDgbmv(clblasOrder order,clblasTranspose transA,size_t M,size_t N,size_t KL,size_t KU,cl_double alpha,const cl_mem A,size_t offa,size_t lda,const cl_mem x,size_t offx,int incx,cl_double beta,cl_mem y,size_t offy,int incy,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)162 clblasDgbmv(
163 clblasOrder order,
164 clblasTranspose transA,
165 size_t M,
166 size_t N,
167 size_t KL,
168 size_t KU,
169 cl_double alpha,
170 const cl_mem A,
171 size_t offa,
172 size_t lda,
173 const cl_mem x,
174 size_t offx,
175 int incx,
176 cl_double beta,
177 cl_mem y,
178 size_t offy,
179 int incy,
180 cl_uint numCommandQueues,
181 cl_command_queue *commandQueues,
182 cl_uint numEventsInWaitList,
183 const cl_event *eventWaitList,
184 cl_event *events)
185 {
186 CLBlasKargs kargs;
187
188 memset(&kargs, 0, sizeof(kargs));
189 kargs.dtype = TYPE_DOUBLE;
190 kargs.pigFuncID = CLBLAS_GBMV;
191 kargs.alpha.argDouble = alpha;
192 kargs.beta.argDouble = beta;
193
194 return doGbmv(&kargs, order, transA, M, N, KL, KU, A, offa, lda, x, offx, incx,
195 y, offy, incy, numCommandQueues, commandQueues,
196 numEventsInWaitList, eventWaitList, events);
197 }
198
199 clblasStatus
clblasCgbmv(clblasOrder order,clblasTranspose transA,size_t M,size_t N,size_t KL,size_t KU,cl_float2 alpha,const cl_mem A,size_t offa,size_t lda,const cl_mem x,size_t offx,int incx,cl_float2 beta,cl_mem y,size_t offy,int incy,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)200 clblasCgbmv(
201 clblasOrder order,
202 clblasTranspose transA,
203 size_t M,
204 size_t N,
205 size_t KL,
206 size_t KU,
207 cl_float2 alpha,
208 const cl_mem A,
209 size_t offa,
210 size_t lda,
211 const cl_mem x,
212 size_t offx,
213 int incx,
214 cl_float2 beta,
215 cl_mem y,
216 size_t offy,
217 int incy,
218 cl_uint numCommandQueues,
219 cl_command_queue *commandQueues,
220 cl_uint numEventsInWaitList,
221 const cl_event *eventWaitList,
222 cl_event *events)
223 {
224 CLBlasKargs kargs;
225
226 memset(&kargs, 0, sizeof(kargs));
227 kargs.dtype = TYPE_COMPLEX_FLOAT;
228 kargs.pigFuncID = CLBLAS_GBMV;
229 kargs.alpha.argFloatComplex = alpha;
230 kargs.beta.argFloatComplex = beta;
231
232 return doGbmv(&kargs, order, transA, M, N, KL, KU, A, offa, lda, x, offx, incx,
233 y, offy, incy, numCommandQueues, commandQueues,
234 numEventsInWaitList, eventWaitList, events);
235 }
236
237 clblasStatus
clblasZgbmv(clblasOrder order,clblasTranspose transA,size_t M,size_t N,size_t KL,size_t KU,cl_double2 alpha,const cl_mem A,size_t offa,size_t lda,const cl_mem x,size_t offx,int incx,cl_double2 beta,cl_mem y,size_t offy,int incy,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)238 clblasZgbmv(
239 clblasOrder order,
240 clblasTranspose transA,
241 size_t M,
242 size_t N,
243 size_t KL,
244 size_t KU,
245 cl_double2 alpha,
246 const cl_mem A,
247 size_t offa,
248 size_t lda,
249 const cl_mem x,
250 size_t offx,
251 int incx,
252 cl_double2 beta,
253 cl_mem y,
254 size_t offy,
255 int incy,
256 cl_uint numCommandQueues,
257 cl_command_queue *commandQueues,
258 cl_uint numEventsInWaitList,
259 const cl_event *eventWaitList,
260 cl_event *events)
261 {
262 CLBlasKargs kargs;
263
264 memset(&kargs, 0, sizeof(kargs));
265 kargs.dtype = TYPE_COMPLEX_DOUBLE;
266 kargs.pigFuncID = CLBLAS_GBMV;
267 kargs.alpha.argDoubleComplex = alpha;
268 kargs.beta.argDoubleComplex = beta;
269
270 return doGbmv(&kargs, order, transA, M, N, KL, KU, A, offa, lda, x, offx, incx,
271 y, offy, incy, numCommandQueues, commandQueues,
272 numEventsInWaitList, eventWaitList, events);
273 }
274
275
276
277