1
2// =================================================================================================
3// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
4// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
5// width of 100 characters per line.
6//
7// Author(s):
8//   Cedric Nugteren <www.cedricnugteren.nl>
9//
10// This file contains the common defines and type-defs for the CLBlast OpenCL kernels.
11//
12// =================================================================================================
13
14// Enables loading of this file using the C++ pre-processor's #include (C++11 standard raw string
15// literal). Comment-out this line for syntax-highlighting when developing.
16R"(
17// =================================================================================================
18
19#define ROUTINE_GEMMBATCHED
20
21#ifdef USE_HALF
22    #ifdef FP16_SUPPORT
23        #define FP16_COMPUTE
24    #else
25        #define FP16_STORAGE
26    #endif
27#endif
28
29#ifndef PRECISION
30    #ifdef FP16_COMPUTE
31      #define PRECISION 16
32    #else
33      #define PRECISION 32      // Data-types: half, single or double precision, complex or regular
34    #endif
35#endif
36
37#ifdef FP16_STORAGE
38    typedef half net_t;
39    #define vload_net_t(offset,p) vload_half(offset,p)
40    #define vstore_net_t(data,offset,p) vstore_half(data,offset,p)
41#else
42    #ifdef FP16_COMPUTE
43        typedef half net_t;
44    #else
45        typedef float net_t;
46    #endif
47    #define vload_net_t(offset,p) ((p)[(offset)])
48    #define vstore_net_t(data,offset,p) (((p)[(offset)])=(data))
49#endif
50
51// =================================================================================================
52#ifndef CUDA
53  // Enable support for double-precision
54  #if PRECISION == 16
55    #pragma OPENCL EXTENSION cl_khr_fp16: enable
56  #endif
57#endif
58
59// Half-precision
60#if PRECISION == 16
61  typedef half real;
62  typedef half2 real2;
63  typedef half4 real4;
64  typedef half8 real8;
65  typedef half16 real16;
66  #define SQ2 1.4142135623730951
67  #define ZERO 0
68  #define ONE 1
69  #define SMALLEST -1.0e14
70
71// Single-precision
72#elif PRECISION == 32
73  typedef float real;
74  typedef float2 real2;
75  typedef float4 real4;
76  typedef float8 real8;
77  typedef float16 real16;
78  #define SQ2 1.4142135623730951f
79  #define ZERO 0.0f
80  #define ONE 1.0f
81  #define SMALLEST -1.0e37f
82#endif
83
84// Single-element version of a complex number
85  typedef real singlereal;
86
87// Converts a 'real argument' value to a 'real' value as passed to the kernel. Normally there is no
88// conversion, but half-precision is not supported as kernel argument so it is converted from float.
89#if PRECISION == 16
90  typedef float real_arg;
91  #define GetRealArg(x) (half)x
92#else
93  typedef real real_arg;
94  #define GetRealArg(x) x
95#endif
96
97// Pointers to local memory objects (using a define because CUDA doesn't need them)
98#ifndef LOCAL_PTR
99  #define LOCAL_PTR __local
100#endif
101
102// =================================================================================================
103
104// Don't use the non-IEEE754 compliant OpenCL built-in mad() instruction per default. For specific
105// devices, this is enabled (see src/routine.cpp).
106#ifndef USE_CL_MAD
107  #define USE_CL_MAD 0
108#endif
109
110// Sets a variable to zero
111#define SetToZero(a) a = ZERO
112
113// Sets a variable to zero (only the imaginary part)
114#define ImagToZero(a)
115
116// Sets a variable to one
117#define SetToOne(a) a = ONE
118
119// Determines whether a variable is zero
120#define IsZero(a) (a == ZERO)
121
122// The absolute value (component-wise)
123#define AbsoluteValue(value) value = fabs(value)
124
125// Negation (component-wise)
126#define Negate(value) value = -(value)
127
128// Adds two complex variables
129#define Add(c,a,b) c = a + b
130
131// Subtracts two complex variables
132#define Subtract(c,a,b) c = a - b
133
134// The scalar multiply function
135#define Multiply(c,a,b) c = a * b
136
137// The scalar multiply-add function
138#if USE_CL_MAD == 1
139  #define MultiplyAdd(c,a,b) c = mad(a, b, c)
140#else
141  #define MultiplyAdd(c,a,b) c += a * b
142#endif
143
144// The scalar multiply-subtract function
145#define MultiplySubtract(c,a,b) c -= a * b
146
147// The scalar division function: full division
148#define DivideFull(c,a,b) c = a / b
149
150// The scalar AXPBY function
151#define AXPBY(e,a,b,c,d) e = a*b + c*d
152
153// The complex conjugate operation for complex transforms
154#define COMPLEX_CONJUGATE(value)
155
156// =================================================================================================
157
158// Force inlining functions or not: some compilers don't support the inline keyword
159#ifdef USE_INLINE_KEYWORD
160  #define INLINE_FUNC inline
161#else
162  #define INLINE_FUNC
163#endif
164
165// =================================================================================================
166
167// Shuffled workgroup indices to avoid partition camping, see below. For specific devices, this is
168// enabled (see src/routine.cc).
169#ifndef USE_STAGGERED_INDICES
170  #define USE_STAGGERED_INDICES 0
171#endif
172
173// Staggered/shuffled group indices to avoid partition camping (AMD GPUs). Formula's are taken from:
174// http://docs.nvidia.com/cuda/samples/6_Advanced/transpose/doc/MatrixTranspose.pdf
175// More details: https://github.com/CNugteren/CLBlast/issues/53
176#if USE_STAGGERED_INDICES == 1
177  INLINE_FUNC int GetGroupIDFlat() {
178    return get_group_id(0) + get_num_groups(0) * get_group_id(1);
179  }
180  INLINE_FUNC int GetGroupID1() {
181    return (GetGroupIDFlat()) % get_num_groups(1);
182  }
183  INLINE_FUNC int GetGroupID0() {
184    return ((GetGroupIDFlat() / get_num_groups(1)) + GetGroupID1()) % get_num_groups(0);
185  }
186#else
187  INLINE_FUNC int GetGroupID1() { return get_group_id(1); }
188  INLINE_FUNC int GetGroupID0() { return get_group_id(0); }
189#endif
190
191// =================================================================================================
192
193// End of the C++11 raw string literal
194)"
195
196// =================================================================================================
197