1/******************************************************************************* 2* Copyright 2020-2021 Intel Corporation 3* 4* Licensed under the Apache License, Version 2.0 (the "License"); 5* you may not use this file except in compliance with the License. 6* You may obtain a copy of the License at 7* 8* http://www.apache.org/licenses/LICENSE-2.0 9* 10* Unless required by applicable law or agreed to in writing, software 11* distributed under the License is distributed on an "AS IS" BASIS, 12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13* See the License for the specific language governing permissions and 14* limitations under the License. 15*******************************************************************************/ 16 17#include "gpu/ocl/ocl_post_ops.h" 18#include "gpu/ocl/ocl_types.h" 19 20// Read functions. 21inline VECT_DATA_T read_vect_c_block(int idx, const __global DATA_T *ptr, int c, 22 int blocks_stride, int chunks_per_block); 23inline VECT_INT_T read_vect_c_block_int(int idx, const __global int *ptr, int c, 24 int blocks_stride, int chunks_per_block); 25 26// Write functions. 27inline void write_vect_c_block(int idx, __global DATA_T *ptr, int c, 28 int blocks_stride, int chunks_per_block, VECT_DATA_T block); 29inline void write_vect_c_block_int(int idx, __global int *ptr, int c, 30 int blocks_stride, int chunks_per_block, VECT_INT_T block); 31 32#if DT_BF16 33#define USE_FLOATS true 34#elif DT_F16 35#define USE_FLOATS false 36#else 37#define USE_FLOATS (ALG_AVG_NP || ALG_AVG_P) 38#endif 39 40#if IS_FWD 41KERNEL_ATTR 42__kernel void gen9_pooling_fwd(__global DATA_T *src, __global int *ws, 43 __global DATA_T *dst POST_OP_ARGS) { 44 const int mb = GWS_GET_MB(); 45 const int c = GWS_GET_C(); 46 const int od = GWS_GET_OD(); 47 const int oh = GWS_GET_OH(); 48 const int ow = GWS_GET_OW(); 49 50 // Calculate number of subgroup chunks inside C block 51 // and stride between consecutive MB/C blocks 52#if USE_MB_C_BLOCK 53 const int src_stride = (SRC_SB0 > 1) ? SRC_SB0 : SRC_S0; 54 const int dst_stride = (DST_SB0 > 1) ? DST_SB0 : DST_S0; 55 const int src_chunks_per_c_block = CHUNKS_PER_C_BLOCK; 56 const int dst_chunks_per_c_block = CHUNKS_PER_C_BLOCK; 57#elif USE_ONLY_C_BLOCK 58 const int src_stride = (SRC_B1 > 1) ? SRC_S1 : SUB_GROUP_SIZE; 59 const int dst_stride = (DST_B1 > 1) ? DST_S1 : SUB_GROUP_SIZE; 60 const int src_chunks_per_c_block 61 = (SRC_B1 > 1) ? (SRC_B1 / SUB_GROUP_SIZE) : 1; 62 const int dst_chunks_per_c_block 63 = (DST_B1 > 1) ? (DST_B1 / SUB_GROUP_SIZE) : 1; 64#endif 65 66 const int ws_stride = dst_stride; 67 const int ws_chunks_per_c_block = dst_chunks_per_c_block; 68 69 if (mb >= SRC_D0) { 70 VECT_DATA_T dst_zero = DATA_ZERO; 71 VECT_INT_T ws_zero = 0; 72 int off = DST_OFF(mb, c, od, oh, ow); 73 write_vect_c_block( 74 0, &dst[off], c, dst_stride, dst_chunks_per_c_block, dst_zero); 75 write_vect_c_block( 76 1, &dst[off], c, dst_stride, dst_chunks_per_c_block, dst_zero); 77#if ALG_MAX && IS_TRAINING 78 write_vect_c_block_int( 79 0, &ws[off], c, ws_stride, ws_chunks_per_c_block, ws_zero); 80 write_vect_c_block_int( 81 1, &ws[off], c, ws_stride, ws_chunks_per_c_block, ws_zero); 82#endif // ALG_MAX && IS_TRAINING 83 84 return; 85 } 86 87 const int id = od * SD - PD; 88 const int ih = oh * SH - PH; 89 const int iw = ow * SW - PW; 90#if USE_FLOATS 91 VECT_FLOAT_T D0 = ALG_MAX ? DATA_MIN : DATA_ZERO; 92 VECT_FLOAT_T D1 = ALG_MAX ? DATA_MIN : DATA_ZERO; 93#else // USE_FLOATS 94 VECT_DATA_T D0 = ALG_MAX ? DATA_MIN : DATA_ZERO; 95 VECT_DATA_T D1 = ALG_MAX ? DATA_MIN : DATA_ZERO; 96#endif // USE_FLOATS 97 VECT_INT_T WS0 = 0, WS1 = 0; 98 99 for (int kd = 0; kd < KD; ++kd) 100 for (int kh = 0; kh < KH; ++kh) { 101 for (int kw = 0; kw < KW; ++kw) { 102 if (id + kd < 0 || id + kd >= ID) continue; 103 if (ih + kh < 0 || ih + kh >= IH) continue; 104 if (iw + kw < 0 || iw + kw >= IW) continue; 105 106 int src_off = SRC_OFF(mb, c, id + kd, ih + kh, iw + kw); 107#if USE_FLOATS 108 VECT_FLOAT_T S0 = CONVERT_VECT_FLOAT_T(read_vect_c_block(0, 109 &src[src_off], c, src_stride, src_chunks_per_c_block)); 110 VECT_FLOAT_T S1 = CONVERT_VECT_FLOAT_T(read_vect_c_block(1, 111 &src[src_off], c, src_stride, src_chunks_per_c_block)); 112#else // USE_FLOATS 113 VECT_DATA_T S0 = read_vect_c_block(0, &src[src_off], c, 114 src_stride, src_chunks_per_c_block); 115 VECT_DATA_T S1 = read_vect_c_block(1, &src[src_off], c, 116 src_stride, src_chunks_per_c_block); 117#endif // USE_FLOATS 118 119#if ALG_MAX 120#if IS_TRAINING 121 VECT_INT_T CMP0 = isless(D0, S0); 122 WS0 = select(WS0, kd * KH * KW + kh * KW + kw, CMP0); 123 D0 = select(D0, S0, CMP0); 124 125 VECT_INT_T CMP1 = isless(D1, S1); 126 WS1 = select(WS1, kd * KH * KW + kh * KW + kw, CMP1); 127 D1 = select(D1, S1, CMP1); 128 129#else // TRAINING 130 D0 = max(D0, S0); 131 D1 = max(D1, S1); 132#endif // TRAINING 133#else // ALG_MAX 134 D0 += S0; 135 D1 += S1; 136#endif // ALG_MAX 137 } 138 } 139 140#if ALG_AVG_P 141 D0 = D0 / (KD * KH * KW); 142 D1 = D1 / (KD * KH * KW); 143 144#endif // ALG_AVG_P 145 146#if ALG_AVG_NP 147 const int id_start = max(od * SD - PD, 0); 148 const int ih_start = max(oh * SH - PH, 0); 149 const int iw_start = max(ow * SW - PW, 0); 150 const int id_end = min(od * SD - PD + KD, ID); 151 const int ih_end = min(oh * SH - PH + KH, IH); 152 const int iw_end = min(ow * SW - PW + KW, IW); 153 const DATA_T num_summands 154 = (ih_end - ih_start) * (iw_end - iw_start) * (id_end - id_start); 155 D0 = D0 / num_summands; 156 D1 = D1 / num_summands; 157#endif // ALG_AVG_NP 158 159 int dst_off = DST_OFF(mb, c, od, oh, ow); 160 VECT_DATA_T sum0; 161 VECT_DATA_T sum1; 162#if WITH_SUM 163 sum0 = read_vect_c_block( 164 0, &dst[dst_off], c, dst_stride, dst_chunks_per_c_block); 165 sum1 = read_vect_c_block( 166 1, &dst[dst_off], c, dst_stride, dst_chunks_per_c_block); 167#endif 168 169 const int local_id = get_sub_group_local_id(); 170 171#if VECT_DT_N == 1 172 const int po_mb = mb; 173 const int po_oc = c + local_id; 174 if (po_oc < C_WO_PADDING) { 175 POST_OP_DATA_T po_sum0 = DATA_TO_REF(sum0); 176 float po_D0 = USE_FLOATS ? D0 : CONVERT_FLOAT_T(D0); 177 APPLY_POST_OPS_SERIAL_BINARY_2D( 178 po_D0, float, po_sum0, POST_OP_DATA_T, po_mb, 1, po_oc, 1); 179 D0 = USE_FLOATS ? po_D0 : CONVERT_DATA_T(po_D0); 180 181 POST_OP_DATA_T po_sum1 = DATA_TO_REF(sum1); 182 float po_D1 = USE_FLOATS ? D1 : CONVERT_FLOAT_T(D1); 183 APPLY_POST_OPS_SERIAL_BINARY_2D( 184 po_D1, float, po_sum1, POST_OP_DATA_T, po_mb, 1, po_oc, 1); 185 D1 = USE_FLOATS ? po_D1 : CONVERT_DATA_T(po_D1); 186 } 187 188#else 189 for (int idx = 0; idx < VECT_DT_N; ++idx) { 190#if USE_MB_C_BLOCK 191 int c_sub_block_id = idx % CHUNKS_PER_C_BLOCK; 192 int mb_sub_block_id = idx / CHUNKS_PER_C_BLOCK; 193 const int po_oc = c + c_sub_block_id * SUB_GROUP_SIZE + local_id; 194 int po_mb = (mb + mb_sub_block_id) % MB; 195#else // USE_MB_C_BLOCK 196 const int po_oc = c + idx * SUB_GROUP_SIZE + local_id; 197 int po_mb = mb; 198#endif // USE_MB_C_BLOCK 199 200 if (po_mb >= MB || po_oc >= C_WO_PADDING) continue; 201 202 float d0_i = USE_FLOATS ? D0[idx] : CONVERT_FLOAT_T(D0[idx]); 203 POST_OP_DATA_T sum0_i = DATA_TO_REF(sum0[idx]); 204 APPLY_POST_OPS_SERIAL_BINARY_2D( 205 d0_i, float, sum0_i, POST_OP_DATA_T, po_mb, 1, po_oc, 1); 206 D0[idx] = USE_FLOATS ? d0_i : CONVERT_DATA_T(d0_i); 207 208 float d1_i = USE_FLOATS ? D1[idx] : CONVERT_FLOAT_T(D1[idx]); 209 POST_OP_DATA_T sum1_i = DATA_TO_REF(sum1[idx]); 210 po_mb += VECT_DT_N; 211 APPLY_POST_OPS_SERIAL_BINARY_2D( 212 d1_i, float, sum1_i, POST_OP_DATA_T, po_mb, 1, po_oc, 1); 213 D1[idx] = USE_FLOATS ? d1_i : CONVERT_DATA_T(d1_i); 214 } 215#endif // #if VECT_DT_N == 1 216#if USE_FLOATS 217 VECT_DATA_T res0 = CONVERT_VECTOR_DATA_T(D0); 218 VECT_DATA_T res1 = CONVERT_VECTOR_DATA_T(D1); 219#else 220 VECT_DATA_T res0 = D0; 221 VECT_DATA_T res1 = D1; 222#endif 223 write_vect_c_block( 224 0, &dst[dst_off], c, dst_stride, dst_chunks_per_c_block, res0); 225 write_vect_c_block( 226 1, &dst[dst_off], c, dst_stride, dst_chunks_per_c_block, res1); 227 228#if ALG_MAX && IS_TRAINING 229 int ws_off = dst_off; 230 write_vect_c_block_int( 231 0, &ws[ws_off], c, ws_stride, ws_chunks_per_c_block, WS0); 232 write_vect_c_block_int( 233 1, &ws[ws_off], c, ws_stride, ws_chunks_per_c_block, WS1); 234#endif // ALG_MAX && IS_TRAINING 235} 236#endif 237 238#if IS_BWD 239KERNEL_ATTR 240__kernel void gen9_pooling_bwd(__global DATA_T *diff_src, __global int *ws, 241 __global DATA_T *diff_dst) { 242 243 const int mb = GWS_GET_MB(); 244 const int c = GWS_GET_C(); 245 const int id = GWS_GET_ID(); 246 const int ih = GWS_GET_IH(); 247 const int iw = GWS_GET_IW(); 248 249 // Calculate number of subgroup chunks inside C block 250 // and stride between consecutive MB/C blocks 251#if USE_MB_C_BLOCK 252 const int src_stride = (SRC_SB0 > 1) ? SRC_SB0 : SRC_S0; 253 const int dst_stride = (DST_SB0 > 1) ? DST_SB0 : DST_S0; 254 const int src_chunks_per_c_block = CHUNKS_PER_C_BLOCK; 255 const int dst_chunks_per_c_block = CHUNKS_PER_C_BLOCK; 256#elif USE_ONLY_C_BLOCK 257 const int src_stride = (SRC_B1 > 1) ? SRC_S1 : SUB_GROUP_SIZE; 258 const int dst_stride = (DST_B1 > 1) ? DST_S1 : SUB_GROUP_SIZE; 259 const int src_chunks_per_c_block 260 = (SRC_B1 > 1) ? (SRC_B1 / SUB_GROUP_SIZE) : 1; 261 const int dst_chunks_per_c_block 262 = (DST_B1 > 1) ? (DST_B1 / SUB_GROUP_SIZE) : 1; 263#endif 264 265 const int ws_stride = dst_stride; 266 const int ws_chunks_per_c_block = dst_chunks_per_c_block; 267 268 VECT_FLOAT_T S0 = 0, S1 = 0; 269 for (int kd = 0; kd < KD; kd++) { 270 for (int kh = 0; kh < KH; kh++) { 271 for (int kw = 0; kw < KW; kw++) { 272 int od = (id + PD - kd); 273 int oh = (ih + PH - kh); 274 int ow = (iw + PW - kw); 275 if (od % SD != 0 || oh % SH != 0 || ow % SW != 0) continue; 276 od /= SD; 277 oh /= SH; 278 ow /= SW; 279 if (od < 0 || od >= OD) continue; 280 if (oh < 0 || oh >= OH) continue; 281 if (ow < 0 || ow >= OW) continue; 282 283 const int dst_off = DST_OFF(mb, c, od, oh, ow); 284 VECT_FLOAT_T D0 = CONVERT_VECT_FLOAT_T( 285 read_vect_c_block(0, &diff_dst[dst_off], c, dst_stride, 286 dst_chunks_per_c_block)); 287 VECT_FLOAT_T D1 = CONVERT_VECT_FLOAT_T( 288 read_vect_c_block(1, &diff_dst[dst_off], c, dst_stride, 289 dst_chunks_per_c_block)); 290 291#if ALG_MAX 292 VECT_INT_T WS0 = read_vect_c_block_int( 293 0, &ws[dst_off], c, ws_stride, ws_chunks_per_c_block); 294 VECT_INT_T WS1 = read_vect_c_block_int( 295 1, &ws[dst_off], c, ws_stride, ws_chunks_per_c_block); 296 297 VECT_INT_T CMP0 = isnotequal( 298 AS_VECT_FLOAT_T(WS0 - kd * KH * KW - kh * KW - kw), 299 (VECT_FLOAT_T)0); 300 D0 = select(D0, (VECT_FLOAT_T)0, CMP0); 301 302 VECT_INT_T CMP1 = isnotequal( 303 AS_VECT_FLOAT_T(WS1 - kd * KH * KW - kh * KW - kw), 304 (VECT_FLOAT_T)0); 305 D1 = select(D1, (VECT_FLOAT_T)0, CMP1); 306#endif 307#if ALG_AVG_NP 308 const int id_start = max(id - kd, 0); 309 const int ih_start = max(ih - kh, 0); 310 const int iw_start = max(iw - kw, 0); 311 const int id_end = min(id - kd + KD, ID); 312 const int ih_end = min(ih - kh + KH, IH); 313 const int iw_end = min(iw - kw + KW, IW); 314 const float num_summands = (ih_end - ih_start) 315 * (iw_end - iw_start) * (id_end - id_start); 316 D0 /= num_summands; 317 D1 /= num_summands; 318#endif 319 S0 += D0; 320 S1 += D1; 321 } 322 } 323 } 324#if ALG_AVG_P 325 S0 /= KD * KH * KW; 326 S1 /= KD * KH * KW; 327#endif 328 329 int src_off = SRC_OFF(mb, c, id, ih, iw); 330 write_vect_c_block(0, &diff_src[src_off], c, src_stride, 331 src_chunks_per_c_block, CONVERT_VECTOR_DATA_T(S0)); 332 write_vect_c_block(1, &diff_src[src_off], c, src_stride, 333 src_chunks_per_c_block, CONVERT_VECTOR_DATA_T(S1)); 334} 335#endif 336 337inline DATA_T read_c_block(const __global DATA_T *ptr, int c) { 338#if C_W_PADDING % SUB_GROUP_SIZE != 0 339 int local_id = get_sub_group_local_id(); 340 int tail = C_WO_PADDING - c; 341 return (local_id < tail) ? ptr[local_id] : 0; 342#else 343 return AS_DATA_T(BLOCK_READ((const __global BLOCK_DATA_T *)ptr)); 344#endif 345} 346 347#define CALC_VECT_LEN() \ 348 ({ \ 349 int size; \ 350 if (USE_ONLY_C_BLOCK == 1 \ 351 && VECT_DT_N > C_WO_PADDING / SUB_GROUP_SIZE + 1) \ 352 size = C_WO_PADDING / SUB_GROUP_SIZE + 1; \ 353 else \ 354 size = VECT_DT_N; \ 355 size; \ 356 }) 357 358inline VECT_DATA_T read_vect_c_block(int idx, const __global DATA_T *ptr, int c, 359 int blocks_stride, int chunks_per_block) { 360 if (idx >= NVECT) return 0; 361 362 if ((blocks_stride == chunks_per_block * SUB_GROUP_SIZE) 363 && (C_WO_PADDING % (chunks_per_block * SUB_GROUP_SIZE) == 0)) { 364 return AS_VECT_DATA_T(VECT_BLOCK_READ((const __global BLOCK_DATA_T *)ptr 365 + idx * VECT_DT_N * SUB_GROUP_SIZE)); 366 } else { 367 VECT_DATA_T ret; 368 for (int i = 0; i < CALC_VECT_LEN(); i++) { 369 const int offset_index = (idx * VECT_DT_N + i); 370 const int local_c_block_index = offset_index % chunks_per_block; 371 const int global_c_block_index = offset_index / chunks_per_block; 372 const int ptr_offset = local_c_block_index * SUB_GROUP_SIZE 373 + global_c_block_index * blocks_stride; 374 const int c_off 375 = (USE_ONLY_C_BLOCK ? offset_index * SUB_GROUP_SIZE 376 : local_c_block_index * SUB_GROUP_SIZE); 377#if VECT_DT_N == 1 378 ret = read_c_block(ptr + ptr_offset, c + c_off); 379#else 380 ret[i] = read_c_block(ptr + ptr_offset, c + c_off); 381#endif 382 } 383#if VECT_DT_N > 1 384 for (int i = CALC_VECT_LEN(); i < VECT_DT_N; ++i) { 385 ret[i] = 0; 386 } 387#endif 388 return ret; 389 } 390} 391 392inline int read_c_block_int(const __global int *ptr, int c) { 393#if C_W_PADDING % SUB_GROUP_SIZE != 0 394 int local_id = get_sub_group_local_id(); 395 int tail = C_WO_PADDING - c; 396 return (local_id < tail) ? ptr[local_id] : 0; 397#else 398 return as_int(intel_sub_group_block_read((const __global uint *)ptr)); 399#endif 400} 401 402inline VECT_INT_T read_vect_c_block_int(int idx, const __global int *ptr, int c, 403 int blocks_stride, int chunks_per_block) { 404 if (idx >= NVECT) return 0; 405 406 if ((blocks_stride == chunks_per_block * SUB_GROUP_SIZE) 407 && (C_WO_PADDING % (chunks_per_block * SUB_GROUP_SIZE) == 0)) { 408 return AS_VECT_INT_T(VECT_UINT_READ( 409 (const __global uint *)ptr + idx * VECT_DT_N * SUB_GROUP_SIZE)); 410 } else { 411 VECT_INT_T ret; 412 for (int i = 0; i < VECT_DT_N; i++) { 413 const int offset_index = (idx * VECT_DT_N + i); 414 const int local_c_block_index = offset_index % chunks_per_block; 415 const int global_c_block_index = offset_index / chunks_per_block; 416 const int ptr_offset = local_c_block_index * SUB_GROUP_SIZE 417 + global_c_block_index * blocks_stride; 418 const int c_off 419 = (USE_ONLY_C_BLOCK ? offset_index * SUB_GROUP_SIZE 420 : local_c_block_index * SUB_GROUP_SIZE); 421#if VECT_DT_N == 1 422 ret = read_c_block_int(ptr + ptr_offset, c + c_off); 423#else 424 ret[i] = read_c_block_int(ptr + ptr_offset, c + c_off); 425#endif 426 } 427 return ret; 428 } 429} 430 431inline void write_c_block(__global DATA_T *ptr, int c, DATA_T value) { 432#if C_W_PADDING % SUB_GROUP_SIZE != 0 433 int local_id = get_sub_group_local_id(); 434 int tail = C_WO_PADDING - c; 435 436 if (local_id < tail) ptr[local_id] = value; 437#else 438#if C_WO_PADDING % SUB_GROUP_SIZE != 0 439 int local_id = get_sub_group_local_id(); 440 if (local_id >= C_WO_PADDING - c && local_id < C_W_PADDING - c) value = 0; 441#endif 442 if (c >= C_WO_PADDING) { 443 BLOCK_WRITE((__global BLOCK_DATA_T *)ptr, 444 AS_BLOCK_DATA_T(CONVERT_DATA_T(DATA_ZERO))); 445 return; 446 } 447 BLOCK_WRITE((__global BLOCK_DATA_T *)ptr, AS_BLOCK_DATA_T(value)); 448#endif 449} 450 451inline void write_vect_c_block(int idx, __global DATA_T *ptr, int c, 452 int blocks_stride, int chunks_per_block, VECT_DATA_T block) { 453 if (idx >= NVECT) return; 454 455 if ((blocks_stride == chunks_per_block * SUB_GROUP_SIZE) 456 && (C_WO_PADDING % (chunks_per_block * SUB_GROUP_SIZE) == 0)) { 457 VECT_BLOCK_WRITE( 458 (__global BLOCK_DATA_T *)ptr + idx * VECT_DT_N * SUB_GROUP_SIZE, 459 AS_VECT_BLOCK_DATA_T(block)); 460 } else { 461 for (int i = 0; i < VECT_DT_N; i++) { 462 const int offset_index = (idx * VECT_DT_N + i); 463 const int local_c_block_index = offset_index % chunks_per_block; 464 const int global_c_block_index = offset_index / chunks_per_block; 465 const int ptr_offset = local_c_block_index * SUB_GROUP_SIZE 466 + global_c_block_index * blocks_stride; 467 const int c_off 468 = (USE_ONLY_C_BLOCK ? offset_index * SUB_GROUP_SIZE 469 : local_c_block_index * SUB_GROUP_SIZE); 470#if VECT_DT_N == 1 471 write_c_block(ptr + ptr_offset, c + c_off, block); 472#else 473 write_c_block(ptr + ptr_offset, c + c_off, block[i]); 474#endif 475 } 476 } 477} 478 479inline void write_c_block_int(__global int *ptr, int c, int value) { 480#if C_WO_PADDING % SUB_GROUP_SIZE != 0 481 int local_id = get_sub_group_local_id(); 482 int tail = C_WO_PADDING - c; 483 if (local_id < tail) 484 ptr[local_id] = value; 485 else if (local_id < C_W_PADDING - c) { 486 ptr[local_id] = 0; 487 } else 488 return; 489#else 490 if (c >= C_WO_PADDING) { 491 intel_sub_group_block_write((__global uint *)ptr, 0); 492 return; 493 } 494 intel_sub_group_block_write((__global uint *)ptr, as_uint(value)); 495#endif 496} 497 498inline void write_vect_c_block_int(int idx, __global int *ptr, int c, 499 int blocks_stride, int chunks_per_block, VECT_INT_T block) { 500 if (idx >= NVECT) return; 501 502 if ((blocks_stride == chunks_per_block * SUB_GROUP_SIZE) 503 && (C_WO_PADDING % (chunks_per_block * SUB_GROUP_SIZE) == 0)) { 504 VECT_UINT_WRITE((__global uint *)ptr + idx * VECT_DT_N * SUB_GROUP_SIZE, 505 AS_VECT_UINT_T(block)); 506 } else { 507 for (int i = 0; i < VECT_DT_N; i++) { 508 const int offset_index = (idx * VECT_DT_N + i); 509 const int local_c_block_index = offset_index % chunks_per_block; 510 const int global_c_block_index = offset_index / chunks_per_block; 511 const int ptr_offset = local_c_block_index * SUB_GROUP_SIZE 512 + global_c_block_index * blocks_stride; 513 const int c_off 514 = (USE_ONLY_C_BLOCK ? offset_index * SUB_GROUP_SIZE 515 : local_c_block_index * SUB_GROUP_SIZE); 516#if VECT_DT_N == 1 517 write_c_block_int(ptr + ptr_offset, c + c_off, block); 518#else 519 write_c_block_int(ptr + ptr_offset, c + c_off, block[i]); 520#endif 521 } 522 } 523} 524