1/******************************************************************************* 2* Copyright 2019-2020 Intel Corporation 3* 4* Licensed under the Apache License, Version 2.0 (the "License"); 5* you may not use this file except in compliance with the License. 6* You may obtain a copy of the License at 7* 8* http://www.apache.org/licenses/LICENSE-2.0 9* 10* Unless required by applicable law or agreed to in writing, software 11* distributed under the License is distributed on an "AS IS" BASIS, 12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13* See the License for the specific language governing permissions and 14* limitations under the License. 15*******************************************************************************/ 16 17#include "gpu/ocl/ocl_post_ops.h" 18#include "gpu/ocl/ocl_types.h" 19 20#if DT_F32 != 1 21#error "Only f32 implemented." 22#endif 23 24#define DO_FMA_NN(hh, i_mod_16, i_div_16, i_mod_4, i_div_4) \ 25 do { \ 26 c[i_div_4].s##i_mod_4 \ 27 = mad(sub_group_broadcast(a[hh].s##i_div_16, i_mod_16), \ 28 b.s##hh, c[i_div_4].s##i_mod_4); \ 29 } while (0) 30 31#define DO_FMA_NT(hh, i_mod_16, i_div_16, i_mod_4, i_div_4) \ 32 do { \ 33 c[i_div_4].s##i_mod_4 \ 34 = mad(sub_group_broadcast(a[hh].s##i_div_16, i_mod_16), b[hh], \ 35 c[i_div_4].s##i_mod_4); \ 36 } while (0) 37 38#if !defined(TRANS_A) 39#if !defined(TRANS_B) 40#define NN 41#define DO_FMA DO_FMA_NN 42#else 43#define NT 44#define DO_FMA DO_FMA_NT 45#endif 46#else 47#error "No superkernel implementation." 48#endif 49 50#define FMA_I_LOOP_32_ROW(hh) \ 51 do { \ 52 DO_FMA(hh, 0, 0, 0, 0); \ 53 DO_FMA(hh, 1, 0, 1, 0); \ 54 DO_FMA(hh, 2, 0, 2, 0); \ 55 DO_FMA(hh, 3, 0, 3, 0); \ 56 DO_FMA(hh, 4, 0, 0, 1); \ 57 DO_FMA(hh, 5, 0, 1, 1); \ 58 DO_FMA(hh, 6, 0, 2, 1); \ 59 DO_FMA(hh, 7, 0, 3, 1); \ 60 DO_FMA(hh, 8, 0, 0, 2); \ 61 DO_FMA(hh, 9, 0, 1, 2); \ 62 DO_FMA(hh, 10, 0, 2, 2); \ 63 DO_FMA(hh, 11, 0, 3, 2); \ 64 DO_FMA(hh, 12, 0, 0, 3); \ 65 DO_FMA(hh, 13, 0, 1, 3); \ 66 DO_FMA(hh, 14, 0, 2, 3); \ 67 DO_FMA(hh, 15, 0, 3, 3); \ 68 DO_FMA(hh, 16, 1, 0, 4); \ 69 DO_FMA(hh, 17, 1, 1, 4); \ 70 DO_FMA(hh, 18, 1, 2, 4); \ 71 DO_FMA(hh, 19, 1, 3, 4); \ 72 DO_FMA(hh, 20, 1, 0, 5); \ 73 DO_FMA(hh, 21, 1, 1, 5); \ 74 DO_FMA(hh, 22, 1, 2, 5); \ 75 DO_FMA(hh, 23, 1, 3, 5); \ 76 DO_FMA(hh, 24, 1, 0, 6); \ 77 DO_FMA(hh, 25, 1, 1, 6); \ 78 DO_FMA(hh, 26, 1, 2, 6); \ 79 DO_FMA(hh, 27, 1, 3, 6); \ 80 DO_FMA(hh, 28, 1, 0, 7); \ 81 DO_FMA(hh, 29, 1, 1, 7); \ 82 DO_FMA(hh, 30, 1, 2, 7); \ 83 DO_FMA(hh, 31, 1, 3, 7); \ 84 } while (0) 85 86#define FMA_I_LOOP_16_ROW(hh) \ 87 do { \ 88 DO_FMA(hh, 0, 0, 0, 0); \ 89 DO_FMA(hh, 1, 0, 1, 0); \ 90 DO_FMA(hh, 2, 0, 2, 0); \ 91 DO_FMA(hh, 3, 0, 3, 0); \ 92 DO_FMA(hh, 4, 0, 0, 1); \ 93 DO_FMA(hh, 5, 0, 1, 1); \ 94 DO_FMA(hh, 6, 0, 2, 1); \ 95 DO_FMA(hh, 7, 0, 3, 1); \ 96 DO_FMA(hh, 8, 0, 0, 2); \ 97 DO_FMA(hh, 9, 0, 1, 2); \ 98 DO_FMA(hh, 10, 0, 2, 2); \ 99 DO_FMA(hh, 11, 0, 3, 2); \ 100 DO_FMA(hh, 12, 0, 0, 3); \ 101 DO_FMA(hh, 13, 0, 1, 3); \ 102 DO_FMA(hh, 14, 0, 2, 3); \ 103 DO_FMA(hh, 15, 0, 3, 3); \ 104 } while (0) 105 106#if WITH_ELTWISE == 1 107#define POST_OP(val) \ 108 do { \ 109 if (last_k_block) \ 110 val = fwd_eltwise( \ 111 val, eltwise_alpha, eltwise_beta, eltwise_scale); \ 112 } while (0) 113#else 114#define POST_OP(val) 115#endif 116 117#define UPDATE_C_ROW(i, ii, betaZero) \ 118 do { \ 119 if (jrem > 0) \ 120 if (irem > i) { \ 121 float val = alpha * c[i / 4].s##ii \ 122 + ((betaZero) ? 0 : beta * *C); \ 123 POST_OP(val); \ 124 *C = val; \ 125 } \ 126 C++; \ 127 } while (0) 128 129#define UPDATE_C_32_ROW(betaZero) \ 130 do { \ 131 UPDATE_C_ROW(0, 0, betaZero); \ 132 UPDATE_C_ROW(1, 1, betaZero); \ 133 UPDATE_C_ROW(2, 2, betaZero); \ 134 UPDATE_C_ROW(3, 3, betaZero); \ 135 UPDATE_C_ROW(4, 0, betaZero); \ 136 UPDATE_C_ROW(5, 1, betaZero); \ 137 UPDATE_C_ROW(6, 2, betaZero); \ 138 UPDATE_C_ROW(7, 3, betaZero); \ 139 UPDATE_C_ROW(8, 0, betaZero); \ 140 UPDATE_C_ROW(9, 1, betaZero); \ 141 UPDATE_C_ROW(10, 2, betaZero); \ 142 UPDATE_C_ROW(11, 3, betaZero); \ 143 UPDATE_C_ROW(12, 0, betaZero); \ 144 UPDATE_C_ROW(13, 1, betaZero); \ 145 UPDATE_C_ROW(14, 2, betaZero); \ 146 UPDATE_C_ROW(15, 3, betaZero); \ 147 UPDATE_C_ROW(16, 0, betaZero); \ 148 UPDATE_C_ROW(17, 1, betaZero); \ 149 UPDATE_C_ROW(18, 2, betaZero); \ 150 UPDATE_C_ROW(19, 3, betaZero); \ 151 UPDATE_C_ROW(20, 0, betaZero); \ 152 UPDATE_C_ROW(21, 1, betaZero); \ 153 UPDATE_C_ROW(22, 2, betaZero); \ 154 UPDATE_C_ROW(23, 3, betaZero); \ 155 UPDATE_C_ROW(24, 0, betaZero); \ 156 UPDATE_C_ROW(25, 1, betaZero); \ 157 UPDATE_C_ROW(26, 2, betaZero); \ 158 UPDATE_C_ROW(27, 3, betaZero); \ 159 UPDATE_C_ROW(28, 0, betaZero); \ 160 UPDATE_C_ROW(29, 1, betaZero); \ 161 UPDATE_C_ROW(30, 2, betaZero); \ 162 UPDATE_C_ROW(31, 3, betaZero); \ 163 } while (0) 164 165#define SUPERKERNEL_PROLOGUE \ 166 global volatile int *p = plan; \ 167 int id = get_group_id(0); \ 168\ 169 A0 += offsetA; \ 170 B0 += offsetB; \ 171 C0 += offsetC; \ 172\ 173 while (id < threads) { \ 174 uint i0, j0; \ 175 uint kid0, kid1; \ 176\ 177 i0 = plan[2 * id + 2]; \ 178 j0 = plan[2 * id + 3]; \ 179 kid0 = (i0 >> 31); \ 180 kid1 = (j0 >> 31); \ 181 i0 &= ~(1 << 31); \ 182 j0 &= ~(1 << 31); \ 183 j0 += get_local_id(0); 184 185#define SUPERKERNEL_EPILOGUE \ 186 if (get_sub_group_local_id() == 0) id = atomic_inc(plan); \ 187\ 188 sub_group_barrier(0); \ 189 id = sub_group_broadcast(id, 0); \ 190 } \ 191 if (get_sub_group_local_id() == 0) { \ 192 if (atomic_inc(plan + 1) == (get_num_groups(0) - 1)) { \ 193 mem_fence(CLK_GLOBAL_MEM_FENCE); \ 194 plan[0] = get_num_groups(0); \ 195 plan[1] = 0; \ 196 } \ 197 } 198 199#ifdef NN 200__attribute__((intel_reqd_sub_group_size(16))) // attr:no-format 201kernel void 202gen9_gemm_nocopy_superkernel_f32(global int *plan, int threads, 203 global float *A0, global float *B0, global float *C0, long offsetA, 204 long offsetB, long offsetC, int lda, int ldb, int ldc, int m, int n, 205 int k, float alpha, float beta, int last_k_block, float eltwise_alpha, 206 float eltwise_beta, float eltwise_scale) { 207 SUPERKERNEL_PROLOGUE 208 209 float2 a[4]; // 32 x 4 block of A, 4x 32x1 block accesses 210 float4 b; // 4 x 16 block of B, 1x 4x16 scattered access 211 float4 c[8]; // 32 x 16 block of C, 8x 4x16 scattered access 212 213 int irem = m - i0; 214 int jrem = n - j0; 215 if (irem < 0) irem = 0; 216 if (jrem < 0) jrem = 0; 217 218 global float *A = A0 + i0; 219 global float *B = B0 + j0 * ldb; 220 global float *C = C0 + i0 + j0 * ldc; 221 222 global float *A_cols[4] = {A, A + lda, A + 2 * lda, A + 3 * lda}; 223 224 int ldax4 = lda << 2; 225 int ldbx4 = ldb << 2; 226 227 if (kid0 == 0) { 228 for (int z = 0; z < 8; z++) 229 c[z] = 0.f; 230 231 for (int h = 0; h < (k >> 2); h++) { 232 // Load A 233 for (int j = 0; j < 4; j++) { 234 a[j] = as_float2( 235 intel_sub_group_block_read2((global uint *)A_cols[j])); 236 A_cols[j] += ldax4; 237 } 238 239 // Load B 240 b = vload4(0, B); 241 B += 4; 242 243 // FMAs 244 FMA_I_LOOP_32_ROW(0); 245 FMA_I_LOOP_32_ROW(1); 246 FMA_I_LOOP_32_ROW(2); 247 FMA_I_LOOP_32_ROW(3); 248 } 249 250 int krem = k & 3; 251 if (krem > 0) { 252 for (int j = 0; j < 4; j++) 253 a[j] = as_float2( 254 intel_sub_group_block_read2((global uint *)A_cols[j])); 255 256 b = vload4(0, B); 257 258 FMA_I_LOOP_32_ROW(0); 259 if (krem > 1) FMA_I_LOOP_32_ROW(1); 260 if (krem > 2) FMA_I_LOOP_32_ROW(2); 261 } 262 } else { 263 if (irem > 16) irem = 16; 264 265 for (int z = 0; z < 4; z++) 266 c[z] = 0.f; 267 268 for (int h = 0; h < (k >> 2); h++) { 269 for (int j = 0; j < 4; j++) { 270 a[j].s0 = as_float( 271 intel_sub_group_block_read((global uint *)A_cols[j])); 272 A_cols[j] += ldax4; 273 } 274 275 b = vload4(0, B); 276 B += 4; 277 278 FMA_I_LOOP_16_ROW(0); 279 FMA_I_LOOP_16_ROW(1); 280 FMA_I_LOOP_16_ROW(2); 281 FMA_I_LOOP_16_ROW(3); 282 } 283 284 int krem = k & 3; 285 if (krem > 0) { 286 for (int j = 0; j < 4; j++) 287 a[j].s0 = as_float( 288 intel_sub_group_block_read((global uint *)A_cols[j])); 289 290 b = vload4(0, B); 291 292 FMA_I_LOOP_16_ROW(0); 293 if (krem > 1) FMA_I_LOOP_16_ROW(1); 294 if (krem > 2) FMA_I_LOOP_16_ROW(2); 295 } 296 } 297 298 if (beta == 0) 299 UPDATE_C_32_ROW(1); 300 else 301 UPDATE_C_32_ROW(0); 302 303 SUPERKERNEL_EPILOGUE 304} 305#endif 306 307#ifdef NT 308__attribute__((intel_reqd_sub_group_size(16))) // attr:no-format 309kernel void 310gen9_gemm_nocopy_superkernel_f32(global int *plan, int threads, 311 global float *A0, global float *B0, global float *C0, long offsetA, 312 long offsetB, long offsetC, int lda, int ldb, int ldc, int m, int n, 313 int k, float alpha, float beta, int last_k_block, float eltwise_alpha, 314 float eltwise_beta, float eltwise_scale) { 315 SUPERKERNEL_PROLOGUE 316 317 float2 a[2]; // 32 x 2 block of A, 2x 32x1 block accesses 318 float b[2]; // 2 x 16 block of B, 2x 1x16 block accesses 319 float4 c[8]; // 32 x 16 block of C, 8x 4x16 scattered access 320 321 int irem = m - i0; 322 int jrem = n - j0; 323 if (irem < 0) irem = 0; 324 if (jrem < 0) jrem = 0; 325 326 global float *A = A0 + i0; 327 global float *B = B0 + j0; 328 global float *C = C0 + i0 + j0 * ldc; 329 330 global float *A_cols[2] = {A, A + lda}; 331 global float *B_rows[2] = {B, B + ldb}; 332 333 int ldax2 = lda << 1; 334 int ldbx2 = ldb << 1; 335 336 if (kid0 == 0) { 337 for (int z = 0; z < 8; z++) 338 c[z] = 0.f; 339 340 for (int h = 0; h < (k >> 1); h++) { 341 // Load A 342 for (int j = 0; j < 2; j++) { 343 a[j] = as_float2( 344 intel_sub_group_block_read2((global uint *)A_cols[j])); 345 A_cols[j] += ldax2; 346 } 347 348 // Load B 349 for (int i = 0; i < 2; i++) { 350 b[i] = as_float( 351 intel_sub_group_block_read((global uint *)B_rows[i])); 352 B_rows[i] += ldbx2; 353 } 354 355 // FMAs 356 FMA_I_LOOP_32_ROW(0); 357 FMA_I_LOOP_32_ROW(1); 358 } 359 360 int krem = k & 1; 361 if (krem > 0) { 362 a[0] = as_float2( 363 intel_sub_group_block_read2((global uint *)A_cols[0])); 364 365 b[0] = as_float( 366 intel_sub_group_block_read((global uint *)B_rows[0])); 367 368 FMA_I_LOOP_32_ROW(0); 369 } 370 } else { 371 if (irem > 16) irem = 16; 372 373 for (int z = 0; z < 4; z++) 374 c[z] = 0.f; 375 376 for (int h = 0; h < (k >> 1); h++) { 377 for (int j = 0; j < 2; j++) { 378 a[j].s0 = as_float( 379 intel_sub_group_block_read((global uint *)A_cols[j])); 380 A_cols[j] += ldax2; 381 } 382 383 for (int i = 0; i < 2; i++) { 384 b[i] = as_float( 385 intel_sub_group_block_read((global uint *)B_rows[i])); 386 B_rows[i] += ldbx2; 387 } 388 389 FMA_I_LOOP_16_ROW(0); 390 FMA_I_LOOP_16_ROW(1); 391 } 392 393 int krem = k & 1; 394 if (krem > 0) { 395 a[0].s0 = as_float( 396 intel_sub_group_block_read((global uint *)A_cols[0])); 397 b[0] = as_float( 398 intel_sub_group_block_read((global uint *)B_rows[0])); 399 400 FMA_I_LOOP_16_ROW(0); 401 } 402 } 403 404 if (beta == 0) 405 UPDATE_C_32_ROW(1); 406 else 407 UPDATE_C_32_ROW(0); 408 409 SUPERKERNEL_EPILOGUE 410} 411#endif 412