1 2// ================================================================================================= 3// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This 4// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- 5// width of 100 characters per line. 6// 7// Author(s): 8// Cedric Nugteren <www.cedricnugteren.nl> 9// 10// This file contains the common defines and type-defs for the CLBlast OpenCL kernels. 11// 12// ================================================================================================= 13 14// Enables loading of this file using the C++ pre-processor's #include (C++11 standard raw string 15// literal). Comment-out this line for syntax-highlighting when developing. 16R"( 17// ================================================================================================= 18 19#define ROUTINE_GEMMBATCHED 20 21#ifdef USE_HALF 22 #ifdef FP16_SUPPORT 23 #define FP16_COMPUTE 24 #else 25 #define FP16_STORAGE 26 #endif 27#endif 28 29#ifndef PRECISION 30 #ifdef FP16_COMPUTE 31 #define PRECISION 16 32 #else 33 #define PRECISION 32 // Data-types: half, single or double precision, complex or regular 34 #endif 35#endif 36 37#ifdef FP16_STORAGE 38 typedef half net_t; 39 #define vload_net_t(offset,p) vload_half(offset,p) 40 #define vstore_net_t(data,offset,p) vstore_half(data,offset,p) 41#else 42 #ifdef FP16_COMPUTE 43 typedef half net_t; 44 #else 45 typedef float net_t; 46 #endif 47 #define vload_net_t(offset,p) ((p)[(offset)]) 48 #define vstore_net_t(data,offset,p) (((p)[(offset)])=(data)) 49#endif 50 51// ================================================================================================= 52#ifndef CUDA 53 // Enable support for double-precision 54 #if PRECISION == 16 55 #pragma OPENCL EXTENSION cl_khr_fp16: enable 56 #endif 57#endif 58 59// Half-precision 60#if PRECISION == 16 61 typedef half real; 62 typedef half2 real2; 63 typedef half4 real4; 64 typedef half8 real8; 65 typedef half16 real16; 66 #define SQ2 1.4142135623730951 67 #define ZERO 0 68 #define ONE 1 69 #define SMALLEST -1.0e14 70 71// Single-precision 72#elif PRECISION == 32 73 typedef float real; 74 typedef float2 real2; 75 typedef float4 real4; 76 typedef float8 real8; 77 typedef float16 real16; 78 #define SQ2 1.4142135623730951f 79 #define ZERO 0.0f 80 #define ONE 1.0f 81 #define SMALLEST -1.0e37f 82#endif 83 84// Single-element version of a complex number 85 typedef real singlereal; 86 87// Converts a 'real argument' value to a 'real' value as passed to the kernel. Normally there is no 88// conversion, but half-precision is not supported as kernel argument so it is converted from float. 89#if PRECISION == 16 90 typedef float real_arg; 91 #define GetRealArg(x) (half)x 92#else 93 typedef real real_arg; 94 #define GetRealArg(x) x 95#endif 96 97// Pointers to local memory objects (using a define because CUDA doesn't need them) 98#ifndef LOCAL_PTR 99 #define LOCAL_PTR __local 100#endif 101 102// ================================================================================================= 103 104// Don't use the non-IEEE754 compliant OpenCL built-in mad() instruction per default. For specific 105// devices, this is enabled (see src/routine.cpp). 106#ifndef USE_CL_MAD 107 #define USE_CL_MAD 0 108#endif 109 110// Sets a variable to zero 111#define SetToZero(a) a = ZERO 112 113// Sets a variable to zero (only the imaginary part) 114#define ImagToZero(a) 115 116// Sets a variable to one 117#define SetToOne(a) a = ONE 118 119// Determines whether a variable is zero 120#define IsZero(a) (a == ZERO) 121 122// The absolute value (component-wise) 123#define AbsoluteValue(value) value = fabs(value) 124 125// Negation (component-wise) 126#define Negate(value) value = -(value) 127 128// Adds two complex variables 129#define Add(c,a,b) c = a + b 130 131// Subtracts two complex variables 132#define Subtract(c,a,b) c = a - b 133 134// The scalar multiply function 135#define Multiply(c,a,b) c = a * b 136 137// The scalar multiply-add function 138#if USE_CL_MAD == 1 139 #define MultiplyAdd(c,a,b) c = mad(a, b, c) 140#else 141 #define MultiplyAdd(c,a,b) c += a * b 142#endif 143 144// The scalar multiply-subtract function 145#define MultiplySubtract(c,a,b) c -= a * b 146 147// The scalar division function: full division 148#define DivideFull(c,a,b) c = a / b 149 150// The scalar AXPBY function 151#define AXPBY(e,a,b,c,d) e = a*b + c*d 152 153// The complex conjugate operation for complex transforms 154#define COMPLEX_CONJUGATE(value) 155 156// ================================================================================================= 157 158// Force inlining functions or not: some compilers don't support the inline keyword 159#ifdef USE_INLINE_KEYWORD 160 #define INLINE_FUNC inline 161#else 162 #define INLINE_FUNC 163#endif 164 165// ================================================================================================= 166 167// Shuffled workgroup indices to avoid partition camping, see below. For specific devices, this is 168// enabled (see src/routine.cc). 169#ifndef USE_STAGGERED_INDICES 170 #define USE_STAGGERED_INDICES 0 171#endif 172 173// Staggered/shuffled group indices to avoid partition camping (AMD GPUs). Formula's are taken from: 174// http://docs.nvidia.com/cuda/samples/6_Advanced/transpose/doc/MatrixTranspose.pdf 175// More details: https://github.com/CNugteren/CLBlast/issues/53 176#if USE_STAGGERED_INDICES == 1 177 INLINE_FUNC int GetGroupIDFlat() { 178 return get_group_id(0) + get_num_groups(0) * get_group_id(1); 179 } 180 INLINE_FUNC int GetGroupID1() { 181 return (GetGroupIDFlat()) % get_num_groups(1); 182 } 183 INLINE_FUNC int GetGroupID0() { 184 return ((GetGroupIDFlat() / get_num_groups(1)) + GetGroupID1()) % get_num_groups(0); 185 } 186#else 187 INLINE_FUNC int GetGroupID1() { return get_group_id(1); } 188 INLINE_FUNC int GetGroupID0() { return get_group_id(0); } 189#endif 190 191// ================================================================================================= 192 193// End of the C++11 raw string literal 194)" 195 196// ================================================================================================= 197