1#ifdef MNN_SUPPORT_FP16
2#pragma OPENCL EXTENSION cl_khr_fp16 : enable
3#endif
4#define READ_INPUT_IMAGE(i, base)                                                                         \
5    int in_width_value##i = in_width##i + base;                                                           \
6    in_width_value##i =                                                                                   \
7        select(in_idx + in_width_value##i, -1, (in_width_value##i < 0 || in_width_value##i >= input_shape.y)); \
8    in##i = RI_F(input, SAMPLER, (int2)(in_width_value##i, in_hb_value));
9
10#define CALCULATE_OUTPUT(i)                  \
11    out##i = mad(in##i.x, weights0, out##i); \
12    out##i = mad(in##i.y, weights1, out##i); \
13    out##i = mad(in##i.z, weights2, out##i); \
14    out##i = mad(in##i.w, weights3, out##i);
15
16#define CALCULATE_OUTPUT_OPT(i)                  \
17    out##i = mad(in_sm##i[local_idx].x, weights0, out##i); \
18    out##i = mad(in_sm##i[local_idx].y, weights1, out##i); \
19    out##i = mad(in_sm##i[local_idx].z, weights2, out##i); \
20    out##i = mad(in_sm##i[local_idx].w, weights3, out##i);
21
22#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0, __private const int global_size_dim1,
23
24__constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
25
26#define DEAL_NON_UNIFORM_DIM2(input1, input2)                       \
27    if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { \
28        return;                                                     \
29    }
30
31#define GLOBAL_SIZE_3_DIMS \
32    __private const int global_size_dim0, __private const int global_size_dim1, __private const int global_size_dim2,
33
34#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3)                                             \
35    if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { \
36        return;                                                                                   \
37    }
38
39#define UNIT 4
40
41__kernel
42#if SET_ATTRIBUTE
43__attribute__((work_group_size_hint(16, 16, 1)))
44#endif
45void conv_2d_1x1_mali(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, __read_only image2d_t input,
46                          #ifdef BUFFER_INP_FP32
47                          __global const float *kernel_ptr,
48                          __global const float *bias_ptr,
49                          #else
50                          __global const FLOAT *kernel_ptr,
51                          __global const FLOAT *bias_ptr,
52                          #endif
53                          __write_only image2d_t output,
54                          __private const int in_c_block, __private const int out_h,
55                          __private const int out_w) {
56
57    const int out_c_w_idx = get_global_id(0); //c/4 w
58    const int out_b_h_idx  = get_global_id(1); //b h
59
60    DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
61
62    const int out_c_idx = out_c_w_idx / out_w_blocks;
63    const int out_w_idx = out_c_w_idx % out_w_blocks;
64
65    const int out_w4_idx = mul24(out_w_idx, 4);
66
67    #ifdef BUFFER_INP_FP32
68    FLOAT4 out0 = CONVERT_FLOAT4(vload4(out_c_idx, (__global float *)bias_ptr));
69    #else
70    FLOAT4 out0 = vload4(out_c_idx, (__global FLOAT *)bias_ptr);
71    #endif
72    FLOAT4 out1 = out0;
73    FLOAT4 out2 = out0;
74    FLOAT4 out3 = out0;
75
76    FLOAT4 weights0;
77    FLOAT4 weights1;
78    FLOAT4 weights2;
79    FLOAT4 weights3;
80
81    FLOAT4 in0;
82    FLOAT4 in1;
83    FLOAT4 in2;
84    FLOAT4 in3;
85
86    FLOAT16 weight16;
87
88    const int intput_width_idx0 = out_w4_idx;
89    const int intput_width_idx1 = out_w4_idx + 1;
90    const int intput_width_idx2 = out_w4_idx + 2;
91    const int intput_width_idx3 = out_w4_idx + 3;
92
93    for (int in_channel_block_idx = 0; in_channel_block_idx < in_c_block; ++in_channel_block_idx) {
94        int input_width_base  = mul24(in_channel_block_idx, out_w);
95
96        int offset = mad24(out_c_idx, in_c_block, in_channel_block_idx)*4;
97        in0 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx0, out_b_h_idx));
98        in1 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx1, out_b_h_idx));
99        in2 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx2, out_b_h_idx));
100        in3 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx3, out_b_h_idx));
101
102        #ifdef BUFFER_INP_FP32
103        weights0 = CONVERT_FLOAT4(vload4(offset, (__global float *)kernel_ptr));
104        weights1 = CONVERT_FLOAT4(vload4(offset + 1, (__global float *)kernel_ptr));
105        weights2 = CONVERT_FLOAT4(vload4(offset + 2, (__global float *)kernel_ptr));
106        weights3 = CONVERT_FLOAT4(vload4(offset + 3, (__global float *)kernel_ptr));
107        #else
108        weights0 = vload4(offset, (__global FLOAT *)kernel_ptr);
109        weights1 = vload4(offset + 1, (__global FLOAT *)kernel_ptr);
110        weights2 = vload4(offset + 2, (__global FLOAT *)kernel_ptr);
111        weights3 = vload4(offset + 3, (__global FLOAT *)kernel_ptr);
112        #endif
113
114        out0.x += dot(weights0, in0);
115        out0.y += dot(weights1, in0);
116        out0.z += dot(weights2, in0);
117        out0.w += dot(weights3, in0);
118
119        out1.x += dot(weights0, in1);
120        out1.y += dot(weights1, in1);
121        out1.z += dot(weights2, in1);
122        out1.w += dot(weights3, in1);
123
124        out2.x += dot(weights0, in2);
125        out2.y += dot(weights1, in2);
126        out2.z += dot(weights2, in2);
127        out2.w += dot(weights3, in2);
128
129        out3.x += dot(weights0, in3);
130        out3.y += dot(weights1, in3);
131        out3.z += dot(weights2, in3);
132        out3.w += dot(weights3, in3);
133
134    }
135
136#ifdef RELU
137    out0 = fmax(out0, (FLOAT4)0);
138    out1 = fmax(out1, (FLOAT4)0);
139    out2 = fmax(out2, (FLOAT4)0);
140    out3 = fmax(out3, (FLOAT4)0);
141#endif
142
143#ifdef RELU6
144    out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6);
145    out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6);
146    out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6);
147    out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6);
148#endif
149
150    const int out_x_base = out_c_idx*out_w;
151
152    const int remain = out_w - out_w4_idx;
153    int output_idx   = out_x_base + out_w4_idx;
154
155    if (remain >= 4) {
156        WI_F(output, (int2)(output_idx, out_b_h_idx), out0);
157        WI_F(output, (int2)(output_idx + 1, out_b_h_idx), out1);
158        WI_F(output, (int2)(output_idx + 2, out_b_h_idx), out2);
159        WI_F(output, (int2)(output_idx + 3, out_b_h_idx), out3);
160    } else if (remain == 3) {
161        WI_F(output, (int2)(output_idx, out_b_h_idx), out0);
162        WI_F(output, (int2)(output_idx + 1, out_b_h_idx), out1);
163        WI_F(output, (int2)(output_idx + 2, out_b_h_idx), out2);
164    } else if (remain == 2) {
165        WI_F(output, (int2)(output_idx, out_b_h_idx), out0);
166        WI_F(output, (int2)(output_idx + 1, out_b_h_idx), out1);
167    } else if (remain == 1) {
168        WI_F(output, (int2)(output_idx, out_b_h_idx), out0);
169    }
170
171}
172
173__kernel void conv_2d_1x1_local(GLOBAL_SIZE_3_DIMS __read_only image2d_t input, __read_only image2d_t weights,
174                          __read_only image2d_t bias,
175                          __write_only image2d_t output,
176                          __private const int in_c_block, __private const int out_h,
177                          __private const int out_w) {
178
179    const int row = get_local_id(0);
180    const int col = get_local_id(1);
181
182    const int out_c_idx = get_global_id(0); //c/4
183    const int out_w_idx = get_global_id(1); //w
184    const int out_b_h_idx  = get_global_id(2); //b h
185
186    DEAL_NON_UNIFORM_DIM3(out_c_idx, out_w_idx, out_b_h_idx);
187
188    const int out_w4_idx = mul24(out_w_idx, 4);
189
190    FLOAT4 out0 = RI_F(bias, SAMPLER, (int2)(out_c_idx, 0));
191    FLOAT4 out1 = out0;
192    FLOAT4 out2 = out0;
193    FLOAT4 out3 = out0;
194
195    FLOAT4 weights0;
196    FLOAT4 weights1;
197    FLOAT4 weights2;
198    FLOAT4 weights3;
199
200    __local FLOAT4 in_sm0[UNIT*UNIT];
201    __local FLOAT4 in_sm1[UNIT*UNIT];
202    __local FLOAT4 in_sm2[UNIT*UNIT];
203    __local FLOAT4 in_sm3[UNIT*UNIT];
204
205    int tiles = (in_c_block + UNIT -1)/ UNIT;
206
207    const int col_x_unit = mul24(col, UNIT);
208    const int in_index = col_x_unit + row;
209
210    for (int t = 0; t < tiles; ++t) {
211
212        int in_c = mad24(t, UNIT, row);
213        int in_c_w_idx = mad24(in_c, out_w, out_w4_idx);
214
215        in_sm0[in_index] = RI_F(input, SAMPLER, (int2)(in_c_w_idx, out_b_h_idx));
216        in_sm1[in_index] = RI_F(input, SAMPLER, (int2)(in_c_w_idx+1, out_b_h_idx));
217        in_sm2[in_index] = RI_F(input, SAMPLER, (int2)(in_c_w_idx+2, out_b_h_idx));
218        in_sm3[in_index] = RI_F(input, SAMPLER, (int2)(in_c_w_idx+3, out_b_h_idx));
219
220        barrier(CLK_GLOBAL_MEM_FENCE);
221
222        int kernel_index = mul24(t, UNIT*4);
223
224        for(int k = 0; k < UNIT; k++){
225
226            __private int kernel_cx4 = mad24(k, 4, kernel_index);
227            __private int local_idx = col_x_unit + k;
228
229            weights0 = RI_F(weights, SAMPLER, (int2)(kernel_cx4++, out_c_idx));
230            weights1 = RI_F(weights, SAMPLER, (int2)(kernel_cx4++, out_c_idx));
231            weights2 = RI_F(weights, SAMPLER, (int2)(kernel_cx4++, out_c_idx));
232            weights3 = RI_F(weights, SAMPLER, (int2)(kernel_cx4++, out_c_idx));
233
234            CALCULATE_OUTPUT_OPT(0);
235            CALCULATE_OUTPUT_OPT(1);
236            CALCULATE_OUTPUT_OPT(2);
237            CALCULATE_OUTPUT_OPT(3);
238        }
239        barrier(CLK_LOCAL_MEM_FENCE);
240    }
241
242#ifdef RELU
243    out0 = fmax(out0, (FLOAT4)0);
244    out1 = fmax(out1, (FLOAT4)0);
245    out2 = fmax(out2, (FLOAT4)0);
246    out3 = fmax(out3, (FLOAT4)0);
247#endif
248
249#ifdef RELU6
250    out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6);
251    out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6);
252    out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6);
253    out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6);
254#endif
255
256    const int out_x_base = out_c_idx*out_w;
257
258    const int remain = out_w - out_w4_idx;
259    int output_idx   = out_x_base + out_w4_idx;
260    if (remain >= 4) {
261        WI_F(output, (int2)(output_idx, out_b_h_idx), out0);
262        WI_F(output, (int2)(output_idx + 1, out_b_h_idx), out1);
263        WI_F(output, (int2)(output_idx + 2, out_b_h_idx), out2);
264        WI_F(output, (int2)(output_idx + 3, out_b_h_idx), out3);
265    } else if (remain == 3) {
266        WI_F(output, (int2)(output_idx, out_b_h_idx), out0);
267        WI_F(output, (int2)(output_idx + 1, out_b_h_idx), out1);
268        WI_F(output, (int2)(output_idx + 2, out_b_h_idx), out2);
269    } else if (remain == 2) {
270        WI_F(output, (int2)(output_idx, out_b_h_idx), out0);
271        WI_F(output, (int2)(output_idx + 1, out_b_h_idx), out1);
272    } else if (remain == 1) {
273        WI_F(output, (int2)(output_idx, out_b_h_idx), out0);
274    }
275
276}
277
278__kernel
279#if SET_ATTRIBUTE
280__attribute__((work_group_size_hint(16, 16, 1)))
281#endif
282void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only image2d_t weights,
283                          __read_only image2d_t bias,
284                          __write_only image2d_t output,
285                          __private const int2 input_shape,
286                          __private const int in_channel_block, __private const int2 output_shape,
287                          __private const int2 stride_shape,
288                          __private const int output_width_4) {
289
290    const int output_channel_width_idx = get_global_id(0);
291    const int output_batch_height_idx  = get_global_id(1);
292    DEAL_NON_UNIFORM_DIM2(output_channel_width_idx, output_batch_height_idx);
293
294    const int output_channel_block_idx = output_channel_width_idx / output_width_4;
295    const int output_width_block_idx   = output_channel_width_idx % output_width_4;
296
297    FLOAT4 out0 = RI_F(bias, SAMPLER, (int2)(output_channel_block_idx, 0));
298    FLOAT4 out1 = out0;
299    FLOAT4 out2 = out0;
300    FLOAT4 out3 = out0;
301
302    int intput_width_idx0 = mul24(output_width_block_idx, stride_shape.y*4);
303    int intput_width_idx1 = intput_width_idx0 + stride_shape.y;
304    int intput_width_idx2 = intput_width_idx1 + stride_shape.y;
305    int intput_width_idx3 = intput_width_idx2 + stride_shape.y;
306
307    intput_width_idx0 = select(intput_width_idx0, INT_MIN, intput_width_idx0 >= input_shape.y);
308    intput_width_idx1 = select(intput_width_idx1, INT_MIN, intput_width_idx1 >= input_shape.y);
309    intput_width_idx2 = select(intput_width_idx2, INT_MIN, intput_width_idx2 >= input_shape.y);
310    intput_width_idx3 = select(intput_width_idx3, INT_MIN, intput_width_idx3 >= input_shape.y);
311
312    int batch_index            = output_batch_height_idx / output_shape.x;
313    int input_height_block_idx = mul24((output_batch_height_idx % output_shape.x), stride_shape.x) + batch_index * input_shape.x;
314
315    FLOAT4 in0;
316    FLOAT4 in1;
317    FLOAT4 in2;
318    FLOAT4 in3;
319    FLOAT4 weights0;
320    FLOAT4 weights1;
321    FLOAT4 weights2;
322    FLOAT4 weights3;
323
324    for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block; ++in_channel_block_idx) {
325        int input_width_base  = in_channel_block_idx * input_shape.y;
326        int weights_width_base = in_channel_block_idx << 2;
327        in0 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx0, input_height_block_idx));
328        in1 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx1, input_height_block_idx));
329        in2 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx2, input_height_block_idx));
330        in3 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx3, input_height_block_idx));
331
332        weights0 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 0, output_channel_block_idx));
333        weights1 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 1, output_channel_block_idx));
334        weights2 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 2, output_channel_block_idx));
335        weights3 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 3, output_channel_block_idx));
336
337        CALCULATE_OUTPUT(0);
338        CALCULATE_OUTPUT(1);
339        CALCULATE_OUTPUT(2);
340        CALCULATE_OUTPUT(3);
341    }
342
343#ifdef RELU
344    out0 = fmax(out0, (FLOAT4)0);
345    out1 = fmax(out1, (FLOAT4)0);
346    out2 = fmax(out2, (FLOAT4)0);
347    out3 = fmax(out3, (FLOAT4)0);
348#endif
349
350#ifdef RELU6
351    out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6);
352    out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6);
353    out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6);
354    out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6);
355#endif
356
357    const int out_x_base = mul24(output_channel_block_idx, output_shape.y);
358    int out_x_idx        = output_width_block_idx << 2;
359
360    const int remain = output_shape.y - out_x_idx;
361    int output_idx   = out_x_base + out_x_idx;
362    if (remain >= 4) {
363        WI_F(output, (int2)(output_idx, output_batch_height_idx), out0);
364        WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1);
365        WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out2);
366        WI_F(output, (int2)(output_idx + 3, output_batch_height_idx), out3);
367    } else if (remain == 3) {
368        WI_F(output, (int2)(output_idx, output_batch_height_idx), out0);
369        WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1);
370        WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out2);
371    } else if (remain == 2) {
372        WI_F(output, (int2)(output_idx, output_batch_height_idx), out0);
373        WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1);
374    } else if (remain == 1) {
375        WI_F(output, (int2)(output_idx, output_batch_height_idx), out0);
376    }
377}
378
379__kernel
380#if SET_ATTRIBUTE
381__attribute__((work_group_size_hint(16, 16, 1)))
382#endif
383void conv_2d(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only image2d_t weights,
384#ifdef BIAS
385                      __read_only image2d_t bias,
386#endif
387                      __write_only image2d_t output,
388                      __private const int2 input_shape,
389                      __private const int in_channel_block_length,
390                      __private const int2 output_shape,
391                      __private const int2 weights_shape,
392                      __private const int2 stride_shape,
393                      __private const int2 padding_shape,
394                      __private const int2 dilation_shape,
395                      __private const int out_width_blocks) {
396
397    const int output_channel_width_idx = get_global_id(0);
398    const int output_batch_height_idx  = get_global_id(1);
399    DEAL_NON_UNIFORM_DIM2(output_channel_width_idx, output_batch_height_idx);
400
401    const int out_channel_block_idx = output_channel_width_idx / out_width_blocks;
402    const int out_height_block_idx   = output_channel_width_idx % out_width_blocks;
403
404#ifdef BIAS
405    FLOAT4 out0 = RI_F(bias, SAMPLER, (int2)(out_channel_block_idx, 0));
406#else
407    FLOAT4 out0 = (FLOAT4)0;
408#endif
409    FLOAT4 out1 = out0;
410    FLOAT4 out2 = out0;
411    FLOAT4 out3 = out0;
412
413    int in_width0          = mad24(out_height_block_idx, stride_shape.y<<2, -padding_shape.y);
414    int in_width1          = in_width0 + stride_shape.y;
415    int in_width2          = in_width0 + stride_shape.y * 2;
416    int in_width3          = in_width0 + stride_shape.y * 3;
417
418    const int height_start = mad24((output_batch_height_idx % output_shape.x), stride_shape.x, -padding_shape.x);
419    int in_height_start    = mad24(select(0, (-height_start + dilation_shape.x - 1) / dilation_shape.x, height_start < 0), dilation_shape.x, height_start);
420    int in_height_end      = min(mad24(weights_shape.x, dilation_shape.x, height_start), input_shape.x);
421
422    const int batch_idx          = mul24((output_batch_height_idx / output_shape.x), input_shape.x);
423    const int weights_h_idx = mul24(out_channel_block_idx, mul24(weights_shape.y, weights_shape.x)) + mul24(select(0, (-height_start + dilation_shape.x - 1) / dilation_shape.x, height_start < 0), weights_shape.y);
424
425    FLOAT4 in0, in1, in2, in3;
426    FLOAT4 weights0, weights1, weights2, weights3;
427    for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block_length; ++in_channel_block_idx) {
428        const int in_idx = mul24(in_channel_block_idx, input_shape.y);
429        int weights_x_idx = in_channel_block_idx << 2;
430        int weights_y_idx = weights_h_idx;
431        for (int iy = in_height_start; iy < in_height_end; iy += dilation_shape.x) {
432            int in_hb_value = iy + batch_idx;
433            for (int w = 0; w < weights_shape.y; w++) {
434                int input_width_base = mul24(w, dilation_shape.y);
435                READ_INPUT_IMAGE(0, input_width_base);
436                READ_INPUT_IMAGE(1, input_width_base);
437                READ_INPUT_IMAGE(2, input_width_base);
438                READ_INPUT_IMAGE(3, input_width_base);
439
440                weights0 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 0, weights_y_idx));
441                weights1 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 1, weights_y_idx));
442                weights2 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weights_y_idx));
443                weights3 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 3, weights_y_idx++));
444
445                CALCULATE_OUTPUT(0);
446                CALCULATE_OUTPUT(1);
447                CALCULATE_OUTPUT(2);
448                CALCULATE_OUTPUT(3);
449            }
450        }
451    }
452
453#ifdef RELU
454    out0 = fmax(out0, (FLOAT4)0);
455    out1 = fmax(out1, (FLOAT4)0);
456    out2 = fmax(out2, (FLOAT4)0);
457    out3 = fmax(out3, (FLOAT4)0);
458#endif
459
460#ifdef RELU6
461    out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6);
462    out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6);
463    out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6);
464    out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6);
465#endif
466
467    const int out_x_base = mul24(out_channel_block_idx, output_shape.y);
468    int out_x_idx        = out_height_block_idx << 2;
469
470    const int remain = output_shape.y - out_x_idx;
471    int output_idx   = out_x_base + out_x_idx;
472    if (remain >= 4) {
473        WI_F(output, (int2)(output_idx, output_batch_height_idx), out0);
474        WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1);
475        WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out2);
476        WI_F(output, (int2)(output_idx + 3, output_batch_height_idx), out3);
477    } else if (remain == 3) {
478        WI_F(output, (int2)(output_idx, output_batch_height_idx), out0);
479        WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1);
480        WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out2);
481    } else if (remain == 2) {
482        WI_F(output, (int2)(output_idx, output_batch_height_idx), out0);
483        WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1);
484    } else if (remain == 1) {
485        WI_F(output, (int2)(output_idx, output_batch_height_idx), out0);
486    }
487}
488