1 /* ************************************************************************
2  * Copyright 2013 Advanced Micro Devices, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  * ************************************************************************/
16 
17 
18 #include <string.h>
19 #include <clBLAS.h>
20 #include <devinfo.h>
21 
22 #include "clblas-internal.h"
23 #include "solution_seq.h"
24 
25 static clblasStatus
doGbmv(CLBlasKargs * kargs,clblasOrder order,clblasTranspose transA,size_t M,size_t N,size_t KL,size_t KU,const cl_mem A,size_t offa,size_t lda,const cl_mem x,size_t offx,int incx,cl_mem y,size_t offy,int incy,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)26 doGbmv(
27     CLBlasKargs *kargs,
28     clblasOrder order,
29     clblasTranspose transA,
30     size_t M,
31     size_t N,
32     size_t KL,
33     size_t KU,
34     const cl_mem A,
35     size_t offa,
36     size_t lda,
37     const cl_mem x,
38     size_t offx,
39     int incx,
40     cl_mem y,
41     size_t offy,
42     int incy,
43     cl_uint numCommandQueues,
44     cl_command_queue *commandQueues,
45     cl_uint numEventsInWaitList,
46     const cl_event *eventWaitList,
47     cl_event *events)
48 {
49     cl_int err;
50     ListHead seq;
51     size_t sizev;
52     clblasStatus retCode = clblasSuccess;
53 
54     if (!clblasInitialized) {
55         return clblasNotInitialized;
56     }
57     if ((commandQueues == NULL) || (numCommandQueues == 0))
58     {
59         return clblasInvalidValue;
60     }
61 
62     if (commandQueues[0] == NULL)
63     {
64         return clblasInvalidCommandQueue;
65     }
66 
67     if ((numEventsInWaitList !=0) && (eventWaitList == NULL))
68     {
69         return clblasInvalidEventWaitList;
70     }
71     /* Validate arguments */
72 
73     if ((retCode = checkMemObjects(A, x, y, true, A_MAT_ERRSET, X_VEC_ERRSET, Y_VEC_ERRSET )))
74     {
75         return retCode;
76     }
77     if ((retCode = checkBandedMatrixSizes(kargs->dtype, order, clblasNoTrans,
78                                             M, N, KL, KU, A, offa, lda, A_MAT_ERRSET ))) {
79         return retCode;
80     }
81     sizev = (transA == clblasNoTrans) ? N : M;
82     if ((retCode = checkVectorSizes(kargs->dtype, sizev, x, offx, incx, X_VEC_ERRSET ))) {
83         return retCode;
84     }
85     sizev = (transA == clblasNoTrans) ? M : N;
86     if ((retCode = checkVectorSizes(kargs->dtype, sizev, y, offy, incy, Y_VEC_ERRSET ))) {
87         return retCode;
88     }
89 
90     /* numCommandQueues will be hardcoded to 1 as of now. No multi-gpu support */
91     numCommandQueues = 1;
92 
93     kargs->order = order;
94     kargs->transA = transA;
95     kargs->M = M;
96     kargs->N = N;
97     kargs->KL = KL;
98     kargs->KU = KU;
99     kargs->A = A;
100     kargs->offA = offa;
101     kargs->offa = offa;
102     kargs->lda.matrix = lda;
103     kargs->B = x;
104     kargs->offBX = offx;
105     kargs->ldb.vector = incx;
106     kargs->C = y;
107     kargs->offCY = offy;
108     kargs->ldc.vector = incy;
109 
110     listInitHead(&seq);
111     err = makeSolutionSeq(CLBLAS_GBMV, kargs, numCommandQueues, commandQueues,
112                               numEventsInWaitList, eventWaitList, events, &seq);
113 
114     if (err == CL_SUCCESS) {
115         err = executeSolutionSeq(&seq);
116     }
117 
118     freeSolutionSeq(&seq);
119 
120     return (clblasStatus)err;
121 }
122 
123 clblasStatus
clblasSgbmv(clblasOrder order,clblasTranspose transA,size_t M,size_t N,size_t KL,size_t KU,cl_float alpha,const cl_mem A,size_t offa,size_t lda,const cl_mem x,size_t offx,int incx,cl_float beta,cl_mem y,size_t offy,int incy,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)124 clblasSgbmv(
125     clblasOrder order,
126     clblasTranspose transA,
127     size_t M,
128     size_t N,
129     size_t KL,
130     size_t KU,
131     cl_float alpha,
132     const cl_mem A,
133     size_t offa,
134     size_t lda,
135     const cl_mem x,
136     size_t offx,
137     int incx,
138     cl_float beta,
139     cl_mem y,
140     size_t offy,
141     int incy,
142     cl_uint numCommandQueues,
143     cl_command_queue *commandQueues,
144     cl_uint numEventsInWaitList,
145     const cl_event *eventWaitList,
146     cl_event *events)
147 {
148     CLBlasKargs kargs;
149 
150     memset(&kargs, 0, sizeof(kargs));
151     kargs.dtype = TYPE_FLOAT;
152     kargs.pigFuncID = CLBLAS_GBMV;
153     kargs.alpha.argFloat = alpha;
154     kargs.beta.argFloat = beta;
155 
156     return doGbmv(&kargs, order, transA, M, N, KL, KU, A, offa, lda, x, offx, incx,
157                   y, offy, incy, numCommandQueues, commandQueues,
158                   numEventsInWaitList, eventWaitList, events);
159 }
160 
161 clblasStatus
clblasDgbmv(clblasOrder order,clblasTranspose transA,size_t M,size_t N,size_t KL,size_t KU,cl_double alpha,const cl_mem A,size_t offa,size_t lda,const cl_mem x,size_t offx,int incx,cl_double beta,cl_mem y,size_t offy,int incy,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)162 clblasDgbmv(
163     clblasOrder order,
164     clblasTranspose transA,
165     size_t M,
166     size_t N,
167     size_t KL,
168     size_t KU,
169     cl_double alpha,
170     const cl_mem A,
171     size_t offa,
172     size_t lda,
173     const cl_mem x,
174     size_t offx,
175     int incx,
176     cl_double beta,
177     cl_mem y,
178     size_t offy,
179     int incy,
180     cl_uint numCommandQueues,
181     cl_command_queue *commandQueues,
182     cl_uint numEventsInWaitList,
183     const cl_event *eventWaitList,
184     cl_event *events)
185 {
186     CLBlasKargs kargs;
187 
188     memset(&kargs, 0, sizeof(kargs));
189     kargs.dtype = TYPE_DOUBLE;
190     kargs.pigFuncID = CLBLAS_GBMV;
191     kargs.alpha.argDouble = alpha;
192     kargs.beta.argDouble = beta;
193 
194     return doGbmv(&kargs, order, transA, M, N, KL, KU, A, offa, lda, x, offx, incx,
195                   y, offy, incy, numCommandQueues, commandQueues,
196                   numEventsInWaitList, eventWaitList, events);
197 }
198 
199 clblasStatus
clblasCgbmv(clblasOrder order,clblasTranspose transA,size_t M,size_t N,size_t KL,size_t KU,cl_float2 alpha,const cl_mem A,size_t offa,size_t lda,const cl_mem x,size_t offx,int incx,cl_float2 beta,cl_mem y,size_t offy,int incy,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)200 clblasCgbmv(
201     clblasOrder order,
202     clblasTranspose transA,
203     size_t M,
204     size_t N,
205     size_t KL,
206     size_t KU,
207     cl_float2 alpha,
208     const cl_mem A,
209     size_t offa,
210     size_t lda,
211     const cl_mem x,
212     size_t offx,
213     int incx,
214     cl_float2 beta,
215     cl_mem y,
216     size_t offy,
217     int incy,
218     cl_uint numCommandQueues,
219     cl_command_queue *commandQueues,
220     cl_uint numEventsInWaitList,
221     const cl_event *eventWaitList,
222     cl_event *events)
223 {
224     CLBlasKargs kargs;
225 
226     memset(&kargs, 0, sizeof(kargs));
227     kargs.dtype = TYPE_COMPLEX_FLOAT;
228     kargs.pigFuncID = CLBLAS_GBMV;
229     kargs.alpha.argFloatComplex = alpha;
230     kargs.beta.argFloatComplex = beta;
231 
232     return doGbmv(&kargs, order, transA, M, N, KL, KU, A, offa, lda, x, offx, incx,
233                   y, offy, incy, numCommandQueues, commandQueues,
234                   numEventsInWaitList, eventWaitList, events);
235 }
236 
237 clblasStatus
clblasZgbmv(clblasOrder order,clblasTranspose transA,size_t M,size_t N,size_t KL,size_t KU,cl_double2 alpha,const cl_mem A,size_t offa,size_t lda,const cl_mem x,size_t offx,int incx,cl_double2 beta,cl_mem y,size_t offy,int incy,cl_uint numCommandQueues,cl_command_queue * commandQueues,cl_uint numEventsInWaitList,const cl_event * eventWaitList,cl_event * events)238 clblasZgbmv(
239     clblasOrder order,
240     clblasTranspose transA,
241     size_t M,
242     size_t N,
243     size_t KL,
244     size_t KU,
245     cl_double2 alpha,
246     const cl_mem A,
247     size_t offa,
248     size_t lda,
249     const cl_mem x,
250     size_t offx,
251     int incx,
252     cl_double2 beta,
253     cl_mem y,
254     size_t offy,
255     int incy,
256     cl_uint numCommandQueues,
257     cl_command_queue *commandQueues,
258     cl_uint numEventsInWaitList,
259     const cl_event *eventWaitList,
260     cl_event *events)
261 {
262     CLBlasKargs kargs;
263 
264     memset(&kargs, 0, sizeof(kargs));
265     kargs.dtype = TYPE_COMPLEX_DOUBLE;
266     kargs.pigFuncID = CLBLAS_GBMV;
267     kargs.alpha.argDoubleComplex = alpha;
268     kargs.beta.argDoubleComplex = beta;
269 
270     return doGbmv(&kargs, order, transA, M, N, KL, KU, A, offa, lda, x, offx, incx,
271                   y, offy, incy, numCommandQueues, commandQueues,
272                   numEventsInWaitList, eventWaitList, events);
273 }
274 
275 
276 
277