1#ifdef MNN_SUPPORT_FP16 2#pragma OPENCL EXTENSION cl_khr_fp16 : enable 3#endif 4#define READ_INPUT_IMAGE(i, base) \ 5 int in_width_value##i = in_width##i + base; \ 6 in_width_value##i = \ 7 select(in_idx + in_width_value##i, -1, (in_width_value##i < 0 || in_width_value##i >= input_shape.y)); \ 8 in##i = RI_F(input, SAMPLER, (int2)(in_width_value##i, in_hb_value)); 9 10#define CALCULATE_OUTPUT(i) \ 11 out##i = mad(in##i.x, weights0, out##i); \ 12 out##i = mad(in##i.y, weights1, out##i); \ 13 out##i = mad(in##i.z, weights2, out##i); \ 14 out##i = mad(in##i.w, weights3, out##i); 15 16#define CALCULATE_OUTPUT_OPT(i) \ 17 out##i = mad(in_sm##i[local_idx].x, weights0, out##i); \ 18 out##i = mad(in_sm##i[local_idx].y, weights1, out##i); \ 19 out##i = mad(in_sm##i[local_idx].z, weights2, out##i); \ 20 out##i = mad(in_sm##i[local_idx].w, weights3, out##i); 21 22#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0, __private const int global_size_dim1, 23 24__constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; 25 26#define DEAL_NON_UNIFORM_DIM2(input1, input2) \ 27 if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { \ 28 return; \ 29 } 30 31#define GLOBAL_SIZE_3_DIMS \ 32 __private const int global_size_dim0, __private const int global_size_dim1, __private const int global_size_dim2, 33 34#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) \ 35 if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { \ 36 return; \ 37 } 38 39#define UNIT 4 40 41__kernel 42#if SET_ATTRIBUTE 43__attribute__((work_group_size_hint(16, 16, 1))) 44#endif 45void conv_2d_1x1_mali(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, __read_only image2d_t input, 46 #ifdef BUFFER_INP_FP32 47 __global const float *kernel_ptr, 48 __global const float *bias_ptr, 49 #else 50 __global const FLOAT *kernel_ptr, 51 __global const FLOAT *bias_ptr, 52 #endif 53 __write_only image2d_t output, 54 __private const int in_c_block, __private const int out_h, 55 __private const int out_w) { 56 57 const int out_c_w_idx = get_global_id(0); //c/4 w 58 const int out_b_h_idx = get_global_id(1); //b h 59 60 DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx); 61 62 const int out_c_idx = out_c_w_idx / out_w_blocks; 63 const int out_w_idx = out_c_w_idx % out_w_blocks; 64 65 const int out_w4_idx = mul24(out_w_idx, 4); 66 67 #ifdef BUFFER_INP_FP32 68 FLOAT4 out0 = CONVERT_FLOAT4(vload4(out_c_idx, (__global float *)bias_ptr)); 69 #else 70 FLOAT4 out0 = vload4(out_c_idx, (__global FLOAT *)bias_ptr); 71 #endif 72 FLOAT4 out1 = out0; 73 FLOAT4 out2 = out0; 74 FLOAT4 out3 = out0; 75 76 FLOAT4 weights0; 77 FLOAT4 weights1; 78 FLOAT4 weights2; 79 FLOAT4 weights3; 80 81 FLOAT4 in0; 82 FLOAT4 in1; 83 FLOAT4 in2; 84 FLOAT4 in3; 85 86 FLOAT16 weight16; 87 88 const int intput_width_idx0 = out_w4_idx; 89 const int intput_width_idx1 = out_w4_idx + 1; 90 const int intput_width_idx2 = out_w4_idx + 2; 91 const int intput_width_idx3 = out_w4_idx + 3; 92 93 for (int in_channel_block_idx = 0; in_channel_block_idx < in_c_block; ++in_channel_block_idx) { 94 int input_width_base = mul24(in_channel_block_idx, out_w); 95 96 int offset = mad24(out_c_idx, in_c_block, in_channel_block_idx)*4; 97 in0 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx0, out_b_h_idx)); 98 in1 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx1, out_b_h_idx)); 99 in2 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx2, out_b_h_idx)); 100 in3 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx3, out_b_h_idx)); 101 102 #ifdef BUFFER_INP_FP32 103 weights0 = CONVERT_FLOAT4(vload4(offset, (__global float *)kernel_ptr)); 104 weights1 = CONVERT_FLOAT4(vload4(offset + 1, (__global float *)kernel_ptr)); 105 weights2 = CONVERT_FLOAT4(vload4(offset + 2, (__global float *)kernel_ptr)); 106 weights3 = CONVERT_FLOAT4(vload4(offset + 3, (__global float *)kernel_ptr)); 107 #else 108 weights0 = vload4(offset, (__global FLOAT *)kernel_ptr); 109 weights1 = vload4(offset + 1, (__global FLOAT *)kernel_ptr); 110 weights2 = vload4(offset + 2, (__global FLOAT *)kernel_ptr); 111 weights3 = vload4(offset + 3, (__global FLOAT *)kernel_ptr); 112 #endif 113 114 out0.x += dot(weights0, in0); 115 out0.y += dot(weights1, in0); 116 out0.z += dot(weights2, in0); 117 out0.w += dot(weights3, in0); 118 119 out1.x += dot(weights0, in1); 120 out1.y += dot(weights1, in1); 121 out1.z += dot(weights2, in1); 122 out1.w += dot(weights3, in1); 123 124 out2.x += dot(weights0, in2); 125 out2.y += dot(weights1, in2); 126 out2.z += dot(weights2, in2); 127 out2.w += dot(weights3, in2); 128 129 out3.x += dot(weights0, in3); 130 out3.y += dot(weights1, in3); 131 out3.z += dot(weights2, in3); 132 out3.w += dot(weights3, in3); 133 134 } 135 136#ifdef RELU 137 out0 = fmax(out0, (FLOAT4)0); 138 out1 = fmax(out1, (FLOAT4)0); 139 out2 = fmax(out2, (FLOAT4)0); 140 out3 = fmax(out3, (FLOAT4)0); 141#endif 142 143#ifdef RELU6 144 out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6); 145 out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6); 146 out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6); 147 out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6); 148#endif 149 150 const int out_x_base = out_c_idx*out_w; 151 152 const int remain = out_w - out_w4_idx; 153 int output_idx = out_x_base + out_w4_idx; 154 155 if (remain >= 4) { 156 WI_F(output, (int2)(output_idx, out_b_h_idx), out0); 157 WI_F(output, (int2)(output_idx + 1, out_b_h_idx), out1); 158 WI_F(output, (int2)(output_idx + 2, out_b_h_idx), out2); 159 WI_F(output, (int2)(output_idx + 3, out_b_h_idx), out3); 160 } else if (remain == 3) { 161 WI_F(output, (int2)(output_idx, out_b_h_idx), out0); 162 WI_F(output, (int2)(output_idx + 1, out_b_h_idx), out1); 163 WI_F(output, (int2)(output_idx + 2, out_b_h_idx), out2); 164 } else if (remain == 2) { 165 WI_F(output, (int2)(output_idx, out_b_h_idx), out0); 166 WI_F(output, (int2)(output_idx + 1, out_b_h_idx), out1); 167 } else if (remain == 1) { 168 WI_F(output, (int2)(output_idx, out_b_h_idx), out0); 169 } 170 171} 172 173__kernel void conv_2d_1x1_local(GLOBAL_SIZE_3_DIMS __read_only image2d_t input, __read_only image2d_t weights, 174 __read_only image2d_t bias, 175 __write_only image2d_t output, 176 __private const int in_c_block, __private const int out_h, 177 __private const int out_w) { 178 179 const int row = get_local_id(0); 180 const int col = get_local_id(1); 181 182 const int out_c_idx = get_global_id(0); //c/4 183 const int out_w_idx = get_global_id(1); //w 184 const int out_b_h_idx = get_global_id(2); //b h 185 186 DEAL_NON_UNIFORM_DIM3(out_c_idx, out_w_idx, out_b_h_idx); 187 188 const int out_w4_idx = mul24(out_w_idx, 4); 189 190 FLOAT4 out0 = RI_F(bias, SAMPLER, (int2)(out_c_idx, 0)); 191 FLOAT4 out1 = out0; 192 FLOAT4 out2 = out0; 193 FLOAT4 out3 = out0; 194 195 FLOAT4 weights0; 196 FLOAT4 weights1; 197 FLOAT4 weights2; 198 FLOAT4 weights3; 199 200 __local FLOAT4 in_sm0[UNIT*UNIT]; 201 __local FLOAT4 in_sm1[UNIT*UNIT]; 202 __local FLOAT4 in_sm2[UNIT*UNIT]; 203 __local FLOAT4 in_sm3[UNIT*UNIT]; 204 205 int tiles = (in_c_block + UNIT -1)/ UNIT; 206 207 const int col_x_unit = mul24(col, UNIT); 208 const int in_index = col_x_unit + row; 209 210 for (int t = 0; t < tiles; ++t) { 211 212 int in_c = mad24(t, UNIT, row); 213 int in_c_w_idx = mad24(in_c, out_w, out_w4_idx); 214 215 in_sm0[in_index] = RI_F(input, SAMPLER, (int2)(in_c_w_idx, out_b_h_idx)); 216 in_sm1[in_index] = RI_F(input, SAMPLER, (int2)(in_c_w_idx+1, out_b_h_idx)); 217 in_sm2[in_index] = RI_F(input, SAMPLER, (int2)(in_c_w_idx+2, out_b_h_idx)); 218 in_sm3[in_index] = RI_F(input, SAMPLER, (int2)(in_c_w_idx+3, out_b_h_idx)); 219 220 barrier(CLK_GLOBAL_MEM_FENCE); 221 222 int kernel_index = mul24(t, UNIT*4); 223 224 for(int k = 0; k < UNIT; k++){ 225 226 __private int kernel_cx4 = mad24(k, 4, kernel_index); 227 __private int local_idx = col_x_unit + k; 228 229 weights0 = RI_F(weights, SAMPLER, (int2)(kernel_cx4++, out_c_idx)); 230 weights1 = RI_F(weights, SAMPLER, (int2)(kernel_cx4++, out_c_idx)); 231 weights2 = RI_F(weights, SAMPLER, (int2)(kernel_cx4++, out_c_idx)); 232 weights3 = RI_F(weights, SAMPLER, (int2)(kernel_cx4++, out_c_idx)); 233 234 CALCULATE_OUTPUT_OPT(0); 235 CALCULATE_OUTPUT_OPT(1); 236 CALCULATE_OUTPUT_OPT(2); 237 CALCULATE_OUTPUT_OPT(3); 238 } 239 barrier(CLK_LOCAL_MEM_FENCE); 240 } 241 242#ifdef RELU 243 out0 = fmax(out0, (FLOAT4)0); 244 out1 = fmax(out1, (FLOAT4)0); 245 out2 = fmax(out2, (FLOAT4)0); 246 out3 = fmax(out3, (FLOAT4)0); 247#endif 248 249#ifdef RELU6 250 out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6); 251 out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6); 252 out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6); 253 out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6); 254#endif 255 256 const int out_x_base = out_c_idx*out_w; 257 258 const int remain = out_w - out_w4_idx; 259 int output_idx = out_x_base + out_w4_idx; 260 if (remain >= 4) { 261 WI_F(output, (int2)(output_idx, out_b_h_idx), out0); 262 WI_F(output, (int2)(output_idx + 1, out_b_h_idx), out1); 263 WI_F(output, (int2)(output_idx + 2, out_b_h_idx), out2); 264 WI_F(output, (int2)(output_idx + 3, out_b_h_idx), out3); 265 } else if (remain == 3) { 266 WI_F(output, (int2)(output_idx, out_b_h_idx), out0); 267 WI_F(output, (int2)(output_idx + 1, out_b_h_idx), out1); 268 WI_F(output, (int2)(output_idx + 2, out_b_h_idx), out2); 269 } else if (remain == 2) { 270 WI_F(output, (int2)(output_idx, out_b_h_idx), out0); 271 WI_F(output, (int2)(output_idx + 1, out_b_h_idx), out1); 272 } else if (remain == 1) { 273 WI_F(output, (int2)(output_idx, out_b_h_idx), out0); 274 } 275 276} 277 278__kernel 279#if SET_ATTRIBUTE 280__attribute__((work_group_size_hint(16, 16, 1))) 281#endif 282void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only image2d_t weights, 283 __read_only image2d_t bias, 284 __write_only image2d_t output, 285 __private const int2 input_shape, 286 __private const int in_channel_block, __private const int2 output_shape, 287 __private const int2 stride_shape, 288 __private const int output_width_4) { 289 290 const int output_channel_width_idx = get_global_id(0); 291 const int output_batch_height_idx = get_global_id(1); 292 DEAL_NON_UNIFORM_DIM2(output_channel_width_idx, output_batch_height_idx); 293 294 const int output_channel_block_idx = output_channel_width_idx / output_width_4; 295 const int output_width_block_idx = output_channel_width_idx % output_width_4; 296 297 FLOAT4 out0 = RI_F(bias, SAMPLER, (int2)(output_channel_block_idx, 0)); 298 FLOAT4 out1 = out0; 299 FLOAT4 out2 = out0; 300 FLOAT4 out3 = out0; 301 302 int intput_width_idx0 = mul24(output_width_block_idx, stride_shape.y*4); 303 int intput_width_idx1 = intput_width_idx0 + stride_shape.y; 304 int intput_width_idx2 = intput_width_idx1 + stride_shape.y; 305 int intput_width_idx3 = intput_width_idx2 + stride_shape.y; 306 307 intput_width_idx0 = select(intput_width_idx0, INT_MIN, intput_width_idx0 >= input_shape.y); 308 intput_width_idx1 = select(intput_width_idx1, INT_MIN, intput_width_idx1 >= input_shape.y); 309 intput_width_idx2 = select(intput_width_idx2, INT_MIN, intput_width_idx2 >= input_shape.y); 310 intput_width_idx3 = select(intput_width_idx3, INT_MIN, intput_width_idx3 >= input_shape.y); 311 312 int batch_index = output_batch_height_idx / output_shape.x; 313 int input_height_block_idx = mul24((output_batch_height_idx % output_shape.x), stride_shape.x) + batch_index * input_shape.x; 314 315 FLOAT4 in0; 316 FLOAT4 in1; 317 FLOAT4 in2; 318 FLOAT4 in3; 319 FLOAT4 weights0; 320 FLOAT4 weights1; 321 FLOAT4 weights2; 322 FLOAT4 weights3; 323 324 for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block; ++in_channel_block_idx) { 325 int input_width_base = in_channel_block_idx * input_shape.y; 326 int weights_width_base = in_channel_block_idx << 2; 327 in0 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx0, input_height_block_idx)); 328 in1 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx1, input_height_block_idx)); 329 in2 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx2, input_height_block_idx)); 330 in3 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx3, input_height_block_idx)); 331 332 weights0 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 0, output_channel_block_idx)); 333 weights1 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 1, output_channel_block_idx)); 334 weights2 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 2, output_channel_block_idx)); 335 weights3 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 3, output_channel_block_idx)); 336 337 CALCULATE_OUTPUT(0); 338 CALCULATE_OUTPUT(1); 339 CALCULATE_OUTPUT(2); 340 CALCULATE_OUTPUT(3); 341 } 342 343#ifdef RELU 344 out0 = fmax(out0, (FLOAT4)0); 345 out1 = fmax(out1, (FLOAT4)0); 346 out2 = fmax(out2, (FLOAT4)0); 347 out3 = fmax(out3, (FLOAT4)0); 348#endif 349 350#ifdef RELU6 351 out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6); 352 out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6); 353 out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6); 354 out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6); 355#endif 356 357 const int out_x_base = mul24(output_channel_block_idx, output_shape.y); 358 int out_x_idx = output_width_block_idx << 2; 359 360 const int remain = output_shape.y - out_x_idx; 361 int output_idx = out_x_base + out_x_idx; 362 if (remain >= 4) { 363 WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); 364 WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1); 365 WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out2); 366 WI_F(output, (int2)(output_idx + 3, output_batch_height_idx), out3); 367 } else if (remain == 3) { 368 WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); 369 WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1); 370 WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out2); 371 } else if (remain == 2) { 372 WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); 373 WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1); 374 } else if (remain == 1) { 375 WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); 376 } 377} 378 379__kernel 380#if SET_ATTRIBUTE 381__attribute__((work_group_size_hint(16, 16, 1))) 382#endif 383void conv_2d(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only image2d_t weights, 384#ifdef BIAS 385 __read_only image2d_t bias, 386#endif 387 __write_only image2d_t output, 388 __private const int2 input_shape, 389 __private const int in_channel_block_length, 390 __private const int2 output_shape, 391 __private const int2 weights_shape, 392 __private const int2 stride_shape, 393 __private const int2 padding_shape, 394 __private const int2 dilation_shape, 395 __private const int out_width_blocks) { 396 397 const int output_channel_width_idx = get_global_id(0); 398 const int output_batch_height_idx = get_global_id(1); 399 DEAL_NON_UNIFORM_DIM2(output_channel_width_idx, output_batch_height_idx); 400 401 const int out_channel_block_idx = output_channel_width_idx / out_width_blocks; 402 const int out_height_block_idx = output_channel_width_idx % out_width_blocks; 403 404#ifdef BIAS 405 FLOAT4 out0 = RI_F(bias, SAMPLER, (int2)(out_channel_block_idx, 0)); 406#else 407 FLOAT4 out0 = (FLOAT4)0; 408#endif 409 FLOAT4 out1 = out0; 410 FLOAT4 out2 = out0; 411 FLOAT4 out3 = out0; 412 413 int in_width0 = mad24(out_height_block_idx, stride_shape.y<<2, -padding_shape.y); 414 int in_width1 = in_width0 + stride_shape.y; 415 int in_width2 = in_width0 + stride_shape.y * 2; 416 int in_width3 = in_width0 + stride_shape.y * 3; 417 418 const int height_start = mad24((output_batch_height_idx % output_shape.x), stride_shape.x, -padding_shape.x); 419 int in_height_start = mad24(select(0, (-height_start + dilation_shape.x - 1) / dilation_shape.x, height_start < 0), dilation_shape.x, height_start); 420 int in_height_end = min(mad24(weights_shape.x, dilation_shape.x, height_start), input_shape.x); 421 422 const int batch_idx = mul24((output_batch_height_idx / output_shape.x), input_shape.x); 423 const int weights_h_idx = mul24(out_channel_block_idx, mul24(weights_shape.y, weights_shape.x)) + mul24(select(0, (-height_start + dilation_shape.x - 1) / dilation_shape.x, height_start < 0), weights_shape.y); 424 425 FLOAT4 in0, in1, in2, in3; 426 FLOAT4 weights0, weights1, weights2, weights3; 427 for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block_length; ++in_channel_block_idx) { 428 const int in_idx = mul24(in_channel_block_idx, input_shape.y); 429 int weights_x_idx = in_channel_block_idx << 2; 430 int weights_y_idx = weights_h_idx; 431 for (int iy = in_height_start; iy < in_height_end; iy += dilation_shape.x) { 432 int in_hb_value = iy + batch_idx; 433 for (int w = 0; w < weights_shape.y; w++) { 434 int input_width_base = mul24(w, dilation_shape.y); 435 READ_INPUT_IMAGE(0, input_width_base); 436 READ_INPUT_IMAGE(1, input_width_base); 437 READ_INPUT_IMAGE(2, input_width_base); 438 READ_INPUT_IMAGE(3, input_width_base); 439 440 weights0 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 0, weights_y_idx)); 441 weights1 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 1, weights_y_idx)); 442 weights2 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weights_y_idx)); 443 weights3 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 3, weights_y_idx++)); 444 445 CALCULATE_OUTPUT(0); 446 CALCULATE_OUTPUT(1); 447 CALCULATE_OUTPUT(2); 448 CALCULATE_OUTPUT(3); 449 } 450 } 451 } 452 453#ifdef RELU 454 out0 = fmax(out0, (FLOAT4)0); 455 out1 = fmax(out1, (FLOAT4)0); 456 out2 = fmax(out2, (FLOAT4)0); 457 out3 = fmax(out3, (FLOAT4)0); 458#endif 459 460#ifdef RELU6 461 out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6); 462 out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6); 463 out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6); 464 out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6); 465#endif 466 467 const int out_x_base = mul24(out_channel_block_idx, output_shape.y); 468 int out_x_idx = out_height_block_idx << 2; 469 470 const int remain = output_shape.y - out_x_idx; 471 int output_idx = out_x_base + out_x_idx; 472 if (remain >= 4) { 473 WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); 474 WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1); 475 WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out2); 476 WI_F(output, (int2)(output_idx + 3, output_batch_height_idx), out3); 477 } else if (remain == 3) { 478 WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); 479 WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1); 480 WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out2); 481 } else if (remain == 2) { 482 WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); 483 WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1); 484 } else if (remain == 1) { 485 WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); 486 } 487} 488