1/*******************************************************************************
2* Copyright 2020-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8*     http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/ocl_post_ops.h"
18#include "gpu/ocl/ocl_types.h"
19
20// Read functions.
21inline VECT_DATA_T read_vect_c_block(int idx, const __global DATA_T *ptr, int c,
22        int blocks_stride, int chunks_per_block);
23inline VECT_INT_T read_vect_c_block_int(int idx, const __global int *ptr, int c,
24        int blocks_stride, int chunks_per_block);
25
26// Write functions.
27inline void write_vect_c_block(int idx, __global DATA_T *ptr, int c,
28        int blocks_stride, int chunks_per_block, VECT_DATA_T block);
29inline void write_vect_c_block_int(int idx, __global int *ptr, int c,
30        int blocks_stride, int chunks_per_block, VECT_INT_T block);
31
32#if DT_BF16
33#define USE_FLOATS true
34#elif DT_F16
35#define USE_FLOATS false
36#else
37#define USE_FLOATS (ALG_AVG_NP || ALG_AVG_P)
38#endif
39
40#if IS_FWD
41KERNEL_ATTR
42__kernel void gen9_pooling_fwd(__global DATA_T *src, __global int *ws,
43        __global DATA_T *dst POST_OP_ARGS) {
44    const int mb = GWS_GET_MB();
45    const int c = GWS_GET_C();
46    const int od = GWS_GET_OD();
47    const int oh = GWS_GET_OH();
48    const int ow = GWS_GET_OW();
49
50    // Calculate number of subgroup chunks inside C block
51    // and stride between consecutive MB/C blocks
52#if USE_MB_C_BLOCK
53    const int src_stride = (SRC_SB0 > 1) ? SRC_SB0 : SRC_S0;
54    const int dst_stride = (DST_SB0 > 1) ? DST_SB0 : DST_S0;
55    const int src_chunks_per_c_block = CHUNKS_PER_C_BLOCK;
56    const int dst_chunks_per_c_block = CHUNKS_PER_C_BLOCK;
57#elif USE_ONLY_C_BLOCK
58    const int src_stride = (SRC_B1 > 1) ? SRC_S1 : SUB_GROUP_SIZE;
59    const int dst_stride = (DST_B1 > 1) ? DST_S1 : SUB_GROUP_SIZE;
60    const int src_chunks_per_c_block
61            = (SRC_B1 > 1) ? (SRC_B1 / SUB_GROUP_SIZE) : 1;
62    const int dst_chunks_per_c_block
63            = (DST_B1 > 1) ? (DST_B1 / SUB_GROUP_SIZE) : 1;
64#endif
65
66    const int ws_stride = dst_stride;
67    const int ws_chunks_per_c_block = dst_chunks_per_c_block;
68
69    if (mb >= SRC_D0) {
70        VECT_DATA_T dst_zero = DATA_ZERO;
71        VECT_INT_T ws_zero = 0;
72        int off = DST_OFF(mb, c, od, oh, ow);
73        write_vect_c_block(
74                0, &dst[off], c, dst_stride, dst_chunks_per_c_block, dst_zero);
75        write_vect_c_block(
76                1, &dst[off], c, dst_stride, dst_chunks_per_c_block, dst_zero);
77#if ALG_MAX && IS_TRAINING
78        write_vect_c_block_int(
79                0, &ws[off], c, ws_stride, ws_chunks_per_c_block, ws_zero);
80        write_vect_c_block_int(
81                1, &ws[off], c, ws_stride, ws_chunks_per_c_block, ws_zero);
82#endif // ALG_MAX && IS_TRAINING
83
84        return;
85    }
86
87    const int id = od * SD - PD;
88    const int ih = oh * SH - PH;
89    const int iw = ow * SW - PW;
90#if USE_FLOATS
91    VECT_FLOAT_T D0 = ALG_MAX ? DATA_MIN : DATA_ZERO;
92    VECT_FLOAT_T D1 = ALG_MAX ? DATA_MIN : DATA_ZERO;
93#else // USE_FLOATS
94    VECT_DATA_T D0 = ALG_MAX ? DATA_MIN : DATA_ZERO;
95    VECT_DATA_T D1 = ALG_MAX ? DATA_MIN : DATA_ZERO;
96#endif // USE_FLOATS
97    VECT_INT_T WS0 = 0, WS1 = 0;
98
99    for (int kd = 0; kd < KD; ++kd)
100        for (int kh = 0; kh < KH; ++kh) {
101            for (int kw = 0; kw < KW; ++kw) {
102                if (id + kd < 0 || id + kd >= ID) continue;
103                if (ih + kh < 0 || ih + kh >= IH) continue;
104                if (iw + kw < 0 || iw + kw >= IW) continue;
105
106                int src_off = SRC_OFF(mb, c, id + kd, ih + kh, iw + kw);
107#if USE_FLOATS
108                VECT_FLOAT_T S0 = CONVERT_VECT_FLOAT_T(read_vect_c_block(0,
109                        &src[src_off], c, src_stride, src_chunks_per_c_block));
110                VECT_FLOAT_T S1 = CONVERT_VECT_FLOAT_T(read_vect_c_block(1,
111                        &src[src_off], c, src_stride, src_chunks_per_c_block));
112#else // USE_FLOATS
113                VECT_DATA_T S0 = read_vect_c_block(0, &src[src_off], c,
114                        src_stride, src_chunks_per_c_block);
115                VECT_DATA_T S1 = read_vect_c_block(1, &src[src_off], c,
116                        src_stride, src_chunks_per_c_block);
117#endif // USE_FLOATS
118
119#if ALG_MAX
120#if IS_TRAINING
121                VECT_INT_T CMP0 = isless(D0, S0);
122                WS0 = select(WS0, kd * KH * KW + kh * KW + kw, CMP0);
123                D0 = select(D0, S0, CMP0);
124
125                VECT_INT_T CMP1 = isless(D1, S1);
126                WS1 = select(WS1, kd * KH * KW + kh * KW + kw, CMP1);
127                D1 = select(D1, S1, CMP1);
128
129#else // TRAINING
130                D0 = max(D0, S0);
131                D1 = max(D1, S1);
132#endif // TRAINING
133#else // ALG_MAX
134                D0 += S0;
135                D1 += S1;
136#endif // ALG_MAX
137            }
138        }
139
140#if ALG_AVG_P
141    D0 = D0 / (KD * KH * KW);
142    D1 = D1 / (KD * KH * KW);
143
144#endif // ALG_AVG_P
145
146#if ALG_AVG_NP
147    const int id_start = max(od * SD - PD, 0);
148    const int ih_start = max(oh * SH - PH, 0);
149    const int iw_start = max(ow * SW - PW, 0);
150    const int id_end = min(od * SD - PD + KD, ID);
151    const int ih_end = min(oh * SH - PH + KH, IH);
152    const int iw_end = min(ow * SW - PW + KW, IW);
153    const DATA_T num_summands
154            = (ih_end - ih_start) * (iw_end - iw_start) * (id_end - id_start);
155    D0 = D0 / num_summands;
156    D1 = D1 / num_summands;
157#endif // ALG_AVG_NP
158
159    int dst_off = DST_OFF(mb, c, od, oh, ow);
160    VECT_DATA_T sum0;
161    VECT_DATA_T sum1;
162#if WITH_SUM
163    sum0 = read_vect_c_block(
164            0, &dst[dst_off], c, dst_stride, dst_chunks_per_c_block);
165    sum1 = read_vect_c_block(
166            1, &dst[dst_off], c, dst_stride, dst_chunks_per_c_block);
167#endif
168
169    const int local_id = get_sub_group_local_id();
170
171#if VECT_DT_N == 1
172    const int po_mb = mb;
173    const int po_oc = c + local_id;
174    if (po_oc < C_WO_PADDING) {
175        POST_OP_DATA_T po_sum0 = DATA_TO_REF(sum0);
176        float po_D0 = USE_FLOATS ? D0 : CONVERT_FLOAT_T(D0);
177        APPLY_POST_OPS_SERIAL_BINARY_2D(
178                po_D0, float, po_sum0, POST_OP_DATA_T, po_mb, 1, po_oc, 1);
179        D0 = USE_FLOATS ? po_D0 : CONVERT_DATA_T(po_D0);
180
181        POST_OP_DATA_T po_sum1 = DATA_TO_REF(sum1);
182        float po_D1 = USE_FLOATS ? D1 : CONVERT_FLOAT_T(D1);
183        APPLY_POST_OPS_SERIAL_BINARY_2D(
184                po_D1, float, po_sum1, POST_OP_DATA_T, po_mb, 1, po_oc, 1);
185        D1 = USE_FLOATS ? po_D1 : CONVERT_DATA_T(po_D1);
186    }
187
188#else
189    for (int idx = 0; idx < VECT_DT_N; ++idx) {
190#if USE_MB_C_BLOCK
191        int c_sub_block_id = idx % CHUNKS_PER_C_BLOCK;
192        int mb_sub_block_id = idx / CHUNKS_PER_C_BLOCK;
193        const int po_oc = c + c_sub_block_id * SUB_GROUP_SIZE + local_id;
194        int po_mb = (mb + mb_sub_block_id) % MB;
195#else // USE_MB_C_BLOCK
196        const int po_oc = c + idx * SUB_GROUP_SIZE + local_id;
197        int po_mb = mb;
198#endif // USE_MB_C_BLOCK
199
200        if (po_mb >= MB || po_oc >= C_WO_PADDING) continue;
201
202        float d0_i = USE_FLOATS ? D0[idx] : CONVERT_FLOAT_T(D0[idx]);
203        POST_OP_DATA_T sum0_i = DATA_TO_REF(sum0[idx]);
204        APPLY_POST_OPS_SERIAL_BINARY_2D(
205                d0_i, float, sum0_i, POST_OP_DATA_T, po_mb, 1, po_oc, 1);
206        D0[idx] = USE_FLOATS ? d0_i : CONVERT_DATA_T(d0_i);
207
208        float d1_i = USE_FLOATS ? D1[idx] : CONVERT_FLOAT_T(D1[idx]);
209        POST_OP_DATA_T sum1_i = DATA_TO_REF(sum1[idx]);
210        po_mb += VECT_DT_N;
211        APPLY_POST_OPS_SERIAL_BINARY_2D(
212                d1_i, float, sum1_i, POST_OP_DATA_T, po_mb, 1, po_oc, 1);
213        D1[idx] = USE_FLOATS ? d1_i : CONVERT_DATA_T(d1_i);
214    }
215#endif // #if VECT_DT_N == 1
216#if USE_FLOATS
217    VECT_DATA_T res0 = CONVERT_VECTOR_DATA_T(D0);
218    VECT_DATA_T res1 = CONVERT_VECTOR_DATA_T(D1);
219#else
220    VECT_DATA_T res0 = D0;
221    VECT_DATA_T res1 = D1;
222#endif
223    write_vect_c_block(
224            0, &dst[dst_off], c, dst_stride, dst_chunks_per_c_block, res0);
225    write_vect_c_block(
226            1, &dst[dst_off], c, dst_stride, dst_chunks_per_c_block, res1);
227
228#if ALG_MAX && IS_TRAINING
229    int ws_off = dst_off;
230    write_vect_c_block_int(
231            0, &ws[ws_off], c, ws_stride, ws_chunks_per_c_block, WS0);
232    write_vect_c_block_int(
233            1, &ws[ws_off], c, ws_stride, ws_chunks_per_c_block, WS1);
234#endif // ALG_MAX && IS_TRAINING
235}
236#endif
237
238#if IS_BWD
239KERNEL_ATTR
240__kernel void gen9_pooling_bwd(__global DATA_T *diff_src, __global int *ws,
241        __global DATA_T *diff_dst) {
242
243    const int mb = GWS_GET_MB();
244    const int c = GWS_GET_C();
245    const int id = GWS_GET_ID();
246    const int ih = GWS_GET_IH();
247    const int iw = GWS_GET_IW();
248
249    // Calculate number of subgroup chunks inside C block
250    // and stride between consecutive MB/C blocks
251#if USE_MB_C_BLOCK
252    const int src_stride = (SRC_SB0 > 1) ? SRC_SB0 : SRC_S0;
253    const int dst_stride = (DST_SB0 > 1) ? DST_SB0 : DST_S0;
254    const int src_chunks_per_c_block = CHUNKS_PER_C_BLOCK;
255    const int dst_chunks_per_c_block = CHUNKS_PER_C_BLOCK;
256#elif USE_ONLY_C_BLOCK
257    const int src_stride = (SRC_B1 > 1) ? SRC_S1 : SUB_GROUP_SIZE;
258    const int dst_stride = (DST_B1 > 1) ? DST_S1 : SUB_GROUP_SIZE;
259    const int src_chunks_per_c_block
260            = (SRC_B1 > 1) ? (SRC_B1 / SUB_GROUP_SIZE) : 1;
261    const int dst_chunks_per_c_block
262            = (DST_B1 > 1) ? (DST_B1 / SUB_GROUP_SIZE) : 1;
263#endif
264
265    const int ws_stride = dst_stride;
266    const int ws_chunks_per_c_block = dst_chunks_per_c_block;
267
268    VECT_FLOAT_T S0 = 0, S1 = 0;
269    for (int kd = 0; kd < KD; kd++) {
270        for (int kh = 0; kh < KH; kh++) {
271            for (int kw = 0; kw < KW; kw++) {
272                int od = (id + PD - kd);
273                int oh = (ih + PH - kh);
274                int ow = (iw + PW - kw);
275                if (od % SD != 0 || oh % SH != 0 || ow % SW != 0) continue;
276                od /= SD;
277                oh /= SH;
278                ow /= SW;
279                if (od < 0 || od >= OD) continue;
280                if (oh < 0 || oh >= OH) continue;
281                if (ow < 0 || ow >= OW) continue;
282
283                const int dst_off = DST_OFF(mb, c, od, oh, ow);
284                VECT_FLOAT_T D0 = CONVERT_VECT_FLOAT_T(
285                        read_vect_c_block(0, &diff_dst[dst_off], c, dst_stride,
286                                dst_chunks_per_c_block));
287                VECT_FLOAT_T D1 = CONVERT_VECT_FLOAT_T(
288                        read_vect_c_block(1, &diff_dst[dst_off], c, dst_stride,
289                                dst_chunks_per_c_block));
290
291#if ALG_MAX
292                VECT_INT_T WS0 = read_vect_c_block_int(
293                        0, &ws[dst_off], c, ws_stride, ws_chunks_per_c_block);
294                VECT_INT_T WS1 = read_vect_c_block_int(
295                        1, &ws[dst_off], c, ws_stride, ws_chunks_per_c_block);
296
297                VECT_INT_T CMP0 = isnotequal(
298                        AS_VECT_FLOAT_T(WS0 - kd * KH * KW - kh * KW - kw),
299                        (VECT_FLOAT_T)0);
300                D0 = select(D0, (VECT_FLOAT_T)0, CMP0);
301
302                VECT_INT_T CMP1 = isnotequal(
303                        AS_VECT_FLOAT_T(WS1 - kd * KH * KW - kh * KW - kw),
304                        (VECT_FLOAT_T)0);
305                D1 = select(D1, (VECT_FLOAT_T)0, CMP1);
306#endif
307#if ALG_AVG_NP
308                const int id_start = max(id - kd, 0);
309                const int ih_start = max(ih - kh, 0);
310                const int iw_start = max(iw - kw, 0);
311                const int id_end = min(id - kd + KD, ID);
312                const int ih_end = min(ih - kh + KH, IH);
313                const int iw_end = min(iw - kw + KW, IW);
314                const float num_summands = (ih_end - ih_start)
315                        * (iw_end - iw_start) * (id_end - id_start);
316                D0 /= num_summands;
317                D1 /= num_summands;
318#endif
319                S0 += D0;
320                S1 += D1;
321            }
322        }
323    }
324#if ALG_AVG_P
325    S0 /= KD * KH * KW;
326    S1 /= KD * KH * KW;
327#endif
328
329    int src_off = SRC_OFF(mb, c, id, ih, iw);
330    write_vect_c_block(0, &diff_src[src_off], c, src_stride,
331            src_chunks_per_c_block, CONVERT_VECTOR_DATA_T(S0));
332    write_vect_c_block(1, &diff_src[src_off], c, src_stride,
333            src_chunks_per_c_block, CONVERT_VECTOR_DATA_T(S1));
334}
335#endif
336
337inline DATA_T read_c_block(const __global DATA_T *ptr, int c) {
338#if C_W_PADDING % SUB_GROUP_SIZE != 0
339    int local_id = get_sub_group_local_id();
340    int tail = C_WO_PADDING - c;
341    return (local_id < tail) ? ptr[local_id] : 0;
342#else
343    return AS_DATA_T(BLOCK_READ((const __global BLOCK_DATA_T *)ptr));
344#endif
345}
346
347#define CALC_VECT_LEN() \
348    ({ \
349        int size; \
350        if (USE_ONLY_C_BLOCK == 1 \
351                && VECT_DT_N > C_WO_PADDING / SUB_GROUP_SIZE + 1) \
352            size = C_WO_PADDING / SUB_GROUP_SIZE + 1; \
353        else \
354            size = VECT_DT_N; \
355        size; \
356    })
357
358inline VECT_DATA_T read_vect_c_block(int idx, const __global DATA_T *ptr, int c,
359        int blocks_stride, int chunks_per_block) {
360    if (idx >= NVECT) return 0;
361
362    if ((blocks_stride == chunks_per_block * SUB_GROUP_SIZE)
363            && (C_WO_PADDING % (chunks_per_block * SUB_GROUP_SIZE) == 0)) {
364        return AS_VECT_DATA_T(VECT_BLOCK_READ((const __global BLOCK_DATA_T *)ptr
365                + idx * VECT_DT_N * SUB_GROUP_SIZE));
366    } else {
367        VECT_DATA_T ret;
368        for (int i = 0; i < CALC_VECT_LEN(); i++) {
369            const int offset_index = (idx * VECT_DT_N + i);
370            const int local_c_block_index = offset_index % chunks_per_block;
371            const int global_c_block_index = offset_index / chunks_per_block;
372            const int ptr_offset = local_c_block_index * SUB_GROUP_SIZE
373                    + global_c_block_index * blocks_stride;
374            const int c_off
375                    = (USE_ONLY_C_BLOCK ? offset_index * SUB_GROUP_SIZE
376                                        : local_c_block_index * SUB_GROUP_SIZE);
377#if VECT_DT_N == 1
378            ret = read_c_block(ptr + ptr_offset, c + c_off);
379#else
380            ret[i] = read_c_block(ptr + ptr_offset, c + c_off);
381#endif
382        }
383#if VECT_DT_N > 1
384        for (int i = CALC_VECT_LEN(); i < VECT_DT_N; ++i) {
385            ret[i] = 0;
386        }
387#endif
388        return ret;
389    }
390}
391
392inline int read_c_block_int(const __global int *ptr, int c) {
393#if C_W_PADDING % SUB_GROUP_SIZE != 0
394    int local_id = get_sub_group_local_id();
395    int tail = C_WO_PADDING - c;
396    return (local_id < tail) ? ptr[local_id] : 0;
397#else
398    return as_int(intel_sub_group_block_read((const __global uint *)ptr));
399#endif
400}
401
402inline VECT_INT_T read_vect_c_block_int(int idx, const __global int *ptr, int c,
403        int blocks_stride, int chunks_per_block) {
404    if (idx >= NVECT) return 0;
405
406    if ((blocks_stride == chunks_per_block * SUB_GROUP_SIZE)
407            && (C_WO_PADDING % (chunks_per_block * SUB_GROUP_SIZE) == 0)) {
408        return AS_VECT_INT_T(VECT_UINT_READ(
409                (const __global uint *)ptr + idx * VECT_DT_N * SUB_GROUP_SIZE));
410    } else {
411        VECT_INT_T ret;
412        for (int i = 0; i < VECT_DT_N; i++) {
413            const int offset_index = (idx * VECT_DT_N + i);
414            const int local_c_block_index = offset_index % chunks_per_block;
415            const int global_c_block_index = offset_index / chunks_per_block;
416            const int ptr_offset = local_c_block_index * SUB_GROUP_SIZE
417                    + global_c_block_index * blocks_stride;
418            const int c_off
419                    = (USE_ONLY_C_BLOCK ? offset_index * SUB_GROUP_SIZE
420                                        : local_c_block_index * SUB_GROUP_SIZE);
421#if VECT_DT_N == 1
422            ret = read_c_block_int(ptr + ptr_offset, c + c_off);
423#else
424            ret[i] = read_c_block_int(ptr + ptr_offset, c + c_off);
425#endif
426        }
427        return ret;
428    }
429}
430
431inline void write_c_block(__global DATA_T *ptr, int c, DATA_T value) {
432#if C_W_PADDING % SUB_GROUP_SIZE != 0
433    int local_id = get_sub_group_local_id();
434    int tail = C_WO_PADDING - c;
435
436    if (local_id < tail) ptr[local_id] = value;
437#else
438#if C_WO_PADDING % SUB_GROUP_SIZE != 0
439    int local_id = get_sub_group_local_id();
440    if (local_id >= C_WO_PADDING - c && local_id < C_W_PADDING - c) value = 0;
441#endif
442    if (c >= C_WO_PADDING) {
443        BLOCK_WRITE((__global BLOCK_DATA_T *)ptr,
444                AS_BLOCK_DATA_T(CONVERT_DATA_T(DATA_ZERO)));
445        return;
446    }
447    BLOCK_WRITE((__global BLOCK_DATA_T *)ptr, AS_BLOCK_DATA_T(value));
448#endif
449}
450
451inline void write_vect_c_block(int idx, __global DATA_T *ptr, int c,
452        int blocks_stride, int chunks_per_block, VECT_DATA_T block) {
453    if (idx >= NVECT) return;
454
455    if ((blocks_stride == chunks_per_block * SUB_GROUP_SIZE)
456            && (C_WO_PADDING % (chunks_per_block * SUB_GROUP_SIZE) == 0)) {
457        VECT_BLOCK_WRITE(
458                (__global BLOCK_DATA_T *)ptr + idx * VECT_DT_N * SUB_GROUP_SIZE,
459                AS_VECT_BLOCK_DATA_T(block));
460    } else {
461        for (int i = 0; i < VECT_DT_N; i++) {
462            const int offset_index = (idx * VECT_DT_N + i);
463            const int local_c_block_index = offset_index % chunks_per_block;
464            const int global_c_block_index = offset_index / chunks_per_block;
465            const int ptr_offset = local_c_block_index * SUB_GROUP_SIZE
466                    + global_c_block_index * blocks_stride;
467            const int c_off
468                    = (USE_ONLY_C_BLOCK ? offset_index * SUB_GROUP_SIZE
469                                        : local_c_block_index * SUB_GROUP_SIZE);
470#if VECT_DT_N == 1
471            write_c_block(ptr + ptr_offset, c + c_off, block);
472#else
473            write_c_block(ptr + ptr_offset, c + c_off, block[i]);
474#endif
475        }
476    }
477}
478
479inline void write_c_block_int(__global int *ptr, int c, int value) {
480#if C_WO_PADDING % SUB_GROUP_SIZE != 0
481    int local_id = get_sub_group_local_id();
482    int tail = C_WO_PADDING - c;
483    if (local_id < tail)
484        ptr[local_id] = value;
485    else if (local_id < C_W_PADDING - c) {
486        ptr[local_id] = 0;
487    } else
488        return;
489#else
490    if (c >= C_WO_PADDING) {
491        intel_sub_group_block_write((__global uint *)ptr, 0);
492        return;
493    }
494    intel_sub_group_block_write((__global uint *)ptr, as_uint(value));
495#endif
496}
497
498inline void write_vect_c_block_int(int idx, __global int *ptr, int c,
499        int blocks_stride, int chunks_per_block, VECT_INT_T block) {
500    if (idx >= NVECT) return;
501
502    if ((blocks_stride == chunks_per_block * SUB_GROUP_SIZE)
503            && (C_WO_PADDING % (chunks_per_block * SUB_GROUP_SIZE) == 0)) {
504        VECT_UINT_WRITE((__global uint *)ptr + idx * VECT_DT_N * SUB_GROUP_SIZE,
505                AS_VECT_UINT_T(block));
506    } else {
507        for (int i = 0; i < VECT_DT_N; i++) {
508            const int offset_index = (idx * VECT_DT_N + i);
509            const int local_c_block_index = offset_index % chunks_per_block;
510            const int global_c_block_index = offset_index / chunks_per_block;
511            const int ptr_offset = local_c_block_index * SUB_GROUP_SIZE
512                    + global_c_block_index * blocks_stride;
513            const int c_off
514                    = (USE_ONLY_C_BLOCK ? offset_index * SUB_GROUP_SIZE
515                                        : local_c_block_index * SUB_GROUP_SIZE);
516#if VECT_DT_N == 1
517            write_c_block_int(ptr + ptr_offset, c + c_off, block);
518#else
519            write_c_block_int(ptr + ptr_offset, c + c_off, block[i]);
520#endif
521        }
522    }
523}
524