1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8*     http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/ocl_math_utils.h"
18
19#if ELEMENT_SIZE == 2
20#pragma OPENCL EXTENSION cl_intel_subgroups_short : enable
21#define ELEMENT ushort
22#define ELEMENT2 ushort2
23#define ELEMENT4 ushort4
24#define ELEMENT8 ushort8
25#define ELEMENT16 ushort16
26#define ELEMENT_INT ushort2
27#define ELEMENT_INT4 ushort8
28#define VLOAD_ELEMENT_INT vload2
29#define ELEMENTS_PER_INT 2
30#define BLOCK_READ_ELEMENT2 intel_sub_group_block_read_us2
31#define BLOCK_READ_ELEMENT4 intel_sub_group_block_read_us4
32#define BLOCK_READ_ELEMENT_INT intel_sub_group_block_read_us2
33#define MASKED_BLOCK_READ_ELEMENT_INT masked_block_read_element2
34#define BLOCK_WRITE_ELEMENT_INT4 intel_sub_group_block_write_us8
35#elif ELEMENT_SIZE == 1
36#define ELEMENT uchar
37#define ELEMENT2 uchar2
38#define ELEMENT4 uchar4
39#define ELEMENT8 uchar8
40#define ELEMENT16 uchar16
41#define ELEMENT_INT uchar4
42#define ELEMENT_INT4 uchar16
43#define VLOAD_ELEMENT_INT vload4
44#define BLOCK_READ_ELEMENT2 intel_sub_group_block_read_uc2
45#define BLOCK_READ_ELEMENT4 intel_sub_group_block_read_uc4
46#define BLOCK_READ_ELEMENT_INT intel_sub_group_block_read_uc4
47#define MASKED_BLOCK_READ_ELEMENT_INT masked_block_read_element4
48#define BLOCK_WRITE_ELEMENT_INT4 intel_sub_group_block_write_uc16
49#define ELEMENTS_PER_INT 4
50#define SUM_T int
51#define SUM_T4 int4
52#define CONVERT_SUM_T convert_int
53#define CONVERT_SUM_T4 convert_int4
54#if COPY_SIGNED
55#define AS_SIGNED_ELEMENT as_char
56#define AS_SIGNED_ELEMENT4 as_char4
57#define AS_SIGNED_ELEMENT_INT as_char4
58#define SIGNED_ELEMENT_INT char4
59#else
60#define AS_SIGNED_ELEMENT as_uchar
61#define AS_SIGNED_ELEMENT4 as_uchar4
62#define AS_SIGNED_ELEMENT_INT as_uchar4
63#define SIGNED_ELEMENT_INT uchar4
64#endif
65#else
66#error Unsupported element size.
67#endif
68
69#if !COPY_A && !COPY_B
70#error Source matrix not defined.
71#endif
72
73inline ELEMENT2 masked_block_read_element2(global ELEMENT *p, int rem) {
74    ELEMENT2 v;
75    int lid = get_sub_group_local_id();
76    int sg = get_sub_group_size();
77
78    v.s0 = (lid < rem) ? p[lid] : 0;
79    v.s1 = (lid + sg < rem) ? p[lid + sg] : 0;
80
81    return v;
82}
83
84inline ELEMENT4 masked_block_read_element4(global ELEMENT *p, int rem) {
85    ELEMENT4 v;
86    int lid = get_sub_group_local_id();
87    int sg = get_sub_group_size();
88
89    v.s0 = (lid < rem) ? p[lid] : 0;
90    v.s1 = (lid + sg < rem) ? p[lid + sg] : 0;
91    v.s2 = (lid + 2 * sg < rem) ? p[lid + 2 * sg] : 0;
92    v.s3 = (lid + 3 * sg < rem) ? p[lid + 3 * sg] : 0;
93
94    return v;
95}
96
97__attribute__((overloadable)) inline int sum(int v) {
98    return sub_group_reduce_add(v);
99}
100
101__attribute__((overloadable)) inline int sum(int4 v) {
102    return sub_group_reduce_add(v.s0) + sub_group_reduce_add(v.s1)
103            + sub_group_reduce_add(v.s2) + sub_group_reduce_add(v.s3);
104}
105
106void dummy_dpas() {
107    if (get_sub_group_local_id() >= 16) {
108        int __builtin_IB_sub_group_idpas_s8_s8_8_1(int, int, int8)
109                __attribute__((const));
110        global volatile int *_;
111
112        int z = __builtin_IB_sub_group_idpas_s8_s8_8_1(0, _[0], 1);
113        for (int i = 0; i < z; i++)
114            (void)_[0];
115    }
116}
117
118#define DUMMY_DPAS dummy_dpas()
119
120#if ELEMENT_SIZE == 2
121#define PARTIAL_LOAD(regs, rrem, crem, cc, p) \
122    if ((2 * cc + 1) < crem) { \
123        if (lid < rrem) regs[cc] = vload2(0, p); \
124    } else if ((2 * cc) < crem) { \
125        if (lid < rrem) regs[cc].s0 = *(p); \
126    }
127#elif ELEMENT_SIZE == 1
128#define PARTIAL_LOAD(regs, rrem, crem, cc, p) \
129    if ((4 * cc + 3) < crem) { \
130        if (lid < rrem) regs[cc] = vload4(0, p); \
131    } else if ((4 * cc + 2) < crem) { \
132        if (lid < rrem) regs[cc].s012 = vload3(0, p); \
133    } else if ((4 * cc + 1) < crem) { \
134        if (lid < rrem) regs[cc].s01 = vload2(0, p); \
135    } else if (4 * cc < crem) { \
136        if (lid < rrem) regs[cc].s0 = *(p); \
137    }
138#endif
139
140#if COPY_A
141
142#define UNROLL_M 32
143#define UNROLL_K (32 / ELEMENT_SIZE)
144
145#if COPY_SUM
146#define GET_A_SUM_ADDRESS \
147    int k_align = (k + UNROLL_K - 1) & ~(UNROLL_K - 1); \
148    global int *a_sum = (global int *)(a_packed + offseta_packed \
149            + m0 * lda_packed + k_align * UNROLL_M);
150#else
151#define GET_A_SUM_ADDRESS
152#endif
153
154#if COPY_CLEAR_SUM
155
156// A sum clear kernel: initialize row sums to zero.
157__attribute__((intel_reqd_sub_group_size(8))) kernel void
158xe_hp_systolic_gemm_copy(long m, long k, global ELEMENT *a_packed,
159        int offseta_packed, int lda_packed) {
160
161    uint m0 = (sub_group_broadcast(get_global_id(0), 0) / 8) * UNROLL_M;
162
163    GET_A_SUM_ADDRESS;
164
165    uint4 zero = 0;
166    intel_sub_group_block_write4(a_sum, zero);
167}
168
169#elif !COPY_TRANS
170
171#if ELEMENT_SIZE == 2
172#define REPACK_REG(rr, cc) \
173    blk_r[rr].s##cc = (((uint)c[2 * cc + 1].s##rr) << 16) | c[2 * cc].s##rr
174#elif ELEMENT_SIZE == 1
175#define REPACK_REG(rr, cc) \
176    blk_r[rr].s##cc = (((uint)c[4 * cc + 3].s##rr) << 24) \
177            | (((uint)c[4 * cc + 2].s##rr) << 16) \
178            | (((uint)c[4 * cc + 1].s##rr) << 8) | c[4 * cc].s##rr
179#endif
180
181#define REPACK_CC(cc) \
182    REPACK_REG(0, cc); \
183    REPACK_REG(1, cc); \
184    REPACK_REG(2, cc); \
185    REPACK_REG(3, cc)
186
187#define REPACK \
188    REPACK_CC(0); \
189    REPACK_CC(1); \
190    REPACK_CC(2); \
191    REPACK_CC(3); \
192    REPACK_CC(4); \
193    REPACK_CC(5); \
194    REPACK_CC(6); \
195    REPACK_CC(7)
196
197// Nontranspose A copy.
198// Each thread packs a 32x16 (f16/bf16) or 32x32 (u8/s8) block of A.
199__attribute__((intel_reqd_sub_group_size(8))) kernel void
200xe_hp_systolic_gemm_copy(long m, long k, global ELEMENT *a, long offseta,
201        long lda, global ELEMENT *a_packed, int offseta_packed,
202        int lda_packed) {
203
204    int lid = get_sub_group_local_id();
205    uint m0 = (sub_group_broadcast(get_global_id(0), 0) / 8) * UNROLL_M;
206    uint k0 = get_global_id(1) * UNROLL_K;
207    int mrem = m - m0;
208    int krem = k - k0;
209    bool aligned = ((as_long(a) | lda | offseta) & (ELEMENTS_PER_INT - 1)) == 0;
210
211    if (mrem <= 0 || krem <= 0) return;
212
213    GET_A_SUM_ADDRESS;
214
215    a += offseta + m0 + k0 * lda;
216    a_packed += offseta_packed + m0 * lda_packed + k0 * UNROLL_M;
217
218    // Read all columns.
219    ELEMENT4 c[UNROLL_K];
220
221    if (mrem >= UNROLL_M && krem >= UNROLL_K && aligned) {
222        for (int h = 0; h < UNROLL_K; h++)
223            c[h] = BLOCK_READ_ELEMENT4(a + h * lda);
224    } else {
225        for (int h = 0; h < UNROLL_K; h++)
226            if (h < krem)
227                c[h] = masked_block_read_element4(a + h * lda, mrem);
228            else
229                c[h] = 0;
230    }
231
232    // Rearrange.
233    uint8 blk_r[UNROLL_M / 8];
234    REPACK;
235
236    // Write out.
237    for (int rr = 0; rr < UNROLL_M / 8; rr++)
238        intel_sub_group_block_write8(
239                (global uint *)(a_packed + rr * UNROLL_K * 8), blk_r[rr]);
240
241        // Sum if needed.
242#if COPY_SUM
243    SUM_T4 sum = 0;
244    for (int h = 0; h < UNROLL_K; h++)
245        sum += CONVERT_SUM_T4(AS_SIGNED_ELEMENT4(c[h]));
246    atomic_add(a_sum + lid, sum.s0);
247    atomic_add(a_sum + lid + 8, sum.s1);
248    atomic_add(a_sum + lid + 16, sum.s2);
249    atomic_add(a_sum + lid + 24, sum.s3);
250#endif
251
252    DUMMY_DPAS;
253}
254
255#else /* COPY_TRANS */
256
257// Transpose A copy.
258__attribute__((intel_reqd_workgroup_walk_order(1, 0)))
259__attribute__((intel_reqd_sub_group_size(8))) kernel void
260xe_hp_systolic_gemm_copy(long m, long k, global ELEMENT *a, long offseta,
261        long lda, global ELEMENT *a_packed, int offseta_packed,
262        int lda_packed) {
263
264    int lid = get_sub_group_local_id();
265    uint m0 = (sub_group_broadcast(get_global_id(0), 0) / 8) * UNROLL_M;
266    uint k0 = get_global_id(1) * UNROLL_K;
267    int mrem = m - m0;
268    int krem = k - k0;
269
270    if (mrem <= 0 || krem <= 0) return;
271
272    GET_A_SUM_ADDRESS;
273
274    a += offseta + m0 * lda + k0;
275    a_packed += offseta_packed + m0 * lda_packed + k0 * UNROLL_M;
276
277#if COPY_SUM
278    SUM_T sum[UNROLL_M / 8] = {0};
279#endif
280
281    for (int rr = 0; rr < UNROLL_M / 8; rr++, mrem -= 8) {
282        ELEMENT_INT regs[8];
283
284        if (mrem >= UNROLL_M && krem >= UNROLL_K) {
285            for (int cc = 0; cc < UNROLL_K / ELEMENTS_PER_INT; cc++)
286                regs[cc] = VLOAD_ELEMENT_INT(0,
287                        a + ((rr * 8) + lid) * lda + (cc * ELEMENTS_PER_INT));
288        } else {
289            for (int cc = 0; cc < UNROLL_K / ELEMENTS_PER_INT; cc++) {
290                regs[cc] = 0;
291                PARTIAL_LOAD(regs, mrem, krem, cc,
292                        a + ((rr * 8) + lid) * lda + (cc * ELEMENTS_PER_INT));
293            }
294        }
295
296        uint8 blk_r;
297        blk_r.s0 = as_uint(regs[0]);
298        blk_r.s1 = as_uint(regs[1]);
299        blk_r.s2 = as_uint(regs[2]);
300        blk_r.s3 = as_uint(regs[3]);
301        blk_r.s4 = as_uint(regs[4]);
302        blk_r.s5 = as_uint(regs[5]);
303        blk_r.s6 = as_uint(regs[6]);
304        blk_r.s7 = as_uint(regs[7]);
305
306#if COPY_SUM
307        for (int cc = 0; cc < UNROLL_K / ELEMENTS_PER_INT; cc++) {
308            sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s0));
309            sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s1));
310            sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s2));
311            sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s3));
312        }
313#endif
314
315        intel_sub_group_block_write8(
316                (global uint *)(a_packed + rr * UNROLL_K * 8), blk_r);
317    }
318
319#if COPY_SUM
320    atomic_add(a_sum + lid, sum[0]);
321    atomic_add(a_sum + lid + 8, sum[1]);
322    atomic_add(a_sum + lid + 16, sum[2]);
323    atomic_add(a_sum + lid + 24, sum[3]);
324#endif
325
326    DUMMY_DPAS;
327}
328
329#endif /* !COPY_TRANS */
330#endif /* COPY_A */
331
332#if COPY_B
333
334#define UNROLL_K (32 / ELEMENT_SIZE)
335
336#if ELEMENT_SIZE == 2
337#define REPACK_CC(cc) \
338    do { \
339        colgroups[cc].s01 = cols[cc * 4]; \
340        colgroups[cc].s23 = cols[cc * 4 + 1]; \
341        colgroups[cc].s45 = cols[cc * 4 + 2]; \
342        colgroups[cc].s67 = cols[cc * 4 + 3]; \
343    } while (false)
344#define REPACK_CC2(cc) \
345    do { \
346        colgroups[cc].s02 = cols[cc * 2]; \
347        colgroups[cc].s13 = cols2[cc * 2]; \
348        colgroups[cc].s46 = cols[cc * 2 + 1]; \
349        colgroups[cc].s57 = cols2[cc * 2 + 1]; \
350    } while (false)
351#elif ELEMENT_SIZE == 1
352#define REPACK_CC(cc) \
353    do { \
354        colgroups[cc].s0123 = cols[cc * 4]; \
355        colgroups[cc].s4567 = cols[cc * 4 + 1]; \
356        colgroups[cc].s89ab = cols[cc * 4 + 2]; \
357        colgroups[cc].scdef = cols[cc * 4 + 3]; \
358    } while (false)
359#define REPACK_CC4(cc) \
360    do { \
361        colgroups[cc].s048c = cols[cc]; \
362        colgroups[cc].s159d = cols2[cc]; \
363        colgroups[cc].s26ae = cols3[cc]; \
364        colgroups[cc].s37bf = cols4[cc]; \
365    } while (false)
366#endif
367
368#if COPY_SUM
369#define GET_B_SUM_ADDRESS \
370    int k_align = (k + UNROLL_K - 1) & ~(UNROLL_K - 1); \
371    global int *b_sum = (global int *)(b_packed + offsetb_packed \
372            + n0 * ldb_packed + k_align * UNROLL_N);
373#else
374#define GET_B_SUM_ADDRESS
375#endif
376
377#if COPY_CLEAR_SUM
378
379// B sum clear kernel: initialize column sums to zero.
380__attribute__((intel_reqd_sub_group_size(8))) kernel void
381xe_hp_systolic_gemm_copy(long k, long n, global ELEMENT *b_packed,
382        int offsetb_packed, int ldb_packed) {
383
384    uint n0 = (sub_group_broadcast(get_global_id(0), 0) / 8) * UNROLL_N;
385
386    GET_B_SUM_ADDRESS;
387
388    uint4 zero = 0;
389    intel_sub_group_block_write4(b_sum, zero);
390#if UNROLL_N > 32
391    intel_sub_group_block_write2(b_sum + 32, zero.s01);
392#endif
393}
394
395#elif !COPY_TRANS
396
397// Each thread packs a 16x{32,48} (f16/bf16) or 32x{32,48} (u8/s8) block of B.
398// Nontranspose B copy.
399__attribute__((intel_reqd_sub_group_size(8))) kernel void
400xe_hp_systolic_gemm_copy(long k, long n, global ELEMENT *b, long offsetb,
401        long ldb, global ELEMENT *b_packed, int offsetb_packed,
402        int ldb_packed) {
403
404    int lid = get_sub_group_local_id();
405    uint k0 = (sub_group_broadcast(get_global_id(0), 0) / 8) * UNROLL_K;
406    uint n0 = get_global_id(1) * UNROLL_N;
407    int krem = k - k0;
408    int nrem = n - n0;
409    bool aligned = ((as_long(b) | ldb | offsetb) & (ELEMENTS_PER_INT - 1)) == 0;
410
411    if (nrem <= 0 || krem <= 0) return;
412
413    GET_B_SUM_ADDRESS;
414    b += offsetb + k0 + n0 * ldb;
415    b_packed += offsetb_packed + n0 * ldb_packed + k0 * UNROLL_N;
416
417    // Copy in two halves.
418
419#define UNROLL_N_CHUNK (UNROLL_N / 2)
420#if COPY_SUM
421    SUM_T sums[UNROLL_N];
422#endif
423    ELEMENT_INT cols[UNROLL_N / 2];
424
425    for (int c0 = 0; c0 < UNROLL_N;
426            c0 += UNROLL_N_CHUNK, nrem -= UNROLL_N_CHUNK) {
427        // Read all columns.
428        if (krem >= UNROLL_K && nrem >= UNROLL_N_CHUNK && aligned) {
429            for (int c = 0; c < UNROLL_N_CHUNK; c++)
430                cols[c] = BLOCK_READ_ELEMENT_INT(b + (c + c0) * ldb);
431        } else {
432            for (int c = 0; c < UNROLL_N_CHUNK; c++)
433                if (c < nrem)
434                    cols[c] = MASKED_BLOCK_READ_ELEMENT_INT(
435                            b + (c + c0) * ldb, krem);
436                else
437                    cols[c] = 0;
438        }
439
440        // Repack.
441        ELEMENT_INT4 colgroups[UNROLL_N_CHUNK / 4];
442        for (int cc = 0; cc < UNROLL_N_CHUNK / 4; cc++)
443            REPACK_CC(cc);
444
445        // Write out.
446        for (int cc = 0; cc < UNROLL_N_CHUNK / 4; cc++)
447            BLOCK_WRITE_ELEMENT_INT4(
448                    b_packed + (cc * 4 + c0) * UNROLL_K, colgroups[cc]);
449
450            // Sum if needed.
451#if COPY_SUM
452        for (int c = 0; c < UNROLL_N_CHUNK; c++)
453            sums[c + c0] = sum(CONVERT_SUM_T4(AS_SIGNED_ELEMENT_INT(cols[c])));
454#endif
455    }
456
457    // Accumulate sums.
458#if COPY_SUM
459    for (int c0 = 0; c0 < UNROLL_N; c0 += get_sub_group_size())
460        atomic_add(b_sum + c0 + lid, sums[c0 + lid]);
461#endif
462
463    DUMMY_DPAS;
464}
465
466#else /* COPY_TRANS */
467
468#define ADD_SUM(coln) \
469    for (int cc = 0; cc < UNROLL_N / 4; cc++) { \
470        sums[4 * cc + 0] \
471                += sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s0))); \
472        sums[4 * cc + 1] \
473                += sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s1))); \
474        sums[4 * cc + 2] \
475                += sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s2))); \
476        sums[4 * cc + 3] \
477                += sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s3))); \
478    }
479
480// Transpose B copy.
481__attribute__((intel_reqd_workgroup_walk_order(1, 0)))
482__attribute__((intel_reqd_sub_group_size(8))) kernel void
483xe_hp_systolic_gemm_copy(long k, long n, global ELEMENT *b, long offsetb,
484        long ldb, global ELEMENT *b_packed, int offsetb_packed,
485        int ldb_packed) {
486
487    int lid = get_sub_group_local_id();
488    uint k0 = (sub_group_broadcast(get_global_id(0), 0) / 8) * UNROLL_K;
489    uint n0 = get_global_id(1) * UNROLL_N;
490    int krem = k - k0;
491    int nrem = n - n0;
492    int sg = get_sub_group_size();
493
494    if (nrem <= 0 || krem <= 0) return;
495
496    GET_B_SUM_ADDRESS;
497    b += offsetb + n0 + k0 * ldb;
498    b_packed += offsetb_packed + n0 * ldb_packed + k0 * UNROLL_N;
499
500    // Read upper 16x{32,48} submatrix.
501    ELEMENT_INT cols[UNROLL_N / ELEMENTS_PER_INT];
502    ELEMENT_INT cols2[UNROLL_N / ELEMENTS_PER_INT];
503    ELEMENT_INT4 colgroups[UNROLL_N / 4];
504    if (krem >= 2 * sg && nrem >= UNROLL_N) {
505        for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++) {
506            cols[cc] = VLOAD_ELEMENT_INT(
507                    0, b + cc * ELEMENTS_PER_INT + lid * ldb);
508            cols2[cc] = VLOAD_ELEMENT_INT(
509                    0, b + cc * ELEMENTS_PER_INT + (lid + sg) * ldb);
510        }
511    } else {
512        for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++) {
513            cols[cc] = 0;
514            cols2[cc] = 0;
515            PARTIAL_LOAD(cols, krem, nrem, cc,
516                    b + cc * ELEMENTS_PER_INT + lid * ldb);
517            PARTIAL_LOAD(cols2, krem - sg, nrem, cc,
518                    b + cc * ELEMENTS_PER_INT + (lid + sg) * ldb);
519        }
520    }
521#if ELEMENT_SIZE == 2
522    // Repack.
523    for (int cc = 0; cc < UNROLL_N / 4; cc++)
524        REPACK_CC2(cc);
525#else
526    // Read lower 16x{32,48} submatrix.
527    ELEMENT_INT cols3[UNROLL_N / ELEMENTS_PER_INT];
528    ELEMENT_INT cols4[UNROLL_N / ELEMENTS_PER_INT];
529    krem -= 2 * sg;
530    if (krem >= 2 * sg && nrem >= UNROLL_N) {
531        for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++) {
532            cols3[cc] = VLOAD_ELEMENT_INT(
533                    0, b + cc * ELEMENTS_PER_INT + (lid + 2 * sg) * ldb);
534            cols4[cc] = VLOAD_ELEMENT_INT(
535                    0, b + cc * ELEMENTS_PER_INT + (lid + 3 * sg) * ldb);
536        }
537    } else {
538        for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++) {
539            cols3[cc] = 0;
540            cols4[cc] = 0;
541            PARTIAL_LOAD(cols3, krem, nrem, cc,
542                    b + cc * ELEMENTS_PER_INT + (lid + 2 * sg) * ldb);
543            PARTIAL_LOAD(cols4, krem - sg, nrem, cc,
544                    b + cc * ELEMENTS_PER_INT + (lid + 3 * sg) * ldb);
545        }
546    }
547    for (int cc = 0; cc < UNROLL_N / 4; cc++)
548        REPACK_CC4(cc);
549#endif
550
551    // Write out.
552    for (int cc = 0; cc < UNROLL_N / 4; cc++)
553        BLOCK_WRITE_ELEMENT_INT4(b_packed + cc * 4 * UNROLL_K, colgroups[cc]);
554
555#if COPY_SUM
556    SUM_T sums[UNROLL_N] = {0};
557    ADD_SUM(cols);
558    ADD_SUM(cols2);
559    ADD_SUM(cols3);
560    ADD_SUM(cols4);
561
562    for (int c0 = 0; c0 < UNROLL_N; c0 += get_sub_group_size())
563        atomic_add(b_sum + c0 + lid, sums[c0 + lid]);
564#endif
565
566    DUMMY_DPAS;
567}
568
569#endif /* !COPY_TRANS */
570#endif /* COPY_B */
571