1 
2 // =================================================================================================
3 // This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
4 // project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
5 // width of 100 characters per line.
6 //
7 // Author(s):
8 //   Cedric Nugteren <www.cedricnugteren.nl>
9 //
10 // This file contains the interface to the CLBlast BLAS routines. It also contains the definitions
11 // of the returned status codes and the layout and transpose types. This is the only header users
12 // of CLBlast should include and use.
13 //
14 // =================================================================================================
15 
16 #ifndef CLBLAST_CLBLAST_H_
17 #define CLBLAST_CLBLAST_H_
18 
19 #include <cstdlib> // For size_t
20 #include <string> // For OverrideParameters function
21 #include <unordered_map> // For OverrideParameters function
22 
23 // Includes the normal OpenCL C header
24 #if defined(__APPLE__) || defined(__MACOSX)
25   #include <OpenCL/opencl.h>
26 #else
27   #include <CL/opencl.h>
28 #endif
29 
30 // Exports library functions under Windows when building a DLL. See also:
31 // https://msdn.microsoft.com/en-us/library/a90k134d.aspx
32 #if defined(_WIN32) && defined(CLBLAST_DLL)
33   #if defined(COMPILING_DLL)
34     #define PUBLIC_API __declspec(dllexport)
35   #else
36     #define PUBLIC_API __declspec(dllimport)
37   #endif
38 #else
39   #define PUBLIC_API
40 #endif
41 
42 namespace clblast {
43 // =================================================================================================
44 
45 // Status codes. These codes can be returned by functions declared in this header file. The error
46 // codes match either the standard OpenCL error codes or the clBLAS error codes.
47 enum class StatusCode {
48 
49   // Status codes in common with the OpenCL standard
50   kSuccess                   =   0, // CL_SUCCESS
51   kOpenCLCompilerNotAvailable=  -3, // CL_COMPILER_NOT_AVAILABLE
52   kTempBufferAllocFailure    =  -4, // CL_MEM_OBJECT_ALLOCATION_FAILURE
53   kOpenCLOutOfResources      =  -5, // CL_OUT_OF_RESOURCES
54   kOpenCLOutOfHostMemory     =  -6, // CL_OUT_OF_HOST_MEMORY
55   kOpenCLBuildProgramFailure = -11, // CL_BUILD_PROGRAM_FAILURE: OpenCL compilation error
56   kInvalidValue              = -30, // CL_INVALID_VALUE
57   kInvalidCommandQueue       = -36, // CL_INVALID_COMMAND_QUEUE
58   kInvalidMemObject          = -38, // CL_INVALID_MEM_OBJECT
59   kInvalidBinary             = -42, // CL_INVALID_BINARY
60   kInvalidBuildOptions       = -43, // CL_INVALID_BUILD_OPTIONS
61   kInvalidProgram            = -44, // CL_INVALID_PROGRAM
62   kInvalidProgramExecutable  = -45, // CL_INVALID_PROGRAM_EXECUTABLE
63   kInvalidKernelName         = -46, // CL_INVALID_KERNEL_NAME
64   kInvalidKernelDefinition   = -47, // CL_INVALID_KERNEL_DEFINITION
65   kInvalidKernel             = -48, // CL_INVALID_KERNEL
66   kInvalidArgIndex           = -49, // CL_INVALID_ARG_INDEX
67   kInvalidArgValue           = -50, // CL_INVALID_ARG_VALUE
68   kInvalidArgSize            = -51, // CL_INVALID_ARG_SIZE
69   kInvalidKernelArgs         = -52, // CL_INVALID_KERNEL_ARGS
70   kInvalidLocalNumDimensions = -53, // CL_INVALID_WORK_DIMENSION: Too many thread dimensions
71   kInvalidLocalThreadsTotal  = -54, // CL_INVALID_WORK_GROUP_SIZE: Too many threads in total
72   kInvalidLocalThreadsDim    = -55, // CL_INVALID_WORK_ITEM_SIZE: ... or for a specific dimension
73   kInvalidGlobalOffset       = -56, // CL_INVALID_GLOBAL_OFFSET
74   kInvalidEventWaitList      = -57, // CL_INVALID_EVENT_WAIT_LIST
75   kInvalidEvent              = -58, // CL_INVALID_EVENT
76   kInvalidOperation          = -59, // CL_INVALID_OPERATION
77   kInvalidBufferSize         = -61, // CL_INVALID_BUFFER_SIZE
78   kInvalidGlobalWorkSize     = -63, // CL_INVALID_GLOBAL_WORK_SIZE
79 
80   // Status codes in common with the clBLAS library
81   kNotImplemented            = -1024, // Routine or functionality not implemented yet
82   kInvalidMatrixA            = -1022, // Matrix A is not a valid OpenCL buffer
83   kInvalidMatrixB            = -1021, // Matrix B is not a valid OpenCL buffer
84   kInvalidMatrixC            = -1020, // Matrix C is not a valid OpenCL buffer
85   kInvalidVectorX            = -1019, // Vector X is not a valid OpenCL buffer
86   kInvalidVectorY            = -1018, // Vector Y is not a valid OpenCL buffer
87   kInvalidDimension          = -1017, // Dimensions M, N, and K have to be larger than zero
88   kInvalidLeadDimA           = -1016, // LD of A is smaller than the matrix's first dimension
89   kInvalidLeadDimB           = -1015, // LD of B is smaller than the matrix's first dimension
90   kInvalidLeadDimC           = -1014, // LD of C is smaller than the matrix's first dimension
91   kInvalidIncrementX         = -1013, // Increment of vector X cannot be zero
92   kInvalidIncrementY         = -1012, // Increment of vector Y cannot be zero
93   kInsufficientMemoryA       = -1011, // Matrix A's OpenCL buffer is too small
94   kInsufficientMemoryB       = -1010, // Matrix B's OpenCL buffer is too small
95   kInsufficientMemoryC       = -1009, // Matrix C's OpenCL buffer is too small
96   kInsufficientMemoryX       = -1008, // Vector X's OpenCL buffer is too small
97   kInsufficientMemoryY       = -1007, // Vector Y's OpenCL buffer is too small
98 
99   // Custom additional status codes for CLBlast
100   kInvalidBatchCount         = -2049, // The batch count needs to be positive
101   kInvalidOverrideKernel     = -2048, // Trying to override parameters for an invalid kernel
102   kMissingOverrideParameter  = -2047, // Missing override parameter(s) for the target kernel
103   kInvalidLocalMemUsage      = -2046, // Not enough local memory available on this device
104   kNoHalfPrecision           = -2045, // Half precision (16-bits) not supported by the device
105   kNoDoublePrecision         = -2044, // Double precision (64-bits) not supported by the device
106   kInvalidVectorScalar       = -2043, // The unit-sized vector is not a valid OpenCL buffer
107   kInsufficientMemoryScalar  = -2042, // The unit-sized vector's OpenCL buffer is too small
108   kDatabaseError             = -2041, // Entry for the device was not found in the database
109   kUnknownError              = -2040, // A catch-all error code representing an unspecified error
110   kUnexpectedError           = -2039, // A catch-all error code representing an unexpected exception
111 };
112 
113 // Matrix layout and transpose types
114 enum class Layout { kRowMajor = 101, kColMajor = 102 };
115 enum class Transpose { kNo = 111, kYes = 112, kConjugate = 113 };
116 enum class Triangle { kUpper = 121, kLower = 122 };
117 enum class Diagonal { kNonUnit = 131, kUnit = 132 };
118 enum class Side { kLeft = 141, kRight = 142 };
119 
120 // Precision scoped enum (values in bits)
121 enum class Precision { kHalf = 16, kSingle = 32, kDouble = 64,
122                        kComplexSingle = 3232, kComplexDouble = 6464, kAny = -1 };
123 
124 // =================================================================================================
125 // BLAS level-1 (vector-vector) routines
126 // =================================================================================================
127 
128 // Generate givens plane rotation: SROTG/DROTG
129 template <typename T>
130 StatusCode Rotg(cl_mem sa_buffer, const size_t sa_offset,
131                 cl_mem sb_buffer, const size_t sb_offset,
132                 cl_mem sc_buffer, const size_t sc_offset,
133                 cl_mem ss_buffer, const size_t ss_offset,
134                 cl_command_queue* queue, cl_event* event = nullptr);
135 
136 // Generate modified givens plane rotation: SROTMG/DROTMG
137 template <typename T>
138 StatusCode Rotmg(cl_mem sd1_buffer, const size_t sd1_offset,
139                  cl_mem sd2_buffer, const size_t sd2_offset,
140                  cl_mem sx1_buffer, const size_t sx1_offset,
141                  const cl_mem sy1_buffer, const size_t sy1_offset,
142                  cl_mem sparam_buffer, const size_t sparam_offset,
143                  cl_command_queue* queue, cl_event* event = nullptr);
144 
145 // Apply givens plane rotation: SROT/DROT
146 template <typename T>
147 StatusCode Rot(const size_t n,
148                cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
149                cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
150                const T cos,
151                const T sin,
152                cl_command_queue* queue, cl_event* event = nullptr);
153 
154 // Apply modified givens plane rotation: SROTM/DROTM
155 template <typename T>
156 StatusCode Rotm(const size_t n,
157                 cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
158                 cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
159                 cl_mem sparam_buffer, const size_t sparam_offset,
160                 cl_command_queue* queue, cl_event* event = nullptr);
161 
162 // Swap two vectors: SSWAP/DSWAP/CSWAP/ZSWAP/HSWAP
163 template <typename T>
164 StatusCode Swap(const size_t n,
165                 cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
166                 cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
167                 cl_command_queue* queue, cl_event* event = nullptr);
168 
169 // Vector scaling: SSCAL/DSCAL/CSCAL/ZSCAL/HSCAL
170 template <typename T>
171 StatusCode Scal(const size_t n,
172                 const T alpha,
173                 cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
174                 cl_command_queue* queue, cl_event* event = nullptr);
175 
176 // Vector copy: SCOPY/DCOPY/CCOPY/ZCOPY/HCOPY
177 template <typename T>
178 StatusCode Copy(const size_t n,
179                 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
180                 cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
181                 cl_command_queue* queue, cl_event* event = nullptr);
182 
183 // Vector-times-constant plus vector: SAXPY/DAXPY/CAXPY/ZAXPY/HAXPY
184 template <typename T>
185 StatusCode Axpy(const size_t n,
186                 const T alpha,
187                 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
188                 cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
189                 cl_command_queue* queue, cl_event* event = nullptr);
190 
191 // Dot product of two vectors: SDOT/DDOT/HDOT
192 template <typename T>
193 StatusCode Dot(const size_t n,
194                cl_mem dot_buffer, const size_t dot_offset,
195                const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
196                const cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
197                cl_command_queue* queue, cl_event* event = nullptr);
198 
199 // Dot product of two complex vectors: CDOTU/ZDOTU
200 template <typename T>
201 StatusCode Dotu(const size_t n,
202                 cl_mem dot_buffer, const size_t dot_offset,
203                 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
204                 const cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
205                 cl_command_queue* queue, cl_event* event = nullptr);
206 
207 // Dot product of two complex vectors, one conjugated: CDOTC/ZDOTC
208 template <typename T>
209 StatusCode Dotc(const size_t n,
210                 cl_mem dot_buffer, const size_t dot_offset,
211                 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
212                 const cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
213                 cl_command_queue* queue, cl_event* event = nullptr);
214 
215 // Euclidian norm of a vector: SNRM2/DNRM2/ScNRM2/DzNRM2/HNRM2
216 template <typename T>
217 StatusCode Nrm2(const size_t n,
218                 cl_mem nrm2_buffer, const size_t nrm2_offset,
219                 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
220                 cl_command_queue* queue, cl_event* event = nullptr);
221 
222 // Absolute sum of values in a vector: SASUM/DASUM/ScASUM/DzASUM/HASUM
223 template <typename T>
224 StatusCode Asum(const size_t n,
225                 cl_mem asum_buffer, const size_t asum_offset,
226                 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
227                 cl_command_queue* queue, cl_event* event = nullptr);
228 
229 // Sum of values in a vector (non-BLAS function): SSUM/DSUM/ScSUM/DzSUM/HSUM
230 template <typename T>
231 StatusCode Sum(const size_t n,
232                cl_mem sum_buffer, const size_t sum_offset,
233                const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
234                cl_command_queue* queue, cl_event* event = nullptr);
235 
236 // Index of absolute maximum value in a vector: iSAMAX/iDAMAX/iCAMAX/iZAMAX/iHAMAX
237 template <typename T>
238 StatusCode Amax(const size_t n,
239                 cl_mem imax_buffer, const size_t imax_offset,
240                 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
241                 cl_command_queue* queue, cl_event* event = nullptr);
242 
243 // Index of absolute minimum value in a vector (non-BLAS function): iSAMIN/iDAMIN/iCAMIN/iZAMIN/iHAMIN
244 template <typename T>
245 StatusCode Amin(const size_t n,
246                 cl_mem imin_buffer, const size_t imin_offset,
247                 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
248                 cl_command_queue* queue, cl_event* event = nullptr);
249 
250 // Index of maximum value in a vector (non-BLAS function): iSMAX/iDMAX/iCMAX/iZMAX/iHMAX
251 template <typename T>
252 StatusCode Max(const size_t n,
253                cl_mem imax_buffer, const size_t imax_offset,
254                const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
255                cl_command_queue* queue, cl_event* event = nullptr);
256 
257 // Index of minimum value in a vector (non-BLAS function): iSMIN/iDMIN/iCMIN/iZMIN/iHMIN
258 template <typename T>
259 StatusCode Min(const size_t n,
260                cl_mem imin_buffer, const size_t imin_offset,
261                const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
262                cl_command_queue* queue, cl_event* event = nullptr);
263 
264 // =================================================================================================
265 // BLAS level-2 (matrix-vector) routines
266 // =================================================================================================
267 
268 // General matrix-vector multiplication: SGEMV/DGEMV/CGEMV/ZGEMV/HGEMV
269 template <typename T>
270 StatusCode Gemv(const Layout layout, const Transpose a_transpose,
271                 const size_t m, const size_t n,
272                 const T alpha,
273                 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
274                 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
275                 const T beta,
276                 cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
277                 cl_command_queue* queue, cl_event* event = nullptr);
278 
279 // General banded matrix-vector multiplication: SGBMV/DGBMV/CGBMV/ZGBMV/HGBMV
280 template <typename T>
281 StatusCode Gbmv(const Layout layout, const Transpose a_transpose,
282                 const size_t m, const size_t n, const size_t kl, const size_t ku,
283                 const T alpha,
284                 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
285                 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
286                 const T beta,
287                 cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
288                 cl_command_queue* queue, cl_event* event = nullptr);
289 
290 // Hermitian matrix-vector multiplication: CHEMV/ZHEMV
291 template <typename T>
292 StatusCode Hemv(const Layout layout, const Triangle triangle,
293                 const size_t n,
294                 const T alpha,
295                 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
296                 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
297                 const T beta,
298                 cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
299                 cl_command_queue* queue, cl_event* event = nullptr);
300 
301 // Hermitian banded matrix-vector multiplication: CHBMV/ZHBMV
302 template <typename T>
303 StatusCode Hbmv(const Layout layout, const Triangle triangle,
304                 const size_t n, const size_t k,
305                 const T alpha,
306                 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
307                 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
308                 const T beta,
309                 cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
310                 cl_command_queue* queue, cl_event* event = nullptr);
311 
312 // Hermitian packed matrix-vector multiplication: CHPMV/ZHPMV
313 template <typename T>
314 StatusCode Hpmv(const Layout layout, const Triangle triangle,
315                 const size_t n,
316                 const T alpha,
317                 const cl_mem ap_buffer, const size_t ap_offset,
318                 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
319                 const T beta,
320                 cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
321                 cl_command_queue* queue, cl_event* event = nullptr);
322 
323 // Symmetric matrix-vector multiplication: SSYMV/DSYMV/HSYMV
324 template <typename T>
325 StatusCode Symv(const Layout layout, const Triangle triangle,
326                 const size_t n,
327                 const T alpha,
328                 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
329                 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
330                 const T beta,
331                 cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
332                 cl_command_queue* queue, cl_event* event = nullptr);
333 
334 // Symmetric banded matrix-vector multiplication: SSBMV/DSBMV/HSBMV
335 template <typename T>
336 StatusCode Sbmv(const Layout layout, const Triangle triangle,
337                 const size_t n, const size_t k,
338                 const T alpha,
339                 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
340                 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
341                 const T beta,
342                 cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
343                 cl_command_queue* queue, cl_event* event = nullptr);
344 
345 // Symmetric packed matrix-vector multiplication: SSPMV/DSPMV/HSPMV
346 template <typename T>
347 StatusCode Spmv(const Layout layout, const Triangle triangle,
348                 const size_t n,
349                 const T alpha,
350                 const cl_mem ap_buffer, const size_t ap_offset,
351                 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
352                 const T beta,
353                 cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
354                 cl_command_queue* queue, cl_event* event = nullptr);
355 
356 // Triangular matrix-vector multiplication: STRMV/DTRMV/CTRMV/ZTRMV/HTRMV
357 template <typename T>
358 StatusCode Trmv(const Layout layout, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
359                 const size_t n,
360                 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
361                 cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
362                 cl_command_queue* queue, cl_event* event = nullptr);
363 
364 // Triangular banded matrix-vector multiplication: STBMV/DTBMV/CTBMV/ZTBMV/HTBMV
365 template <typename T>
366 StatusCode Tbmv(const Layout layout, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
367                 const size_t n, const size_t k,
368                 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
369                 cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
370                 cl_command_queue* queue, cl_event* event = nullptr);
371 
372 // Triangular packed matrix-vector multiplication: STPMV/DTPMV/CTPMV/ZTPMV/HTPMV
373 template <typename T>
374 StatusCode Tpmv(const Layout layout, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
375                 const size_t n,
376                 const cl_mem ap_buffer, const size_t ap_offset,
377                 cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
378                 cl_command_queue* queue, cl_event* event = nullptr);
379 
380 // Solves a triangular system of equations: STRSV/DTRSV/CTRSV/ZTRSV
381 template <typename T>
382 StatusCode Trsv(const Layout layout, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
383                 const size_t n,
384                 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
385                 cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
386                 cl_command_queue* queue, cl_event* event = nullptr);
387 
388 // Solves a banded triangular system of equations: STBSV/DTBSV/CTBSV/ZTBSV
389 template <typename T>
390 StatusCode Tbsv(const Layout layout, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
391                 const size_t n, const size_t k,
392                 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
393                 cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
394                 cl_command_queue* queue, cl_event* event = nullptr);
395 
396 // Solves a packed triangular system of equations: STPSV/DTPSV/CTPSV/ZTPSV
397 template <typename T>
398 StatusCode Tpsv(const Layout layout, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
399                 const size_t n,
400                 const cl_mem ap_buffer, const size_t ap_offset,
401                 cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
402                 cl_command_queue* queue, cl_event* event = nullptr);
403 
404 // General rank-1 matrix update: SGER/DGER/HGER
405 template <typename T>
406 StatusCode Ger(const Layout layout,
407                const size_t m, const size_t n,
408                const T alpha,
409                const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
410                const cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
411                cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
412                cl_command_queue* queue, cl_event* event = nullptr);
413 
414 // General rank-1 complex matrix update: CGERU/ZGERU
415 template <typename T>
416 StatusCode Geru(const Layout layout,
417                 const size_t m, const size_t n,
418                 const T alpha,
419                 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
420                 const cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
421                 cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
422                 cl_command_queue* queue, cl_event* event = nullptr);
423 
424 // General rank-1 complex conjugated matrix update: CGERC/ZGERC
425 template <typename T>
426 StatusCode Gerc(const Layout layout,
427                 const size_t m, const size_t n,
428                 const T alpha,
429                 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
430                 const cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
431                 cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
432                 cl_command_queue* queue, cl_event* event = nullptr);
433 
434 // Hermitian rank-1 matrix update: CHER/ZHER
435 template <typename T>
436 StatusCode Her(const Layout layout, const Triangle triangle,
437                const size_t n,
438                const T alpha,
439                const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
440                cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
441                cl_command_queue* queue, cl_event* event = nullptr);
442 
443 // Hermitian packed rank-1 matrix update: CHPR/ZHPR
444 template <typename T>
445 StatusCode Hpr(const Layout layout, const Triangle triangle,
446                const size_t n,
447                const T alpha,
448                const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
449                cl_mem ap_buffer, const size_t ap_offset,
450                cl_command_queue* queue, cl_event* event = nullptr);
451 
452 // Hermitian rank-2 matrix update: CHER2/ZHER2
453 template <typename T>
454 StatusCode Her2(const Layout layout, const Triangle triangle,
455                 const size_t n,
456                 const T alpha,
457                 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
458                 const cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
459                 cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
460                 cl_command_queue* queue, cl_event* event = nullptr);
461 
462 // Hermitian packed rank-2 matrix update: CHPR2/ZHPR2
463 template <typename T>
464 StatusCode Hpr2(const Layout layout, const Triangle triangle,
465                 const size_t n,
466                 const T alpha,
467                 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
468                 const cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
469                 cl_mem ap_buffer, const size_t ap_offset,
470                 cl_command_queue* queue, cl_event* event = nullptr);
471 
472 // Symmetric rank-1 matrix update: SSYR/DSYR/HSYR
473 template <typename T>
474 StatusCode Syr(const Layout layout, const Triangle triangle,
475                const size_t n,
476                const T alpha,
477                const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
478                cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
479                cl_command_queue* queue, cl_event* event = nullptr);
480 
481 // Symmetric packed rank-1 matrix update: SSPR/DSPR/HSPR
482 template <typename T>
483 StatusCode Spr(const Layout layout, const Triangle triangle,
484                const size_t n,
485                const T alpha,
486                const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
487                cl_mem ap_buffer, const size_t ap_offset,
488                cl_command_queue* queue, cl_event* event = nullptr);
489 
490 // Symmetric rank-2 matrix update: SSYR2/DSYR2/HSYR2
491 template <typename T>
492 StatusCode Syr2(const Layout layout, const Triangle triangle,
493                 const size_t n,
494                 const T alpha,
495                 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
496                 const cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
497                 cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
498                 cl_command_queue* queue, cl_event* event = nullptr);
499 
500 // Symmetric packed rank-2 matrix update: SSPR2/DSPR2/HSPR2
501 template <typename T>
502 StatusCode Spr2(const Layout layout, const Triangle triangle,
503                 const size_t n,
504                 const T alpha,
505                 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
506                 const cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
507                 cl_mem ap_buffer, const size_t ap_offset,
508                 cl_command_queue* queue, cl_event* event = nullptr);
509 
510 // =================================================================================================
511 // BLAS level-3 (matrix-matrix) routines
512 // =================================================================================================
513 
514 // General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM
515 template <typename T>
516 StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
517                 const size_t m, const size_t n, const size_t k,
518                 const T alpha,
519                 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
520                 const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
521                 const T beta,
522                 cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
523                 cl_command_queue* queue, cl_event* event = nullptr);
524 
525 // Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM
526 template <typename T>
527 StatusCode Symm(const Layout layout, const Side side, const Triangle triangle,
528                 const size_t m, const size_t n,
529                 const T alpha,
530                 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
531                 const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
532                 const T beta,
533                 cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
534                 cl_command_queue* queue, cl_event* event = nullptr);
535 
536 // Hermitian matrix-matrix multiplication: CHEMM/ZHEMM
537 template <typename T>
538 StatusCode Hemm(const Layout layout, const Side side, const Triangle triangle,
539                 const size_t m, const size_t n,
540                 const T alpha,
541                 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
542                 const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
543                 const T beta,
544                 cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
545                 cl_command_queue* queue, cl_event* event = nullptr);
546 
547 // Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK/HSYRK
548 template <typename T>
549 StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose a_transpose,
550                 const size_t n, const size_t k,
551                 const T alpha,
552                 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
553                 const T beta,
554                 cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
555                 cl_command_queue* queue, cl_event* event = nullptr);
556 
557 // Rank-K update of a hermitian matrix: CHERK/ZHERK
558 template <typename T>
559 StatusCode Herk(const Layout layout, const Triangle triangle, const Transpose a_transpose,
560                 const size_t n, const size_t k,
561                 const T alpha,
562                 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
563                 const T beta,
564                 cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
565                 cl_command_queue* queue, cl_event* event = nullptr);
566 
567 // Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K/HSYR2K
568 template <typename T>
569 StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose,
570                  const size_t n, const size_t k,
571                  const T alpha,
572                  const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
573                  const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
574                  const T beta,
575                  cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
576                  cl_command_queue* queue, cl_event* event = nullptr);
577 
578 // Rank-2K update of a hermitian matrix: CHER2K/ZHER2K
579 template <typename T, typename U>
580 StatusCode Her2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose,
581                  const size_t n, const size_t k,
582                  const T alpha,
583                  const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
584                  const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
585                  const U beta,
586                  cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
587                  cl_command_queue* queue, cl_event* event = nullptr);
588 
589 // Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM/HTRMM
590 template <typename T>
591 StatusCode Trmm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
592                 const size_t m, const size_t n,
593                 const T alpha,
594                 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
595                 cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
596                 cl_command_queue* queue, cl_event* event = nullptr);
597 
598 // Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM
599 template <typename T>
600 StatusCode Trsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
601                 const size_t m, const size_t n,
602                 const T alpha,
603                 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
604                 cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
605                 cl_command_queue* queue, cl_event* event = nullptr);
606 
607 // =================================================================================================
608 // Extra non-BLAS routines (level-X)
609 // =================================================================================================
610 
611 // Scaling and out-place transpose/copy (non-BLAS function): SOMATCOPY/DOMATCOPY/COMATCOPY/ZOMATCOPY/HOMATCOPY
612 template <typename T>
613 StatusCode Omatcopy(const Layout layout, const Transpose a_transpose,
614                     const size_t m, const size_t n,
615                     const T alpha,
616                     const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
617                     cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
618                     cl_command_queue* queue, cl_event* event = nullptr);
619 
620 // Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL
621 template <typename T>
622 StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
623                   const cl_mem im_buffer, const size_t im_offset,
624                   cl_mem col_buffer, const size_t col_offset,
625                   cl_command_queue* queue, cl_event* event = nullptr);
626 
627 // Batched version of AXPY: SAXPYBATCHED/DAXPYBATCHED/CAXPYBATCHED/ZAXPYBATCHED/HAXPYBATCHED
628 template <typename T>
629 StatusCode AxpyBatched(const size_t n,
630                        const T *alphas,
631                        const cl_mem x_buffer, const size_t *x_offsets, const size_t x_inc,
632                        cl_mem y_buffer, const size_t *y_offsets, const size_t y_inc,
633                        const size_t batch_count,
634                        cl_command_queue* queue, cl_event* event = nullptr);
635 
636 // Batched version of GEMM: SGEMMBATCHED/DGEMMBATCHED/CGEMMBATCHED/ZGEMMBATCHED/HGEMMBATCHED
637 template <typename T>
638 StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
639                        const size_t m, const size_t n, const size_t k,
640                        const T *alphas,
641                        const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
642                        const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
643                        const T *betas,
644                        cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
645                        const size_t batch_count,
646                        cl_command_queue* queue, cl_event* event = nullptr);
647 
648 // =================================================================================================
649 
650 // CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on
651 // for the same device. This cache can be cleared to free up system memory or in case of debugging.
652 StatusCode PUBLIC_API ClearCache();
653 
654 // The cache can also be pre-initialized for a specific device with all possible CLBLast kernels.
655 // Further CLBlast routine calls will then run at maximum speed.
656 StatusCode PUBLIC_API FillCache(const cl_device_id device);
657 
658 // =================================================================================================
659 
660 // Overrides tuning parameters for a specific device-precision-kernel combination. The next time
661 // the target routine is called it will re-compile and use the new parameters from then on.
662 StatusCode PUBLIC_API OverrideParameters(const cl_device_id device, const std::string &kernel_name,
663                                          const Precision precision,
664                                          const std::unordered_map<std::string,size_t> &parameters);
665 
666 // =================================================================================================
667 
668 } // namespace clblast
669 
670 // CLBLAST_CLBLAST_H_
671 #endif
672