1/*******************************************************************************
2* Copyright 2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8*     http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#if ELEMENT_SIZE == 2
18#pragma OPENCL EXTENSION cl_intel_subgroups_short : enable
19#define ELEMENT ushort
20#define ELEMENT2 ushort2
21#define ELEMENT4 ushort4
22#define ELEMENT_WORD ushort
23#define ELEMENT_WORD4 ushort4
24#define ELEMENT_INT ushort2
25#define ELEMENT_INT4 ushort8
26#define VLOAD_ELEMENT_INT vload2
27#define ELEMENTS_PER_INT 2
28#define BLOCK_READ_ELEMENT2 intel_sub_group_block_read_us2
29#define BLOCK_READ_ELEMENT4 intel_sub_group_block_read_us4
30#define BLOCK_READ_ELEMENT_WORD intel_sub_group_block_read_us
31#define MASKED_BLOCK_READ_ELEMENT_WORD masked_block_read_element
32#define BLOCK_WRITE_ELEMENT_WORD4 intel_sub_group_block_write_us4
33#define BLOCK_WRITE_ELEMENT_INT4 intel_sub_group_block_write_us8
34#elif ELEMENT_SIZE == 1
35#define ELEMENT uchar
36#define ELEMENT2 uchar2
37#define ELEMENT4 uchar4
38#define ELEMENT_WORD uchar2
39#define ELEMENT_WORD4 uchar8
40#define ELEMENT_INT uchar4
41#define ELEMENT_INT4 uchar16
42#define VLOAD_ELEMENT_INT vload4
43#define BLOCK_READ_ELEMENT2 intel_sub_group_block_read_uc2
44#define BLOCK_READ_ELEMENT4 intel_sub_group_block_read_uc4
45#define BLOCK_READ_ELEMENT_WORD intel_sub_group_block_read_uc2
46#define MASKED_BLOCK_READ_ELEMENT_WORD masked_block_read_element2
47#define BLOCK_WRITE_ELEMENT_WORD4 intel_sub_group_block_write_uc8
48#define BLOCK_WRITE_ELEMENT_INT2 intel_sub_group_block_write_uc8
49#define BLOCK_WRITE_ELEMENT_INT4 intel_sub_group_block_write_uc16
50#define ELEMENTS_PER_INT 4
51#define SUM_T int
52#define SUM_T2 int2
53#define SUM_T4 int4
54#define CONVERT_SUM_T convert_int
55#define CONVERT_SUM_T2 convert_int2
56#define CONVERT_SUM_T4 convert_int4
57#if COPY_SIGNED
58#define AS_SIGNED_ELEMENT as_char
59#define AS_SIGNED_ELEMENT4 as_char4
60#define AS_SIGNED_ELEMENT_WORD as_char2
61#define AS_SIGNED_ELEMENT_INT as_char4
62#define SIGNED_ELEMENT_WORD char2
63#define SIGNED_ELEMENT_INT char4
64#else
65#define AS_SIGNED_ELEMENT as_uchar
66#define AS_SIGNED_ELEMENT4 as_uchar4
67#define AS_SIGNED_ELEMENT_WORD as_uchar2
68#define AS_SIGNED_ELEMENT_INT as_uchar4
69#define SIGNED_ELEMENT_WORD uchar2
70#define SIGNED_ELEMENT_INT uchar4
71#endif
72#else
73#error Unsupported element size.
74#endif
75
76#if !COPY_A && !COPY_B
77#error Source matrix not defined.
78#endif
79
80inline ELEMENT masked_block_read_element(global ELEMENT *p, int rem) {
81    ELEMENT v;
82    int lid = get_sub_group_local_id();
83    int sg = get_sub_group_size();
84
85    v = (lid < rem) ? p[lid] : 0;
86
87    return v;
88}
89
90inline ELEMENT2 masked_block_read_element2(global ELEMENT *p, int rem) {
91    ELEMENT2 v;
92    int lid = get_sub_group_local_id();
93    int sg = get_sub_group_size();
94
95    v.s0 = (lid < rem) ? p[lid] : 0;
96    v.s1 = (lid + sg < rem) ? p[lid + sg] : 0;
97
98    return v;
99}
100
101inline ELEMENT4 masked_block_read_element4(global ELEMENT *p, int rem) {
102    ELEMENT4 v;
103    int lid = get_sub_group_local_id();
104    int sg = get_sub_group_size();
105
106    v.s0 = (lid < rem) ? p[lid] : 0;
107    v.s1 = (lid + sg < rem) ? p[lid + sg] : 0;
108    v.s2 = (lid + 2 * sg < rem) ? p[lid + 2 * sg] : 0;
109    v.s3 = (lid + 3 * sg < rem) ? p[lid + 3 * sg] : 0;
110
111    return v;
112}
113
114__attribute__((overloadable)) inline int sum(int v) {
115    return sub_group_reduce_add(v);
116}
117
118__attribute__((overloadable)) inline int sum(int2 v) {
119    return sub_group_reduce_add(v.s0) + sub_group_reduce_add(v.s1);
120}
121
122__attribute__((overloadable)) inline int sum(int4 v) {
123    return sub_group_reduce_add(v.s0) + sub_group_reduce_add(v.s1)
124            + sub_group_reduce_add(v.s2) + sub_group_reduce_add(v.s3);
125}
126
127void dummy_dpas() {
128    if (get_sub_group_local_id() >= 16) {
129        int __builtin_IB_sub_group_idpas_s8_s8_8_1(int, int, int8)
130                __attribute__((const));
131        global volatile int *_;
132
133        int z = __builtin_IB_sub_group_idpas_s8_s8_8_1(0, _[0], 1);
134        for (int i = 0; i < z; i++)
135            (void)_[0];
136    }
137}
138
139#define DUMMY_DPAS /*dummy_dpas()*/
140
141#if ELEMENT_SIZE == 2
142#define PARTIAL_LOAD(regs, rrem, crem, cc, p) \
143    if ((2 * cc + 1) < crem) { \
144        if (lid < rrem) regs[cc] = vload2(0, p); \
145    } else if ((2 * cc) < crem) { \
146        if (lid < rrem) regs[cc].s0 = *(p); \
147    }
148#elif ELEMENT_SIZE == 1
149#define PARTIAL_LOAD(regs, rrem, crem, cc, p) \
150    if ((4 * cc + 3) < crem) { \
151        if (lid < rrem) regs[cc] = vload4(0, p); \
152    } else if ((4 * cc + 2) < crem) { \
153        if (lid < rrem) regs[cc].s012 = vload3(0, p); \
154    } else if ((4 * cc + 1) < crem) { \
155        if (lid < rrem) regs[cc].s01 = vload2(0, p); \
156    } else if (4 * cc < crem) { \
157        if (lid < rrem) regs[cc].s0 = *(p); \
158    }
159#endif
160
161#if COPY_A
162
163#define UNROLL_M 64
164#define UNROLL_K (32 / ELEMENT_SIZE)
165
166#if COPY_SUM
167#define GET_A_SUM_ADDRESS \
168    int k_align = (k + UNROLL_K - 1) & ~(UNROLL_K - 1); \
169    global int *a_sum = (global int *)(a_packed + offseta_packed \
170            + m0 * lda_packed + k_align * UNROLL_M);
171#else
172#define GET_A_SUM_ADDRESS
173#endif
174
175#if COPY_CLEAR_SUM
176
177// A sum clear kernel: initialize row sums to zero.
178__attribute__((intel_reqd_sub_group_size(16))) kernel void
179xe_hpc_systolic_gemm_copy(long m, long k, global ELEMENT *a_packed,
180        int offseta_packed, int lda_packed) {
181
182    uint m0 = (sub_group_broadcast(get_global_id(0), 0) / 16) * UNROLL_M;
183
184    GET_A_SUM_ADDRESS;
185
186    uint4 zero = 0;
187    intel_sub_group_block_write4(a_sum, zero);
188}
189
190#elif !COPY_TRANS
191
192#if ELEMENT_SIZE == 2
193#define REPACK_REG(rr, cc) \
194    blk_r[rr].s##cc = (((uint)c[2 * cc + 1].s##rr) << 16) | c[2 * cc].s##rr
195#elif ELEMENT_SIZE == 1
196#define REPACK_REG(rr, cc) \
197    blk_r[rr].s##cc = (((uint)c[4 * cc + 3].s##rr) << 24) \
198            | (((uint)c[4 * cc + 2].s##rr) << 16) \
199            | (((uint)c[4 * cc + 1].s##rr) << 8) | c[4 * cc].s##rr
200#endif
201
202#define REPACK_CC(cc) \
203    REPACK_REG(0, cc); \
204    REPACK_REG(1, cc); \
205    REPACK_REG(2, cc); \
206    REPACK_REG(3, cc)
207
208#define REPACK \
209    REPACK_CC(0); \
210    REPACK_CC(1); \
211    REPACK_CC(2); \
212    REPACK_CC(3); \
213    REPACK_CC(4); \
214    REPACK_CC(5); \
215    REPACK_CC(6); \
216    REPACK_CC(7)
217
218// Nontranspose A copy.
219// Each thread packs a 64x16 (f16/bf16) or 64x32 (u8/s8) block of A.
220__attribute__((intel_reqd_sub_group_size(16))) kernel void
221xe_hpc_systolic_gemm_copy(long m, long k, global ELEMENT *a, long offseta,
222        long lda, global ELEMENT *a_packed, int offseta_packed,
223        int lda_packed) {
224
225    int lid = get_sub_group_local_id();
226    uint m0 = (sub_group_broadcast(get_global_id(0), 0) / 16) * UNROLL_M;
227    uint k0 = get_global_id(1) * UNROLL_K;
228    int mrem = m - m0;
229    int krem = k - k0;
230    bool aligned = ((as_long(a) | lda | offseta) & (ELEMENTS_PER_INT - 1)) == 0;
231
232    if (mrem <= 0 || krem <= 0) return;
233
234    GET_A_SUM_ADDRESS;
235
236    a += offseta + m0 + k0 * lda;
237    a_packed += offseta_packed + m0 * lda_packed + k0 * UNROLL_M;
238
239    // Read all columns.
240    ELEMENT4 c[UNROLL_K];
241
242    if (mrem >= UNROLL_M && krem >= UNROLL_K && aligned) {
243        for (int h = 0; h < UNROLL_K; h++)
244            c[h] = BLOCK_READ_ELEMENT4(a + h * lda);
245    } else {
246        for (int h = 0; h < UNROLL_K; h++)
247            if (h < krem)
248                c[h] = masked_block_read_element4(a + h * lda, mrem);
249            else
250                c[h] = 0;
251    }
252
253    // Rearrange.
254    uint8 blk_r[UNROLL_M / 16];
255    REPACK;
256
257    // Write out.
258    for (int rr = 0; rr < UNROLL_M / 16; rr++)
259        intel_sub_group_block_write8(
260                (global uint *)(a_packed + rr * UNROLL_K * 16), blk_r[rr]);
261
262        // Sum if needed.
263#if COPY_SUM
264    SUM_T4 sum = 0;
265    for (int h = 0; h < UNROLL_K; h++)
266        sum += CONVERT_SUM_T4(AS_SIGNED_ELEMENT4(c[h]));
267    atomic_add(a_sum + lid, sum.s0);
268    atomic_add(a_sum + lid + 16, sum.s1);
269    atomic_add(a_sum + lid + 32, sum.s2);
270    atomic_add(a_sum + lid + 48, sum.s3);
271#endif
272
273    DUMMY_DPAS;
274}
275
276#else /* COPY_TRANS */
277
278// Transpose A copy.
279__attribute__((intel_reqd_workgroup_walk_order(1, 0)))
280__attribute__((intel_reqd_sub_group_size(16))) kernel void
281xe_hpc_systolic_gemm_copy(long m, long k, global ELEMENT *a, long offseta,
282        long lda, global ELEMENT *a_packed, int offseta_packed,
283        int lda_packed) {
284
285    int lid = get_sub_group_local_id();
286    uint m0 = (sub_group_broadcast(get_global_id(0), 0) / 16) * UNROLL_M;
287    uint k0 = get_global_id(1) * UNROLL_K;
288    int mrem = m - m0;
289    int krem = k - k0;
290
291    if (mrem <= 0 || krem <= 0) return;
292
293    GET_A_SUM_ADDRESS;
294
295    a += offseta + m0 * lda + k0;
296    a_packed += offseta_packed + m0 * lda_packed + k0 * UNROLL_M;
297
298#if COPY_SUM
299    SUM_T sum[UNROLL_M / 16] = {0};
300#endif
301
302    for (int rr = 0; rr < UNROLL_M / 16; rr++, mrem -= 16) {
303        ELEMENT_INT regs[8];
304
305        if (mrem >= UNROLL_M && krem >= UNROLL_K) {
306            for (int cc = 0; cc < UNROLL_K / ELEMENTS_PER_INT; cc++)
307                regs[cc] = VLOAD_ELEMENT_INT(0,
308                        a + ((rr * 16) + lid) * lda + (cc * ELEMENTS_PER_INT));
309        } else {
310            for (int cc = 0; cc < UNROLL_K / ELEMENTS_PER_INT; cc++) {
311                regs[cc] = 0;
312                PARTIAL_LOAD(regs, mrem, krem, cc,
313                        a + ((rr * 16) + lid) * lda + (cc * ELEMENTS_PER_INT));
314            }
315        }
316
317        uint8 blk_r;
318        blk_r.s0 = as_uint(regs[0]);
319        blk_r.s1 = as_uint(regs[1]);
320        blk_r.s2 = as_uint(regs[2]);
321        blk_r.s3 = as_uint(regs[3]);
322        blk_r.s4 = as_uint(regs[4]);
323        blk_r.s5 = as_uint(regs[5]);
324        blk_r.s6 = as_uint(regs[6]);
325        blk_r.s7 = as_uint(regs[7]);
326
327#if COPY_SUM
328        for (int cc = 0; cc < UNROLL_K / ELEMENTS_PER_INT; cc++) {
329            sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s0));
330            sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s1));
331            sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s2));
332            sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s3));
333        }
334#endif
335
336        intel_sub_group_block_write8(
337                (global uint *)(a_packed + rr * UNROLL_K * 16), blk_r);
338    }
339
340#if COPY_SUM
341    atomic_add(a_sum + lid, sum[0]);
342    atomic_add(a_sum + lid + 16, sum[1]);
343    atomic_add(a_sum + lid + 32, sum[2]);
344    atomic_add(a_sum + lid + 48, sum[3]);
345#endif
346
347    DUMMY_DPAS;
348}
349
350#endif /* !COPY_TRANS */
351#endif /* COPY_A */
352
353#if COPY_B
354
355#define UNROLL_K (32 / ELEMENT_SIZE)
356
357#if ELEMENT_SIZE == 2
358#define REPACK_CC_WORD(cc) \
359    do { \
360        colgroups[cc].s0 = cols[cc * 4]; \
361        colgroups[cc].s1 = cols[cc * 4 + 1]; \
362        colgroups[cc].s2 = cols[cc * 4 + 2]; \
363        colgroups[cc].s3 = cols[cc * 4 + 3]; \
364    } while (false)
365#define REPACK_CC(cc) \
366    do { \
367        colgroups[cc].s01 = cols[cc * 4]; \
368        colgroups[cc].s23 = cols[cc * 4 + 1]; \
369        colgroups[cc].s45 = cols[cc * 4 + 2]; \
370        colgroups[cc].s67 = cols[cc * 4 + 3]; \
371    } while (false)
372#elif ELEMENT_SIZE == 1
373#define REPACK_CC_WORD(cc) \
374    do { \
375        colgroups[cc].s01 = cols[cc * 4]; \
376        colgroups[cc].s23 = cols[cc * 4 + 1]; \
377        colgroups[cc].s45 = cols[cc * 4 + 2]; \
378        colgroups[cc].s67 = cols[cc * 4 + 3]; \
379    } while (false)
380#define REPACK_CC(cc) \
381    do { \
382        colgroups[cc].s0123 = cols[cc * 4]; \
383        colgroups[cc].s4567 = cols[cc * 4 + 1]; \
384        colgroups[cc].s89ab = cols[cc * 4 + 2]; \
385        colgroups[cc].scdef = cols[cc * 4 + 3]; \
386    } while (false)
387#define REPACK_CC2(cc) \
388    do { \
389        colgroups[cc].s0246 = cols[cc * 2]; \
390        colgroups[cc].s1357 = cols2[cc * 2]; \
391        colgroups[cc].s8ace = cols[cc * 2 + 1]; \
392        colgroups[cc].s9bdf = cols2[cc * 2 + 1]; \
393    } while (false)
394#endif
395
396#if COPY_SUM
397#define GET_B_SUM_ADDRESS \
398    int k_align = (k + UNROLL_K - 1) & ~(UNROLL_K - 1); \
399    global int *b_sum = (global int *)(b_packed + offsetb_packed \
400            + n0 * ldb_packed + k_align * UNROLL_N);
401#else
402#define GET_B_SUM_ADDRESS
403#endif
404
405#if COPY_CLEAR_SUM
406
407// B sum clear kernel: initialize column sums to zero.
408__attribute__((intel_reqd_sub_group_size(16))) kernel void
409xe_hpc_systolic_gemm_copy(long k, long n, global ELEMENT *b_packed,
410        int offsetb_packed, int ldb_packed) {
411
412    uint n0 = (sub_group_broadcast(get_global_id(0), 0) / 16) * UNROLL_N;
413
414    GET_B_SUM_ADDRESS;
415
416    uint2 zero = 0;
417    intel_sub_group_block_write2(b_sum, zero);
418#if UNROLL_N > 32
419    intel_sub_group_block_write(b_sum + 32, zero.s0);
420#endif
421}
422
423#elif !COPY_TRANS
424
425// Each thread packs a 16x{32,48} (f16/bf16) or 32x{32,48} (u8/s8) block of B.
426// Nontranspose B copy.
427__attribute__((intel_reqd_sub_group_size(16))) kernel void
428xe_hpc_systolic_gemm_copy(long k, long n, global ELEMENT *b, long offsetb,
429        long ldb, global ELEMENT *b_packed, int offsetb_packed,
430        int ldb_packed) {
431
432    int lid = get_sub_group_local_id();
433    uint k0 = (sub_group_broadcast(get_global_id(0), 0) / 16) * UNROLL_K;
434    uint n0 = get_global_id(1) * UNROLL_N;
435    int krem = k - k0;
436    int nrem = n - n0;
437    bool aligned = ((as_long(b) | ldb | offsetb) & (ELEMENTS_PER_INT - 1)) == 0;
438
439    if (nrem <= 0 || krem <= 0) return;
440
441    GET_B_SUM_ADDRESS;
442    b += offsetb + k0 + n0 * ldb;
443    b_packed += offsetb_packed + n0 * ldb_packed + k0 * UNROLL_N;
444
445    // Copy in two halves.
446
447#define UNROLL_N_CHUNK (UNROLL_N / 2)
448#if COPY_SUM
449    SUM_T sums[UNROLL_N];
450#endif
451    ELEMENT_WORD cols[UNROLL_N / 2];
452
453    for (int c0 = 0; c0 < UNROLL_N;
454            c0 += UNROLL_N_CHUNK, nrem -= UNROLL_N_CHUNK) {
455        // Read all columns.
456        if (krem >= UNROLL_K && nrem >= UNROLL_N_CHUNK && aligned) {
457            for (int c = 0; c < UNROLL_N_CHUNK; c++)
458                cols[c] = BLOCK_READ_ELEMENT_WORD(b + (c + c0) * ldb);
459        } else {
460            for (int c = 0; c < UNROLL_N_CHUNK; c++)
461                if (c < nrem)
462                    cols[c] = MASKED_BLOCK_READ_ELEMENT_WORD(
463                            b + (c + c0) * ldb, krem);
464                else
465                    cols[c] = 0;
466        }
467
468        // Repack.
469        ELEMENT_WORD4 colgroups[UNROLL_N_CHUNK / 4];
470        for (int cc = 0; cc < UNROLL_N_CHUNK / 4; cc++)
471            REPACK_CC_WORD(cc);
472
473        // Write out.
474        for (int cc = 0; cc < UNROLL_N_CHUNK / 4; cc++)
475            BLOCK_WRITE_ELEMENT_WORD4(
476                    b_packed + (cc * 4 + c0) * UNROLL_K, colgroups[cc]);
477
478            // Sum if needed.
479#if COPY_SUM
480        for (int c = 0; c < UNROLL_N_CHUNK; c++)
481            sums[c + c0] = sum(CONVERT_SUM_T2(AS_SIGNED_ELEMENT_WORD(cols[c])));
482#endif
483    }
484
485    // Accumulate sums.
486#if COPY_SUM
487    for (int c0 = 0; c0 < UNROLL_N; c0 += get_sub_group_size())
488        atomic_add(b_sum + c0 + lid, sums[c0 + lid]);
489#endif
490
491    DUMMY_DPAS;
492}
493
494#else /* COPY_TRANS */
495
496#define ADD_SUM(coln) \
497    for (int cc = 0; cc < UNROLL_N / 4; cc++) { \
498        sums[4 * cc + 0] \
499                += sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s0))); \
500        sums[4 * cc + 1] \
501                += sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s1))); \
502        sums[4 * cc + 2] \
503                += sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s2))); \
504        sums[4 * cc + 3] \
505                += sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s3))); \
506    }
507
508// Transpose B copy.
509__attribute__((intel_reqd_workgroup_walk_order(1, 0)))
510__attribute__((intel_reqd_sub_group_size(16))) kernel void
511xe_hpc_systolic_gemm_copy(long k, long n, global ELEMENT *b, long offsetb,
512        long ldb, global ELEMENT *b_packed, int offsetb_packed,
513        int ldb_packed) {
514
515    int lid = get_sub_group_local_id();
516    uint k0 = (sub_group_broadcast(get_global_id(0), 0) / 16) * UNROLL_K;
517    uint n0 = get_global_id(1) * UNROLL_N;
518    int krem = k - k0;
519    int nrem = n - n0;
520    int sg = get_sub_group_size();
521
522    if (nrem <= 0 || krem <= 0) return;
523
524    GET_B_SUM_ADDRESS;
525    b += offsetb + n0 + k0 * ldb;
526    b_packed += offsetb_packed + n0 * ldb_packed + k0 * UNROLL_N;
527
528    // Read upper 16x48 submatrix.
529    ELEMENT_INT cols[UNROLL_N / ELEMENTS_PER_INT];
530    ELEMENT_INT4 colgroups[UNROLL_N / 8];
531    if (krem >= sg && nrem >= UNROLL_N) {
532        for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++)
533            cols[cc] = VLOAD_ELEMENT_INT(
534                    0, b + cc * ELEMENTS_PER_INT + lid * ldb);
535    } else {
536        for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++) {
537            cols[cc] = 0;
538            PARTIAL_LOAD(cols, krem, nrem, cc,
539                    b + cc * ELEMENTS_PER_INT + lid * ldb);
540        }
541    }
542#if ELEMENT_SIZE == 2
543    // Repack.
544    for (int cc = 0; cc < UNROLL_N / 8; cc++)
545        REPACK_CC(cc);
546#else
547#if COPY_SUM
548    SUM_T sums[UNROLL_N] = {0};
549    ADD_SUM(cols);
550#endif
551
552    // Read lower 16x48 submatrix.
553    ELEMENT_INT cols2[UNROLL_N / ELEMENTS_PER_INT];
554    krem -= sg;
555    if (krem >= sg && nrem >= UNROLL_N) {
556        for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++)
557            cols2[cc] = VLOAD_ELEMENT_INT(
558                    0, b + cc * ELEMENTS_PER_INT + (lid + sg) * ldb);
559    } else {
560        for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++) {
561            cols2[cc] = 0;
562            PARTIAL_LOAD(cols2, krem, nrem, cc,
563                    b + cc * ELEMENTS_PER_INT + (lid + sg) * ldb);
564        }
565    }
566    for (int cc = 0; cc < UNROLL_N / 8; cc++)
567        REPACK_CC2(cc);
568
569#if COPY_SUM
570    ADD_SUM(cols2);
571#endif
572#endif
573
574    // Write out.
575    for (int cc = 0; cc < UNROLL_N / 8; cc++)
576        BLOCK_WRITE_ELEMENT_INT4(b_packed + cc * 8 * UNROLL_K, colgroups[cc]);
577
578#if COPY_SUM
579    for (int c0 = 0; c0 < UNROLL_N; c0 += get_sub_group_size())
580        atomic_add(b_sum + c0 + lid, sums[c0 + lid]);
581#endif
582
583    DUMMY_DPAS;
584}
585
586#endif /* !COPY_TRANS */
587#endif /* COPY_B */
588