1#ifdef MNN_SUPPORT_FP16
2#pragma OPENCL EXTENSION cl_khr_fp16 : enable
3#endif
4#define READ_INPUT_IMAGE(i, base)                                                                         \
5    int inOffset##i = inWidthOffset##i + base;                                                           \
6    inOffset##i =                                                                                   \
7        select(inCurIdx + inOffset##i, -1, (inOffset##i < 0 || inOffset##i >= inputShape.y)); \
8    in_c_block_idx = inOffset##i / inputShape.y; \
9    in_w_idx = inOffset##i % inputShape.y; \
10    inpOffset = (((in_b_idx*channelBlocks + in_c_block_idx)*inputShape.x + in_h_idx)* inputShape.y + in_w_idx)*4; \
11    inValue##i = ((inOffset##i)==-1 || inHeightIdx == -1) ? (FLOAT4)0 : vload4(0, input+inpOffset);
12
13
14#define CALCULATE_OUTPUT(i)                  \
15    outValue##i = mad(inValue##i.x, weights0, outValue##i); \
16    outValue##i = mad(inValue##i.y, weights1, outValue##i); \
17    outValue##i = mad(inValue##i.z, weights2, outValue##i); \
18    outValue##i = mad(inValue##i.w, weights3, outValue##i);
19
20#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0, __private const int global_size_dim1,
21
22#define DEAL_NON_UNIFORM_DIM2(input1, input2)                       \
23    if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { \
24        return;                                                     \
25    }
26
27#define DW_CONV_NEXT_LINE_CAL(x,y)  \
28    x = mad(inValue0, weights0, x); \
29    x = mad(inValue1, weights1, x); \
30    x = mad(inValue2, weights2, x); \
31    y = mad(inValue1, weights0, y); \
32    y = mad(inValue2, weights1, y); \
33    y = mad(inValue3, weights2, y);
34
35__kernel
36void depthwise_conv2d_c4h1w4(GLOBAL_SIZE_2_DIMS __global const FLOAT *input,
37                                  __global const FLOAT *filter,
38                                  __global const FLOAT *bias,
39                                  __global FLOAT *output,
40                                  __private const int2 in_hw,
41                                  __private const int channel,
42                                  __private const int2 out_hw,
43                                  __private const int2 filter_hw,
44                                  __private const int2 pad_hw,
45                                  __private const int2 dilate_hw,
46                                  __private const int2 stride_hw,
47                                  __private const int out_w_blocks,
48                                  __private const int c_blocks) {
49
50    const int out_c_w_idx = get_global_id(0);// oc/4 * ow/4
51    const int out_b_h_idx  = get_global_id(1);// b * h
52    DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
53
54    const int c_idx = out_c_w_idx / out_w_blocks;
55    const int out_w_idx = out_c_w_idx % out_w_blocks;
56    const int b_idx = out_b_h_idx / out_hw.x;
57    const int out_h_idx = out_b_h_idx % out_hw.x;
58
59    FLOAT4 outValue0 = vload4(c_idx, bias);
60    FLOAT4 outValue1 = outValue0;
61    FLOAT4 outValue2 = outValue0;
62    FLOAT4 outValue3 = outValue0;
63
64    const int out_w4_idx = out_w_idx << 2;
65    const int in_w_start_0 = out_w4_idx * stride_hw.y - pad_hw.y;
66    const int in_w_start_1 = in_w_start_0 + stride_hw.y;
67    const int in_w_start_2 = in_w_start_1 + stride_hw.y;
68    const int in_w_start_3 = in_w_start_2 + stride_hw.y;
69
70    const int in_h_start = out_h_idx * stride_hw.x - pad_hw.x;
71
72    for (int kh = 0; kh < filter_hw.x; kh++) {
73        const int in_h_cur = in_h_start + kh * dilate_hw.x;
74        if(in_h_cur < 0 || in_h_cur >= in_hw.x) continue;
75
76        int inp_offset = (((b_idx*c_blocks + c_idx)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4;
77        for (int kw = 0; kw < filter_hw.y; kw++) {
78            const int filter_idx = mad24(kh, filter_hw.y, kw);
79            const int kw_dilate = kw * dilate_hw.y;
80            FLOAT4 inValue0 = (in_w_start_0+kw_dilate < 0 || in_w_start_0+kw_dilate >= in_hw.y) ? (FLOAT4)0 : vload4(kw_dilate+0, input+inp_offset);
81            FLOAT4 inValue1 = (in_w_start_1+kw_dilate < 0 || in_w_start_1+kw_dilate >= in_hw.y) ? (FLOAT4)0 : vload4(kw_dilate+1*stride_hw.y, input+inp_offset);
82            FLOAT4 inValue2 = (in_w_start_2+kw_dilate < 0 || in_w_start_2+kw_dilate >= in_hw.y) ? (FLOAT4)0 : vload4(kw_dilate+2*stride_hw.y, input+inp_offset);
83            FLOAT4 inValue3 = (in_w_start_3+kw_dilate < 0 || in_w_start_3+kw_dilate >= in_hw.y) ? (FLOAT4)0 : vload4(kw_dilate+3*stride_hw.y, input+inp_offset);
84
85            //NC4HW4 [1, filterShape.x*filterShape.y, 1, channelBlocks] x oc4
86            //index: [0, filterIdx,                   0, inChannelBlockIdx]
87            FLOAT4 weights = vload4(0, filter+(filter_idx*c_blocks+c_idx)*4);
88
89            outValue0 = mad(inValue0, weights, outValue0);
90            outValue1 = mad(inValue1, weights, outValue1);
91            outValue2 = mad(inValue2, weights, outValue2);
92            outValue3 = mad(inValue3, weights, outValue3);
93        }
94    }
95
96#ifdef RELU
97    outValue0 = fmax(outValue0, (FLOAT4)0);
98    outValue1 = fmax(outValue1, (FLOAT4)0);
99    outValue2 = fmax(outValue2, (FLOAT4)0);
100    outValue3 = fmax(outValue3, (FLOAT4)0);
101#endif
102
103#ifdef RELU6
104    outValue0 = clamp(outValue0, (FLOAT4)0, (FLOAT4)6);
105    outValue1 = clamp(outValue1, (FLOAT4)0, (FLOAT4)6);
106    outValue2 = clamp(outValue2, (FLOAT4)0, (FLOAT4)6);
107    outValue3 = clamp(outValue3, (FLOAT4)0, (FLOAT4)6);
108#endif
109
110    const int out_offset = (((b_idx*c_blocks + c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w4_idx)*4;
111
112    const int remain     = out_hw.y - out_w4_idx;
113    if (remain >= 4) {
114        vstore4(outValue0, 0, output+out_offset);
115        vstore4(outValue1, 1, output+out_offset);
116        vstore4(outValue2, 2, output+out_offset);
117        vstore4(outValue3, 3, output+out_offset);
118    } else if (remain == 3) {
119        vstore4(outValue0, 0, output+out_offset);
120        vstore4(outValue1, 1, output+out_offset);
121        vstore4(outValue2, 2, output+out_offset);
122    } else if (remain == 2) {
123        vstore4(outValue0, 0, output+out_offset);
124        vstore4(outValue1, 1, output+out_offset);
125    } else if (remain == 1) {
126        vstore4(outValue0, 0, output+out_offset);
127    }
128
129}
130
131__kernel
132void depthwise_conv2d_c4h1w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *input,
133                                  __global const FLOAT *filter,
134                                  __global const FLOAT *bias,
135                                  __global FLOAT *output,
136                                  __private const int2 in_hw,
137                                  __private const int channel,
138                                  __private const int2 out_hw,
139                                  __private const int2 filter_hw,
140                                  __private const int2 pad_hw,
141                                  __private const int2 dilate_hw,
142                                  __private const int2 stride_hw,
143                                  __private const int out_w_blocks,
144                                  __private const int c_blocks) {
145
146    const int out_c_w_idx = get_global_id(0);// oc/4 * ow/4
147    const int out_b_h_idx  = get_global_id(1);// b * h
148    DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
149
150    const int c_idx = out_c_w_idx / out_w_blocks;
151    const int out_w_idx = out_c_w_idx % out_w_blocks;
152    const int b_idx = out_b_h_idx / out_hw.x;
153    const int out_h_idx = out_b_h_idx % out_hw.x;
154
155    FLOAT4 outValue0 = vload4(c_idx, bias);
156    FLOAT4 outValue1 = outValue0;
157
158    const int out_w2_idx = out_w_idx << 1;
159    const int in_w_start_0 = out_w2_idx * stride_hw.y - pad_hw.y;
160    const int in_w_start_1 = in_w_start_0 + stride_hw.y;
161
162    const int in_h_start = out_h_idx * stride_hw.x - pad_hw.x;
163
164    for (int kh = 0; kh < filter_hw.x; kh++) {
165        const int in_h_cur = in_h_start + kh * dilate_hw.x;
166        if(in_h_cur < 0 || in_h_cur >= in_hw.x) continue;
167
168        int inp_offset = (((b_idx*c_blocks + c_idx)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4;
169        for (int kw = 0; kw < filter_hw.y; kw++) {
170            const int filter_idx = mad24(kh, filter_hw.y, kw);
171            const int kw_dilate = kw * dilate_hw.y;
172            FLOAT4 inValue0 = (in_w_start_0+kw_dilate < 0 || in_w_start_0+kw_dilate >= in_hw.y) ? (FLOAT4)0 : vload4(kw_dilate+0, input+inp_offset);
173            FLOAT4 inValue1 = (in_w_start_1+kw_dilate < 0 || in_w_start_1+kw_dilate >= in_hw.y) ? (FLOAT4)0 : vload4(kw_dilate+1*stride_hw.y, input+inp_offset);
174
175            //NC4HW4 [1, filterShape.x*filterShape.y, 1, channelBlocks] x oc4
176            //index: [0, filterIdx,                   0, inChannelBlockIdx]
177            FLOAT4 weights = vload4(0, filter+(filter_idx*c_blocks+c_idx)*4);
178
179            outValue0 = mad(inValue0, weights, outValue0);
180            outValue1 = mad(inValue1, weights, outValue1);
181        }
182    }
183
184#ifdef RELU
185    outValue0 = fmax(outValue0, (FLOAT4)0);
186    outValue1 = fmax(outValue1, (FLOAT4)0);
187#endif
188
189#ifdef RELU6
190    outValue0 = clamp(outValue0, (FLOAT4)0, (FLOAT4)6);
191    outValue1 = clamp(outValue1, (FLOAT4)0, (FLOAT4)6);
192#endif
193
194    const int out_offset = (((b_idx*c_blocks + c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w2_idx)*4;
195
196    const int remain     = out_hw.y - out_w2_idx;
197    if (remain >= 2) {
198        vstore4(outValue0, 0, output+out_offset);
199        vstore4(outValue1, 1, output+out_offset);
200    } else if (remain == 1) {
201        vstore4(outValue0, 0, output+out_offset);
202    }
203
204}
205
206__kernel
207void depthwise_conv2d_c4h1w1(GLOBAL_SIZE_2_DIMS __global const FLOAT *input,
208                                  __global const FLOAT *filter,
209                                  __global const FLOAT *bias,
210                                  __global FLOAT *output,
211                                  __private const int2 in_hw,
212                                  __private const int channel,
213                                  __private const int2 out_hw,
214                                  __private const int2 filter_hw,
215                                  __private const int2 pad_hw,
216                                  __private const int2 dilate_hw,
217                                  __private const int2 stride_hw,
218                                  __private const int out_w_blocks,
219                                  __private const int c_blocks) {
220
221    const int out_c_w_idx = get_global_id(0);// oc/4 * ow/4
222    const int out_b_h_idx  = get_global_id(1);// b * h
223    DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
224
225    const int c_idx = out_c_w_idx / out_w_blocks;
226    const int out_w_idx = out_c_w_idx % out_w_blocks;
227    const int b_idx = out_b_h_idx / out_hw.x;
228    const int out_h_idx = out_b_h_idx % out_hw.x;
229
230    FLOAT4 outValue0 = vload4(c_idx, bias);
231    FLOAT4 outValue1 = outValue0;
232
233    const int in_w_start_0 = out_w_idx * stride_hw.y - pad_hw.y;
234    const int in_h_start = out_h_idx * stride_hw.x - pad_hw.x;
235
236    for (int kh = 0; kh < filter_hw.x; kh++) {
237        const int in_h_cur = in_h_start + kh * dilate_hw.x;
238        if(in_h_cur < 0 || in_h_cur >= in_hw.x) continue;
239
240        int inp_offset = (((b_idx*c_blocks + c_idx)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4;
241        for (int kw = 0; kw < filter_hw.y; kw++) {
242            const int filter_idx = mad24(kh, filter_hw.y, kw);
243            const int kw_dilate = kw * dilate_hw.y;
244            FLOAT4 inValue0 = (in_w_start_0+kw_dilate < 0 || in_w_start_0+kw_dilate >= in_hw.y) ? (FLOAT4)0 : vload4(kw_dilate+0, input+inp_offset);
245
246            //NC4HW4 [1, filterShape.x*filterShape.y, 1, channelBlocks] x oc4
247            //index: [0, filterIdx,                   0, inChannelBlockIdx]
248            FLOAT4 weights = vload4(0, filter+(filter_idx*c_blocks+c_idx)*4);
249
250            outValue0 = mad(inValue0, weights, outValue0);
251        }
252    }
253
254#ifdef RELU
255    outValue0 = fmax(outValue0, (FLOAT4)0);
256#endif
257
258#ifdef RELU6
259    outValue0 = clamp(outValue0, (FLOAT4)0, (FLOAT4)6);
260#endif
261
262    const int out_offset = (((b_idx*c_blocks + c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
263
264    vstore4(outValue0, 0, output+out_offset);
265}
266
267__kernel
268void depthwise_conv2d_s1_c8h1w4(GLOBAL_SIZE_2_DIMS __global const FLOAT *input,
269                                  __global const FLOAT *filter,
270                                  __global const FLOAT *bias,
271                                  __global FLOAT *output,
272                                  __private const int2 in_hw,
273                                  __private const int channel,
274                                  __private const int2 out_hw,
275                                  __private const int2 filter_hw,
276                                  __private const int2 pad_hw,
277                                  __private const int2 dilate_hw,
278                                  __private const int2 stride_hw,
279                                  __private const int out_w_blocks,
280                                  __private const int c_blocks) {
281
282    const int out_c_w_idx = get_global_id(0);// oc/4 * ow/4
283    const int out_b_h_idx  = get_global_id(1);// b * h
284    DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
285
286    const int c_idx = (out_c_w_idx / out_w_blocks) << 1;
287    const int out_w_idx = out_c_w_idx % out_w_blocks;
288    const int b_idx = out_b_h_idx / out_hw.x;
289    const int out_h_idx = out_b_h_idx % out_hw.x;
290
291    FLOAT4 outValue0 = vload4(c_idx+0, bias);
292    FLOAT4 outValue1 = outValue0;
293    FLOAT4 outValue2 = outValue0;
294    FLOAT4 outValue3 = outValue0;
295    FLOAT4 outValue4 = vload4(c_idx+1, bias);
296    FLOAT4 outValue5 = outValue4;
297    FLOAT4 outValue6 = outValue4;
298    FLOAT4 outValue7 = outValue4;
299
300    const int out_w4_idx = out_w_idx << 2;
301    const int in_w_start_0 = out_w4_idx - pad_hw.y;
302    const int in_w_start_1 = in_w_start_0 + 1;
303    const int in_w_start_2 = in_w_start_0 + 2;
304    const int in_w_start_3 = in_w_start_0 + 3;
305
306    const int in_h_start = out_h_idx - pad_hw.x;
307
308    for (int kh = 0; kh < filter_hw.x; kh++) {
309        const int in_h_cur = in_h_start + kh;
310        if(in_h_cur < 0 || in_h_cur >= in_hw.x) continue;
311
312        int inp_offset_c0 = (((b_idx*c_blocks + c_idx+0)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4;
313        int inp_offset_c1 = (((b_idx*c_blocks + c_idx+1)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4;
314        for (int kw = 0; kw < filter_hw.y; kw++) {
315            const int filter_idx = mad24(kh, filter_hw.y, kw);
316            FLOAT4 inValue0 = (in_w_start_0+kw < 0 || in_w_start_0+kw >= in_hw.y) ? (FLOAT4)0 : vload4(kw+0, input+inp_offset_c0);
317            FLOAT4 inValue1 = (in_w_start_1+kw < 0 || in_w_start_1+kw >= in_hw.y) ? (FLOAT4)0 : vload4(kw+1, input+inp_offset_c0);
318            FLOAT4 inValue2 = (in_w_start_2+kw < 0 || in_w_start_2+kw >= in_hw.y) ? (FLOAT4)0 : vload4(kw+2, input+inp_offset_c0);
319            FLOAT4 inValue3 = (in_w_start_3+kw < 0 || in_w_start_3+kw >= in_hw.y) ? (FLOAT4)0 : vload4(kw+3, input+inp_offset_c0);
320
321            FLOAT4 inValue4 = (in_w_start_0+kw < 0 || in_w_start_0+kw >= in_hw.y) ? (FLOAT4)0 : vload4(kw+0, input+inp_offset_c1);
322            FLOAT4 inValue5 = (in_w_start_1+kw < 0 || in_w_start_1+kw >= in_hw.y) ? (FLOAT4)0 : vload4(kw+1, input+inp_offset_c1);
323            FLOAT4 inValue6 = (in_w_start_2+kw < 0 || in_w_start_2+kw >= in_hw.y) ? (FLOAT4)0 : vload4(kw+2, input+inp_offset_c1);
324            FLOAT4 inValue7 = (in_w_start_3+kw < 0 || in_w_start_3+kw >= in_hw.y) ? (FLOAT4)0 : vload4(kw+3, input+inp_offset_c1);
325
326            //NC4HW4 [1, filterShape.x*filterShape.y, 1, channelBlocks] x oc4
327            //index: [0, filterIdx,                   0, inChannelBlockIdx]
328            FLOAT4 weights_0 = vload4(0, filter+(filter_idx*c_blocks+c_idx+0)*4);
329            FLOAT4 weights_1 = vload4(0, filter+(filter_idx*c_blocks+c_idx+1)*4);
330
331            outValue0 = mad(inValue0, weights_0, outValue0);
332            outValue1 = mad(inValue1, weights_0, outValue1);
333            outValue2 = mad(inValue2, weights_0, outValue2);
334            outValue3 = mad(inValue3, weights_0, outValue3);
335
336            outValue4 = mad(inValue4, weights_1, outValue4);
337            outValue5 = mad(inValue5, weights_1, outValue5);
338            outValue6 = mad(inValue6, weights_1, outValue6);
339            outValue7 = mad(inValue7, weights_1, outValue7);
340        }
341    }
342
343#ifdef RELU
344    outValue0 = fmax(outValue0, (FLOAT4)0);
345    outValue1 = fmax(outValue1, (FLOAT4)0);
346    outValue2 = fmax(outValue2, (FLOAT4)0);
347    outValue3 = fmax(outValue3, (FLOAT4)0);
348
349    outValue4 = fmax(outValue4, (FLOAT4)0);
350    outValue5 = fmax(outValue5, (FLOAT4)0);
351    outValue6 = fmax(outValue6, (FLOAT4)0);
352    outValue7 = fmax(outValue7, (FLOAT4)0);
353#endif
354
355#ifdef RELU6
356    outValue0 = clamp(outValue0, (FLOAT4)0, (FLOAT4)6);
357    outValue1 = clamp(outValue1, (FLOAT4)0, (FLOAT4)6);
358    outValue2 = clamp(outValue2, (FLOAT4)0, (FLOAT4)6);
359    outValue3 = clamp(outValue3, (FLOAT4)0, (FLOAT4)6);
360
361    outValue4 = clamp(outValue4, (FLOAT4)0, (FLOAT4)6);
362    outValue5 = clamp(outValue5, (FLOAT4)0, (FLOAT4)6);
363    outValue6 = clamp(outValue6, (FLOAT4)0, (FLOAT4)6);
364    outValue7 = clamp(outValue7, (FLOAT4)0, (FLOAT4)6);
365#endif
366
367    int out_offset = (((b_idx*c_blocks + c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w4_idx)*4;
368
369    const int remain     = out_hw.y - out_w4_idx;
370    if (remain >= 4) {
371        vstore4(outValue0, 0, output+out_offset);
372        vstore4(outValue1, 1, output+out_offset);
373        vstore4(outValue2, 2, output+out_offset);
374        vstore4(outValue3, 3, output+out_offset);
375    } else if (remain == 3) {
376        vstore4(outValue0, 0, output+out_offset);
377        vstore4(outValue1, 1, output+out_offset);
378        vstore4(outValue2, 2, output+out_offset);
379    } else if (remain == 2) {
380        vstore4(outValue0, 0, output+out_offset);
381        vstore4(outValue1, 1, output+out_offset);
382    } else if (remain == 1) {
383        vstore4(outValue0, 0, output+out_offset);
384    }
385
386    if(c_idx + 1 >= c_blocks) return;
387
388    out_offset += out_hw.x * out_hw.y * 4;
389    if (remain >= 4) {
390        vstore4(outValue4, 0, output+out_offset);
391        vstore4(outValue5, 1, output+out_offset);
392        vstore4(outValue6, 2, output+out_offset);
393        vstore4(outValue7, 3, output+out_offset);
394    } else if (remain == 3) {
395        vstore4(outValue4, 0, output+out_offset);
396        vstore4(outValue5, 1, output+out_offset);
397        vstore4(outValue6, 2, output+out_offset);
398    } else if (remain == 2) {
399        vstore4(outValue4, 0, output+out_offset);
400        vstore4(outValue5, 1, output+out_offset);
401    } else if (remain == 1) {
402        vstore4(outValue4, 0, output+out_offset);
403    }
404}
405
406
407__kernel
408void depthwise_conv2d_s1_c8h1w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *input,
409                                  __global const FLOAT *filter,
410                                  __global const FLOAT *bias,
411                                  __global FLOAT *output,
412                                  __private const int2 in_hw,
413                                  __private const int channel,
414                                  __private const int2 out_hw,
415                                  __private const int2 filter_hw,
416                                  __private const int2 pad_hw,
417                                  __private const int2 dilate_hw,
418                                  __private const int2 stride_hw,
419                                  __private const int out_w_blocks,
420                                  __private const int c_blocks) {
421
422    const int out_c_w_idx = get_global_id(0);// oc/4 * ow/4
423    const int out_b_h_idx  = get_global_id(1);// b * h
424    DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
425
426    const int c_idx = (out_c_w_idx / out_w_blocks) << 1;
427    const int out_w_idx = out_c_w_idx % out_w_blocks;
428    const int b_idx = out_b_h_idx / out_hw.x;
429    const int out_h_idx = out_b_h_idx % out_hw.x;
430
431    FLOAT4 outValue0 = vload4(c_idx+0, bias);
432    FLOAT4 outValue1 = outValue0;
433    FLOAT4 outValue4 = vload4(c_idx+1, bias);
434    FLOAT4 outValue5 = outValue4;
435
436    const int out_w2_idx = out_w_idx << 1;
437    const int in_w_start_0 = out_w2_idx - pad_hw.y;
438    const int in_w_start_1 = in_w_start_0 + 1;
439
440    const int in_h_start = out_h_idx - pad_hw.x;
441
442    for (int kh = 0; kh < filter_hw.x; kh++) {
443        const int in_h_cur = in_h_start + kh;
444        if(in_h_cur < 0 || in_h_cur >= in_hw.x) continue;
445
446        int inp_offset_c0 = (((b_idx*c_blocks + c_idx+0)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4;
447        int inp_offset_c1 = (((b_idx*c_blocks + c_idx+1)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4;
448        for (int kw = 0; kw < filter_hw.y; kw++) {
449            const int filter_idx = mad24(kh, filter_hw.y, kw);
450            FLOAT4 inValue0 = (in_w_start_0+kw < 0 || in_w_start_0+kw >= in_hw.y) ? (FLOAT4)0 : vload4(kw+0, input+inp_offset_c0);
451            FLOAT4 inValue1 = (in_w_start_1+kw < 0 || in_w_start_1+kw >= in_hw.y) ? (FLOAT4)0 : vload4(kw+1, input+inp_offset_c0);
452
453            FLOAT4 inValue4 = (in_w_start_0+kw < 0 || in_w_start_0+kw >= in_hw.y) ? (FLOAT4)0 : vload4(kw+0, input+inp_offset_c1);
454            FLOAT4 inValue5 = (in_w_start_1+kw < 0 || in_w_start_1+kw >= in_hw.y) ? (FLOAT4)0 : vload4(kw+1, input+inp_offset_c1);
455
456            //NC4HW4 [1, filterShape.x*filterShape.y, 1, channelBlocks] x oc4
457            //index: [0, filterIdx,                   0, inChannelBlockIdx]
458            FLOAT4 weights_0 = vload4(0, filter+(filter_idx*c_blocks+c_idx+0)*4);
459            FLOAT4 weights_1 = vload4(0, filter+(filter_idx*c_blocks+c_idx+1)*4);
460
461            outValue0 = mad(inValue0, weights_0, outValue0);
462            outValue1 = mad(inValue1, weights_0, outValue1);
463
464            outValue4 = mad(inValue4, weights_1, outValue4);
465            outValue5 = mad(inValue5, weights_1, outValue5);
466        }
467    }
468
469#ifdef RELU
470    outValue0 = fmax(outValue0, (FLOAT4)0);
471    outValue1 = fmax(outValue1, (FLOAT4)0);
472
473    outValue4 = fmax(outValue4, (FLOAT4)0);
474    outValue5 = fmax(outValue5, (FLOAT4)0);
475#endif
476
477#ifdef RELU6
478    outValue0 = clamp(outValue0, (FLOAT4)0, (FLOAT4)6);
479    outValue1 = clamp(outValue1, (FLOAT4)0, (FLOAT4)6);
480
481    outValue4 = clamp(outValue4, (FLOAT4)0, (FLOAT4)6);
482    outValue5 = clamp(outValue5, (FLOAT4)0, (FLOAT4)6);
483#endif
484
485    int out_offset = (((b_idx*c_blocks + c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w2_idx)*4;
486
487    const int remain     = out_hw.y - out_w2_idx;
488    if (remain >= 2) {
489        vstore4(outValue0, 0, output+out_offset);
490        vstore4(outValue1, 1, output+out_offset);
491    } else if (remain == 1) {
492        vstore4(outValue0, 0, output+out_offset);
493    }
494
495    if(c_idx + 1 >= c_blocks) return;
496
497    out_offset += out_hw.x * out_hw.y * 4;
498    if (remain >= 2) {
499        vstore4(outValue4, 0, output+out_offset);
500        vstore4(outValue5, 1, output+out_offset);
501    } else if (remain == 1) {
502        vstore4(outValue4, 0, output+out_offset);
503    }
504}
505
506__kernel
507void depthwise_conv2d_s1_c4h1w4(GLOBAL_SIZE_2_DIMS __global const FLOAT *input,
508                                  __global const FLOAT *filter,
509                                  __global const FLOAT *bias,
510                                  __global FLOAT *output,
511                                  __private const int2 in_hw,
512                                  __private const int channel,
513                                  __private const int2 out_hw,
514                                  __private const int2 filter_hw,
515                                  __private const int2 pad_hw,
516                                  __private const int2 dilate_hw,
517                                  __private const int2 stride_hw,
518                                  __private const int out_w_blocks,
519                                  __private const int c_blocks) {
520
521    const int out_c_w_idx = get_global_id(0);// oc/4 * ow/4
522    const int out_b_h_idx  = get_global_id(1);// b * h
523    DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
524
525    const int c_idx = out_c_w_idx / out_w_blocks;
526    const int out_w_idx = out_c_w_idx % out_w_blocks;
527    const int b_idx = out_b_h_idx / out_hw.x;
528    const int out_h_idx = out_b_h_idx % out_hw.x;
529
530    FLOAT4 outValue0 = vload4(c_idx, bias);
531    FLOAT4 outValue1 = outValue0;
532    FLOAT4 outValue2 = outValue0;
533    FLOAT4 outValue3 = outValue0;
534
535    const int out_w4_idx = out_w_idx << 2;
536    const int in_w_start_0 = out_w4_idx - pad_hw.y;
537    const int in_w_start_1 = in_w_start_0 + 1;
538    const int in_w_start_2 = in_w_start_0 + 2;
539    const int in_w_start_3 = in_w_start_0 + 3;
540
541    const int in_h_start = out_h_idx - pad_hw.x;
542
543    FLOAT4 inValue0, inValue1, inValue2, inValue3;
544    for (int kh = 0; kh < filter_hw.x; kh++) {
545        const int in_h_cur = in_h_start + kh;
546        if(in_h_cur < 0 || in_h_cur >= in_hw.x) continue;
547
548        int inp_offset = (((b_idx*c_blocks + c_idx)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4;
549        for (int kw = 0; kw < filter_hw.y; kw++) {
550            const int filter_idx = mad24(kh, filter_hw.y, kw);
551            inValue0 = (in_w_start_0+kw < 0 || in_w_start_0+kw >= in_hw.y) ? (FLOAT4)0 : vload4(kw+0, input+inp_offset);
552            inValue1 = (in_w_start_1+kw < 0 || in_w_start_1+kw >= in_hw.y) ? (FLOAT4)0 : vload4(kw+1, input+inp_offset);
553            inValue2 = (in_w_start_2+kw < 0 || in_w_start_2+kw >= in_hw.y) ? (FLOAT4)0 : vload4(kw+2, input+inp_offset);
554            inValue3 = (in_w_start_3+kw < 0 || in_w_start_3+kw >= in_hw.y) ? (FLOAT4)0 : vload4(kw+3, input+inp_offset);
555
556            //NC4HW4 [1, filterShape.x*filterShape.y, 1, channelBlocks] x oc4
557            //index: [0, filterIdx,                   0, inChannelBlockIdx]
558            FLOAT4 weights = vload4(0, filter+(filter_idx*c_blocks+c_idx)*4);
559
560            outValue0 = mad(inValue0, weights, outValue0);
561            outValue1 = mad(inValue1, weights, outValue1);
562            outValue2 = mad(inValue2, weights, outValue2);
563            outValue3 = mad(inValue3, weights, outValue3);
564        }
565    }
566
567#ifdef RELU
568    outValue0 = fmax(outValue0, (FLOAT4)0);
569    outValue1 = fmax(outValue1, (FLOAT4)0);
570    outValue2 = fmax(outValue2, (FLOAT4)0);
571    outValue3 = fmax(outValue3, (FLOAT4)0);
572#endif
573
574#ifdef RELU6
575    outValue0 = clamp(outValue0, (FLOAT4)0, (FLOAT4)6);
576    outValue1 = clamp(outValue1, (FLOAT4)0, (FLOAT4)6);
577    outValue2 = clamp(outValue2, (FLOAT4)0, (FLOAT4)6);
578    outValue3 = clamp(outValue3, (FLOAT4)0, (FLOAT4)6);
579#endif
580
581    const int out_offset = (((b_idx*c_blocks + c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w4_idx)*4;
582
583    const int remain     = out_hw.y - out_w4_idx;
584    if (remain >= 4) {
585        vstore4(outValue0, 0, output+out_offset);
586        vstore4(outValue1, 1, output+out_offset);
587        vstore4(outValue2, 2, output+out_offset);
588        vstore4(outValue3, 3, output+out_offset);
589    } else if (remain == 3) {
590        vstore4(outValue0, 0, output+out_offset);
591        vstore4(outValue1, 1, output+out_offset);
592        vstore4(outValue2, 2, output+out_offset);
593    } else if (remain == 2) {
594        vstore4(outValue0, 0, output+out_offset);
595        vstore4(outValue1, 1, output+out_offset);
596    } else if (remain == 1) {
597        vstore4(outValue0, 0, output+out_offset);
598    }
599}
600
601
602__kernel
603void depthwise_conv2d_k3s1p1_c4h1w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *input,
604                                  __global const FLOAT *filter,
605                                  __global const FLOAT *bias,
606                                  __global FLOAT *output,
607                                  __private const int2 in_hw,
608                                  __private const int channel,
609                                  __private const int2 out_hw,
610                                  __private const int2 filter_hw,
611                                  __private const int2 pad_hw,
612                                  __private const int2 dilate_hw,
613                                  __private const int2 stride_hw,
614                                  __private const int out_w_blocks,
615                                  __private const int c_blocks) {
616
617    const int out_c_w_idx = get_global_id(0);// oc/4 * ow/2
618    const int out_b_h_idx  = get_global_id(1);// b * h
619    DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
620
621    const int c_idx = out_c_w_idx / out_w_blocks;
622    const int out_w_idx = out_c_w_idx % out_w_blocks;
623    const int b_idx = out_b_h_idx / out_hw.x;
624    const int out_h_idx = out_b_h_idx % out_hw.x;
625
626    FLOAT4 outValue0 = vload4(c_idx, bias);
627    FLOAT4 outValue1 = outValue0;
628
629    const int out_w2_idx = out_w_idx << 1;
630    const int in_w_start_0 = out_w2_idx - pad_hw.y;
631
632    const int in_h_start = out_h_idx - pad_hw.x;
633    FLOAT4 inValue0, inValue1, inValue2, inValue3;
634    //first line
635    const int inp_offset = (((b_idx*c_blocks + c_idx)*in_hw.x + in_h_start)* in_hw.y + in_w_start_0)*4;
636    inValue0 = (in_h_start < 0 || in_w_start_0 < 0          ) ? (FLOAT4)0 : vload4(0, input+inp_offset);
637    inValue1 = (in_h_start < 0 || in_w_start_0+1 >=  in_hw.y) ? (FLOAT4)0 : vload4(1, input+inp_offset);
638    inValue2 = (in_h_start < 0 || in_w_start_0+2 >=  in_hw.y) ? (FLOAT4)0 : vload4(2, input+inp_offset);
639    inValue3 = (in_h_start < 0 || in_w_start_0+3 >=  in_hw.y) ? (FLOAT4)0 : vload4(3, input+inp_offset);
640
641    int filter_idx = mad24(0, filter_hw.y, 0);
642    FLOAT4 weights = vload4(0, filter+(filter_idx*c_blocks+c_idx)*4);
643    outValue0 = mad(inValue0, weights, outValue0);
644    outValue1 = mad(inValue1, weights, outValue1);
645
646    filter_idx = mad24(0, filter_hw.y, 1);
647    weights = vload4(0, filter+(filter_idx*c_blocks+c_idx)*4);
648    outValue0 = mad(inValue1, weights, outValue0);
649    outValue1 = mad(inValue2, weights, outValue1);
650
651    filter_idx = mad24(0, filter_hw.y, 2);
652    weights = vload4(0, filter+(filter_idx*c_blocks+c_idx)*4);
653    outValue0 = mad(inValue2, weights, outValue0);
654    outValue1 = mad(inValue3, weights, outValue1);
655
656    //second line
657    inValue0 = (in_h_start+1 >= in_hw.x || in_w_start_0 < 0          ) ? (FLOAT4)0 : vload4(in_hw.y+0, input+inp_offset);
658    inValue1 = (in_h_start+1 >= in_hw.x || in_w_start_0+1 >=  in_hw.y) ? (FLOAT4)0 : vload4(in_hw.y+1, input+inp_offset);
659    inValue2 = (in_h_start+1 >= in_hw.x || in_w_start_0+2 >=  in_hw.y) ? (FLOAT4)0 : vload4(in_hw.y+2, input+inp_offset);
660    inValue3 = (in_h_start+1 >= in_hw.x || in_w_start_0+3 >=  in_hw.y) ? (FLOAT4)0 : vload4(in_hw.y+3, input+inp_offset);
661
662    filter_idx = mad24(1, filter_hw.y, 0);
663    weights = vload4(0, filter+(filter_idx*c_blocks+c_idx)*4);
664    outValue0 = mad(inValue0, weights, outValue0);
665    outValue1 = mad(inValue1, weights, outValue1);
666
667    filter_idx = mad24(1, filter_hw.y, 1);
668    weights = vload4(0, filter+(filter_idx*c_blocks+c_idx)*4);
669    outValue0 = mad(inValue1, weights, outValue0);
670    outValue1 = mad(inValue2, weights, outValue1);
671
672    filter_idx = mad24(1, filter_hw.y, 2);
673    weights = vload4(0, filter+(filter_idx*c_blocks+c_idx)*4);
674    outValue0 = mad(inValue2, weights, outValue0);
675    outValue1 = mad(inValue3, weights, outValue1);
676
677    //third line
678    inValue0 = (in_h_start+2 >= in_hw.x || in_w_start_0 < 0          ) ? (FLOAT4)0 : vload4(2*in_hw.y+0, input+inp_offset);
679    inValue1 = (in_h_start+2 >= in_hw.x || in_w_start_0+1 >=  in_hw.y) ? (FLOAT4)0 : vload4(2*in_hw.y+1, input+inp_offset);
680    inValue2 = (in_h_start+2 >= in_hw.x || in_w_start_0+2 >=  in_hw.y) ? (FLOAT4)0 : vload4(2*in_hw.y+2, input+inp_offset);
681    inValue3 = (in_h_start+2 >= in_hw.x || in_w_start_0+3 >=  in_hw.y) ? (FLOAT4)0 : vload4(2*in_hw.y+3, input+inp_offset);
682
683    filter_idx = mad24(2, filter_hw.y, 0);
684    weights = vload4(0, filter+(filter_idx*c_blocks+c_idx)*4);
685    outValue0 = mad(inValue0, weights, outValue0);
686    outValue1 = mad(inValue1, weights, outValue1);
687
688    filter_idx = mad24(2, filter_hw.y, 1);
689    weights = vload4(0, filter+(filter_idx*c_blocks+c_idx)*4);
690    outValue0 = mad(inValue1, weights, outValue0);
691    outValue1 = mad(inValue2, weights, outValue1);
692
693    filter_idx = mad24(2, filter_hw.y, 2);
694    weights = vload4(0, filter+(filter_idx*c_blocks+c_idx)*4);
695    outValue0 = mad(inValue2, weights, outValue0);
696    outValue1 = mad(inValue3, weights, outValue1);
697
698#ifdef RELU
699    outValue0 = fmax(outValue0, (FLOAT4)0);
700    outValue1 = fmax(outValue1, (FLOAT4)0);
701#endif
702
703#ifdef RELU6
704    outValue0 = clamp(outValue0, (FLOAT4)0, (FLOAT4)6);
705    outValue1 = clamp(outValue1, (FLOAT4)0, (FLOAT4)6);
706#endif
707
708    const int out_offset = (((b_idx*c_blocks + c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w2_idx)*4;
709
710    const int remain     = out_hw.y - out_w2_idx;
711    if (remain >= 2) {
712        vstore4(outValue0, 0, output+out_offset);
713        vstore4(outValue1, 1, output+out_offset);
714    } else if (remain == 1) {
715        vstore4(outValue0, 0, output+out_offset);
716    }
717}
718
719
720__kernel
721void depthwise_conv2d_k3s1p1_c4h2w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *input,
722                                  __global const FLOAT *filter,
723                                  __global const FLOAT *bias,
724                                  __global FLOAT *output,
725                                  __private const int2 in_hw,
726                                  __private const int channel,
727                                  __private const int2 out_hw,
728                                  __private const int2 filter_hw,
729                                  __private const int2 pad_hw,
730                                  __private const int2 dilate_hw,
731                                  __private const int2 stride_hw,
732                                  __private const int out_w_blocks,
733                                  __private const int c_blocks) {
734
735    const int out_c_w_idx = get_global_id(0);// oc/4 * ow/2
736    const int out_b_h_idx  = get_global_id(1);// b * h/2
737    DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
738
739    const int out_h_blocks = (out_hw.x + 1) / 2;
740    const int c_idx = out_c_w_idx / out_w_blocks;
741    const int out_w_idx = out_c_w_idx % out_w_blocks;
742    const int b_idx = out_b_h_idx / out_h_blocks;
743    const int out_h_idx = out_b_h_idx % out_h_blocks;
744
745    FLOAT4 outValue0 = vload4(c_idx, bias);
746    FLOAT4 outValue1 = outValue0;
747    FLOAT4 outValue2 = outValue0;
748    FLOAT4 outValue3 = outValue0;
749
750    const int out_w2_idx = out_w_idx << 1;
751    const int in_w_start = out_w2_idx - pad_hw.y;
752
753    const int out_h2_idx = out_h_idx << 1;
754    const int in_h_start = out_h2_idx - pad_hw.x;
755    FLOAT4 inValue0, inValue1, inValue2, inValue3;
756    //first line
757    const int inp_offset = (((b_idx*c_blocks + c_idx)*in_hw.x + in_h_start)* in_hw.y + in_w_start)*4;
758    inValue0 = (in_h_start < 0 || in_w_start < 0          ) ? (FLOAT4)0 : vload4(0, input+inp_offset);
759    inValue1 = (in_h_start < 0 || in_w_start+1 >=  in_hw.y) ? (FLOAT4)0 : vload4(1, input+inp_offset);
760    inValue2 = (in_h_start < 0 || in_w_start+2 >=  in_hw.y) ? (FLOAT4)0 : vload4(2, input+inp_offset);
761    inValue3 = (in_h_start < 0 || in_w_start+3 >=  in_hw.y) ? (FLOAT4)0 : vload4(3, input+inp_offset);
762
763    int filter_idx = mad24(0, filter_hw.y, 0);
764    FLOAT4 weights0 = vload4(0, filter+(filter_idx*c_blocks+c_idx)*4);
765    outValue0 = mad(inValue0, weights0, outValue0);
766    outValue1 = mad(inValue1, weights0, outValue1);
767
768    filter_idx = mad24(0, filter_hw.y, 1);
769    FLOAT4 weights1 = vload4(0, filter+(filter_idx*c_blocks+c_idx)*4);
770    outValue0 = mad(inValue1, weights1, outValue0);
771    outValue1 = mad(inValue2, weights1, outValue1);
772
773
774    filter_idx = mad24(0, filter_hw.y, 2);
775    FLOAT4 weights2 = vload4(0, filter+(filter_idx*c_blocks+c_idx)*4);
776    outValue0 = mad(inValue2, weights2, outValue0);
777    outValue1 = mad(inValue3, weights2, outValue1);
778
779    //second line
780    inValue0 = (in_h_start+1 >= in_hw.x || in_w_start < 0          ) ? (FLOAT4)0 : vload4(in_hw.y+0, input+inp_offset);
781    inValue1 = (in_h_start+1 >= in_hw.x || in_w_start+1 >=  in_hw.y) ? (FLOAT4)0 : vload4(in_hw.y+1, input+inp_offset);
782    inValue2 = (in_h_start+1 >= in_hw.x || in_w_start+2 >=  in_hw.y) ? (FLOAT4)0 : vload4(in_hw.y+2, input+inp_offset);
783    inValue3 = (in_h_start+1 >= in_hw.x || in_w_start+3 >=  in_hw.y) ? (FLOAT4)0 : vload4(in_hw.y+3, input+inp_offset);
784
785    DW_CONV_NEXT_LINE_CAL(outValue2, outValue3)
786
787    filter_idx = mad24(1, filter_hw.y, 0);
788    weights0 = vload4(0, filter+(filter_idx*c_blocks+c_idx)*4);
789    outValue0 = mad(inValue0, weights0, outValue0);
790    outValue1 = mad(inValue1, weights0, outValue1);
791
792    filter_idx = mad24(1, filter_hw.y, 1);
793    weights1 = vload4(0, filter+(filter_idx*c_blocks+c_idx)*4);
794    outValue0 = mad(inValue1, weights1, outValue0);
795    outValue1 = mad(inValue2, weights1, outValue1);
796
797    filter_idx = mad24(1, filter_hw.y, 2);
798    weights2 = vload4(0, filter+(filter_idx*c_blocks+c_idx)*4);
799    outValue0 = mad(inValue2, weights2, outValue0);
800    outValue1 = mad(inValue3, weights2, outValue1);
801
802    //third line
803    inValue0 = (in_h_start+2 >= in_hw.x || in_w_start < 0          ) ? (FLOAT4)0 : vload4(2*in_hw.y+0, input+inp_offset);
804    inValue1 = (in_h_start+2 >= in_hw.x || in_w_start+1 >=  in_hw.y) ? (FLOAT4)0 : vload4(2*in_hw.y+1, input+inp_offset);
805    inValue2 = (in_h_start+2 >= in_hw.x || in_w_start+2 >=  in_hw.y) ? (FLOAT4)0 : vload4(2*in_hw.y+2, input+inp_offset);
806    inValue3 = (in_h_start+2 >= in_hw.x || in_w_start+3 >=  in_hw.y) ? (FLOAT4)0 : vload4(2*in_hw.y+3, input+inp_offset);
807
808    DW_CONV_NEXT_LINE_CAL(outValue2, outValue3)
809
810    filter_idx = mad24(2, filter_hw.y, 0);
811    weights0 = vload4(0, filter+(filter_idx*c_blocks+c_idx)*4);
812    outValue0 = mad(inValue0, weights0, outValue0);
813    outValue1 = mad(inValue1, weights0, outValue1);
814
815    filter_idx = mad24(2, filter_hw.y, 1);
816    weights1 = vload4(0, filter+(filter_idx*c_blocks+c_idx)*4);
817    outValue0 = mad(inValue1, weights1, outValue0);
818    outValue1 = mad(inValue2, weights1, outValue1);
819
820    filter_idx = mad24(2, filter_hw.y, 2);
821    weights2 = vload4(0, filter+(filter_idx*c_blocks+c_idx)*4);
822    outValue0 = mad(inValue2, weights2, outValue0);
823    outValue1 = mad(inValue3, weights2, outValue1);
824
825
826    //forth line
827    inValue0 = (in_h_start+3 >= in_hw.x || in_w_start < 0          ) ? (FLOAT4)0 : vload4(3*in_hw.y+0, input+inp_offset);
828    inValue1 = (in_h_start+3 >= in_hw.x || in_w_start+1 >=  in_hw.y) ? (FLOAT4)0 : vload4(3*in_hw.y+1, input+inp_offset);
829    inValue2 = (in_h_start+3 >= in_hw.x || in_w_start+2 >=  in_hw.y) ? (FLOAT4)0 : vload4(3*in_hw.y+2, input+inp_offset);
830    inValue3 = (in_h_start+3 >= in_hw.x || in_w_start+3 >=  in_hw.y) ? (FLOAT4)0 : vload4(3*in_hw.y+3, input+inp_offset);
831
832    DW_CONV_NEXT_LINE_CAL(outValue2, outValue3)
833
834#ifdef RELU
835    outValue0 = fmax(outValue0, (FLOAT4)0);
836    outValue1 = fmax(outValue1, (FLOAT4)0);
837    outValue2 = fmax(outValue2, (FLOAT4)0);
838    outValue3 = fmax(outValue3, (FLOAT4)0);
839#endif
840
841#ifdef RELU6
842    outValue0 = clamp(outValue0, (FLOAT4)0, (FLOAT4)6);
843    outValue1 = clamp(outValue1, (FLOAT4)0, (FLOAT4)6);
844    outValue2 = clamp(outValue2, (FLOAT4)0, (FLOAT4)6);
845    outValue3 = clamp(outValue3, (FLOAT4)0, (FLOAT4)6);
846#endif
847
848    const int out_offset = (((b_idx*c_blocks + c_idx)*out_hw.x + out_h2_idx)*out_hw.y + out_w2_idx)*4;
849
850    const int remain_w     = out_hw.y - out_w2_idx;
851    const int remain_h     = out_hw.x - out_h2_idx;
852
853    if(remain_w >= 2 && remain_h >= 2) {
854        vstore4(outValue0, 0, output+out_offset);
855        vstore4(outValue1, 1, output+out_offset);
856        vstore4(outValue2, out_hw.y+0, output+out_offset);
857        vstore4(outValue3, out_hw.y+1, output+out_offset);
858    } else if(remain_w == 1  && remain_h >= 2) {
859        vstore4(outValue0, 0, output+out_offset);
860        vstore4(outValue2, out_hw.y+0, output+out_offset);
861    } else if(remain_w >= 2  && remain_h == 1) {
862        vstore4(outValue0, 0, output+out_offset);
863        vstore4(outValue1, 1, output+out_offset);
864    } else {
865        vstore4(outValue0, 0, output+out_offset);
866    }
867}
868