1/******************************************************************************* 2* Copyright 2021 Intel Corporation 3* 4* Licensed under the Apache License, Version 2.0 (the "License"); 5* you may not use this file except in compliance with the License. 6* You may obtain a copy of the License at 7* 8* http://www.apache.org/licenses/LICENSE-2.0 9* 10* Unless required by applicable law or agreed to in writing, software 11* distributed under the License is distributed on an "AS IS" BASIS, 12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13* See the License for the specific language governing permissions and 14* limitations under the License. 15*******************************************************************************/ 16 17#if ELEMENT_SIZE == 2 18#pragma OPENCL EXTENSION cl_intel_subgroups_short : enable 19#define ELEMENT ushort 20#define ELEMENT2 ushort2 21#define ELEMENT4 ushort4 22#define ELEMENT_WORD ushort 23#define ELEMENT_WORD4 ushort4 24#define ELEMENT_INT ushort2 25#define ELEMENT_INT4 ushort8 26#define VLOAD_ELEMENT_INT vload2 27#define ELEMENTS_PER_INT 2 28#define BLOCK_READ_ELEMENT2 intel_sub_group_block_read_us2 29#define BLOCK_READ_ELEMENT4 intel_sub_group_block_read_us4 30#define BLOCK_READ_ELEMENT_WORD intel_sub_group_block_read_us 31#define MASKED_BLOCK_READ_ELEMENT_WORD masked_block_read_element 32#define BLOCK_WRITE_ELEMENT_WORD4 intel_sub_group_block_write_us4 33#define BLOCK_WRITE_ELEMENT_INT4 intel_sub_group_block_write_us8 34#elif ELEMENT_SIZE == 1 35#define ELEMENT uchar 36#define ELEMENT2 uchar2 37#define ELEMENT4 uchar4 38#define ELEMENT_WORD uchar2 39#define ELEMENT_WORD4 uchar8 40#define ELEMENT_INT uchar4 41#define ELEMENT_INT4 uchar16 42#define VLOAD_ELEMENT_INT vload4 43#define BLOCK_READ_ELEMENT2 intel_sub_group_block_read_uc2 44#define BLOCK_READ_ELEMENT4 intel_sub_group_block_read_uc4 45#define BLOCK_READ_ELEMENT_WORD intel_sub_group_block_read_uc2 46#define MASKED_BLOCK_READ_ELEMENT_WORD masked_block_read_element2 47#define BLOCK_WRITE_ELEMENT_WORD4 intel_sub_group_block_write_uc8 48#define BLOCK_WRITE_ELEMENT_INT2 intel_sub_group_block_write_uc8 49#define BLOCK_WRITE_ELEMENT_INT4 intel_sub_group_block_write_uc16 50#define ELEMENTS_PER_INT 4 51#define SUM_T int 52#define SUM_T2 int2 53#define SUM_T4 int4 54#define CONVERT_SUM_T convert_int 55#define CONVERT_SUM_T2 convert_int2 56#define CONVERT_SUM_T4 convert_int4 57#if COPY_SIGNED 58#define AS_SIGNED_ELEMENT as_char 59#define AS_SIGNED_ELEMENT4 as_char4 60#define AS_SIGNED_ELEMENT_WORD as_char2 61#define AS_SIGNED_ELEMENT_INT as_char4 62#define SIGNED_ELEMENT_WORD char2 63#define SIGNED_ELEMENT_INT char4 64#else 65#define AS_SIGNED_ELEMENT as_uchar 66#define AS_SIGNED_ELEMENT4 as_uchar4 67#define AS_SIGNED_ELEMENT_WORD as_uchar2 68#define AS_SIGNED_ELEMENT_INT as_uchar4 69#define SIGNED_ELEMENT_WORD uchar2 70#define SIGNED_ELEMENT_INT uchar4 71#endif 72#else 73#error Unsupported element size. 74#endif 75 76#if !COPY_A && !COPY_B 77#error Source matrix not defined. 78#endif 79 80inline ELEMENT masked_block_read_element(global ELEMENT *p, int rem) { 81 ELEMENT v; 82 int lid = get_sub_group_local_id(); 83 int sg = get_sub_group_size(); 84 85 v = (lid < rem) ? p[lid] : 0; 86 87 return v; 88} 89 90inline ELEMENT2 masked_block_read_element2(global ELEMENT *p, int rem) { 91 ELEMENT2 v; 92 int lid = get_sub_group_local_id(); 93 int sg = get_sub_group_size(); 94 95 v.s0 = (lid < rem) ? p[lid] : 0; 96 v.s1 = (lid + sg < rem) ? p[lid + sg] : 0; 97 98 return v; 99} 100 101inline ELEMENT4 masked_block_read_element4(global ELEMENT *p, int rem) { 102 ELEMENT4 v; 103 int lid = get_sub_group_local_id(); 104 int sg = get_sub_group_size(); 105 106 v.s0 = (lid < rem) ? p[lid] : 0; 107 v.s1 = (lid + sg < rem) ? p[lid + sg] : 0; 108 v.s2 = (lid + 2 * sg < rem) ? p[lid + 2 * sg] : 0; 109 v.s3 = (lid + 3 * sg < rem) ? p[lid + 3 * sg] : 0; 110 111 return v; 112} 113 114__attribute__((overloadable)) inline int sum(int v) { 115 return sub_group_reduce_add(v); 116} 117 118__attribute__((overloadable)) inline int sum(int2 v) { 119 return sub_group_reduce_add(v.s0) + sub_group_reduce_add(v.s1); 120} 121 122__attribute__((overloadable)) inline int sum(int4 v) { 123 return sub_group_reduce_add(v.s0) + sub_group_reduce_add(v.s1) 124 + sub_group_reduce_add(v.s2) + sub_group_reduce_add(v.s3); 125} 126 127void dummy_dpas() { 128 if (get_sub_group_local_id() >= 16) { 129 int __builtin_IB_sub_group_idpas_s8_s8_8_1(int, int, int8) 130 __attribute__((const)); 131 global volatile int *_; 132 133 int z = __builtin_IB_sub_group_idpas_s8_s8_8_1(0, _[0], 1); 134 for (int i = 0; i < z; i++) 135 (void)_[0]; 136 } 137} 138 139#define DUMMY_DPAS /*dummy_dpas()*/ 140 141#if ELEMENT_SIZE == 2 142#define PARTIAL_LOAD(regs, rrem, crem, cc, p) \ 143 if ((2 * cc + 1) < crem) { \ 144 if (lid < rrem) regs[cc] = vload2(0, p); \ 145 } else if ((2 * cc) < crem) { \ 146 if (lid < rrem) regs[cc].s0 = *(p); \ 147 } 148#elif ELEMENT_SIZE == 1 149#define PARTIAL_LOAD(regs, rrem, crem, cc, p) \ 150 if ((4 * cc + 3) < crem) { \ 151 if (lid < rrem) regs[cc] = vload4(0, p); \ 152 } else if ((4 * cc + 2) < crem) { \ 153 if (lid < rrem) regs[cc].s012 = vload3(0, p); \ 154 } else if ((4 * cc + 1) < crem) { \ 155 if (lid < rrem) regs[cc].s01 = vload2(0, p); \ 156 } else if (4 * cc < crem) { \ 157 if (lid < rrem) regs[cc].s0 = *(p); \ 158 } 159#endif 160 161#if COPY_A 162 163#define UNROLL_M 64 164#define UNROLL_K (32 / ELEMENT_SIZE) 165 166#if COPY_SUM 167#define GET_A_SUM_ADDRESS \ 168 int k_align = (k + UNROLL_K - 1) & ~(UNROLL_K - 1); \ 169 global int *a_sum = (global int *)(a_packed + offseta_packed \ 170 + m0 * lda_packed + k_align * UNROLL_M); 171#else 172#define GET_A_SUM_ADDRESS 173#endif 174 175#if COPY_CLEAR_SUM 176 177// A sum clear kernel: initialize row sums to zero. 178__attribute__((intel_reqd_sub_group_size(16))) kernel void 179xe_hpc_systolic_gemm_copy(long m, long k, global ELEMENT *a_packed, 180 int offseta_packed, int lda_packed) { 181 182 uint m0 = (sub_group_broadcast(get_global_id(0), 0) / 16) * UNROLL_M; 183 184 GET_A_SUM_ADDRESS; 185 186 uint4 zero = 0; 187 intel_sub_group_block_write4(a_sum, zero); 188} 189 190#elif !COPY_TRANS 191 192#if ELEMENT_SIZE == 2 193#define REPACK_REG(rr, cc) \ 194 blk_r[rr].s##cc = (((uint)c[2 * cc + 1].s##rr) << 16) | c[2 * cc].s##rr 195#elif ELEMENT_SIZE == 1 196#define REPACK_REG(rr, cc) \ 197 blk_r[rr].s##cc = (((uint)c[4 * cc + 3].s##rr) << 24) \ 198 | (((uint)c[4 * cc + 2].s##rr) << 16) \ 199 | (((uint)c[4 * cc + 1].s##rr) << 8) | c[4 * cc].s##rr 200#endif 201 202#define REPACK_CC(cc) \ 203 REPACK_REG(0, cc); \ 204 REPACK_REG(1, cc); \ 205 REPACK_REG(2, cc); \ 206 REPACK_REG(3, cc) 207 208#define REPACK \ 209 REPACK_CC(0); \ 210 REPACK_CC(1); \ 211 REPACK_CC(2); \ 212 REPACK_CC(3); \ 213 REPACK_CC(4); \ 214 REPACK_CC(5); \ 215 REPACK_CC(6); \ 216 REPACK_CC(7) 217 218// Nontranspose A copy. 219// Each thread packs a 64x16 (f16/bf16) or 64x32 (u8/s8) block of A. 220__attribute__((intel_reqd_sub_group_size(16))) kernel void 221xe_hpc_systolic_gemm_copy(long m, long k, global ELEMENT *a, long offseta, 222 long lda, global ELEMENT *a_packed, int offseta_packed, 223 int lda_packed) { 224 225 int lid = get_sub_group_local_id(); 226 uint m0 = (sub_group_broadcast(get_global_id(0), 0) / 16) * UNROLL_M; 227 uint k0 = get_global_id(1) * UNROLL_K; 228 int mrem = m - m0; 229 int krem = k - k0; 230 bool aligned = ((as_long(a) | lda | offseta) & (ELEMENTS_PER_INT - 1)) == 0; 231 232 if (mrem <= 0 || krem <= 0) return; 233 234 GET_A_SUM_ADDRESS; 235 236 a += offseta + m0 + k0 * lda; 237 a_packed += offseta_packed + m0 * lda_packed + k0 * UNROLL_M; 238 239 // Read all columns. 240 ELEMENT4 c[UNROLL_K]; 241 242 if (mrem >= UNROLL_M && krem >= UNROLL_K && aligned) { 243 for (int h = 0; h < UNROLL_K; h++) 244 c[h] = BLOCK_READ_ELEMENT4(a + h * lda); 245 } else { 246 for (int h = 0; h < UNROLL_K; h++) 247 if (h < krem) 248 c[h] = masked_block_read_element4(a + h * lda, mrem); 249 else 250 c[h] = 0; 251 } 252 253 // Rearrange. 254 uint8 blk_r[UNROLL_M / 16]; 255 REPACK; 256 257 // Write out. 258 for (int rr = 0; rr < UNROLL_M / 16; rr++) 259 intel_sub_group_block_write8( 260 (global uint *)(a_packed + rr * UNROLL_K * 16), blk_r[rr]); 261 262 // Sum if needed. 263#if COPY_SUM 264 SUM_T4 sum = 0; 265 for (int h = 0; h < UNROLL_K; h++) 266 sum += CONVERT_SUM_T4(AS_SIGNED_ELEMENT4(c[h])); 267 atomic_add(a_sum + lid, sum.s0); 268 atomic_add(a_sum + lid + 16, sum.s1); 269 atomic_add(a_sum + lid + 32, sum.s2); 270 atomic_add(a_sum + lid + 48, sum.s3); 271#endif 272 273 DUMMY_DPAS; 274} 275 276#else /* COPY_TRANS */ 277 278// Transpose A copy. 279__attribute__((intel_reqd_workgroup_walk_order(1, 0))) 280__attribute__((intel_reqd_sub_group_size(16))) kernel void 281xe_hpc_systolic_gemm_copy(long m, long k, global ELEMENT *a, long offseta, 282 long lda, global ELEMENT *a_packed, int offseta_packed, 283 int lda_packed) { 284 285 int lid = get_sub_group_local_id(); 286 uint m0 = (sub_group_broadcast(get_global_id(0), 0) / 16) * UNROLL_M; 287 uint k0 = get_global_id(1) * UNROLL_K; 288 int mrem = m - m0; 289 int krem = k - k0; 290 291 if (mrem <= 0 || krem <= 0) return; 292 293 GET_A_SUM_ADDRESS; 294 295 a += offseta + m0 * lda + k0; 296 a_packed += offseta_packed + m0 * lda_packed + k0 * UNROLL_M; 297 298#if COPY_SUM 299 SUM_T sum[UNROLL_M / 16] = {0}; 300#endif 301 302 for (int rr = 0; rr < UNROLL_M / 16; rr++, mrem -= 16) { 303 ELEMENT_INT regs[8]; 304 305 if (mrem >= UNROLL_M && krem >= UNROLL_K) { 306 for (int cc = 0; cc < UNROLL_K / ELEMENTS_PER_INT; cc++) 307 regs[cc] = VLOAD_ELEMENT_INT(0, 308 a + ((rr * 16) + lid) * lda + (cc * ELEMENTS_PER_INT)); 309 } else { 310 for (int cc = 0; cc < UNROLL_K / ELEMENTS_PER_INT; cc++) { 311 regs[cc] = 0; 312 PARTIAL_LOAD(regs, mrem, krem, cc, 313 a + ((rr * 16) + lid) * lda + (cc * ELEMENTS_PER_INT)); 314 } 315 } 316 317 uint8 blk_r; 318 blk_r.s0 = as_uint(regs[0]); 319 blk_r.s1 = as_uint(regs[1]); 320 blk_r.s2 = as_uint(regs[2]); 321 blk_r.s3 = as_uint(regs[3]); 322 blk_r.s4 = as_uint(regs[4]); 323 blk_r.s5 = as_uint(regs[5]); 324 blk_r.s6 = as_uint(regs[6]); 325 blk_r.s7 = as_uint(regs[7]); 326 327#if COPY_SUM 328 for (int cc = 0; cc < UNROLL_K / ELEMENTS_PER_INT; cc++) { 329 sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s0)); 330 sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s1)); 331 sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s2)); 332 sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s3)); 333 } 334#endif 335 336 intel_sub_group_block_write8( 337 (global uint *)(a_packed + rr * UNROLL_K * 16), blk_r); 338 } 339 340#if COPY_SUM 341 atomic_add(a_sum + lid, sum[0]); 342 atomic_add(a_sum + lid + 16, sum[1]); 343 atomic_add(a_sum + lid + 32, sum[2]); 344 atomic_add(a_sum + lid + 48, sum[3]); 345#endif 346 347 DUMMY_DPAS; 348} 349 350#endif /* !COPY_TRANS */ 351#endif /* COPY_A */ 352 353#if COPY_B 354 355#define UNROLL_K (32 / ELEMENT_SIZE) 356 357#if ELEMENT_SIZE == 2 358#define REPACK_CC_WORD(cc) \ 359 do { \ 360 colgroups[cc].s0 = cols[cc * 4]; \ 361 colgroups[cc].s1 = cols[cc * 4 + 1]; \ 362 colgroups[cc].s2 = cols[cc * 4 + 2]; \ 363 colgroups[cc].s3 = cols[cc * 4 + 3]; \ 364 } while (false) 365#define REPACK_CC(cc) \ 366 do { \ 367 colgroups[cc].s01 = cols[cc * 4]; \ 368 colgroups[cc].s23 = cols[cc * 4 + 1]; \ 369 colgroups[cc].s45 = cols[cc * 4 + 2]; \ 370 colgroups[cc].s67 = cols[cc * 4 + 3]; \ 371 } while (false) 372#elif ELEMENT_SIZE == 1 373#define REPACK_CC_WORD(cc) \ 374 do { \ 375 colgroups[cc].s01 = cols[cc * 4]; \ 376 colgroups[cc].s23 = cols[cc * 4 + 1]; \ 377 colgroups[cc].s45 = cols[cc * 4 + 2]; \ 378 colgroups[cc].s67 = cols[cc * 4 + 3]; \ 379 } while (false) 380#define REPACK_CC(cc) \ 381 do { \ 382 colgroups[cc].s0123 = cols[cc * 4]; \ 383 colgroups[cc].s4567 = cols[cc * 4 + 1]; \ 384 colgroups[cc].s89ab = cols[cc * 4 + 2]; \ 385 colgroups[cc].scdef = cols[cc * 4 + 3]; \ 386 } while (false) 387#define REPACK_CC2(cc) \ 388 do { \ 389 colgroups[cc].s0246 = cols[cc * 2]; \ 390 colgroups[cc].s1357 = cols2[cc * 2]; \ 391 colgroups[cc].s8ace = cols[cc * 2 + 1]; \ 392 colgroups[cc].s9bdf = cols2[cc * 2 + 1]; \ 393 } while (false) 394#endif 395 396#if COPY_SUM 397#define GET_B_SUM_ADDRESS \ 398 int k_align = (k + UNROLL_K - 1) & ~(UNROLL_K - 1); \ 399 global int *b_sum = (global int *)(b_packed + offsetb_packed \ 400 + n0 * ldb_packed + k_align * UNROLL_N); 401#else 402#define GET_B_SUM_ADDRESS 403#endif 404 405#if COPY_CLEAR_SUM 406 407// B sum clear kernel: initialize column sums to zero. 408__attribute__((intel_reqd_sub_group_size(16))) kernel void 409xe_hpc_systolic_gemm_copy(long k, long n, global ELEMENT *b_packed, 410 int offsetb_packed, int ldb_packed) { 411 412 uint n0 = (sub_group_broadcast(get_global_id(0), 0) / 16) * UNROLL_N; 413 414 GET_B_SUM_ADDRESS; 415 416 uint2 zero = 0; 417 intel_sub_group_block_write2(b_sum, zero); 418#if UNROLL_N > 32 419 intel_sub_group_block_write(b_sum + 32, zero.s0); 420#endif 421} 422 423#elif !COPY_TRANS 424 425// Each thread packs a 16x{32,48} (f16/bf16) or 32x{32,48} (u8/s8) block of B. 426// Nontranspose B copy. 427__attribute__((intel_reqd_sub_group_size(16))) kernel void 428xe_hpc_systolic_gemm_copy(long k, long n, global ELEMENT *b, long offsetb, 429 long ldb, global ELEMENT *b_packed, int offsetb_packed, 430 int ldb_packed) { 431 432 int lid = get_sub_group_local_id(); 433 uint k0 = (sub_group_broadcast(get_global_id(0), 0) / 16) * UNROLL_K; 434 uint n0 = get_global_id(1) * UNROLL_N; 435 int krem = k - k0; 436 int nrem = n - n0; 437 bool aligned = ((as_long(b) | ldb | offsetb) & (ELEMENTS_PER_INT - 1)) == 0; 438 439 if (nrem <= 0 || krem <= 0) return; 440 441 GET_B_SUM_ADDRESS; 442 b += offsetb + k0 + n0 * ldb; 443 b_packed += offsetb_packed + n0 * ldb_packed + k0 * UNROLL_N; 444 445 // Copy in two halves. 446 447#define UNROLL_N_CHUNK (UNROLL_N / 2) 448#if COPY_SUM 449 SUM_T sums[UNROLL_N]; 450#endif 451 ELEMENT_WORD cols[UNROLL_N / 2]; 452 453 for (int c0 = 0; c0 < UNROLL_N; 454 c0 += UNROLL_N_CHUNK, nrem -= UNROLL_N_CHUNK) { 455 // Read all columns. 456 if (krem >= UNROLL_K && nrem >= UNROLL_N_CHUNK && aligned) { 457 for (int c = 0; c < UNROLL_N_CHUNK; c++) 458 cols[c] = BLOCK_READ_ELEMENT_WORD(b + (c + c0) * ldb); 459 } else { 460 for (int c = 0; c < UNROLL_N_CHUNK; c++) 461 if (c < nrem) 462 cols[c] = MASKED_BLOCK_READ_ELEMENT_WORD( 463 b + (c + c0) * ldb, krem); 464 else 465 cols[c] = 0; 466 } 467 468 // Repack. 469 ELEMENT_WORD4 colgroups[UNROLL_N_CHUNK / 4]; 470 for (int cc = 0; cc < UNROLL_N_CHUNK / 4; cc++) 471 REPACK_CC_WORD(cc); 472 473 // Write out. 474 for (int cc = 0; cc < UNROLL_N_CHUNK / 4; cc++) 475 BLOCK_WRITE_ELEMENT_WORD4( 476 b_packed + (cc * 4 + c0) * UNROLL_K, colgroups[cc]); 477 478 // Sum if needed. 479#if COPY_SUM 480 for (int c = 0; c < UNROLL_N_CHUNK; c++) 481 sums[c + c0] = sum(CONVERT_SUM_T2(AS_SIGNED_ELEMENT_WORD(cols[c]))); 482#endif 483 } 484 485 // Accumulate sums. 486#if COPY_SUM 487 for (int c0 = 0; c0 < UNROLL_N; c0 += get_sub_group_size()) 488 atomic_add(b_sum + c0 + lid, sums[c0 + lid]); 489#endif 490 491 DUMMY_DPAS; 492} 493 494#else /* COPY_TRANS */ 495 496#define ADD_SUM(coln) \ 497 for (int cc = 0; cc < UNROLL_N / 4; cc++) { \ 498 sums[4 * cc + 0] \ 499 += sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s0))); \ 500 sums[4 * cc + 1] \ 501 += sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s1))); \ 502 sums[4 * cc + 2] \ 503 += sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s2))); \ 504 sums[4 * cc + 3] \ 505 += sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s3))); \ 506 } 507 508// Transpose B copy. 509__attribute__((intel_reqd_workgroup_walk_order(1, 0))) 510__attribute__((intel_reqd_sub_group_size(16))) kernel void 511xe_hpc_systolic_gemm_copy(long k, long n, global ELEMENT *b, long offsetb, 512 long ldb, global ELEMENT *b_packed, int offsetb_packed, 513 int ldb_packed) { 514 515 int lid = get_sub_group_local_id(); 516 uint k0 = (sub_group_broadcast(get_global_id(0), 0) / 16) * UNROLL_K; 517 uint n0 = get_global_id(1) * UNROLL_N; 518 int krem = k - k0; 519 int nrem = n - n0; 520 int sg = get_sub_group_size(); 521 522 if (nrem <= 0 || krem <= 0) return; 523 524 GET_B_SUM_ADDRESS; 525 b += offsetb + n0 + k0 * ldb; 526 b_packed += offsetb_packed + n0 * ldb_packed + k0 * UNROLL_N; 527 528 // Read upper 16x48 submatrix. 529 ELEMENT_INT cols[UNROLL_N / ELEMENTS_PER_INT]; 530 ELEMENT_INT4 colgroups[UNROLL_N / 8]; 531 if (krem >= sg && nrem >= UNROLL_N) { 532 for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++) 533 cols[cc] = VLOAD_ELEMENT_INT( 534 0, b + cc * ELEMENTS_PER_INT + lid * ldb); 535 } else { 536 for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++) { 537 cols[cc] = 0; 538 PARTIAL_LOAD(cols, krem, nrem, cc, 539 b + cc * ELEMENTS_PER_INT + lid * ldb); 540 } 541 } 542#if ELEMENT_SIZE == 2 543 // Repack. 544 for (int cc = 0; cc < UNROLL_N / 8; cc++) 545 REPACK_CC(cc); 546#else 547#if COPY_SUM 548 SUM_T sums[UNROLL_N] = {0}; 549 ADD_SUM(cols); 550#endif 551 552 // Read lower 16x48 submatrix. 553 ELEMENT_INT cols2[UNROLL_N / ELEMENTS_PER_INT]; 554 krem -= sg; 555 if (krem >= sg && nrem >= UNROLL_N) { 556 for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++) 557 cols2[cc] = VLOAD_ELEMENT_INT( 558 0, b + cc * ELEMENTS_PER_INT + (lid + sg) * ldb); 559 } else { 560 for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++) { 561 cols2[cc] = 0; 562 PARTIAL_LOAD(cols2, krem, nrem, cc, 563 b + cc * ELEMENTS_PER_INT + (lid + sg) * ldb); 564 } 565 } 566 for (int cc = 0; cc < UNROLL_N / 8; cc++) 567 REPACK_CC2(cc); 568 569#if COPY_SUM 570 ADD_SUM(cols2); 571#endif 572#endif 573 574 // Write out. 575 for (int cc = 0; cc < UNROLL_N / 8; cc++) 576 BLOCK_WRITE_ELEMENT_INT4(b_packed + cc * 8 * UNROLL_K, colgroups[cc]); 577 578#if COPY_SUM 579 for (int c0 = 0; c0 < UNROLL_N; c0 += get_sub_group_size()) 580 atomic_add(b_sum + c0 + lid, sums[c0 + lid]); 581#endif 582 583 DUMMY_DPAS; 584} 585 586#endif /* !COPY_TRANS */ 587#endif /* COPY_B */ 588