1/******************************************************************************* 2* Copyright 2019-2021 Intel Corporation 3* 4* Licensed under the Apache License, Version 2.0 (the "License"); 5* you may not use this file except in compliance with the License. 6* You may obtain a copy of the License at 7* 8* http://www.apache.org/licenses/LICENSE-2.0 9* 10* Unless required by applicable law or agreed to in writing, software 11* distributed under the License is distributed on an "AS IS" BASIS, 12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13* See the License for the specific language governing permissions and 14* limitations under the License. 15*******************************************************************************/ 16 17#include "gpu/ocl/ocl_math_utils.h" 18 19#if ELEMENT_SIZE == 2 20#pragma OPENCL EXTENSION cl_intel_subgroups_short : enable 21#define ELEMENT ushort 22#define ELEMENT2 ushort2 23#define ELEMENT4 ushort4 24#define ELEMENT8 ushort8 25#define ELEMENT16 ushort16 26#define ELEMENT_INT ushort2 27#define ELEMENT_INT4 ushort8 28#define VLOAD_ELEMENT_INT vload2 29#define ELEMENTS_PER_INT 2 30#define BLOCK_READ_ELEMENT2 intel_sub_group_block_read_us2 31#define BLOCK_READ_ELEMENT4 intel_sub_group_block_read_us4 32#define BLOCK_READ_ELEMENT_INT intel_sub_group_block_read_us2 33#define MASKED_BLOCK_READ_ELEMENT_INT masked_block_read_element2 34#define BLOCK_WRITE_ELEMENT_INT4 intel_sub_group_block_write_us8 35#elif ELEMENT_SIZE == 1 36#define ELEMENT uchar 37#define ELEMENT2 uchar2 38#define ELEMENT4 uchar4 39#define ELEMENT8 uchar8 40#define ELEMENT16 uchar16 41#define ELEMENT_INT uchar4 42#define ELEMENT_INT4 uchar16 43#define VLOAD_ELEMENT_INT vload4 44#define BLOCK_READ_ELEMENT2 intel_sub_group_block_read_uc2 45#define BLOCK_READ_ELEMENT4 intel_sub_group_block_read_uc4 46#define BLOCK_READ_ELEMENT_INT intel_sub_group_block_read_uc4 47#define MASKED_BLOCK_READ_ELEMENT_INT masked_block_read_element4 48#define BLOCK_WRITE_ELEMENT_INT4 intel_sub_group_block_write_uc16 49#define ELEMENTS_PER_INT 4 50#define SUM_T int 51#define SUM_T4 int4 52#define CONVERT_SUM_T convert_int 53#define CONVERT_SUM_T4 convert_int4 54#if COPY_SIGNED 55#define AS_SIGNED_ELEMENT as_char 56#define AS_SIGNED_ELEMENT4 as_char4 57#define AS_SIGNED_ELEMENT_INT as_char4 58#define SIGNED_ELEMENT_INT char4 59#else 60#define AS_SIGNED_ELEMENT as_uchar 61#define AS_SIGNED_ELEMENT4 as_uchar4 62#define AS_SIGNED_ELEMENT_INT as_uchar4 63#define SIGNED_ELEMENT_INT uchar4 64#endif 65#else 66#error Unsupported element size. 67#endif 68 69#if !COPY_A && !COPY_B 70#error Source matrix not defined. 71#endif 72 73inline ELEMENT2 masked_block_read_element2(global ELEMENT *p, int rem) { 74 ELEMENT2 v; 75 int lid = get_sub_group_local_id(); 76 int sg = get_sub_group_size(); 77 78 v.s0 = (lid < rem) ? p[lid] : 0; 79 v.s1 = (lid + sg < rem) ? p[lid + sg] : 0; 80 81 return v; 82} 83 84inline ELEMENT4 masked_block_read_element4(global ELEMENT *p, int rem) { 85 ELEMENT4 v; 86 int lid = get_sub_group_local_id(); 87 int sg = get_sub_group_size(); 88 89 v.s0 = (lid < rem) ? p[lid] : 0; 90 v.s1 = (lid + sg < rem) ? p[lid + sg] : 0; 91 v.s2 = (lid + 2 * sg < rem) ? p[lid + 2 * sg] : 0; 92 v.s3 = (lid + 3 * sg < rem) ? p[lid + 3 * sg] : 0; 93 94 return v; 95} 96 97__attribute__((overloadable)) inline int sum(int v) { 98 return sub_group_reduce_add(v); 99} 100 101__attribute__((overloadable)) inline int sum(int4 v) { 102 return sub_group_reduce_add(v.s0) + sub_group_reduce_add(v.s1) 103 + sub_group_reduce_add(v.s2) + sub_group_reduce_add(v.s3); 104} 105 106void dummy_dpas() { 107 if (get_sub_group_local_id() >= 16) { 108 int __builtin_IB_sub_group_idpas_s8_s8_8_1(int, int, int8) 109 __attribute__((const)); 110 global volatile int *_; 111 112 int z = __builtin_IB_sub_group_idpas_s8_s8_8_1(0, _[0], 1); 113 for (int i = 0; i < z; i++) 114 (void)_[0]; 115 } 116} 117 118#define DUMMY_DPAS dummy_dpas() 119 120#if ELEMENT_SIZE == 2 121#define PARTIAL_LOAD(regs, rrem, crem, cc, p) \ 122 if ((2 * cc + 1) < crem) { \ 123 if (lid < rrem) regs[cc] = vload2(0, p); \ 124 } else if ((2 * cc) < crem) { \ 125 if (lid < rrem) regs[cc].s0 = *(p); \ 126 } 127#elif ELEMENT_SIZE == 1 128#define PARTIAL_LOAD(regs, rrem, crem, cc, p) \ 129 if ((4 * cc + 3) < crem) { \ 130 if (lid < rrem) regs[cc] = vload4(0, p); \ 131 } else if ((4 * cc + 2) < crem) { \ 132 if (lid < rrem) regs[cc].s012 = vload3(0, p); \ 133 } else if ((4 * cc + 1) < crem) { \ 134 if (lid < rrem) regs[cc].s01 = vload2(0, p); \ 135 } else if (4 * cc < crem) { \ 136 if (lid < rrem) regs[cc].s0 = *(p); \ 137 } 138#endif 139 140#if COPY_A 141 142#define UNROLL_M 32 143#define UNROLL_K (32 / ELEMENT_SIZE) 144 145#if COPY_SUM 146#define GET_A_SUM_ADDRESS \ 147 int k_align = (k + UNROLL_K - 1) & ~(UNROLL_K - 1); \ 148 global int *a_sum = (global int *)(a_packed + offseta_packed \ 149 + m0 * lda_packed + k_align * UNROLL_M); 150#else 151#define GET_A_SUM_ADDRESS 152#endif 153 154#if COPY_CLEAR_SUM 155 156// A sum clear kernel: initialize row sums to zero. 157__attribute__((intel_reqd_sub_group_size(8))) kernel void 158xe_hp_systolic_gemm_copy(long m, long k, global ELEMENT *a_packed, 159 int offseta_packed, int lda_packed) { 160 161 uint m0 = (sub_group_broadcast(get_global_id(0), 0) / 8) * UNROLL_M; 162 163 GET_A_SUM_ADDRESS; 164 165 uint4 zero = 0; 166 intel_sub_group_block_write4(a_sum, zero); 167} 168 169#elif !COPY_TRANS 170 171#if ELEMENT_SIZE == 2 172#define REPACK_REG(rr, cc) \ 173 blk_r[rr].s##cc = (((uint)c[2 * cc + 1].s##rr) << 16) | c[2 * cc].s##rr 174#elif ELEMENT_SIZE == 1 175#define REPACK_REG(rr, cc) \ 176 blk_r[rr].s##cc = (((uint)c[4 * cc + 3].s##rr) << 24) \ 177 | (((uint)c[4 * cc + 2].s##rr) << 16) \ 178 | (((uint)c[4 * cc + 1].s##rr) << 8) | c[4 * cc].s##rr 179#endif 180 181#define REPACK_CC(cc) \ 182 REPACK_REG(0, cc); \ 183 REPACK_REG(1, cc); \ 184 REPACK_REG(2, cc); \ 185 REPACK_REG(3, cc) 186 187#define REPACK \ 188 REPACK_CC(0); \ 189 REPACK_CC(1); \ 190 REPACK_CC(2); \ 191 REPACK_CC(3); \ 192 REPACK_CC(4); \ 193 REPACK_CC(5); \ 194 REPACK_CC(6); \ 195 REPACK_CC(7) 196 197// Nontranspose A copy. 198// Each thread packs a 32x16 (f16/bf16) or 32x32 (u8/s8) block of A. 199__attribute__((intel_reqd_sub_group_size(8))) kernel void 200xe_hp_systolic_gemm_copy(long m, long k, global ELEMENT *a, long offseta, 201 long lda, global ELEMENT *a_packed, int offseta_packed, 202 int lda_packed) { 203 204 int lid = get_sub_group_local_id(); 205 uint m0 = (sub_group_broadcast(get_global_id(0), 0) / 8) * UNROLL_M; 206 uint k0 = get_global_id(1) * UNROLL_K; 207 int mrem = m - m0; 208 int krem = k - k0; 209 bool aligned = ((as_long(a) | lda | offseta) & (ELEMENTS_PER_INT - 1)) == 0; 210 211 if (mrem <= 0 || krem <= 0) return; 212 213 GET_A_SUM_ADDRESS; 214 215 a += offseta + m0 + k0 * lda; 216 a_packed += offseta_packed + m0 * lda_packed + k0 * UNROLL_M; 217 218 // Read all columns. 219 ELEMENT4 c[UNROLL_K]; 220 221 if (mrem >= UNROLL_M && krem >= UNROLL_K && aligned) { 222 for (int h = 0; h < UNROLL_K; h++) 223 c[h] = BLOCK_READ_ELEMENT4(a + h * lda); 224 } else { 225 for (int h = 0; h < UNROLL_K; h++) 226 if (h < krem) 227 c[h] = masked_block_read_element4(a + h * lda, mrem); 228 else 229 c[h] = 0; 230 } 231 232 // Rearrange. 233 uint8 blk_r[UNROLL_M / 8]; 234 REPACK; 235 236 // Write out. 237 for (int rr = 0; rr < UNROLL_M / 8; rr++) 238 intel_sub_group_block_write8( 239 (global uint *)(a_packed + rr * UNROLL_K * 8), blk_r[rr]); 240 241 // Sum if needed. 242#if COPY_SUM 243 SUM_T4 sum = 0; 244 for (int h = 0; h < UNROLL_K; h++) 245 sum += CONVERT_SUM_T4(AS_SIGNED_ELEMENT4(c[h])); 246 atomic_add(a_sum + lid, sum.s0); 247 atomic_add(a_sum + lid + 8, sum.s1); 248 atomic_add(a_sum + lid + 16, sum.s2); 249 atomic_add(a_sum + lid + 24, sum.s3); 250#endif 251 252 DUMMY_DPAS; 253} 254 255#else /* COPY_TRANS */ 256 257// Transpose A copy. 258__attribute__((intel_reqd_workgroup_walk_order(1, 0))) 259__attribute__((intel_reqd_sub_group_size(8))) kernel void 260xe_hp_systolic_gemm_copy(long m, long k, global ELEMENT *a, long offseta, 261 long lda, global ELEMENT *a_packed, int offseta_packed, 262 int lda_packed) { 263 264 int lid = get_sub_group_local_id(); 265 uint m0 = (sub_group_broadcast(get_global_id(0), 0) / 8) * UNROLL_M; 266 uint k0 = get_global_id(1) * UNROLL_K; 267 int mrem = m - m0; 268 int krem = k - k0; 269 270 if (mrem <= 0 || krem <= 0) return; 271 272 GET_A_SUM_ADDRESS; 273 274 a += offseta + m0 * lda + k0; 275 a_packed += offseta_packed + m0 * lda_packed + k0 * UNROLL_M; 276 277#if COPY_SUM 278 SUM_T sum[UNROLL_M / 8] = {0}; 279#endif 280 281 for (int rr = 0; rr < UNROLL_M / 8; rr++, mrem -= 8) { 282 ELEMENT_INT regs[8]; 283 284 if (mrem >= UNROLL_M && krem >= UNROLL_K) { 285 for (int cc = 0; cc < UNROLL_K / ELEMENTS_PER_INT; cc++) 286 regs[cc] = VLOAD_ELEMENT_INT(0, 287 a + ((rr * 8) + lid) * lda + (cc * ELEMENTS_PER_INT)); 288 } else { 289 for (int cc = 0; cc < UNROLL_K / ELEMENTS_PER_INT; cc++) { 290 regs[cc] = 0; 291 PARTIAL_LOAD(regs, mrem, krem, cc, 292 a + ((rr * 8) + lid) * lda + (cc * ELEMENTS_PER_INT)); 293 } 294 } 295 296 uint8 blk_r; 297 blk_r.s0 = as_uint(regs[0]); 298 blk_r.s1 = as_uint(regs[1]); 299 blk_r.s2 = as_uint(regs[2]); 300 blk_r.s3 = as_uint(regs[3]); 301 blk_r.s4 = as_uint(regs[4]); 302 blk_r.s5 = as_uint(regs[5]); 303 blk_r.s6 = as_uint(regs[6]); 304 blk_r.s7 = as_uint(regs[7]); 305 306#if COPY_SUM 307 for (int cc = 0; cc < UNROLL_K / ELEMENTS_PER_INT; cc++) { 308 sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s0)); 309 sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s1)); 310 sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s2)); 311 sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s3)); 312 } 313#endif 314 315 intel_sub_group_block_write8( 316 (global uint *)(a_packed + rr * UNROLL_K * 8), blk_r); 317 } 318 319#if COPY_SUM 320 atomic_add(a_sum + lid, sum[0]); 321 atomic_add(a_sum + lid + 8, sum[1]); 322 atomic_add(a_sum + lid + 16, sum[2]); 323 atomic_add(a_sum + lid + 24, sum[3]); 324#endif 325 326 DUMMY_DPAS; 327} 328 329#endif /* !COPY_TRANS */ 330#endif /* COPY_A */ 331 332#if COPY_B 333 334#define UNROLL_K (32 / ELEMENT_SIZE) 335 336#if ELEMENT_SIZE == 2 337#define REPACK_CC(cc) \ 338 do { \ 339 colgroups[cc].s01 = cols[cc * 4]; \ 340 colgroups[cc].s23 = cols[cc * 4 + 1]; \ 341 colgroups[cc].s45 = cols[cc * 4 + 2]; \ 342 colgroups[cc].s67 = cols[cc * 4 + 3]; \ 343 } while (false) 344#define REPACK_CC2(cc) \ 345 do { \ 346 colgroups[cc].s02 = cols[cc * 2]; \ 347 colgroups[cc].s13 = cols2[cc * 2]; \ 348 colgroups[cc].s46 = cols[cc * 2 + 1]; \ 349 colgroups[cc].s57 = cols2[cc * 2 + 1]; \ 350 } while (false) 351#elif ELEMENT_SIZE == 1 352#define REPACK_CC(cc) \ 353 do { \ 354 colgroups[cc].s0123 = cols[cc * 4]; \ 355 colgroups[cc].s4567 = cols[cc * 4 + 1]; \ 356 colgroups[cc].s89ab = cols[cc * 4 + 2]; \ 357 colgroups[cc].scdef = cols[cc * 4 + 3]; \ 358 } while (false) 359#define REPACK_CC4(cc) \ 360 do { \ 361 colgroups[cc].s048c = cols[cc]; \ 362 colgroups[cc].s159d = cols2[cc]; \ 363 colgroups[cc].s26ae = cols3[cc]; \ 364 colgroups[cc].s37bf = cols4[cc]; \ 365 } while (false) 366#endif 367 368#if COPY_SUM 369#define GET_B_SUM_ADDRESS \ 370 int k_align = (k + UNROLL_K - 1) & ~(UNROLL_K - 1); \ 371 global int *b_sum = (global int *)(b_packed + offsetb_packed \ 372 + n0 * ldb_packed + k_align * UNROLL_N); 373#else 374#define GET_B_SUM_ADDRESS 375#endif 376 377#if COPY_CLEAR_SUM 378 379// B sum clear kernel: initialize column sums to zero. 380__attribute__((intel_reqd_sub_group_size(8))) kernel void 381xe_hp_systolic_gemm_copy(long k, long n, global ELEMENT *b_packed, 382 int offsetb_packed, int ldb_packed) { 383 384 uint n0 = (sub_group_broadcast(get_global_id(0), 0) / 8) * UNROLL_N; 385 386 GET_B_SUM_ADDRESS; 387 388 uint4 zero = 0; 389 intel_sub_group_block_write4(b_sum, zero); 390#if UNROLL_N > 32 391 intel_sub_group_block_write2(b_sum + 32, zero.s01); 392#endif 393} 394 395#elif !COPY_TRANS 396 397// Each thread packs a 16x{32,48} (f16/bf16) or 32x{32,48} (u8/s8) block of B. 398// Nontranspose B copy. 399__attribute__((intel_reqd_sub_group_size(8))) kernel void 400xe_hp_systolic_gemm_copy(long k, long n, global ELEMENT *b, long offsetb, 401 long ldb, global ELEMENT *b_packed, int offsetb_packed, 402 int ldb_packed) { 403 404 int lid = get_sub_group_local_id(); 405 uint k0 = (sub_group_broadcast(get_global_id(0), 0) / 8) * UNROLL_K; 406 uint n0 = get_global_id(1) * UNROLL_N; 407 int krem = k - k0; 408 int nrem = n - n0; 409 bool aligned = ((as_long(b) | ldb | offsetb) & (ELEMENTS_PER_INT - 1)) == 0; 410 411 if (nrem <= 0 || krem <= 0) return; 412 413 GET_B_SUM_ADDRESS; 414 b += offsetb + k0 + n0 * ldb; 415 b_packed += offsetb_packed + n0 * ldb_packed + k0 * UNROLL_N; 416 417 // Copy in two halves. 418 419#define UNROLL_N_CHUNK (UNROLL_N / 2) 420#if COPY_SUM 421 SUM_T sums[UNROLL_N]; 422#endif 423 ELEMENT_INT cols[UNROLL_N / 2]; 424 425 for (int c0 = 0; c0 < UNROLL_N; 426 c0 += UNROLL_N_CHUNK, nrem -= UNROLL_N_CHUNK) { 427 // Read all columns. 428 if (krem >= UNROLL_K && nrem >= UNROLL_N_CHUNK && aligned) { 429 for (int c = 0; c < UNROLL_N_CHUNK; c++) 430 cols[c] = BLOCK_READ_ELEMENT_INT(b + (c + c0) * ldb); 431 } else { 432 for (int c = 0; c < UNROLL_N_CHUNK; c++) 433 if (c < nrem) 434 cols[c] = MASKED_BLOCK_READ_ELEMENT_INT( 435 b + (c + c0) * ldb, krem); 436 else 437 cols[c] = 0; 438 } 439 440 // Repack. 441 ELEMENT_INT4 colgroups[UNROLL_N_CHUNK / 4]; 442 for (int cc = 0; cc < UNROLL_N_CHUNK / 4; cc++) 443 REPACK_CC(cc); 444 445 // Write out. 446 for (int cc = 0; cc < UNROLL_N_CHUNK / 4; cc++) 447 BLOCK_WRITE_ELEMENT_INT4( 448 b_packed + (cc * 4 + c0) * UNROLL_K, colgroups[cc]); 449 450 // Sum if needed. 451#if COPY_SUM 452 for (int c = 0; c < UNROLL_N_CHUNK; c++) 453 sums[c + c0] = sum(CONVERT_SUM_T4(AS_SIGNED_ELEMENT_INT(cols[c]))); 454#endif 455 } 456 457 // Accumulate sums. 458#if COPY_SUM 459 for (int c0 = 0; c0 < UNROLL_N; c0 += get_sub_group_size()) 460 atomic_add(b_sum + c0 + lid, sums[c0 + lid]); 461#endif 462 463 DUMMY_DPAS; 464} 465 466#else /* COPY_TRANS */ 467 468#define ADD_SUM(coln) \ 469 for (int cc = 0; cc < UNROLL_N / 4; cc++) { \ 470 sums[4 * cc + 0] \ 471 += sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s0))); \ 472 sums[4 * cc + 1] \ 473 += sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s1))); \ 474 sums[4 * cc + 2] \ 475 += sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s2))); \ 476 sums[4 * cc + 3] \ 477 += sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s3))); \ 478 } 479 480// Transpose B copy. 481__attribute__((intel_reqd_workgroup_walk_order(1, 0))) 482__attribute__((intel_reqd_sub_group_size(8))) kernel void 483xe_hp_systolic_gemm_copy(long k, long n, global ELEMENT *b, long offsetb, 484 long ldb, global ELEMENT *b_packed, int offsetb_packed, 485 int ldb_packed) { 486 487 int lid = get_sub_group_local_id(); 488 uint k0 = (sub_group_broadcast(get_global_id(0), 0) / 8) * UNROLL_K; 489 uint n0 = get_global_id(1) * UNROLL_N; 490 int krem = k - k0; 491 int nrem = n - n0; 492 int sg = get_sub_group_size(); 493 494 if (nrem <= 0 || krem <= 0) return; 495 496 GET_B_SUM_ADDRESS; 497 b += offsetb + n0 + k0 * ldb; 498 b_packed += offsetb_packed + n0 * ldb_packed + k0 * UNROLL_N; 499 500 // Read upper 16x{32,48} submatrix. 501 ELEMENT_INT cols[UNROLL_N / ELEMENTS_PER_INT]; 502 ELEMENT_INT cols2[UNROLL_N / ELEMENTS_PER_INT]; 503 ELEMENT_INT4 colgroups[UNROLL_N / 4]; 504 if (krem >= 2 * sg && nrem >= UNROLL_N) { 505 for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++) { 506 cols[cc] = VLOAD_ELEMENT_INT( 507 0, b + cc * ELEMENTS_PER_INT + lid * ldb); 508 cols2[cc] = VLOAD_ELEMENT_INT( 509 0, b + cc * ELEMENTS_PER_INT + (lid + sg) * ldb); 510 } 511 } else { 512 for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++) { 513 cols[cc] = 0; 514 cols2[cc] = 0; 515 PARTIAL_LOAD(cols, krem, nrem, cc, 516 b + cc * ELEMENTS_PER_INT + lid * ldb); 517 PARTIAL_LOAD(cols2, krem - sg, nrem, cc, 518 b + cc * ELEMENTS_PER_INT + (lid + sg) * ldb); 519 } 520 } 521#if ELEMENT_SIZE == 2 522 // Repack. 523 for (int cc = 0; cc < UNROLL_N / 4; cc++) 524 REPACK_CC2(cc); 525#else 526 // Read lower 16x{32,48} submatrix. 527 ELEMENT_INT cols3[UNROLL_N / ELEMENTS_PER_INT]; 528 ELEMENT_INT cols4[UNROLL_N / ELEMENTS_PER_INT]; 529 krem -= 2 * sg; 530 if (krem >= 2 * sg && nrem >= UNROLL_N) { 531 for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++) { 532 cols3[cc] = VLOAD_ELEMENT_INT( 533 0, b + cc * ELEMENTS_PER_INT + (lid + 2 * sg) * ldb); 534 cols4[cc] = VLOAD_ELEMENT_INT( 535 0, b + cc * ELEMENTS_PER_INT + (lid + 3 * sg) * ldb); 536 } 537 } else { 538 for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++) { 539 cols3[cc] = 0; 540 cols4[cc] = 0; 541 PARTIAL_LOAD(cols3, krem, nrem, cc, 542 b + cc * ELEMENTS_PER_INT + (lid + 2 * sg) * ldb); 543 PARTIAL_LOAD(cols4, krem - sg, nrem, cc, 544 b + cc * ELEMENTS_PER_INT + (lid + 3 * sg) * ldb); 545 } 546 } 547 for (int cc = 0; cc < UNROLL_N / 4; cc++) 548 REPACK_CC4(cc); 549#endif 550 551 // Write out. 552 for (int cc = 0; cc < UNROLL_N / 4; cc++) 553 BLOCK_WRITE_ELEMENT_INT4(b_packed + cc * 4 * UNROLL_K, colgroups[cc]); 554 555#if COPY_SUM 556 SUM_T sums[UNROLL_N] = {0}; 557 ADD_SUM(cols); 558 ADD_SUM(cols2); 559 ADD_SUM(cols3); 560 ADD_SUM(cols4); 561 562 for (int c0 = 0; c0 < UNROLL_N; c0 += get_sub_group_size()) 563 atomic_add(b_sum + c0 + lid, sums[c0 + lid]); 564#endif 565 566 DUMMY_DPAS; 567} 568 569#endif /* !COPY_TRANS */ 570#endif /* COPY_B */ 571