1/******************************************************************************* 2* Copyright 2019-2021 Intel Corporation 3* 4* Licensed under the Apache License, Version 2.0 (the "License"); 5* you may not use this file except in compliance with the License. 6* You may obtain a copy of the License at 7* 8* http://www.apache.org/licenses/LICENSE-2.0 9* 10* Unless required by applicable law or agreed to in writing, software 11* distributed under the License is distributed on an "AS IS" BASIS, 12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13* See the License for the specific language governing permissions and 14* limitations under the License. 15*******************************************************************************/ 16 17#include "gpu/ocl/ocl_post_ops.h" 18#include "gpu/ocl/ocl_types.h" 19 20#if IS_DW != 1 21#error "Kernel supports depth-wise convolutions only" 22#endif 23 24#ifdef DST_DT_S8 25#if VER_32MB16C 26#define DST_MB_BLOCK MB_BLOCK 27#else // VER_32MB16C 28#define DST_MB_BLOCK (MB_BLOCK * 2) 29#endif // VER_32MB16C 30#define DST_OC_BLOCK (OC_BLOCK * 2) 31#endif // DST_DT_S8 32 33#define APPLY_POST_OPS_COMMON(nelems, accumulator, dest_data, mb_shift) \ 34 { \ 35 const int po_mb = mb_shift + mb; \ 36 const int po_oc = g; \ 37 int po_mb_count; \ 38 if (VER_16MB16C == 1) { \ 39 po_mb_count = nelems; \ 40 } else { \ 41 po_mb_count = 1; \ 42 } \ 43 APPLY_POST_OPS_TRY_BURST(accumulator, DATA_T, dest_data, DATA_T, \ 44 po_mb, po_mb_count, po_oc, SUB_GROUP_SIZE, \ 45 get_sub_group_local_id()); \ 46 } 47 48__attribute__((reqd_work_group_size(LWS_0, LWS_1, LWS_2))) // attr:no-format 49#if SUB_GROUP_SIZE != 1 50__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) // attr:no-format 51#endif 52__kernel void 53gen9_conv_dw_fwd(const __global DATA_T *src, const __global DATA_T *wei, 54 const __global DATA_T *bias, __global DST_DATA_T *dst POST_OP_ARGS) { 55 56 MAYBE_SKIP_NON_UNIFORM_WG(); 57 58#if VER_8OW16C 59 const int osp = get_global_id(1); 60 const int od = osp / (OWB * OH); 61 const int ohw = osp % (OWB * OH); 62 const int ow = (ohw % OWB) * OW_BLOCK; 63 const int oh = ohw / OWB; 64 const int g 65 = (get_group_id(0) * (LWS_0 / SUB_GROUP_SIZE) + get_sub_group_id()) 66 * OC_BLOCK; 67 const int mb = get_global_id(2) * MB_BLOCK; 68 69 const int id = od * SD - PD; 70 const int ih = oh * SH - PH; 71 const int iw = ow * SW - PW; 72#ifdef DST_DT_S8 // 32c dst 73 const int G_32block = G % 32 ? (32 + G - (G % 32)) : G; 74 dst += mb * G_32block * OD * OH * OW 75 + (g / 32 * 32) * OD * OH * OW * MB_BLOCK 76 + (od * OH * OW + oh * OW + ow) * MB_BLOCK * (DST_OC_BLOCK) 77 + (g % 32); 78#else 79 dst += mb * G * OD * OH * OW + g * OD * OH * OW * MB_BLOCK 80 + (od * OH * OW + oh * OW + ow) * MB_BLOCK * OC_BLOCK; 81#endif 82 src += mb 83 * ((G_WO_PADDING / IC_BLOCK) 84 + (G_WO_PADDING % IC_BLOCK > 0 ? 1 : 0)) 85 * IC_BLOCK * ID * IH * IW 86 + g * ID * IH * IW * MB_BLOCK 87 + (id * IH * IW + ih * IW + iw) * MB_BLOCK * IC_BLOCK; 88 wei += g * KD * KH * KW; 89 90 DATA_T S00[OW_BLOCK] = {DATA_ZERO}; 91 if (WITH_BIAS) { 92 const int bg_off = g + get_sub_group_local_id(); 93 DATA_T b = (G_WO_PADDING % OC_BLOCK == 0 || bg_off < G_WO_PADDING) 94 ? bias[bg_off] 95 : DATA_ZERO; 96 unroll_for(int k = 0; k < OW_BLOCK; k++) { S00[k] = b; } 97 } 98 99#if KH != 1 || KW != 1 || KD != 1 100 for (int kd = 0; kd < KD; kd++) 101 for (int kh = 0; kh < KH; kh++) { 102 if (id + kd * (1 + DD) < 0 || id + kd * (1 + DD) >= ID) continue; 103 if (ih + kh * (1 + DH) < 0 || ih + kh * (1 + DH) >= IH) continue; 104 105 const __global DATA_T *src1 = src 106 + (kd * (1 + DD) * IH + kh * (1 + DH)) * IW * MB_BLOCK 107 * IC_BLOCK; 108 DATA_T tempA[SW * OW_BLOCK + KW * (1 + DW)] = {0}; 109 __attribute__((opencl_unroll_hint( 110 SW * OW_BLOCK + KW * (1 + DW)))) // attr:no-format 111 for (int i = 0; i < SW * OW_BLOCK + KW * (1 + DW); i++) { 112 if ((i + iw) >= 0 && (i + iw) < IW) { 113 tempA[i] = AS_DATA_T(BLOCK_READ((const __global BLOCK_DATA_T 114 *)(&src1[i * IC_BLOCK]))); 115 } 116 } 117 for (int kw = 0; kw < KW; kw++) { 118 const __global DATA_T *wei1 119 = wei + (kd * KH * KW + kh * KW + kw) * OC_BLOCK; 120#else 121 const int kw = 0; 122 const __global DATA_T *wei1 = wei; 123 const __global DATA_T *src1 = src; 124#endif 125 DATA_T B0 = AS_DATA_T( 126 BLOCK_READ((const __global BLOCK_DATA_T *)(wei1))); 127 DATA_T A0; 128 129 __attribute__((opencl_unroll_hint(OW_BLOCK))) // attr:no-format 130 for (int k = 0; k < OW_BLOCK; k++) { 131 if (G != G_WO_PADDING && g >= G_WO_PADDING) { 132 S00[k] = DATA_ZERO; 133 continue; 134 } 135#if KH != 1 || KW != 1 || KD != 1 136 A0 = tempA[k * SW + kw * (1 + DW)]; 137#else 138 if (iw + kw * (1 + DW) + k * SW < 0 139 || iw + kw * (1 + DW) + k * SW >= IW) 140 A0 = DATA_ZERO; 141 else 142 A0 = AS_DATA_T(BLOCK_READ( 143 (const __global BLOCK_DATA_T *)(&src1[k * SW * IC_BLOCK]))); 144#endif 145 S00[k] = fma(A0, (DATA_T)B0, S00[k]); 146 } 147#if KH != 1 || KW != 1 || KD != 1 148 } 149 } 150#endif 151 152 DATA_T D00[OW_BLOCK] = {0}; 153#if WITH_SUM 154#ifdef DST_DT_S8 155 __attribute__((opencl_unroll_hint(OW_BLOCK))) // attr:no-format 156 for (int k = 0; k < OW_BLOCK; k++) { 157 D00[k] = CONVERT_DATA_T(BLOCK_READ_DST( 158 (const __global DST_DATA_T *)&dst[k * DST_OC_BLOCK])); 159 } 160#else 161 __attribute__((opencl_unroll_hint(OW_BLOCK))) // attr:no-format 162 for (int k = 0; k < OW_BLOCK; k++) { 163 D00[k] = AS_DATA_T( 164 BLOCK_READ((const __global BLOCK_DATA_T *)&dst[k * OC_BLOCK])); 165 } 166#endif 167#endif 168 169 APPLY_POST_OPS_COMMON(OW_BLOCK, S00, D00, 0); 170 171 if (OW % OW_BLOCK == 0 || ow + OW_BLOCK <= OW) { 172 __attribute__((opencl_unroll_hint)) // attr:no-format 173 for (int k = 0; k < OW_BLOCK; k++) { 174#ifdef DST_DT_S8 175 BLOCK_WRITE_DST((__global DST_DATA_T *)&dst[k * DST_OC_BLOCK], 176 CONVERT_DST_DATA_T(S00[k])); 177#else 178 BLOCK_WRITE((__global BLOCK_DATA_T *)&dst[k * OC_BLOCK], 179 AS_UINT_T(S00[k])); 180#endif 181 } 182 } else { 183 __attribute__((opencl_unroll_hint)) // attr:no-format 184 for (int k = 0; k < OW % OW_BLOCK; k++) { 185#ifdef DST_DT_S8 186 BLOCK_WRITE_DST((__global DST_DATA_T *)&dst[k * DST_OC_BLOCK], 187 CONVERT_DST_DATA_T(S00[k])); 188#else 189 BLOCK_WRITE((__global BLOCK_DATA_T *)&dst[k * OC_BLOCK], 190 AS_UINT_T(S00[k])); 191#endif 192 } 193 } 194#endif 195 196#if VER_16MB16C || VER_32MB16C 197 const int osp = get_global_id(1); 198 const int od = osp / (OWB * OH); 199 const int ohw = osp % (OWB * OH); 200 const int ow = (ohw % OWB) * OW_BLOCK; 201 const int oh = ohw / OWB; 202 const int g 203 = (get_group_id(0) * (LWS_0 / SUB_GROUP_SIZE) + get_sub_group_id()) 204 * OC_BLOCK; 205 const int mb = get_global_id(2) * MB_BLOCK; 206 207 const int id = od * SD - PD; 208 const int ih = oh * SH - PH; 209 const int iw = ow * SW - PW; 210 211#ifdef DST_DT_S8 //32n32c dst 212 const int G_32block = G % 32 ? (32 + G - (G % 32)) : G; 213 dst += (mb / DST_MB_BLOCK) * G_32block * OD * OH * OW * DST_MB_BLOCK 214 + (mb % DST_MB_BLOCK) * DST_OC_BLOCK 215 + (g / DST_OC_BLOCK) * OD * OH * OW * DST_MB_BLOCK * DST_OC_BLOCK 216 + (od * OH * OW + oh * OW + ow) * DST_MB_BLOCK * DST_OC_BLOCK 217 + (g % DST_OC_BLOCK); 218#else 219 dst += mb * G * OD * OH * OW + g * OD * OH * OW * MB_BLOCK 220 + (od * OH * OW + oh * OW + ow) * MB_BLOCK * OC_BLOCK; 221#endif 222 src += mb 223 * ((G_WO_PADDING / IC_BLOCK) 224 + (G_WO_PADDING % IC_BLOCK > 0 ? 1 : 0)) 225 * IC_BLOCK * ID * IH * IW 226 + g * ID * IH * IW * MB_BLOCK 227 + (id * IH * IW + ih * IW + iw) * MB_BLOCK * IC_BLOCK; 228 wei += g * KD * KH * KW; 229 230 DATA8_T S00 = DATA_ZERO; 231 DATA8_T S01 = DATA_ZERO; 232#if VER_32MB16C 233 DATA8_T S02 = DATA_ZERO; 234 DATA8_T S03 = DATA_ZERO; 235#endif 236 237 if (WITH_BIAS) { 238 const int bg_off = g + get_sub_group_local_id(); 239 DATA_T b = (G_WO_PADDING % OC_BLOCK == 0 || bg_off < G_WO_PADDING) 240 ? bias[bg_off] 241 : DATA_ZERO; 242 unroll_for(int k = 0; k < 8; k++) { 243 S00[k] = b; 244 S01[k] = b; 245#if VER_32MB16C 246 S02[k] = b; 247 S03[k] = b; 248#endif 249 } 250 } 251 252#if KH != 1 || KW != 1 || KD != 1 253 for (int kd = 0; kd < KD; kd++) 254 for (int kh = 0; kh < KH; kh++) 255 for (int kw = 0; kw < KW; kw++) { 256 if (id + kd * (1 + DD) < 0 || id + kd * (1 + DD) >= ID) 257 continue; 258 if (ih + kh * (1 + DH) < 0 || ih + kh * (1 + DH) >= IH) 259 continue; 260 if (iw + kw * (1 + DW) < 0 || iw + kw * (1 + DW) >= IW) 261 continue; 262 263 const __global DATA_T *wei1 264 = wei + (kd * KH * KW + kh * KW + kw) * OC_BLOCK; 265 const __global DATA_T *src1 = src 266 + (kd * (1 + DD) * IH * IW + kh * (1 + DH) * IW 267 + kw * (1 + DW)) 268 * MB_BLOCK * IC_BLOCK; 269#else 270 const __global DATA_T *wei1 = wei; 271 const __global DATA_T *src1 = src; 272#endif 273 if (G != G_WO_PADDING && g >= G_WO_PADDING) { 274 S00 = DATA_ZERO; 275 S01 = DATA_ZERO; 276#if VER_32MB16C 277 S02 = DATA_ZERO; 278 S03 = DATA_ZERO; 279#endif 280 continue; 281 } 282 DATA8_T A0 = AS_DATA8_T( 283 BLOCK_READ8((const __global BLOCK_DATA_T *)(src1))); 284 DATA8_T A1 = AS_DATA8_T(BLOCK_READ8( 285 (const __global BLOCK_DATA_T *)&src1[8 * IC_BLOCK])); 286#if VER_32MB16C 287 DATA8_T A2 = AS_DATA8_T(BLOCK_READ8( 288 (const __global BLOCK_DATA_T *)&src1[16 * IC_BLOCK])); 289 DATA8_T A3 = AS_DATA8_T(BLOCK_READ8( 290 (const __global BLOCK_DATA_T *)&src1[24 * IC_BLOCK])); 291#endif 292 DATA_T B0 = AS_DATA_T( 293 BLOCK_READ((const __global BLOCK_DATA_T *)(wei1))); 294 295 S00 = fma(A0, (DATA8_T)B0, S00); 296 S01 = fma(A1, (DATA8_T)B0, S01); 297#if VER_32MB16C 298 S02 = fma(A2, (DATA8_T)B0, S02); 299 S03 = fma(A3, (DATA8_T)B0, S03); 300#endif 301#if KH != 1 || KW != 1 || KD != 1 302 } 303#endif 304 305 DATA8_T D00; 306 DATA8_T D01; 307#if VER_32MB16C 308 DATA8_T D02; 309 DATA8_T D03; 310#endif 311#if WITH_SUM 312#ifdef DST_DT_S8 313 for (int i = 0; i < 8; ++i) { 314 D00[i] = CONVERT_DATA_T( 315 BLOCK_READ_DST((__global DST_DATA_T *)&dst[i * 32])); 316 D01[i] = CONVERT_DATA_T( 317 BLOCK_READ_DST((__global DST_DATA_T *)&dst[(i * 32) + 256])); 318#if VER_32MB16C 319 D02[i] = CONVERT_DATA_T( 320 BLOCK_READ_DST((__global DST_DATA_T *)&dst[i * 32] + 512)); 321 D03[i] = CONVERT_DATA_T( 322 BLOCK_READ_DST((__global DST_DATA_T *)&dst[(i * 32) + 768])); 323#endif 324 } 325#else 326 D00 = AS_DATA8_T(BLOCK_READ8((const __global BLOCK_DATA_T *)dst)); 327 D01 = AS_DATA8_T( 328 BLOCK_READ8((const __global BLOCK_DATA_T *)&dst[8 * OC_BLOCK])); 329#if VER_32MB16C 330 D02 = AS_DATA8_T( 331 BLOCK_READ8((const __global BLOCK_DATA_T *)&dst[16 * OC_BLOCK])); 332 D03 = AS_DATA8_T( 333 BLOCK_READ8((const __global BLOCK_DATA_T *)&dst[24 * OC_BLOCK])); 334#endif 335#endif 336#endif 337 338 APPLY_POST_OPS_COMMON(8, S00, D00, 0); 339 APPLY_POST_OPS_COMMON(8, S01, D01, 8); 340#if VER_32MB16C 341 APPLY_POST_OPS_COMMON(8, S02, D02, 16); 342 APPLY_POST_OPS_COMMON(8, S03, D03, 24); 343#endif 344 345#ifdef DST_DT_S8 346 for (int i = 0; i < 8; ++i) { 347 BLOCK_WRITE_DST((__global DST_DATA_T *)&dst[i * DST_OC_BLOCK], 348 CONVERT_DST_DATA_T(S00[i])); 349 BLOCK_WRITE_DST((__global DST_DATA_T *)&dst[(i + 8) * DST_OC_BLOCK], 350 CONVERT_DST_DATA_T(S01[i])); 351#if VER_32MB16C 352 BLOCK_WRITE_DST((__global DST_DATA_T *)&dst[(i + 16) * DST_OC_BLOCK], 353 CONVERT_DST_DATA_T(S02[i])); 354 BLOCK_WRITE_DST((__global DST_DATA_T *)&dst[(i + 24) * DST_OC_BLOCK], 355 CONVERT_DST_DATA_T(S03[i])); 356#endif 357 } 358#else 359 BLOCK_WRITE8((__global BLOCK_DATA_T *)&dst[0], AS_UINT8_T(S00)); 360 BLOCK_WRITE8((__global BLOCK_DATA_T *)&dst[8 * OC_BLOCK], AS_UINT8_T(S01)); 361#if VER_32MB16C 362 BLOCK_WRITE8((__global BLOCK_DATA_T *)&dst[16 * OC_BLOCK], AS_UINT8_T(S02)); 363 BLOCK_WRITE8((__global BLOCK_DATA_T *)&dst[24 * OC_BLOCK], AS_UINT8_T(S03)); 364#endif 365#endif 366 367#endif 368 return; 369} 370