1#ifdef MNN_SUPPORT_FP16 2#pragma OPENCL EXTENSION cl_khr_fp16 : enable 3#endif 4#define READ_INPUT_IMAGE(i, base) \ 5 int in_width_value##i = in_width##i + base; \ 6 in_width_value##i = \ 7 select(in_idx + in_width_value##i, -1, (in_width_value##i < 0 || in_width_value##i >= input_shape.y)); \ 8 in_w_idx = in_width_value##i % input_shape.y; \ 9 inp_offset = (((in_b_idx*in_channel_block_length + in_channel_block_idx)*input_shape.x + in_h_idx)* input_shape.y + in_w_idx)*4; \ 10 in##i = (in_width_value##i)==-1 ? (FLOAT4)0 : vload4(0, input+inp_offset); 11 12#define CALCULATE_OUTPUT(i) \ 13 out##i = mad(in##i.x, weights0, out##i); \ 14 out##i = mad(in##i.y, weights1, out##i); \ 15 out##i = mad(in##i.z, weights2, out##i); \ 16 out##i = mad(in##i.w, weights3, out##i); 17 18#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0, __private const int global_size_dim1, 19 20#define DEAL_NON_UNIFORM_DIM2(input1, input2) \ 21 if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { \ 22 return; \ 23 } 24 25__kernel 26void conv_2d_c4h1w1(GLOBAL_SIZE_2_DIMS 27 __global const FLOAT *input, 28 __global const FLOAT *weight, 29 __global const FLOAT *bias, 30 __global FLOAT *output, 31 __private const int2 in_hw, 32 __private const int inChannel, 33 __private const int in_c_blocks, 34 __private const int2 out_hw, 35 __private const int2 filter_hw, 36 __private const int2 stride_hw, 37 __private const int2 pad_hw, 38 __private const int2 dilate_hw, 39 __private const int out_w_blocks, 40 __private const int out_c_blocks) { 41 const int out_c_w_idx = get_global_id(0); //c/4 w 42 const int out_b_h_idx = get_global_id(1); //b h 43 44 DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx); 45 46 const int out_c_idx = out_c_w_idx / out_hw.y; 47 const int out_w_idx = out_c_w_idx % out_hw.y; 48 const int out_b_idx = out_b_h_idx / out_hw.x;//equal to in_b_idx 49 const int out_h_idx = out_b_h_idx % out_hw.x; 50 51 FLOAT4 out0 = vload4(out_c_idx, bias); 52 53 const int in_w_idx_base = mad24(out_w_idx, stride_hw.y, -pad_hw.y); 54 const int in_h_idx_base = mad24(out_h_idx, stride_hw.x, -pad_hw.x); 55 56 const int kw_start = select(0, (-in_w_idx_base + dilate_hw.y - 1) / dilate_hw.y, in_w_idx_base < 0); 57 const int kh_start = select(0, (-in_h_idx_base + dilate_hw.x - 1) / dilate_hw.x, in_h_idx_base < 0); 58 59 const int in_w_idx_start = mad24(kw_start, dilate_hw.y, in_w_idx_base); 60 const int in_w_idx_end = min(mad24(filter_hw.y, dilate_hw.y, in_w_idx_base), in_hw.y); 61 62 const int in_h_idx_start = mad24(kh_start, dilate_hw.x, in_h_idx_base); 63 const int in_h_idx_end = min(mad24(filter_hw.x, dilate_hw.x, in_h_idx_base), in_hw.x); 64 65 const int weight_oc_offset = out_c_blocks * filter_hw.x * filter_hw.y * 4; 66 for(ushort in_c_idx = 0; in_c_idx < (ushort)IN_C_BLOCK; in_c_idx++) { 67 //weights NC4HW4 [1, 4*icC4, ocC4*kh*kw, 1] xic4 68 //index: [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0] 69 int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + kh_start)*filter_hw.y + kw_start) * 4; 70 for(int iy = in_h_idx_start; iy < in_h_idx_end; iy += dilate_hw.x) { 71 for(int ix = in_w_idx_start; ix < in_w_idx_end; ix += dilate_hw.y) { 72 int inp_offset = (((out_b_idx * in_c_blocks + in_c_idx) * in_hw.x + iy) * in_hw.y + ix) * 4; 73 FLOAT4 in0 = vload4(0, input+inp_offset); 74 75 const int filter_w_inc = (ix-in_w_idx_start)/dilate_hw.y; 76 FLOAT4 weight0 = vload4(filter_w_inc, weight+weight_offset); 77 FLOAT4 weight1 = vload4(filter_w_inc, weight+weight_offset+weight_oc_offset); 78 FLOAT4 weight2 = vload4(filter_w_inc, weight+weight_offset+weight_oc_offset*2); 79 FLOAT4 weight3 = vload4(filter_w_inc, weight+weight_offset+weight_oc_offset*3); 80 81 out0 = mad(in0.x, weight0, out0); 82 out0 = mad(in0.y, weight1, out0); 83 out0 = mad(in0.z, weight2, out0); 84 out0 = mad(in0.w, weight3, out0); 85 86 } 87 weight_offset += 4*filter_hw.y; 88 } 89 } 90#ifdef RELU 91 out0 = fmax(out0, (FLOAT4)0); 92#endif 93 94#ifdef RELU6 95 out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6); 96#endif 97 98 const int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; 99 vstore4(out0, 0, output+out_offset); 100 101} 102 103 104__kernel 105void conv_2d_c8h1w1(GLOBAL_SIZE_2_DIMS 106 __global const FLOAT *input, 107 __global const FLOAT *weight, 108 __global const FLOAT *bias, 109 __global FLOAT *output, 110 __private const int2 in_hw, 111 __private const int inChannel, 112 __private const int in_c_blocks, 113 __private const int2 out_hw, 114 __private const int2 filter_hw, 115 __private const int2 stride_hw, 116 __private const int2 pad_hw, 117 __private const int2 dilate_hw, 118 __private const int out_w_blocks, 119 __private const int out_c_blocks) { 120 const int out_c_w_idx = get_global_id(0); //c/4 w 121 const int out_b_h_idx = get_global_id(1); //b h 122 123 DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx); 124 125 const int out_c_idx = (out_c_w_idx / out_hw.y) << 1; 126 const int out_w_idx = out_c_w_idx % out_hw.y; 127 const int out_b_idx = out_b_h_idx / out_hw.x;//equal to in_b_idx 128 const int out_h_idx = out_b_h_idx % out_hw.x; 129 130 FLOAT4 out0 = vload4(out_c_idx, bias); 131 FLOAT4 out1 = vload4(out_c_idx+1, bias); 132 133 const int in_w_idx_base = mad24(out_w_idx, stride_hw.y, -pad_hw.y); 134 const int in_h_idx_base = mad24(out_h_idx, stride_hw.x, -pad_hw.x); 135 136 const int kw_start = select(0, (-in_w_idx_base + dilate_hw.y - 1) / dilate_hw.y, in_w_idx_base < 0); 137 const int kh_start = select(0, (-in_h_idx_base + dilate_hw.x - 1) / dilate_hw.x, in_h_idx_base < 0); 138 139 const int in_w_idx_start = mad24(kw_start, dilate_hw.y, in_w_idx_base); 140 const int in_w_idx_end = min(mad24(filter_hw.y, dilate_hw.y, in_w_idx_base), in_hw.y); 141 142 const int in_h_idx_start = mad24(kh_start, dilate_hw.x, in_h_idx_base); 143 const int in_h_idx_end = min(mad24(filter_hw.x, dilate_hw.x, in_h_idx_base), in_hw.x); 144 145 const int weight_oc_offset = filter_hw.x * filter_hw.y * 4; 146 const int weight_ic_offset = out_c_blocks * weight_oc_offset; 147 for(ushort in_c_idx = 0; in_c_idx < (ushort)IN_C_BLOCK; in_c_idx++) { 148 //weights NC4HW4 [1, 4*icC4, ocC4*kh*kw, 1] xic4 149 //index: [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0] 150 int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + kh_start)*filter_hw.y + kw_start) * 4; 151 for(int iy = in_h_idx_start; iy < in_h_idx_end; iy += dilate_hw.x) { 152 for(int ix = in_w_idx_start; ix < in_w_idx_end; ix += dilate_hw.y) { 153 int inp_offset = (((out_b_idx * in_c_blocks + in_c_idx) * in_hw.x + iy) * in_hw.y + ix) * 4; 154 FLOAT4 in0 = vload4(0, input+inp_offset); 155 156 const int filter_w_inc = (ix-in_w_idx_start)/dilate_hw.y; 157 FLOAT4 weight0 = vload4(filter_w_inc, weight+weight_offset); 158 FLOAT4 weight1 = vload4(filter_w_inc, weight+weight_offset+weight_ic_offset); 159 FLOAT4 weight2 = vload4(filter_w_inc, weight+weight_offset+weight_ic_offset*2); 160 FLOAT4 weight3 = vload4(filter_w_inc, weight+weight_offset+weight_ic_offset*3); 161 162 out0 = mad(in0.x, weight0, out0); 163 out0 = mad(in0.y, weight1, out0); 164 out0 = mad(in0.z, weight2, out0); 165 out0 = mad(in0.w, weight3, out0); 166 167 weight0 = vload4(filter_w_inc, weight+weight_offset+weight_oc_offset); 168 weight1 = vload4(filter_w_inc, weight+weight_offset+weight_oc_offset+weight_ic_offset); 169 weight2 = vload4(filter_w_inc, weight+weight_offset+weight_oc_offset+weight_ic_offset*2); 170 weight3 = vload4(filter_w_inc, weight+weight_offset+weight_oc_offset+weight_ic_offset*3); 171 172 out1 = mad(in0.x, weight0, out1); 173 out1 = mad(in0.y, weight1, out1); 174 out1 = mad(in0.z, weight2, out1); 175 out1 = mad(in0.w, weight3, out1); 176 } 177 weight_offset += 4*filter_hw.y; 178 } 179 } 180#ifdef RELU 181 out0 = fmax(out0, (FLOAT4)0); 182 out1 = fmax(out1, (FLOAT4)0); 183#endif 184 185#ifdef RELU6 186 out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6); 187 out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6); 188#endif 189 190 const int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; 191 vstore4(out0, 0, output+out_offset); 192 if(out_c_idx+1 >= out_c_blocks) return; 193 vstore4(out1, 0, output+out_offset+out_hw.x*out_hw.y*4); 194 195} 196 197__kernel 198void conv_2d_c4h1w2(GLOBAL_SIZE_2_DIMS 199 __global const FLOAT *input, 200 __global const FLOAT *weight, 201 __global const FLOAT *bias, 202 __global FLOAT *output, 203 __private const int2 in_hw, 204 __private const int inChannel, 205 __private const int in_c_blocks, 206 __private const int2 out_hw, 207 __private const int2 filter_hw, 208 __private const int2 stride_hw, 209 __private const int2 pad_hw, 210 __private const int2 dilate_hw, 211 __private const int out_w_blocks,//generate width's num 212 __private const int out_c_blocks) { 213 const int out_c_w_idx = get_global_id(0); //c/4 w 214 const int out_b_h_idx = get_global_id(1); //b h 215 216 DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx); 217 218 const int out_c_idx = out_c_w_idx / out_w_blocks; 219 const int out_w_idx = (out_c_w_idx % out_w_blocks) << 1; 220 const int out_b_idx = out_b_h_idx / out_hw.x;//equal to in_b_idx 221 const int out_h_idx = out_b_h_idx % out_hw.x; 222 223 FLOAT4 out0 = vload4(out_c_idx, bias); 224 FLOAT4 out1 = out0; 225 226 const int in_w0_idx_base = mad24(out_w_idx, stride_hw.y, -pad_hw.y); 227 const int in_w1_idx_base = in_w0_idx_base + stride_hw.y; 228 229 const int in_h_idx_base = mad24(out_h_idx, stride_hw.x, -pad_hw.x); 230 231 const int kh_start = select(0, (-in_h_idx_base + dilate_hw.x - 1) / dilate_hw.x, in_h_idx_base < 0); 232 const int in_h_idx_start = mad24(kh_start, dilate_hw.x, in_h_idx_base); 233 const int in_h_idx_end = min(mad24(filter_hw.x, dilate_hw.x, in_h_idx_base), in_hw.x); 234 235 const int weight_oc_offset = out_c_blocks * filter_hw.x * filter_hw.y * 4; 236 for(ushort in_c_idx = 0; in_c_idx < (ushort)IN_C_BLOCK; in_c_idx++) { 237 //weights NC4HW4 [1, 4*icC4, ocC4*kh*kw, 1] xic4 238 //index: [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0] 239 int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + kh_start)*filter_hw.y + 0) * 4; 240 241 for(int iy = in_h_idx_start; iy < in_h_idx_end; iy += dilate_hw.x) { 242 const int inp_offset_base = (((out_b_idx * in_c_blocks + in_c_idx) * in_hw.x + iy) * in_hw.y + 0) * 4; 243 244 for(int fw = 0; fw < filter_hw.y; fw++) { 245 const int in_w0_idx = fw * dilate_hw.y + in_w0_idx_base; 246 const int in_w1_idx = fw * dilate_hw.y + in_w1_idx_base; 247 248 FLOAT4 in0 = (in_w0_idx < 0 || in_w0_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w0_idx, input+inp_offset_base); 249 FLOAT4 in1 = (in_w1_idx < 0 || in_w1_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w1_idx, input+inp_offset_base); 250 251 FLOAT4 weight0 = vload4(0, weight+weight_offset); 252 FLOAT4 weight1 = vload4(0, weight+weight_offset+weight_oc_offset); 253 FLOAT4 weight2 = vload4(0, weight+weight_offset+weight_oc_offset*2); 254 FLOAT4 weight3 = vload4(0, weight+weight_offset+weight_oc_offset*3); 255 256 out0 = mad(in0.x, weight0, out0); 257 out0 = mad(in0.y, weight1, out0); 258 out0 = mad(in0.z, weight2, out0); 259 out0 = mad(in0.w, weight3, out0); 260 261 out1 = mad(in1.x, weight0, out1); 262 out1 = mad(in1.y, weight1, out1); 263 out1 = mad(in1.z, weight2, out1); 264 out1 = mad(in1.w, weight3, out1); 265 266 weight_offset += 4; 267 } 268 } 269 } 270#ifdef RELU 271 out0 = fmax(out0, (FLOAT4)0); 272 out1 = fmax(out1, (FLOAT4)0); 273#endif 274 275#ifdef RELU6 276 out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6); 277 out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6); 278#endif 279 280 const int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; 281 vstore4(out0, 0, output+out_offset); 282 if(out_w_idx + 1 >= out_hw.y) return; 283 vstore4(out1, 1, output+out_offset); 284} 285 286__kernel 287void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS 288 __global const FLOAT *input, 289 __global const FLOAT *weight, 290 __global const FLOAT *bias, 291 __global FLOAT *output, 292 __private const int2 in_hw, 293 __private const int inChannel, 294 __private const int in_c_blocks, 295 __private const int2 out_hw, 296 __private const int2 filter_hw, 297 __private const int2 stride_hw, 298 __private const int2 pad_hw, 299 __private const int2 dilate_hw, 300 __private const int out_w_blocks, 301 __private const int out_c_blocks) { 302 const int out_c_w_idx = get_global_id(0); //c/4 w 303 const int out_b_h_idx = get_global_id(1); //b h 304 305 DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx); 306 307 const int out_c_idx = out_c_w_idx / out_w_blocks; 308 const int out_w_idx = (out_c_w_idx % out_w_blocks) << 2; 309 const int out_b_idx = out_b_h_idx / out_hw.x;//equal to in_b_idx 310 const int out_h_idx = out_b_h_idx % out_hw.x; 311 312 FLOAT4 out0 = vload4(out_c_idx, bias); 313 FLOAT4 out1 = out0; 314 FLOAT4 out2 = out0; 315 FLOAT4 out3 = out0; 316 317 const int in_w0_idx_base = mad24(out_w_idx, stride_hw.y, -pad_hw.y); 318 const int in_w1_idx_base = in_w0_idx_base + stride_hw.y; 319 const int in_w2_idx_base = in_w1_idx_base + stride_hw.y; 320 const int in_w3_idx_base = in_w2_idx_base + stride_hw.y; 321 322 const int in_h_idx_base = mad24(out_h_idx, stride_hw.x, -pad_hw.x); 323 324 const int kh_start = select(0, (-in_h_idx_base + dilate_hw.x - 1) / dilate_hw.x, in_h_idx_base < 0); 325 const int in_h_idx_start = mad24(kh_start, dilate_hw.x, in_h_idx_base); 326 const int in_h_idx_end = min(mad24(filter_hw.x, dilate_hw.x, in_h_idx_base), in_hw.x); 327 328 const int weight_oc_offset = out_c_blocks * filter_hw.x * filter_hw.y * 4; 329 for(ushort in_c_idx = 0; in_c_idx < (ushort)IN_C_BLOCK; in_c_idx++) { 330 //weights NC4HW4 [1, 4*icC4, ocC4*kh*kw, 1] xic4 331 //index: [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0] 332 int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + kh_start)*filter_hw.y + 0) * 4; 333 334 for(int iy = in_h_idx_start; iy < in_h_idx_end; iy += dilate_hw.x) { 335 const int inp_offset_base = (((out_b_idx * in_c_blocks + in_c_idx) * in_hw.x + iy) * in_hw.y + 0) * 4; 336 337 for(int fw = 0; fw < filter_hw.y; fw++) { 338 const int in_w0_idx = fw * dilate_hw.y + in_w0_idx_base; 339 const int in_w1_idx = fw * dilate_hw.y + in_w1_idx_base; 340 const int in_w2_idx = fw * dilate_hw.y + in_w2_idx_base; 341 const int in_w3_idx = fw * dilate_hw.y + in_w3_idx_base; 342 343 FLOAT4 in0 = (in_w0_idx < 0 || in_w0_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w0_idx, input+inp_offset_base); 344 FLOAT4 in1 = (in_w1_idx < 0 || in_w1_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w1_idx, input+inp_offset_base); 345 FLOAT4 in2 = (in_w2_idx < 0 || in_w2_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w2_idx, input+inp_offset_base); 346 FLOAT4 in3 = (in_w3_idx < 0 || in_w3_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w3_idx, input+inp_offset_base); 347 348 FLOAT4 weight0 = vload4(0, weight+weight_offset); 349 FLOAT4 weight1 = vload4(0, weight+weight_offset+weight_oc_offset); 350 FLOAT4 weight2 = vload4(0, weight+weight_offset+weight_oc_offset*2); 351 FLOAT4 weight3 = vload4(0, weight+weight_offset+weight_oc_offset*3); 352 353 out0 = mad(in0.x, weight0, out0); 354 out0 = mad(in0.y, weight1, out0); 355 out0 = mad(in0.z, weight2, out0); 356 out0 = mad(in0.w, weight3, out0); 357 358 out1 = mad(in1.x, weight0, out1); 359 out1 = mad(in1.y, weight1, out1); 360 out1 = mad(in1.z, weight2, out1); 361 out1 = mad(in1.w, weight3, out1); 362 363 out2 = mad(in2.x, weight0, out2); 364 out2 = mad(in2.y, weight1, out2); 365 out2 = mad(in2.z, weight2, out2); 366 out2 = mad(in2.w, weight3, out2); 367 368 out3 = mad(in3.x, weight0, out3); 369 out3 = mad(in3.y, weight1, out3); 370 out3 = mad(in3.z, weight2, out3); 371 out3 = mad(in3.w, weight3, out3); 372 373 weight_offset += 4; 374 } 375 } 376 } 377#ifdef RELU 378 out0 = fmax(out0, (FLOAT4)0); 379 out1 = fmax(out1, (FLOAT4)0); 380 out2 = fmax(out2, (FLOAT4)0); 381 out3 = fmax(out3, (FLOAT4)0); 382#endif 383 384#ifdef RELU6 385 out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6); 386 out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6); 387 out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6); 388 out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6); 389#endif 390 391 const int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; 392 vstore4(out0, 0, output+out_offset); 393 if(out_w_idx + 1 >= out_hw.y) return; 394 vstore4(out1, 1, output+out_offset); 395 if(out_w_idx + 2 >= out_hw.y) return; 396 vstore4(out2, 2, output+out_offset); 397 if(out_w_idx + 3 >= out_hw.y) return; 398 vstore4(out3, 3, output+out_offset); 399} 400 401__kernel 402void conv_2d_1x1_c4h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, 403 __global const FLOAT *input, 404 __global const FLOAT *kernel_ptr, 405 __global const FLOAT *bias_ptr, 406 __global FLOAT *output, 407 __private const int in_c_block, 408 __private const int out_h, 409 __private const int out_w, 410 __private const int out_c_block) { 411 412 const int out_c_w_idx = get_global_id(0); //c/4 w 413 const int out_b_h_idx = get_global_id(1); //b h 414 415 DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx); 416 417 const int out_c_idx = out_c_w_idx / out_w_blocks; 418 const int out_w_idx = out_c_w_idx % out_w_blocks; 419 const int out_b_idx = out_b_h_idx / out_h;//equal to in_b_idx 420 const int out_h_idx = out_b_h_idx % out_h;//equal to in_h_idx 421 422 const int out_w4_idx = mul24(out_w_idx, 4); 423 FLOAT4 out0 = vload4(out_c_idx, (__global FLOAT *)bias_ptr); 424 425 FLOAT4 out1 = out0; 426 FLOAT4 out2 = out0; 427 FLOAT4 out3 = out0; 428 429 const int intput_width_idx0 = out_w4_idx; 430 431 int offset = mul24(out_c_idx, in_c_block) << 2; 432 int inp_offset = 433 (((out_b_idx*in_c_block)*out_h + out_h_idx)* out_w + intput_width_idx0) << 2; 434 435 const inp_add = out_h*out_w*4; 436 for (ushort in_channel_block_idx = 0; in_channel_block_idx < (ushort)IN_C_BLOCK; ++in_channel_block_idx) { 437 438 FLOAT4 in0 = vload4(0, input+inp_offset); 439 FLOAT4 in1 = vload4(1, input+inp_offset);; 440 FLOAT4 in2 = vload4(2, input+inp_offset);; 441 FLOAT4 in3 = vload4(3, input+inp_offset);; 442 443 FLOAT4 weights0 = vload4(offset, (__global FLOAT *)kernel_ptr); 444 FLOAT4 weights1 = vload4(offset + 1, (__global FLOAT *)kernel_ptr); 445 FLOAT4 weights2 = vload4(offset + 2, (__global FLOAT *)kernel_ptr); 446 FLOAT4 weights3 = vload4(offset + 3, (__global FLOAT *)kernel_ptr); 447 448 out0.x += dot(weights0, in0); 449 out0.y += dot(weights1, in0); 450 out0.z += dot(weights2, in0); 451 out0.w += dot(weights3, in0); 452 453 out1.x += dot(weights0, in1); 454 out1.y += dot(weights1, in1); 455 out1.z += dot(weights2, in1); 456 out1.w += dot(weights3, in1); 457 458 out2.x += dot(weights0, in2); 459 out2.y += dot(weights1, in2); 460 out2.z += dot(weights2, in2); 461 out2.w += dot(weights3, in2); 462 463 out3.x += dot(weights0, in3); 464 out3.y += dot(weights1, in3); 465 out3.z += dot(weights2, in3); 466 out3.w += dot(weights3, in3); 467 468 offset += 4; 469 inp_offset += inp_add; 470 } 471 472#ifdef RELU 473 out0 = fmax(out0, (FLOAT4)0); 474 out1 = fmax(out1, (FLOAT4)0); 475 out2 = fmax(out2, (FLOAT4)0); 476 out3 = fmax(out3, (FLOAT4)0); 477#endif 478 479#ifdef RELU6 480 out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6); 481 out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6); 482 out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6); 483 out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6); 484#endif 485 486 const int out_offset = (((out_b_idx*out_c_block + out_c_idx)*out_h + out_h_idx)* out_w + out_w4_idx)*4; 487 488 const int remain = out_w - out_w4_idx; 489 490 if (remain >= 4) { 491 vstore4(out0, 0, output+out_offset); 492 vstore4(out1, 1, output+out_offset); 493 vstore4(out2, 2, output+out_offset); 494 vstore4(out3, 3, output+out_offset); 495 } else if (remain == 3) { 496 vstore4(out0, 0, output+out_offset); 497 vstore4(out1, 1, output+out_offset); 498 vstore4(out2, 2, output+out_offset); 499 } else if (remain == 2) { 500 vstore4(out0, 0, output+out_offset); 501 vstore4(out1, 1, output+out_offset); 502 } else if (remain == 1) { 503 vstore4(out0, 0, output+out_offset); 504 } 505 506} 507 508 509 510__kernel 511void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, 512 __global const FLOAT *input, 513 __global const FLOAT *kernel_ptr, 514 __global const FLOAT *bias_ptr, 515 __global FLOAT *output, 516 __private const int in_c_block, 517 __private const int out_h, 518 __private const int out_w, 519 __private const int out_c_block) { 520 521 const int out_c_w_idx = get_global_id(0); //c/8 w/4 522 const int out_b_h_idx = get_global_id(1); //b h 523 524 DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx); 525 526 const int out_c_idx = out_c_w_idx / out_w_blocks; 527 const int out_w_idx = out_c_w_idx % out_w_blocks; 528 const int out_b_idx = out_b_h_idx / out_h;//equal to in_b_idx 529 const int out_h_idx = out_b_h_idx % out_h;//equal to in_h_idx 530 531 const int out_w4_idx = mul24(out_w_idx, 4); 532 FLOAT4 out0 = vload4(out_c_idx<<1, (__global FLOAT *)bias_ptr); 533 FLOAT4 out1 = out0; 534 FLOAT4 out2 = out0; 535 FLOAT4 out3 = out0; 536 537 FLOAT4 out4 = vload4((out_c_idx<<1)+1, (__global FLOAT *)bias_ptr); 538 FLOAT4 out5 = out4; 539 FLOAT4 out6 = out4; 540 FLOAT4 out7 = out4; 541 542 const int intput_width_idx0 = out_w4_idx; 543 544 for (int in_channel_block_idx = 0; in_channel_block_idx < IN_C_BLOCK; ++in_channel_block_idx) { 545 int input_width_base = mul24(in_channel_block_idx, out_w); 546 547 int offset = mad24(out_c_idx, in_c_block, in_channel_block_idx)*8; 548 const int inp_offset = 549 (((out_b_idx*in_c_block + in_channel_block_idx)*out_h + out_h_idx)* out_w + intput_width_idx0)*4; 550 551 FLOAT4 in0 = vload4(0, input+inp_offset); 552 FLOAT4 in1 = vload4(1, input+inp_offset);; 553 FLOAT4 in2 = vload4(2, input+inp_offset);; 554 FLOAT4 in3 = vload4(3, input+inp_offset);; 555 556 FLOAT4 weights0 = vload4(offset, (__global FLOAT *)kernel_ptr); 557 FLOAT4 weights1 = vload4(offset + 1, (__global FLOAT *)kernel_ptr); 558 FLOAT4 weights2 = vload4(offset + 2, (__global FLOAT *)kernel_ptr); 559 FLOAT4 weights3 = vload4(offset + 3, (__global FLOAT *)kernel_ptr); 560 FLOAT4 weights4 = vload4(offset + 4, (__global FLOAT *)kernel_ptr); 561 FLOAT4 weights5 = vload4(offset + 5, (__global FLOAT *)kernel_ptr); 562 FLOAT4 weights6 = vload4(offset + 6, (__global FLOAT *)kernel_ptr); 563 FLOAT4 weights7 = vload4(offset + 7, (__global FLOAT *)kernel_ptr); 564 565 out0.x += dot(weights0, in0); 566 out0.y += dot(weights1, in0); 567 out0.z += dot(weights2, in0); 568 out0.w += dot(weights3, in0); 569 570 out1.x += dot(weights0, in1); 571 out1.y += dot(weights1, in1); 572 out1.z += dot(weights2, in1); 573 out1.w += dot(weights3, in1); 574 575 out2.x += dot(weights0, in2); 576 out2.y += dot(weights1, in2); 577 out2.z += dot(weights2, in2); 578 out2.w += dot(weights3, in2); 579 580 out3.x += dot(weights0, in3); 581 out3.y += dot(weights1, in3); 582 out3.z += dot(weights2, in3); 583 out3.w += dot(weights3, in3); 584 585 out4.x += dot(weights4, in0); 586 out4.y += dot(weights5, in0); 587 out4.z += dot(weights6, in0); 588 out4.w += dot(weights7, in0); 589 590 out5.x += dot(weights4, in1); 591 out5.y += dot(weights5, in1); 592 out5.z += dot(weights6, in1); 593 out5.w += dot(weights7, in1); 594 595 out6.x += dot(weights4, in2); 596 out6.y += dot(weights5, in2); 597 out6.z += dot(weights6, in2); 598 out6.w += dot(weights7, in2); 599 600 out7.x += dot(weights4, in3); 601 out7.y += dot(weights5, in3); 602 out7.z += dot(weights6, in3); 603 out7.w += dot(weights7, in3); 604 605 } 606 607#ifdef RELU 608 out0 = fmax(out0, (FLOAT4)0); 609 out1 = fmax(out1, (FLOAT4)0); 610 out2 = fmax(out2, (FLOAT4)0); 611 out3 = fmax(out3, (FLOAT4)0); 612 613 out4 = fmax(out4, (FLOAT4)0); 614 out5 = fmax(out5, (FLOAT4)0); 615 out6 = fmax(out6, (FLOAT4)0); 616 out7 = fmax(out7, (FLOAT4)0); 617#endif 618 619#ifdef RELU6 620 out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6); 621 out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6); 622 out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6); 623 out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6); 624 625 out4 = clamp(out4, (FLOAT4)0, (FLOAT4)6); 626 out5 = clamp(out5, (FLOAT4)0, (FLOAT4)6); 627 out6 = clamp(out6, (FLOAT4)0, (FLOAT4)6); 628 out7 = clamp(out7, (FLOAT4)0, (FLOAT4)6); 629#endif 630 631 const int out_offset = (((out_b_idx*out_c_block + out_c_idx*2)*out_h + out_h_idx)* out_w + out_w4_idx)*4; 632 633 const int remain = out_w - out_w4_idx; 634 635 __global FLOAT* _tempoutput = output + out_offset; 636 __global FLOAT* _tempoutput1 = _tempoutput + 4*out_h*out_w; 637 638 if (remain >= 4) { 639 vstore4(out0, 0, _tempoutput); 640 vstore4(out1, 1, _tempoutput); 641 vstore4(out2, 2, _tempoutput); 642 vstore4(out3, 3, _tempoutput); 643 } else if (remain == 3) { 644 vstore4(out0, 0, _tempoutput); 645 vstore4(out1, 1, _tempoutput); 646 vstore4(out2, 2, _tempoutput); 647 } else if (remain == 2) { 648 vstore4(out0, 0, _tempoutput); 649 vstore4(out1, 1, _tempoutput); 650 } else if (remain == 1) { 651 vstore4(out0, 0, _tempoutput); 652 } 653 if(out_c_idx*2+1 >= out_c_block) { 654 return; 655 } 656 if (remain >= 4) { 657 vstore4(out4, 0, _tempoutput1); 658 vstore4(out5, 1, _tempoutput1); 659 vstore4(out6, 2, _tempoutput1); 660 vstore4(out7, 3, _tempoutput1); 661 } else if (remain == 3) { 662 vstore4(out4, 0, _tempoutput1); 663 vstore4(out5, 1, _tempoutput1); 664 vstore4(out6, 2, _tempoutput1); 665 } else if (remain == 2) { 666 vstore4(out4, 0, _tempoutput1); 667 vstore4(out5, 1, _tempoutput1); 668 } else if (remain == 1) { 669 vstore4(out4, 0, _tempoutput1); 670 } 671} 672 673 674__kernel 675void conv_2d_1x1_c8h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, 676 __global const FLOAT *input, 677 __global const FLOAT *kernel_ptr, 678 __global const FLOAT *bias_ptr, 679 __global FLOAT *output, 680 __private const int in_c_block, 681 __private const int out_h, 682 __private const int out_w, 683 __private const int out_c_block) { 684 685 const int out_c_w_idx = get_global_id(0); //c/8 w/4 686 const int out_b_h_idx = get_global_id(1); //b h 687 688 DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx); 689 690 const int out_c_idx = out_c_w_idx / out_w_blocks; 691 const int out_w_idx = out_c_w_idx % out_w_blocks; 692 const int out_b_idx = out_b_h_idx / out_h;//equal to in_b_idx 693 const int out_h_idx = out_b_h_idx % out_h;//equal to in_h_idx 694 695 const int out_w2_idx = mul24(out_w_idx, 2); 696 FLOAT4 out0 = vload4(out_c_idx<<1, (__global FLOAT *)bias_ptr); 697 FLOAT4 out1 = out0; 698 699 FLOAT4 out4 = vload4((out_c_idx<<1)+1, (__global FLOAT *)bias_ptr); 700 FLOAT4 out5 = out4; 701 702 const int intput_width_idx0 = out_w2_idx; 703 704 for (int in_channel_block_idx = 0; in_channel_block_idx < IN_C_BLOCK; ++in_channel_block_idx) { 705 int input_width_base = mul24(in_channel_block_idx, out_w); 706 707 int offset = mad24(out_c_idx, in_c_block, in_channel_block_idx)*8; 708 const int inp_offset = 709 (((out_b_idx*in_c_block + in_channel_block_idx)*out_h + out_h_idx)* out_w + intput_width_idx0)*4; 710 711 FLOAT4 in0 = vload4(0, input+inp_offset); 712 FLOAT4 in1 = vload4(1, input+inp_offset);; 713 714 FLOAT4 weights0 = vload4(offset, (__global FLOAT *)kernel_ptr); 715 FLOAT4 weights1 = vload4(offset + 1, (__global FLOAT *)kernel_ptr); 716 FLOAT4 weights2 = vload4(offset + 2, (__global FLOAT *)kernel_ptr); 717 FLOAT4 weights3 = vload4(offset + 3, (__global FLOAT *)kernel_ptr); 718 FLOAT4 weights4 = vload4(offset + 4, (__global FLOAT *)kernel_ptr); 719 FLOAT4 weights5 = vload4(offset + 5, (__global FLOAT *)kernel_ptr); 720 FLOAT4 weights6 = vload4(offset + 6, (__global FLOAT *)kernel_ptr); 721 FLOAT4 weights7 = vload4(offset + 7, (__global FLOAT *)kernel_ptr); 722 723 out0.x += dot(weights0, in0); 724 out0.y += dot(weights1, in0); 725 out0.z += dot(weights2, in0); 726 out0.w += dot(weights3, in0); 727 728 out1.x += dot(weights0, in1); 729 out1.y += dot(weights1, in1); 730 out1.z += dot(weights2, in1); 731 out1.w += dot(weights3, in1); 732 733 out4.x += dot(weights4, in0); 734 out4.y += dot(weights5, in0); 735 out4.z += dot(weights6, in0); 736 out4.w += dot(weights7, in0); 737 738 out5.x += dot(weights4, in1); 739 out5.y += dot(weights5, in1); 740 out5.z += dot(weights6, in1); 741 out5.w += dot(weights7, in1); 742 } 743 744#ifdef RELU 745 out0 = fmax(out0, (FLOAT4)0); 746 out1 = fmax(out1, (FLOAT4)0); 747 748 out4 = fmax(out4, (FLOAT4)0); 749 out5 = fmax(out5, (FLOAT4)0); 750#endif 751 752#ifdef RELU6 753 out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6); 754 out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6); 755 756 out4 = clamp(out4, (FLOAT4)0, (FLOAT4)6); 757 out5 = clamp(out5, (FLOAT4)0, (FLOAT4)6); 758#endif 759 760 const int out_offset = (((out_b_idx*out_c_block + out_c_idx*2)*out_h + out_h_idx)* out_w + out_w2_idx)*4; 761 762 const int remain = out_w - out_w2_idx; 763 764 __global FLOAT* _tempoutput = output + out_offset; 765 __global FLOAT* _tempoutput1 = _tempoutput + 4*out_h*out_w; 766 767 if (remain >= 2) { 768 vstore4(out0, 0, _tempoutput); 769 vstore4(out1, 1, _tempoutput); 770 } else if (remain == 1) { 771 vstore4(out0, 0, _tempoutput); 772 } 773 if(out_c_idx*2+1 >= out_c_block) { 774 return; 775 } 776 if (remain >= 2) { 777 vstore4(out4, 0, _tempoutput1); 778 vstore4(out5, 1, _tempoutput1); 779 } else if (remain == 1) { 780 vstore4(out4, 0, _tempoutput1); 781 } 782} 783 784__kernel 785void conv_2d_1x1_c4h1w1(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, 786 __global const FLOAT *input, 787 __global const FLOAT *kernel_ptr, 788 __global const FLOAT *bias_ptr, 789 __global FLOAT *output, 790 __private const int in_c_block, 791 __private const int out_h, 792 __private const int out_w, 793 __private const int out_c_block) { 794 795 const int out_c_w_idx = get_global_id(0); //c/4 w 796 const int out_b_h_idx = get_global_id(1); //b h 797 798 DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx); 799 800 const int out_c_idx = out_c_w_idx / out_w; 801 const int out_w_idx = out_c_w_idx % out_w; 802 const int out_b_idx = out_b_h_idx / out_h;//equal to in_b_idx 803 const int out_h_idx = out_b_h_idx % out_h;//equal to in_h_idx 804 805 FLOAT4 out0 = vload4(out_c_idx, (__global FLOAT *)bias_ptr); 806 const int intput_width_idx0 = out_w_idx; 807 808 for (int in_channel_block_idx = 0; in_channel_block_idx < in_c_block; ++in_channel_block_idx) { 809 int input_width_base = mul24(in_channel_block_idx, out_w); 810 811 int offset = mad24(out_c_idx, in_c_block, in_channel_block_idx)*4; 812 const int inp_offset = 813 (((out_b_idx*in_c_block + in_channel_block_idx)*out_h + out_h_idx)* out_w + intput_width_idx0)*4; 814 815 FLOAT4 in0 = vload4(0, input+inp_offset); 816 817 FLOAT4 weights0 = vload4(offset, (__global FLOAT *)kernel_ptr); 818 FLOAT4 weights1 = vload4(offset + 1, (__global FLOAT *)kernel_ptr); 819 FLOAT4 weights2 = vload4(offset + 2, (__global FLOAT *)kernel_ptr); 820 FLOAT4 weights3 = vload4(offset + 3, (__global FLOAT *)kernel_ptr); 821 822 out0.x += dot(weights0, in0); 823 out0.y += dot(weights1, in0); 824 out0.z += dot(weights2, in0); 825 out0.w += dot(weights3, in0); 826 } 827 828#ifdef RELU 829 out0 = fmax(out0, (FLOAT4)0); 830#endif 831 832#ifdef RELU6 833 out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6); 834#endif 835 836 const int out_offset = (((out_b_idx*out_c_block + out_c_idx)*out_h + out_h_idx)* out_w + out_w_idx)*4; 837 838 vstore4(out0, 0, output+out_offset); 839} 840 841 842__kernel 843void conv_2d_1x1_c4h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, 844 __global const FLOAT *input, 845 __global const FLOAT *kernel_ptr, 846 __global const FLOAT *bias_ptr, 847 __global FLOAT *output, 848 __private const int in_c_block, 849 __private const int out_h, 850 __private const int out_w, 851 __private const int out_c_block) { 852 853 const int out_c_w_idx = get_global_id(0); //c/4 w 854 const int out_b_h_idx = get_global_id(1); //b h 855 856 DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx); 857 858 const int out_c_idx = out_c_w_idx / out_w_blocks; 859 const int out_w_idx = out_c_w_idx % out_w_blocks; 860 const int out_b_idx = out_b_h_idx / out_h;//equal to in_b_idx 861 const int out_h_idx = out_b_h_idx % out_h;//equal to in_h_idx 862 863 const int out_w2_idx = mul24(out_w_idx, 2); 864 865 FLOAT4 out0 = vload4(out_c_idx, (__global FLOAT *)bias_ptr); 866 FLOAT4 out1 = out0; 867 868 const int intput_width_idx0 = out_w2_idx; 869 870 for (int in_channel_block_idx = 0; in_channel_block_idx < in_c_block; ++in_channel_block_idx) { 871 int input_width_base = mul24(in_channel_block_idx, out_w); 872 873 int offset = mad24(out_c_idx, in_c_block, in_channel_block_idx)*4; 874 const int inp_offset = 875 (((out_b_idx*in_c_block + in_channel_block_idx)*out_h + out_h_idx)* out_w + intput_width_idx0)*4; 876 877 FLOAT4 in0 = vload4(0, input+inp_offset); 878 FLOAT4 in1 = vload4(1, input+inp_offset);; 879 880 FLOAT4 weights0 = vload4(offset, (__global FLOAT *)kernel_ptr); 881 FLOAT4 weights1 = vload4(offset + 1, (__global FLOAT *)kernel_ptr); 882 FLOAT4 weights2 = vload4(offset + 2, (__global FLOAT *)kernel_ptr); 883 FLOAT4 weights3 = vload4(offset + 3, (__global FLOAT *)kernel_ptr); 884 885 out0.x += dot(weights0, in0); 886 out0.y += dot(weights1, in0); 887 out0.z += dot(weights2, in0); 888 out0.w += dot(weights3, in0); 889 890 out1.x += dot(weights0, in1); 891 out1.y += dot(weights1, in1); 892 out1.z += dot(weights2, in1); 893 out1.w += dot(weights3, in1); 894 } 895 896#ifdef RELU 897 out0 = fmax(out0, (FLOAT4)0); 898 out1 = fmax(out1, (FLOAT4)0); 899#endif 900 901#ifdef RELU6 902 out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6); 903 out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6); 904#endif 905 906 const int out_offset = (((out_b_idx*out_c_block + out_c_idx)*out_h + out_h_idx)* out_w + out_w2_idx)*4; 907 908 const int remain = out_w - out_w2_idx; 909 910 if (remain >= 2) { 911 vstore4(out0, 0, output+out_offset); 912 vstore4(out1, 1, output+out_offset); 913 } else if (remain == 1) { 914 vstore4(out0, 0, output+out_offset); 915 } 916 917} 918