1 2// ================================================================================================= 3// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This 4// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- 5// width of 100 characters per line. 6// 7// Author(s): 8// Cedric Nugteren <www.cedricnugteren.nl> 9// 10// This file contains the Xgemv kernel (fast versions) for matrix-vector multiplication. 11// 12// ================================================================================================= 13 14// Enables loading of this file using the C++ pre-processor's #include (C++11 standard raw string 15// literal). Comment-out this line for syntax-highlighting when developing. 16R"( 17 18// ================================================================================================= 19 20// Parameters set by the tuner or by the database. Here they are given a basic default value in case 21// this kernel file is used outside of the CLBlast library. 22 23// 1: For the full version, see 'xgemv.opencl' 24 25// 2: For the fast version 26#ifndef WGS2 27 #define WGS2 64 // The local work-group size 28#endif 29#ifndef WPT2 30 #define WPT2 1 // The amount of work-per-thread 31#endif 32#ifndef VW2 33 #define VW2 1 // Vector width of matrix A loads 34#endif 35 36// 3: For the fast rotated version 37#ifndef WGS3 38 #define WGS3 64 // The local work-group size 39#endif 40#ifndef WPT3 41 #define WPT3 1 // The tile-size 42#endif 43#ifndef VW3 44 #define VW3 1 // Vector width of matrix A loads 45#endif 46 47// ================================================================================================= 48 49// Data-widths for the 'fast' kernel 50#if VW2 == 1 51 typedef real realVF; 52#elif VW2 == 2 53 typedef real2 realVF; 54#elif VW2 == 4 55 typedef real4 realVF; 56#elif VW2 == 8 57 typedef real8 realVF; 58#elif VW2 == 16 59 typedef real16 realVF; 60#endif 61 62// Data-widths for the 'fast' kernel with rotated matrix 63#if VW3 == 1 64 typedef real realVFR; 65#elif VW3 == 2 66 typedef real2 realVFR; 67#elif VW3 == 4 68 typedef real4 realVFR; 69#elif VW3 == 8 70 typedef real8 realVFR; 71#elif VW3 == 16 72 typedef real16 realVFR; 73#endif 74 75// ================================================================================================= 76 77// Loads a vector input value 78INLINE_FUNC realVF LoadMatrixAVF(const __global realVF* restrict agm, const int x, const int y, 79 const int a_ld) { 80 return agm[a_ld*y + x]; 81} 82 83// ================================================================================================= 84 85// Faster version of the kernel, assuming that: 86// --> 'm' and 'n' are multiples of WGS2 87// --> 'a_offset' is 0 88// --> 'a_ld' is a multiple of VW2 89// --> 'a_rotated' is 0 90// --> 'do_conjugate' is 0 91__kernel __attribute__((reqd_work_group_size(WGS2, 1, 1))) 92void XgemvFast(const int m, const int n, 93 const real_arg arg_alpha, 94 const real_arg arg_beta, 95 const int a_rotated, 96 const __global realVF* restrict agm, const int a_offset, const int a_ld, 97 const __global real* restrict xgm, const int x_offset, const int x_inc, 98 __global real* ygm, const int y_offset, const int y_inc, 99 const int do_conjugate, const int parameter, 100 const int kl_unused, const int ku_unused) { 101 const real alpha = GetRealArg(arg_alpha); 102 const real beta = GetRealArg(arg_beta); 103 104 // Local memory for the vector X 105 __local real xlm[WGS2]; 106 107 // Initializes the accumulation registers 108 real acc[WPT2]; 109 #pragma unroll 110 for (int w=0; w<WPT2; ++w) { 111 SetToZero(acc[w]); 112 } 113 114 // Loops over work-group sized portions of the work 115 for (int kwg=0; kwg<n; kwg+=WGS2) { 116 117 // Loads the vector X into local memory 118 const int lid = get_local_id(0); 119 xlm[lid] = xgm[(kwg + lid)*x_inc + x_offset]; 120 121 // Synchronizes all threads in a workgroup 122 barrier(CLK_LOCAL_MEM_FENCE); 123 124 // The multiply-add function (not rotated) 125 #pragma unroll 126 for (int kl=0; kl<WGS2; ++kl) { 127 const int k = kwg + kl; 128 #pragma unroll 129 for (int w=0; w<WPT2/VW2; ++w) { 130 const int gid = (WPT2/VW2)*get_global_id(0) + w; 131 realVF avec = agm[(a_ld/VW2)*k + gid]; 132 #if VW2 == 1 133 MultiplyAdd(acc[VW2*w+0], xlm[kl], avec); 134 #elif VW2 == 2 135 MultiplyAdd(acc[VW2*w+0], xlm[kl], avec.x); 136 MultiplyAdd(acc[VW2*w+1], xlm[kl], avec.y); 137 #elif VW2 == 4 138 MultiplyAdd(acc[VW2*w+0], xlm[kl], avec.x); 139 MultiplyAdd(acc[VW2*w+1], xlm[kl], avec.y); 140 MultiplyAdd(acc[VW2*w+2], xlm[kl], avec.z); 141 MultiplyAdd(acc[VW2*w+3], xlm[kl], avec.w); 142 #elif VW2 == 8 143 MultiplyAdd(acc[VW2*w+0], xlm[kl], avec.s0); 144 MultiplyAdd(acc[VW2*w+1], xlm[kl], avec.s1); 145 MultiplyAdd(acc[VW2*w+2], xlm[kl], avec.s2); 146 MultiplyAdd(acc[VW2*w+3], xlm[kl], avec.s3); 147 MultiplyAdd(acc[VW2*w+4], xlm[kl], avec.s4); 148 MultiplyAdd(acc[VW2*w+5], xlm[kl], avec.s5); 149 MultiplyAdd(acc[VW2*w+6], xlm[kl], avec.s6); 150 MultiplyAdd(acc[VW2*w+7], xlm[kl], avec.s7); 151 #elif VW2 == 16 152 MultiplyAdd(acc[VW2*w+0], xlm[kl], avec.s0); 153 MultiplyAdd(acc[VW2*w+1], xlm[kl], avec.s1); 154 MultiplyAdd(acc[VW2*w+2], xlm[kl], avec.s2); 155 MultiplyAdd(acc[VW2*w+3], xlm[kl], avec.s3); 156 MultiplyAdd(acc[VW2*w+4], xlm[kl], avec.s4); 157 MultiplyAdd(acc[VW2*w+5], xlm[kl], avec.s5); 158 MultiplyAdd(acc[VW2*w+6], xlm[kl], avec.s6); 159 MultiplyAdd(acc[VW2*w+7], xlm[kl], avec.s7); 160 MultiplyAdd(acc[VW2*w+8], xlm[kl], avec.s8); 161 MultiplyAdd(acc[VW2*w+9], xlm[kl], avec.s9); 162 MultiplyAdd(acc[VW2*w+10], xlm[kl], avec.sA); 163 MultiplyAdd(acc[VW2*w+11], xlm[kl], avec.sB); 164 MultiplyAdd(acc[VW2*w+12], xlm[kl], avec.sC); 165 MultiplyAdd(acc[VW2*w+13], xlm[kl], avec.sD); 166 MultiplyAdd(acc[VW2*w+14], xlm[kl], avec.sE); 167 MultiplyAdd(acc[VW2*w+15], xlm[kl], avec.sF); 168 #endif 169 } 170 } 171 172 // Synchronizes all threads in a workgroup 173 barrier(CLK_LOCAL_MEM_FENCE); 174 } 175 176 // Stores the final result 177 #pragma unroll 178 for (int w=0; w<WPT2; ++w) { 179 const int gid = WPT2*get_global_id(0) + w; 180 real yval = ygm[gid*y_inc + y_offset]; 181 AXPBY(ygm[gid*y_inc + y_offset], alpha, acc[w], beta, yval); 182 } 183} 184 185// ================================================================================================= 186 187// Faster version of the kernel, assuming that: 188// --> 'm' and 'n' are multiples of WGS3 189// --> 'a_offset' is 0 190// --> 'a_ld' is a multiple of VW3 191// --> 'a_rotated' is 1 192// --> 'do_conjugate' is 0 193__kernel __attribute__((reqd_work_group_size(WGS3, 1, 1))) 194void XgemvFastRot(const int m, const int n, 195 const real_arg arg_alpha, 196 const real_arg arg_beta, 197 const int a_rotated, 198 const __global realVFR* restrict agm, const int a_offset, const int a_ld, 199 const __global real* restrict xgm, const int x_offset, const int x_inc, 200 __global real* ygm, const int y_offset, const int y_inc, 201 const int do_conjugate, const int parameter, 202 const int kl_unused, const int ku_unused) { 203 const real alpha = GetRealArg(arg_alpha); 204 const real beta = GetRealArg(arg_beta); 205 206 // Local memory to store a tile of the matrix (for coalescing) 207 __local real tile[WPT3][WGS3]; 208 const int lid = get_local_id(0); 209 const int lid_mod = lid % (WPT3/VW3); 210 const int lid_div = lid / (WPT3/VW3); 211 212 // Local memory for the vector X 213 __local real xlm[WPT3]; 214 215 // Initializes the accumulation register 216 real acc; 217 SetToZero(acc); 218 219 // Loops over tile-sized portions of the work 220 for (int kwg=0; kwg<n; kwg+=WPT3) { 221 222 // Loads the vector X into local memory 223 if (lid < WPT3) { 224 xlm[lid] = xgm[(kwg + lid) * x_inc + x_offset]; 225 } 226 227 // Loads the matrix A into local memory 228 #pragma unroll 229 for (int kl=0; kl<WPT3/VW3; ++kl) { 230 const int x = (kwg/VW3) + lid_mod; 231 const int y = get_group_id(0) * WGS3 + lid_div * (WPT3/VW3) + kl; 232 realVFR avec = agm[(a_ld/VW3) * y + x]; 233 #if VW3 == 1 234 tile[kl*VW3 + 0][lid] = avec; 235 #elif VW3 == 2 236 tile[kl*VW3 + 0][lid] = avec.x; 237 tile[kl*VW3 + 1][lid] = avec.y; 238 #elif VW3 == 4 239 tile[kl*VW3 + 0][lid] = avec.x; 240 tile[kl*VW3 + 1][lid] = avec.y; 241 tile[kl*VW3 + 2][lid] = avec.z; 242 tile[kl*VW3 + 3][lid] = avec.w; 243 #elif VW3 == 8 244 tile[kl*VW3 + 0][lid] = avec.s0; 245 tile[kl*VW3 + 1][lid] = avec.s1; 246 tile[kl*VW3 + 2][lid] = avec.s2; 247 tile[kl*VW3 + 3][lid] = avec.s3; 248 tile[kl*VW3 + 4][lid] = avec.s4; 249 tile[kl*VW3 + 5][lid] = avec.s5; 250 tile[kl*VW3 + 6][lid] = avec.s6; 251 tile[kl*VW3 + 7][lid] = avec.s7; 252 #elif VW3 == 16 253 tile[kl*VW3 + 0][lid] = avec.s0; 254 tile[kl*VW3 + 1][lid] = avec.s1; 255 tile[kl*VW3 + 2][lid] = avec.s2; 256 tile[kl*VW3 + 3][lid] = avec.s3; 257 tile[kl*VW3 + 4][lid] = avec.s4; 258 tile[kl*VW3 + 5][lid] = avec.s5; 259 tile[kl*VW3 + 6][lid] = avec.s6; 260 tile[kl*VW3 + 7][lid] = avec.s7; 261 tile[kl*VW3 + 8][lid] = avec.s8; 262 tile[kl*VW3 + 9][lid] = avec.s9; 263 tile[kl*VW3 + 10][lid] = avec.sA; 264 tile[kl*VW3 + 11][lid] = avec.sB; 265 tile[kl*VW3 + 12][lid] = avec.sC; 266 tile[kl*VW3 + 13][lid] = avec.sD; 267 tile[kl*VW3 + 14][lid] = avec.sE; 268 tile[kl*VW3 + 15][lid] = avec.sF; 269 #endif 270 } 271 272 // Synchronizes all threads in a workgroup 273 barrier(CLK_LOCAL_MEM_FENCE); 274 275 // The multiply-add function (rotated) 276 #pragma unroll 277 for (int kl=0; kl<WPT3/VW3; ++kl) { 278 #pragma unroll 279 for (int v=0; v<VW3; ++v) { 280 real aval = tile[lid_mod*VW3 + v][lid_div * (WPT3/VW3) + kl]; 281 real xval = xlm[kl*VW3 + v]; 282 MultiplyAdd(acc, xval, aval); 283 } 284 } 285 286 // Synchronizes all threads in a workgroup 287 barrier(CLK_LOCAL_MEM_FENCE); 288 } 289 290 // Stores the final result 291 const int gid = get_global_id(0); 292 real yval = ygm[gid * y_inc + y_offset]; 293 AXPBY(ygm[gid * y_inc + y_offset], alpha, acc, beta, yval); 294} 295 296// ================================================================================================= 297 298// End of the C++11 raw string literal 299)" 300 301// ================================================================================================= 302