1 2 // ================================================================================================= 3 // This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This 4 // project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- 5 // width of 100 characters per line. 6 // 7 // Author(s): 8 // Cedric Nugteren <www.cedricnugteren.nl> 9 // 10 // This file contains the interface to the CLBlast BLAS routines. It also contains the definitions 11 // of the returned status codes and the layout and transpose types. This is the only header users 12 // of CLBlast should include and use. 13 // 14 // ================================================================================================= 15 16 #ifndef CLBLAST_CLBLAST_H_ 17 #define CLBLAST_CLBLAST_H_ 18 19 #include <cstdlib> // For size_t 20 #include <string> // For OverrideParameters function 21 #include <unordered_map> // For OverrideParameters function 22 23 // Includes the normal OpenCL C header 24 #if defined(__APPLE__) || defined(__MACOSX) 25 #include <OpenCL/opencl.h> 26 #else 27 #include <CL/opencl.h> 28 #endif 29 30 // Exports library functions under Windows when building a DLL. See also: 31 // https://msdn.microsoft.com/en-us/library/a90k134d.aspx 32 #if defined(_WIN32) && defined(CLBLAST_DLL) 33 #if defined(COMPILING_DLL) 34 #define PUBLIC_API __declspec(dllexport) 35 #else 36 #define PUBLIC_API __declspec(dllimport) 37 #endif 38 #else 39 #define PUBLIC_API 40 #endif 41 42 namespace clblast { 43 // ================================================================================================= 44 45 // Status codes. These codes can be returned by functions declared in this header file. The error 46 // codes match either the standard OpenCL error codes or the clBLAS error codes. 47 enum class StatusCode { 48 49 // Status codes in common with the OpenCL standard 50 kSuccess = 0, // CL_SUCCESS 51 kOpenCLCompilerNotAvailable= -3, // CL_COMPILER_NOT_AVAILABLE 52 kTempBufferAllocFailure = -4, // CL_MEM_OBJECT_ALLOCATION_FAILURE 53 kOpenCLOutOfResources = -5, // CL_OUT_OF_RESOURCES 54 kOpenCLOutOfHostMemory = -6, // CL_OUT_OF_HOST_MEMORY 55 kOpenCLBuildProgramFailure = -11, // CL_BUILD_PROGRAM_FAILURE: OpenCL compilation error 56 kInvalidValue = -30, // CL_INVALID_VALUE 57 kInvalidCommandQueue = -36, // CL_INVALID_COMMAND_QUEUE 58 kInvalidMemObject = -38, // CL_INVALID_MEM_OBJECT 59 kInvalidBinary = -42, // CL_INVALID_BINARY 60 kInvalidBuildOptions = -43, // CL_INVALID_BUILD_OPTIONS 61 kInvalidProgram = -44, // CL_INVALID_PROGRAM 62 kInvalidProgramExecutable = -45, // CL_INVALID_PROGRAM_EXECUTABLE 63 kInvalidKernelName = -46, // CL_INVALID_KERNEL_NAME 64 kInvalidKernelDefinition = -47, // CL_INVALID_KERNEL_DEFINITION 65 kInvalidKernel = -48, // CL_INVALID_KERNEL 66 kInvalidArgIndex = -49, // CL_INVALID_ARG_INDEX 67 kInvalidArgValue = -50, // CL_INVALID_ARG_VALUE 68 kInvalidArgSize = -51, // CL_INVALID_ARG_SIZE 69 kInvalidKernelArgs = -52, // CL_INVALID_KERNEL_ARGS 70 kInvalidLocalNumDimensions = -53, // CL_INVALID_WORK_DIMENSION: Too many thread dimensions 71 kInvalidLocalThreadsTotal = -54, // CL_INVALID_WORK_GROUP_SIZE: Too many threads in total 72 kInvalidLocalThreadsDim = -55, // CL_INVALID_WORK_ITEM_SIZE: ... or for a specific dimension 73 kInvalidGlobalOffset = -56, // CL_INVALID_GLOBAL_OFFSET 74 kInvalidEventWaitList = -57, // CL_INVALID_EVENT_WAIT_LIST 75 kInvalidEvent = -58, // CL_INVALID_EVENT 76 kInvalidOperation = -59, // CL_INVALID_OPERATION 77 kInvalidBufferSize = -61, // CL_INVALID_BUFFER_SIZE 78 kInvalidGlobalWorkSize = -63, // CL_INVALID_GLOBAL_WORK_SIZE 79 80 // Status codes in common with the clBLAS library 81 kNotImplemented = -1024, // Routine or functionality not implemented yet 82 kInvalidMatrixA = -1022, // Matrix A is not a valid OpenCL buffer 83 kInvalidMatrixB = -1021, // Matrix B is not a valid OpenCL buffer 84 kInvalidMatrixC = -1020, // Matrix C is not a valid OpenCL buffer 85 kInvalidVectorX = -1019, // Vector X is not a valid OpenCL buffer 86 kInvalidVectorY = -1018, // Vector Y is not a valid OpenCL buffer 87 kInvalidDimension = -1017, // Dimensions M, N, and K have to be larger than zero 88 kInvalidLeadDimA = -1016, // LD of A is smaller than the matrix's first dimension 89 kInvalidLeadDimB = -1015, // LD of B is smaller than the matrix's first dimension 90 kInvalidLeadDimC = -1014, // LD of C is smaller than the matrix's first dimension 91 kInvalidIncrementX = -1013, // Increment of vector X cannot be zero 92 kInvalidIncrementY = -1012, // Increment of vector Y cannot be zero 93 kInsufficientMemoryA = -1011, // Matrix A's OpenCL buffer is too small 94 kInsufficientMemoryB = -1010, // Matrix B's OpenCL buffer is too small 95 kInsufficientMemoryC = -1009, // Matrix C's OpenCL buffer is too small 96 kInsufficientMemoryX = -1008, // Vector X's OpenCL buffer is too small 97 kInsufficientMemoryY = -1007, // Vector Y's OpenCL buffer is too small 98 99 // Custom additional status codes for CLBlast 100 kInvalidBatchCount = -2049, // The batch count needs to be positive 101 kInvalidOverrideKernel = -2048, // Trying to override parameters for an invalid kernel 102 kMissingOverrideParameter = -2047, // Missing override parameter(s) for the target kernel 103 kInvalidLocalMemUsage = -2046, // Not enough local memory available on this device 104 kNoHalfPrecision = -2045, // Half precision (16-bits) not supported by the device 105 kNoDoublePrecision = -2044, // Double precision (64-bits) not supported by the device 106 kInvalidVectorScalar = -2043, // The unit-sized vector is not a valid OpenCL buffer 107 kInsufficientMemoryScalar = -2042, // The unit-sized vector's OpenCL buffer is too small 108 kDatabaseError = -2041, // Entry for the device was not found in the database 109 kUnknownError = -2040, // A catch-all error code representing an unspecified error 110 kUnexpectedError = -2039, // A catch-all error code representing an unexpected exception 111 }; 112 113 // Matrix layout and transpose types 114 enum class Layout { kRowMajor = 101, kColMajor = 102 }; 115 enum class Transpose { kNo = 111, kYes = 112, kConjugate = 113 }; 116 enum class Triangle { kUpper = 121, kLower = 122 }; 117 enum class Diagonal { kNonUnit = 131, kUnit = 132 }; 118 enum class Side { kLeft = 141, kRight = 142 }; 119 120 // Precision scoped enum (values in bits) 121 enum class Precision { kHalf = 16, kSingle = 32, kDouble = 64, 122 kComplexSingle = 3232, kComplexDouble = 6464, kAny = -1 }; 123 124 // ================================================================================================= 125 // BLAS level-1 (vector-vector) routines 126 // ================================================================================================= 127 128 // Generate givens plane rotation: SROTG/DROTG 129 template <typename T> 130 StatusCode Rotg(cl_mem sa_buffer, const size_t sa_offset, 131 cl_mem sb_buffer, const size_t sb_offset, 132 cl_mem sc_buffer, const size_t sc_offset, 133 cl_mem ss_buffer, const size_t ss_offset, 134 cl_command_queue* queue, cl_event* event = nullptr); 135 136 // Generate modified givens plane rotation: SROTMG/DROTMG 137 template <typename T> 138 StatusCode Rotmg(cl_mem sd1_buffer, const size_t sd1_offset, 139 cl_mem sd2_buffer, const size_t sd2_offset, 140 cl_mem sx1_buffer, const size_t sx1_offset, 141 const cl_mem sy1_buffer, const size_t sy1_offset, 142 cl_mem sparam_buffer, const size_t sparam_offset, 143 cl_command_queue* queue, cl_event* event = nullptr); 144 145 // Apply givens plane rotation: SROT/DROT 146 template <typename T> 147 StatusCode Rot(const size_t n, 148 cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 149 cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 150 const T cos, 151 const T sin, 152 cl_command_queue* queue, cl_event* event = nullptr); 153 154 // Apply modified givens plane rotation: SROTM/DROTM 155 template <typename T> 156 StatusCode Rotm(const size_t n, 157 cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 158 cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 159 cl_mem sparam_buffer, const size_t sparam_offset, 160 cl_command_queue* queue, cl_event* event = nullptr); 161 162 // Swap two vectors: SSWAP/DSWAP/CSWAP/ZSWAP/HSWAP 163 template <typename T> 164 StatusCode Swap(const size_t n, 165 cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 166 cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 167 cl_command_queue* queue, cl_event* event = nullptr); 168 169 // Vector scaling: SSCAL/DSCAL/CSCAL/ZSCAL/HSCAL 170 template <typename T> 171 StatusCode Scal(const size_t n, 172 const T alpha, 173 cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 174 cl_command_queue* queue, cl_event* event = nullptr); 175 176 // Vector copy: SCOPY/DCOPY/CCOPY/ZCOPY/HCOPY 177 template <typename T> 178 StatusCode Copy(const size_t n, 179 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 180 cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 181 cl_command_queue* queue, cl_event* event = nullptr); 182 183 // Vector-times-constant plus vector: SAXPY/DAXPY/CAXPY/ZAXPY/HAXPY 184 template <typename T> 185 StatusCode Axpy(const size_t n, 186 const T alpha, 187 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 188 cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 189 cl_command_queue* queue, cl_event* event = nullptr); 190 191 // Dot product of two vectors: SDOT/DDOT/HDOT 192 template <typename T> 193 StatusCode Dot(const size_t n, 194 cl_mem dot_buffer, const size_t dot_offset, 195 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 196 const cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 197 cl_command_queue* queue, cl_event* event = nullptr); 198 199 // Dot product of two complex vectors: CDOTU/ZDOTU 200 template <typename T> 201 StatusCode Dotu(const size_t n, 202 cl_mem dot_buffer, const size_t dot_offset, 203 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 204 const cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 205 cl_command_queue* queue, cl_event* event = nullptr); 206 207 // Dot product of two complex vectors, one conjugated: CDOTC/ZDOTC 208 template <typename T> 209 StatusCode Dotc(const size_t n, 210 cl_mem dot_buffer, const size_t dot_offset, 211 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 212 const cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 213 cl_command_queue* queue, cl_event* event = nullptr); 214 215 // Euclidian norm of a vector: SNRM2/DNRM2/ScNRM2/DzNRM2/HNRM2 216 template <typename T> 217 StatusCode Nrm2(const size_t n, 218 cl_mem nrm2_buffer, const size_t nrm2_offset, 219 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 220 cl_command_queue* queue, cl_event* event = nullptr); 221 222 // Absolute sum of values in a vector: SASUM/DASUM/ScASUM/DzASUM/HASUM 223 template <typename T> 224 StatusCode Asum(const size_t n, 225 cl_mem asum_buffer, const size_t asum_offset, 226 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 227 cl_command_queue* queue, cl_event* event = nullptr); 228 229 // Sum of values in a vector (non-BLAS function): SSUM/DSUM/ScSUM/DzSUM/HSUM 230 template <typename T> 231 StatusCode Sum(const size_t n, 232 cl_mem sum_buffer, const size_t sum_offset, 233 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 234 cl_command_queue* queue, cl_event* event = nullptr); 235 236 // Index of absolute maximum value in a vector: iSAMAX/iDAMAX/iCAMAX/iZAMAX/iHAMAX 237 template <typename T> 238 StatusCode Amax(const size_t n, 239 cl_mem imax_buffer, const size_t imax_offset, 240 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 241 cl_command_queue* queue, cl_event* event = nullptr); 242 243 // Index of absolute minimum value in a vector (non-BLAS function): iSAMIN/iDAMIN/iCAMIN/iZAMIN/iHAMIN 244 template <typename T> 245 StatusCode Amin(const size_t n, 246 cl_mem imin_buffer, const size_t imin_offset, 247 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 248 cl_command_queue* queue, cl_event* event = nullptr); 249 250 // Index of maximum value in a vector (non-BLAS function): iSMAX/iDMAX/iCMAX/iZMAX/iHMAX 251 template <typename T> 252 StatusCode Max(const size_t n, 253 cl_mem imax_buffer, const size_t imax_offset, 254 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 255 cl_command_queue* queue, cl_event* event = nullptr); 256 257 // Index of minimum value in a vector (non-BLAS function): iSMIN/iDMIN/iCMIN/iZMIN/iHMIN 258 template <typename T> 259 StatusCode Min(const size_t n, 260 cl_mem imin_buffer, const size_t imin_offset, 261 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 262 cl_command_queue* queue, cl_event* event = nullptr); 263 264 // ================================================================================================= 265 // BLAS level-2 (matrix-vector) routines 266 // ================================================================================================= 267 268 // General matrix-vector multiplication: SGEMV/DGEMV/CGEMV/ZGEMV/HGEMV 269 template <typename T> 270 StatusCode Gemv(const Layout layout, const Transpose a_transpose, 271 const size_t m, const size_t n, 272 const T alpha, 273 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 274 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 275 const T beta, 276 cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 277 cl_command_queue* queue, cl_event* event = nullptr); 278 279 // General banded matrix-vector multiplication: SGBMV/DGBMV/CGBMV/ZGBMV/HGBMV 280 template <typename T> 281 StatusCode Gbmv(const Layout layout, const Transpose a_transpose, 282 const size_t m, const size_t n, const size_t kl, const size_t ku, 283 const T alpha, 284 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 285 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 286 const T beta, 287 cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 288 cl_command_queue* queue, cl_event* event = nullptr); 289 290 // Hermitian matrix-vector multiplication: CHEMV/ZHEMV 291 template <typename T> 292 StatusCode Hemv(const Layout layout, const Triangle triangle, 293 const size_t n, 294 const T alpha, 295 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 296 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 297 const T beta, 298 cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 299 cl_command_queue* queue, cl_event* event = nullptr); 300 301 // Hermitian banded matrix-vector multiplication: CHBMV/ZHBMV 302 template <typename T> 303 StatusCode Hbmv(const Layout layout, const Triangle triangle, 304 const size_t n, const size_t k, 305 const T alpha, 306 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 307 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 308 const T beta, 309 cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 310 cl_command_queue* queue, cl_event* event = nullptr); 311 312 // Hermitian packed matrix-vector multiplication: CHPMV/ZHPMV 313 template <typename T> 314 StatusCode Hpmv(const Layout layout, const Triangle triangle, 315 const size_t n, 316 const T alpha, 317 const cl_mem ap_buffer, const size_t ap_offset, 318 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 319 const T beta, 320 cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 321 cl_command_queue* queue, cl_event* event = nullptr); 322 323 // Symmetric matrix-vector multiplication: SSYMV/DSYMV/HSYMV 324 template <typename T> 325 StatusCode Symv(const Layout layout, const Triangle triangle, 326 const size_t n, 327 const T alpha, 328 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 329 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 330 const T beta, 331 cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 332 cl_command_queue* queue, cl_event* event = nullptr); 333 334 // Symmetric banded matrix-vector multiplication: SSBMV/DSBMV/HSBMV 335 template <typename T> 336 StatusCode Sbmv(const Layout layout, const Triangle triangle, 337 const size_t n, const size_t k, 338 const T alpha, 339 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 340 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 341 const T beta, 342 cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 343 cl_command_queue* queue, cl_event* event = nullptr); 344 345 // Symmetric packed matrix-vector multiplication: SSPMV/DSPMV/HSPMV 346 template <typename T> 347 StatusCode Spmv(const Layout layout, const Triangle triangle, 348 const size_t n, 349 const T alpha, 350 const cl_mem ap_buffer, const size_t ap_offset, 351 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 352 const T beta, 353 cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 354 cl_command_queue* queue, cl_event* event = nullptr); 355 356 // Triangular matrix-vector multiplication: STRMV/DTRMV/CTRMV/ZTRMV/HTRMV 357 template <typename T> 358 StatusCode Trmv(const Layout layout, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, 359 const size_t n, 360 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 361 cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 362 cl_command_queue* queue, cl_event* event = nullptr); 363 364 // Triangular banded matrix-vector multiplication: STBMV/DTBMV/CTBMV/ZTBMV/HTBMV 365 template <typename T> 366 StatusCode Tbmv(const Layout layout, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, 367 const size_t n, const size_t k, 368 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 369 cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 370 cl_command_queue* queue, cl_event* event = nullptr); 371 372 // Triangular packed matrix-vector multiplication: STPMV/DTPMV/CTPMV/ZTPMV/HTPMV 373 template <typename T> 374 StatusCode Tpmv(const Layout layout, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, 375 const size_t n, 376 const cl_mem ap_buffer, const size_t ap_offset, 377 cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 378 cl_command_queue* queue, cl_event* event = nullptr); 379 380 // Solves a triangular system of equations: STRSV/DTRSV/CTRSV/ZTRSV 381 template <typename T> 382 StatusCode Trsv(const Layout layout, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, 383 const size_t n, 384 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 385 cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 386 cl_command_queue* queue, cl_event* event = nullptr); 387 388 // Solves a banded triangular system of equations: STBSV/DTBSV/CTBSV/ZTBSV 389 template <typename T> 390 StatusCode Tbsv(const Layout layout, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, 391 const size_t n, const size_t k, 392 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 393 cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 394 cl_command_queue* queue, cl_event* event = nullptr); 395 396 // Solves a packed triangular system of equations: STPSV/DTPSV/CTPSV/ZTPSV 397 template <typename T> 398 StatusCode Tpsv(const Layout layout, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, 399 const size_t n, 400 const cl_mem ap_buffer, const size_t ap_offset, 401 cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 402 cl_command_queue* queue, cl_event* event = nullptr); 403 404 // General rank-1 matrix update: SGER/DGER/HGER 405 template <typename T> 406 StatusCode Ger(const Layout layout, 407 const size_t m, const size_t n, 408 const T alpha, 409 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 410 const cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 411 cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 412 cl_command_queue* queue, cl_event* event = nullptr); 413 414 // General rank-1 complex matrix update: CGERU/ZGERU 415 template <typename T> 416 StatusCode Geru(const Layout layout, 417 const size_t m, const size_t n, 418 const T alpha, 419 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 420 const cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 421 cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 422 cl_command_queue* queue, cl_event* event = nullptr); 423 424 // General rank-1 complex conjugated matrix update: CGERC/ZGERC 425 template <typename T> 426 StatusCode Gerc(const Layout layout, 427 const size_t m, const size_t n, 428 const T alpha, 429 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 430 const cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 431 cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 432 cl_command_queue* queue, cl_event* event = nullptr); 433 434 // Hermitian rank-1 matrix update: CHER/ZHER 435 template <typename T> 436 StatusCode Her(const Layout layout, const Triangle triangle, 437 const size_t n, 438 const T alpha, 439 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 440 cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 441 cl_command_queue* queue, cl_event* event = nullptr); 442 443 // Hermitian packed rank-1 matrix update: CHPR/ZHPR 444 template <typename T> 445 StatusCode Hpr(const Layout layout, const Triangle triangle, 446 const size_t n, 447 const T alpha, 448 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 449 cl_mem ap_buffer, const size_t ap_offset, 450 cl_command_queue* queue, cl_event* event = nullptr); 451 452 // Hermitian rank-2 matrix update: CHER2/ZHER2 453 template <typename T> 454 StatusCode Her2(const Layout layout, const Triangle triangle, 455 const size_t n, 456 const T alpha, 457 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 458 const cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 459 cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 460 cl_command_queue* queue, cl_event* event = nullptr); 461 462 // Hermitian packed rank-2 matrix update: CHPR2/ZHPR2 463 template <typename T> 464 StatusCode Hpr2(const Layout layout, const Triangle triangle, 465 const size_t n, 466 const T alpha, 467 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 468 const cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 469 cl_mem ap_buffer, const size_t ap_offset, 470 cl_command_queue* queue, cl_event* event = nullptr); 471 472 // Symmetric rank-1 matrix update: SSYR/DSYR/HSYR 473 template <typename T> 474 StatusCode Syr(const Layout layout, const Triangle triangle, 475 const size_t n, 476 const T alpha, 477 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 478 cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 479 cl_command_queue* queue, cl_event* event = nullptr); 480 481 // Symmetric packed rank-1 matrix update: SSPR/DSPR/HSPR 482 template <typename T> 483 StatusCode Spr(const Layout layout, const Triangle triangle, 484 const size_t n, 485 const T alpha, 486 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 487 cl_mem ap_buffer, const size_t ap_offset, 488 cl_command_queue* queue, cl_event* event = nullptr); 489 490 // Symmetric rank-2 matrix update: SSYR2/DSYR2/HSYR2 491 template <typename T> 492 StatusCode Syr2(const Layout layout, const Triangle triangle, 493 const size_t n, 494 const T alpha, 495 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 496 const cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 497 cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 498 cl_command_queue* queue, cl_event* event = nullptr); 499 500 // Symmetric packed rank-2 matrix update: SSPR2/DSPR2/HSPR2 501 template <typename T> 502 StatusCode Spr2(const Layout layout, const Triangle triangle, 503 const size_t n, 504 const T alpha, 505 const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, 506 const cl_mem y_buffer, const size_t y_offset, const size_t y_inc, 507 cl_mem ap_buffer, const size_t ap_offset, 508 cl_command_queue* queue, cl_event* event = nullptr); 509 510 // ================================================================================================= 511 // BLAS level-3 (matrix-matrix) routines 512 // ================================================================================================= 513 514 // General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM 515 template <typename T> 516 StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, 517 const size_t m, const size_t n, const size_t k, 518 const T alpha, 519 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 520 const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, 521 const T beta, 522 cl_mem c_buffer, const size_t c_offset, const size_t c_ld, 523 cl_command_queue* queue, cl_event* event = nullptr); 524 525 // Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM 526 template <typename T> 527 StatusCode Symm(const Layout layout, const Side side, const Triangle triangle, 528 const size_t m, const size_t n, 529 const T alpha, 530 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 531 const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, 532 const T beta, 533 cl_mem c_buffer, const size_t c_offset, const size_t c_ld, 534 cl_command_queue* queue, cl_event* event = nullptr); 535 536 // Hermitian matrix-matrix multiplication: CHEMM/ZHEMM 537 template <typename T> 538 StatusCode Hemm(const Layout layout, const Side side, const Triangle triangle, 539 const size_t m, const size_t n, 540 const T alpha, 541 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 542 const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, 543 const T beta, 544 cl_mem c_buffer, const size_t c_offset, const size_t c_ld, 545 cl_command_queue* queue, cl_event* event = nullptr); 546 547 // Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK/HSYRK 548 template <typename T> 549 StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose a_transpose, 550 const size_t n, const size_t k, 551 const T alpha, 552 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 553 const T beta, 554 cl_mem c_buffer, const size_t c_offset, const size_t c_ld, 555 cl_command_queue* queue, cl_event* event = nullptr); 556 557 // Rank-K update of a hermitian matrix: CHERK/ZHERK 558 template <typename T> 559 StatusCode Herk(const Layout layout, const Triangle triangle, const Transpose a_transpose, 560 const size_t n, const size_t k, 561 const T alpha, 562 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 563 const T beta, 564 cl_mem c_buffer, const size_t c_offset, const size_t c_ld, 565 cl_command_queue* queue, cl_event* event = nullptr); 566 567 // Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K/HSYR2K 568 template <typename T> 569 StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose, 570 const size_t n, const size_t k, 571 const T alpha, 572 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 573 const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, 574 const T beta, 575 cl_mem c_buffer, const size_t c_offset, const size_t c_ld, 576 cl_command_queue* queue, cl_event* event = nullptr); 577 578 // Rank-2K update of a hermitian matrix: CHER2K/ZHER2K 579 template <typename T, typename U> 580 StatusCode Her2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose, 581 const size_t n, const size_t k, 582 const T alpha, 583 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 584 const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, 585 const U beta, 586 cl_mem c_buffer, const size_t c_offset, const size_t c_ld, 587 cl_command_queue* queue, cl_event* event = nullptr); 588 589 // Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM/HTRMM 590 template <typename T> 591 StatusCode Trmm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, 592 const size_t m, const size_t n, 593 const T alpha, 594 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 595 cl_mem b_buffer, const size_t b_offset, const size_t b_ld, 596 cl_command_queue* queue, cl_event* event = nullptr); 597 598 // Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM 599 template <typename T> 600 StatusCode Trsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, 601 const size_t m, const size_t n, 602 const T alpha, 603 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 604 cl_mem b_buffer, const size_t b_offset, const size_t b_ld, 605 cl_command_queue* queue, cl_event* event = nullptr); 606 607 // ================================================================================================= 608 // Extra non-BLAS routines (level-X) 609 // ================================================================================================= 610 611 // Scaling and out-place transpose/copy (non-BLAS function): SOMATCOPY/DOMATCOPY/COMATCOPY/ZOMATCOPY/HOMATCOPY 612 template <typename T> 613 StatusCode Omatcopy(const Layout layout, const Transpose a_transpose, 614 const size_t m, const size_t n, 615 const T alpha, 616 const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, 617 cl_mem b_buffer, const size_t b_offset, const size_t b_ld, 618 cl_command_queue* queue, cl_event* event = nullptr); 619 620 // Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL 621 template <typename T> 622 StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, 623 const cl_mem im_buffer, const size_t im_offset, 624 cl_mem col_buffer, const size_t col_offset, 625 cl_command_queue* queue, cl_event* event = nullptr); 626 627 // Batched version of AXPY: SAXPYBATCHED/DAXPYBATCHED/CAXPYBATCHED/ZAXPYBATCHED/HAXPYBATCHED 628 template <typename T> 629 StatusCode AxpyBatched(const size_t n, 630 const T *alphas, 631 const cl_mem x_buffer, const size_t *x_offsets, const size_t x_inc, 632 cl_mem y_buffer, const size_t *y_offsets, const size_t y_inc, 633 const size_t batch_count, 634 cl_command_queue* queue, cl_event* event = nullptr); 635 636 // Batched version of GEMM: SGEMMBATCHED/DGEMMBATCHED/CGEMMBATCHED/ZGEMMBATCHED/HGEMMBATCHED 637 template <typename T> 638 StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, 639 const size_t m, const size_t n, const size_t k, 640 const T *alphas, 641 const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld, 642 const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld, 643 const T *betas, 644 cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld, 645 const size_t batch_count, 646 cl_command_queue* queue, cl_event* event = nullptr); 647 648 // ================================================================================================= 649 650 // CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on 651 // for the same device. This cache can be cleared to free up system memory or in case of debugging. 652 StatusCode PUBLIC_API ClearCache(); 653 654 // The cache can also be pre-initialized for a specific device with all possible CLBLast kernels. 655 // Further CLBlast routine calls will then run at maximum speed. 656 StatusCode PUBLIC_API FillCache(const cl_device_id device); 657 658 // ================================================================================================= 659 660 // Overrides tuning parameters for a specific device-precision-kernel combination. The next time 661 // the target routine is called it will re-compile and use the new parameters from then on. 662 StatusCode PUBLIC_API OverrideParameters(const cl_device_id device, const std::string &kernel_name, 663 const Precision precision, 664 const std::unordered_map<std::string,size_t> ¶meters); 665 666 // ================================================================================================= 667 668 } // namespace clblast 669 670 // CLBLAST_CLBLAST_H_ 671 #endif 672