1#ifdef MNN_SUPPORT_FP16
2#pragma OPENCL EXTENSION cl_khr_fp16 : enable
3#endif
4#define READ_INPUT_IMAGE(i, base)                                                                         \
5    int in_width_value##i = in_width##i + base;                                                           \
6    in_width_value##i =                                                                                   \
7        select(in_idx + in_width_value##i, -1, (in_width_value##i < 0 || in_width_value##i >= input_shape.y)); \
8    in_w_idx = in_width_value##i % input_shape.y; \
9    inp_offset = (((in_b_idx*in_channel_block_length + in_channel_block_idx)*input_shape.x + in_h_idx)* input_shape.y + in_w_idx)*4; \
10    in##i = (in_width_value##i)==-1 ? (FLOAT4)0 : vload4(0, input+inp_offset);
11
12#define CALCULATE_OUTPUT(i)                  \
13    out##i = mad(in##i.x, weights0, out##i); \
14    out##i = mad(in##i.y, weights1, out##i); \
15    out##i = mad(in##i.z, weights2, out##i); \
16    out##i = mad(in##i.w, weights3, out##i);
17
18#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0, __private const int global_size_dim1,
19
20#define DEAL_NON_UNIFORM_DIM2(input1, input2)                       \
21    if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { \
22        return;                                                     \
23    }
24
25__kernel
26void conv_2d_c4h1w1(GLOBAL_SIZE_2_DIMS
27                      __global const FLOAT *input,
28                      __global const FLOAT *weight,
29                      __global const FLOAT *bias,
30                      __global FLOAT *output,
31                      __private const int2 in_hw,
32                      __private const int inChannel,
33                      __private const int in_c_blocks,
34                      __private const int2 out_hw,
35                      __private const int2 filter_hw,
36                      __private const int2 stride_hw,
37                      __private const int2 pad_hw,
38                      __private const int2 dilate_hw,
39                      __private const int out_w_blocks,
40                      __private const int out_c_blocks) {
41    const int out_c_w_idx = get_global_id(0); //c/4 w
42    const int out_b_h_idx  = get_global_id(1); //b h
43
44    DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
45
46    const int out_c_idx = out_c_w_idx / out_hw.y;
47    const int out_w_idx = out_c_w_idx % out_hw.y;
48    const int out_b_idx = out_b_h_idx / out_hw.x;//equal to in_b_idx
49    const int out_h_idx = out_b_h_idx % out_hw.x;
50
51    FLOAT4 out0 = vload4(out_c_idx, bias);
52
53    const int in_w_idx_base = mad24(out_w_idx, stride_hw.y, -pad_hw.y);
54    const int in_h_idx_base = mad24(out_h_idx, stride_hw.x, -pad_hw.x);
55
56    const int kw_start = select(0, (-in_w_idx_base + dilate_hw.y - 1) / dilate_hw.y, in_w_idx_base < 0);
57    const int kh_start = select(0, (-in_h_idx_base + dilate_hw.x - 1) / dilate_hw.x, in_h_idx_base < 0);
58
59    const int in_w_idx_start = mad24(kw_start, dilate_hw.y, in_w_idx_base);
60    const int in_w_idx_end = min(mad24(filter_hw.y, dilate_hw.y, in_w_idx_base), in_hw.y);
61
62    const int in_h_idx_start = mad24(kh_start, dilate_hw.x, in_h_idx_base);
63    const int in_h_idx_end = min(mad24(filter_hw.x, dilate_hw.x, in_h_idx_base), in_hw.x);
64
65    const int weight_oc_offset = out_c_blocks * filter_hw.x * filter_hw.y * 4;
66    for(ushort in_c_idx = 0; in_c_idx < (ushort)IN_C_BLOCK; in_c_idx++) {
67        //weights  NC4HW4  [1,  4*icC4,  ocC4*kh*kw,  1] xic4
68        //index:   [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0]
69        int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + kh_start)*filter_hw.y + kw_start) * 4;
70        for(int iy = in_h_idx_start; iy < in_h_idx_end; iy += dilate_hw.x) {
71            for(int ix = in_w_idx_start; ix < in_w_idx_end; ix += dilate_hw.y) {
72                int inp_offset = (((out_b_idx * in_c_blocks + in_c_idx) * in_hw.x + iy) * in_hw.y + ix) * 4;
73                FLOAT4 in0 = vload4(0, input+inp_offset);
74
75                const int filter_w_inc = (ix-in_w_idx_start)/dilate_hw.y;
76                FLOAT4 weight0 = vload4(filter_w_inc, weight+weight_offset);
77                FLOAT4 weight1 = vload4(filter_w_inc, weight+weight_offset+weight_oc_offset);
78                FLOAT4 weight2 = vload4(filter_w_inc, weight+weight_offset+weight_oc_offset*2);
79                FLOAT4 weight3 = vload4(filter_w_inc, weight+weight_offset+weight_oc_offset*3);
80
81                out0 = mad(in0.x, weight0, out0);
82                out0 = mad(in0.y, weight1, out0);
83                out0 = mad(in0.z, weight2, out0);
84                out0 = mad(in0.w, weight3, out0);
85
86            }
87            weight_offset += 4*filter_hw.y;
88        }
89    }
90#ifdef RELU
91    out0 = fmax(out0, (FLOAT4)0);
92#endif
93
94#ifdef RELU6
95    out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6);
96#endif
97
98    const int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
99    vstore4(out0, 0, output+out_offset);
100
101}
102
103
104__kernel
105void conv_2d_c8h1w1(GLOBAL_SIZE_2_DIMS
106                      __global const FLOAT *input,
107                      __global const FLOAT *weight,
108                      __global const FLOAT *bias,
109                      __global FLOAT *output,
110                      __private const int2 in_hw,
111                      __private const int inChannel,
112                      __private const int in_c_blocks,
113                      __private const int2 out_hw,
114                      __private const int2 filter_hw,
115                      __private const int2 stride_hw,
116                      __private const int2 pad_hw,
117                      __private const int2 dilate_hw,
118                      __private const int out_w_blocks,
119                      __private const int out_c_blocks) {
120    const int out_c_w_idx = get_global_id(0); //c/4 w
121    const int out_b_h_idx  = get_global_id(1); //b h
122
123    DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
124
125    const int out_c_idx = (out_c_w_idx / out_hw.y) << 1;
126    const int out_w_idx = out_c_w_idx % out_hw.y;
127    const int out_b_idx = out_b_h_idx / out_hw.x;//equal to in_b_idx
128    const int out_h_idx = out_b_h_idx % out_hw.x;
129
130    FLOAT4 out0 = vload4(out_c_idx, bias);
131    FLOAT4 out1 = vload4(out_c_idx+1, bias);
132
133    const int in_w_idx_base = mad24(out_w_idx, stride_hw.y, -pad_hw.y);
134    const int in_h_idx_base = mad24(out_h_idx, stride_hw.x, -pad_hw.x);
135
136    const int kw_start = select(0, (-in_w_idx_base + dilate_hw.y - 1) / dilate_hw.y, in_w_idx_base < 0);
137    const int kh_start = select(0, (-in_h_idx_base + dilate_hw.x - 1) / dilate_hw.x, in_h_idx_base < 0);
138
139    const int in_w_idx_start = mad24(kw_start, dilate_hw.y, in_w_idx_base);
140    const int in_w_idx_end = min(mad24(filter_hw.y, dilate_hw.y, in_w_idx_base), in_hw.y);
141
142    const int in_h_idx_start = mad24(kh_start, dilate_hw.x, in_h_idx_base);
143    const int in_h_idx_end = min(mad24(filter_hw.x, dilate_hw.x, in_h_idx_base), in_hw.x);
144
145    const int weight_oc_offset = filter_hw.x * filter_hw.y * 4;
146    const int weight_ic_offset = out_c_blocks * weight_oc_offset;
147    for(ushort in_c_idx = 0; in_c_idx < (ushort)IN_C_BLOCK; in_c_idx++) {
148        //weights  NC4HW4  [1,  4*icC4,  ocC4*kh*kw,  1] xic4
149        //index:   [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0]
150        int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + kh_start)*filter_hw.y + kw_start) * 4;
151        for(int iy = in_h_idx_start; iy < in_h_idx_end; iy += dilate_hw.x) {
152            for(int ix = in_w_idx_start; ix < in_w_idx_end; ix += dilate_hw.y) {
153                int inp_offset = (((out_b_idx * in_c_blocks + in_c_idx) * in_hw.x + iy) * in_hw.y + ix) * 4;
154                FLOAT4 in0 = vload4(0, input+inp_offset);
155
156                const int filter_w_inc = (ix-in_w_idx_start)/dilate_hw.y;
157                FLOAT4 weight0 = vload4(filter_w_inc, weight+weight_offset);
158                FLOAT4 weight1 = vload4(filter_w_inc, weight+weight_offset+weight_ic_offset);
159                FLOAT4 weight2 = vload4(filter_w_inc, weight+weight_offset+weight_ic_offset*2);
160                FLOAT4 weight3 = vload4(filter_w_inc, weight+weight_offset+weight_ic_offset*3);
161
162                out0 = mad(in0.x, weight0, out0);
163                out0 = mad(in0.y, weight1, out0);
164                out0 = mad(in0.z, weight2, out0);
165                out0 = mad(in0.w, weight3, out0);
166
167                weight0 = vload4(filter_w_inc, weight+weight_offset+weight_oc_offset);
168                weight1 = vload4(filter_w_inc, weight+weight_offset+weight_oc_offset+weight_ic_offset);
169                weight2 = vload4(filter_w_inc, weight+weight_offset+weight_oc_offset+weight_ic_offset*2);
170                weight3 = vload4(filter_w_inc, weight+weight_offset+weight_oc_offset+weight_ic_offset*3);
171
172                out1 = mad(in0.x, weight0, out1);
173                out1 = mad(in0.y, weight1, out1);
174                out1 = mad(in0.z, weight2, out1);
175                out1 = mad(in0.w, weight3, out1);
176            }
177            weight_offset += 4*filter_hw.y;
178        }
179    }
180#ifdef RELU
181    out0 = fmax(out0, (FLOAT4)0);
182    out1 = fmax(out1, (FLOAT4)0);
183#endif
184
185#ifdef RELU6
186    out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6);
187    out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6);
188#endif
189
190    const int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
191    vstore4(out0, 0, output+out_offset);
192    if(out_c_idx+1 >= out_c_blocks) return;
193    vstore4(out1, 0, output+out_offset+out_hw.x*out_hw.y*4);
194
195}
196
197__kernel
198void conv_2d_c4h1w2(GLOBAL_SIZE_2_DIMS
199                      __global const FLOAT *input,
200                      __global const FLOAT *weight,
201                      __global const FLOAT *bias,
202                      __global FLOAT *output,
203                      __private const int2 in_hw,
204                      __private const int inChannel,
205                      __private const int in_c_blocks,
206                      __private const int2 out_hw,
207                      __private const int2 filter_hw,
208                      __private const int2 stride_hw,
209                      __private const int2 pad_hw,
210                      __private const int2 dilate_hw,
211                      __private const int out_w_blocks,//generate width's num
212                      __private const int out_c_blocks) {
213    const int out_c_w_idx = get_global_id(0); //c/4 w
214    const int out_b_h_idx  = get_global_id(1); //b h
215
216    DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
217
218    const int out_c_idx = out_c_w_idx / out_w_blocks;
219    const int out_w_idx = (out_c_w_idx % out_w_blocks) << 1;
220    const int out_b_idx = out_b_h_idx / out_hw.x;//equal to in_b_idx
221    const int out_h_idx = out_b_h_idx % out_hw.x;
222
223    FLOAT4 out0 = vload4(out_c_idx, bias);
224    FLOAT4 out1 = out0;
225
226    const int in_w0_idx_base = mad24(out_w_idx, stride_hw.y, -pad_hw.y);
227    const int in_w1_idx_base = in_w0_idx_base + stride_hw.y;
228
229    const int in_h_idx_base = mad24(out_h_idx, stride_hw.x, -pad_hw.x);
230
231    const int kh_start = select(0, (-in_h_idx_base + dilate_hw.x - 1) / dilate_hw.x, in_h_idx_base < 0);
232    const int in_h_idx_start = mad24(kh_start, dilate_hw.x, in_h_idx_base);
233    const int in_h_idx_end = min(mad24(filter_hw.x, dilate_hw.x, in_h_idx_base), in_hw.x);
234
235    const int weight_oc_offset = out_c_blocks * filter_hw.x * filter_hw.y * 4;
236    for(ushort in_c_idx = 0; in_c_idx < (ushort)IN_C_BLOCK; in_c_idx++) {
237        //weights  NC4HW4  [1,  4*icC4,  ocC4*kh*kw,  1] xic4
238        //index:   [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0]
239        int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + kh_start)*filter_hw.y + 0) * 4;
240
241        for(int iy = in_h_idx_start; iy < in_h_idx_end; iy += dilate_hw.x) {
242            const int inp_offset_base = (((out_b_idx * in_c_blocks + in_c_idx) * in_hw.x + iy) * in_hw.y + 0) * 4;
243
244            for(int fw = 0; fw < filter_hw.y; fw++) {
245                const int in_w0_idx = fw * dilate_hw.y + in_w0_idx_base;
246                const int in_w1_idx = fw * dilate_hw.y + in_w1_idx_base;
247
248                FLOAT4 in0 = (in_w0_idx < 0 || in_w0_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w0_idx, input+inp_offset_base);
249                FLOAT4 in1 = (in_w1_idx < 0 || in_w1_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w1_idx, input+inp_offset_base);
250
251                FLOAT4 weight0 = vload4(0, weight+weight_offset);
252                FLOAT4 weight1 = vload4(0, weight+weight_offset+weight_oc_offset);
253                FLOAT4 weight2 = vload4(0, weight+weight_offset+weight_oc_offset*2);
254                FLOAT4 weight3 = vload4(0, weight+weight_offset+weight_oc_offset*3);
255
256                out0 = mad(in0.x, weight0, out0);
257                out0 = mad(in0.y, weight1, out0);
258                out0 = mad(in0.z, weight2, out0);
259                out0 = mad(in0.w, weight3, out0);
260
261                out1 = mad(in1.x, weight0, out1);
262                out1 = mad(in1.y, weight1, out1);
263                out1 = mad(in1.z, weight2, out1);
264                out1 = mad(in1.w, weight3, out1);
265
266                weight_offset += 4;
267            }
268        }
269    }
270#ifdef RELU
271    out0 = fmax(out0, (FLOAT4)0);
272    out1 = fmax(out1, (FLOAT4)0);
273#endif
274
275#ifdef RELU6
276    out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6);
277    out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6);
278#endif
279
280    const int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
281    vstore4(out0, 0, output+out_offset);
282    if(out_w_idx + 1 >= out_hw.y) return;
283    vstore4(out1, 1, output+out_offset);
284}
285
286__kernel
287void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS
288                      __global const FLOAT *input,
289                      __global const FLOAT *weight,
290                      __global const FLOAT *bias,
291                      __global FLOAT *output,
292                      __private const int2 in_hw,
293                      __private const int inChannel,
294                      __private const int in_c_blocks,
295                      __private const int2 out_hw,
296                      __private const int2 filter_hw,
297                      __private const int2 stride_hw,
298                      __private const int2 pad_hw,
299                      __private const int2 dilate_hw,
300                      __private const int out_w_blocks,
301                      __private const int out_c_blocks) {
302    const int out_c_w_idx = get_global_id(0); //c/4 w
303    const int out_b_h_idx  = get_global_id(1); //b h
304
305    DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
306
307    const int out_c_idx = out_c_w_idx / out_w_blocks;
308    const int out_w_idx = (out_c_w_idx % out_w_blocks) << 2;
309    const int out_b_idx = out_b_h_idx / out_hw.x;//equal to in_b_idx
310    const int out_h_idx = out_b_h_idx % out_hw.x;
311
312    FLOAT4 out0 = vload4(out_c_idx, bias);
313    FLOAT4 out1 = out0;
314    FLOAT4 out2 = out0;
315    FLOAT4 out3 = out0;
316
317    const int in_w0_idx_base = mad24(out_w_idx, stride_hw.y, -pad_hw.y);
318    const int in_w1_idx_base = in_w0_idx_base + stride_hw.y;
319    const int in_w2_idx_base = in_w1_idx_base + stride_hw.y;
320    const int in_w3_idx_base = in_w2_idx_base + stride_hw.y;
321
322    const int in_h_idx_base = mad24(out_h_idx, stride_hw.x, -pad_hw.x);
323
324    const int kh_start = select(0, (-in_h_idx_base + dilate_hw.x - 1) / dilate_hw.x, in_h_idx_base < 0);
325    const int in_h_idx_start = mad24(kh_start, dilate_hw.x, in_h_idx_base);
326    const int in_h_idx_end = min(mad24(filter_hw.x, dilate_hw.x, in_h_idx_base), in_hw.x);
327
328    const int weight_oc_offset = out_c_blocks * filter_hw.x * filter_hw.y * 4;
329    for(ushort in_c_idx = 0; in_c_idx < (ushort)IN_C_BLOCK; in_c_idx++) {
330        //weights  NC4HW4  [1,  4*icC4,  ocC4*kh*kw,  1] xic4
331        //index:   [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0]
332        int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + kh_start)*filter_hw.y + 0) * 4;
333
334        for(int iy = in_h_idx_start; iy < in_h_idx_end; iy += dilate_hw.x) {
335            const int inp_offset_base = (((out_b_idx * in_c_blocks + in_c_idx) * in_hw.x + iy) * in_hw.y + 0) * 4;
336
337            for(int fw = 0; fw < filter_hw.y; fw++) {
338                const int in_w0_idx = fw * dilate_hw.y + in_w0_idx_base;
339                const int in_w1_idx = fw * dilate_hw.y + in_w1_idx_base;
340                const int in_w2_idx = fw * dilate_hw.y + in_w2_idx_base;
341                const int in_w3_idx = fw * dilate_hw.y + in_w3_idx_base;
342
343                FLOAT4 in0 = (in_w0_idx < 0 || in_w0_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w0_idx, input+inp_offset_base);
344                FLOAT4 in1 = (in_w1_idx < 0 || in_w1_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w1_idx, input+inp_offset_base);
345                FLOAT4 in2 = (in_w2_idx < 0 || in_w2_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w2_idx, input+inp_offset_base);
346                FLOAT4 in3 = (in_w3_idx < 0 || in_w3_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w3_idx, input+inp_offset_base);
347
348                FLOAT4 weight0 = vload4(0, weight+weight_offset);
349                FLOAT4 weight1 = vload4(0, weight+weight_offset+weight_oc_offset);
350                FLOAT4 weight2 = vload4(0, weight+weight_offset+weight_oc_offset*2);
351                FLOAT4 weight3 = vload4(0, weight+weight_offset+weight_oc_offset*3);
352
353                out0 = mad(in0.x, weight0, out0);
354                out0 = mad(in0.y, weight1, out0);
355                out0 = mad(in0.z, weight2, out0);
356                out0 = mad(in0.w, weight3, out0);
357
358                out1 = mad(in1.x, weight0, out1);
359                out1 = mad(in1.y, weight1, out1);
360                out1 = mad(in1.z, weight2, out1);
361                out1 = mad(in1.w, weight3, out1);
362
363                out2 = mad(in2.x, weight0, out2);
364                out2 = mad(in2.y, weight1, out2);
365                out2 = mad(in2.z, weight2, out2);
366                out2 = mad(in2.w, weight3, out2);
367
368                out3 = mad(in3.x, weight0, out3);
369                out3 = mad(in3.y, weight1, out3);
370                out3 = mad(in3.z, weight2, out3);
371                out3 = mad(in3.w, weight3, out3);
372
373                weight_offset += 4;
374            }
375        }
376    }
377#ifdef RELU
378    out0 = fmax(out0, (FLOAT4)0);
379    out1 = fmax(out1, (FLOAT4)0);
380    out2 = fmax(out2, (FLOAT4)0);
381    out3 = fmax(out3, (FLOAT4)0);
382#endif
383
384#ifdef RELU6
385    out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6);
386    out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6);
387    out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6);
388    out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6);
389#endif
390
391    const int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
392    vstore4(out0, 0, output+out_offset);
393    if(out_w_idx + 1 >= out_hw.y) return;
394    vstore4(out1, 1, output+out_offset);
395    if(out_w_idx + 2 >= out_hw.y) return;
396    vstore4(out2, 2, output+out_offset);
397    if(out_w_idx + 3 >= out_hw.y) return;
398    vstore4(out3, 3, output+out_offset);
399}
400
401__kernel
402void conv_2d_1x1_c4h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,
403                          __global const FLOAT *input,
404                          __global const FLOAT *kernel_ptr,
405                          __global const FLOAT *bias_ptr,
406                          __global FLOAT *output,
407                          __private const int in_c_block,
408                          __private const int out_h,
409                          __private const int out_w,
410                          __private const int out_c_block) {
411
412    const int out_c_w_idx = get_global_id(0); //c/4 w
413    const int out_b_h_idx  = get_global_id(1); //b h
414
415    DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
416
417    const int out_c_idx = out_c_w_idx / out_w_blocks;
418    const int out_w_idx = out_c_w_idx % out_w_blocks;
419    const int out_b_idx = out_b_h_idx / out_h;//equal to in_b_idx
420    const int out_h_idx = out_b_h_idx % out_h;//equal to in_h_idx
421
422    const int out_w4_idx = mul24(out_w_idx, 4);
423    FLOAT4 out0 = vload4(out_c_idx, (__global FLOAT *)bias_ptr);
424
425    FLOAT4 out1 = out0;
426    FLOAT4 out2 = out0;
427    FLOAT4 out3 = out0;
428
429    const int intput_width_idx0 = out_w4_idx;
430
431    int offset = mul24(out_c_idx, in_c_block) << 2;
432    int inp_offset =
433    (((out_b_idx*in_c_block)*out_h + out_h_idx)* out_w + intput_width_idx0) << 2;
434
435    const inp_add = out_h*out_w*4;
436    for (ushort in_channel_block_idx = 0; in_channel_block_idx < (ushort)IN_C_BLOCK; ++in_channel_block_idx) {
437
438        FLOAT4 in0 = vload4(0, input+inp_offset);
439        FLOAT4 in1 = vload4(1, input+inp_offset);;
440        FLOAT4 in2 = vload4(2, input+inp_offset);;
441        FLOAT4 in3 = vload4(3, input+inp_offset);;
442
443        FLOAT4 weights0 = vload4(offset, (__global FLOAT *)kernel_ptr);
444        FLOAT4 weights1 = vload4(offset + 1, (__global FLOAT *)kernel_ptr);
445        FLOAT4 weights2 = vload4(offset + 2, (__global FLOAT *)kernel_ptr);
446        FLOAT4 weights3 = vload4(offset + 3, (__global FLOAT *)kernel_ptr);
447
448        out0.x += dot(weights0, in0);
449        out0.y += dot(weights1, in0);
450        out0.z += dot(weights2, in0);
451        out0.w += dot(weights3, in0);
452
453        out1.x += dot(weights0, in1);
454        out1.y += dot(weights1, in1);
455        out1.z += dot(weights2, in1);
456        out1.w += dot(weights3, in1);
457
458        out2.x += dot(weights0, in2);
459        out2.y += dot(weights1, in2);
460        out2.z += dot(weights2, in2);
461        out2.w += dot(weights3, in2);
462
463        out3.x += dot(weights0, in3);
464        out3.y += dot(weights1, in3);
465        out3.z += dot(weights2, in3);
466        out3.w += dot(weights3, in3);
467
468        offset += 4;
469        inp_offset += inp_add;
470    }
471
472#ifdef RELU
473    out0 = fmax(out0, (FLOAT4)0);
474    out1 = fmax(out1, (FLOAT4)0);
475    out2 = fmax(out2, (FLOAT4)0);
476    out3 = fmax(out3, (FLOAT4)0);
477#endif
478
479#ifdef RELU6
480    out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6);
481    out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6);
482    out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6);
483    out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6);
484#endif
485
486    const int out_offset = (((out_b_idx*out_c_block + out_c_idx)*out_h + out_h_idx)* out_w + out_w4_idx)*4;
487
488    const int remain = out_w - out_w4_idx;
489
490    if (remain >= 4) {
491        vstore4(out0, 0, output+out_offset);
492        vstore4(out1, 1, output+out_offset);
493        vstore4(out2, 2, output+out_offset);
494        vstore4(out3, 3, output+out_offset);
495    } else if (remain == 3) {
496        vstore4(out0, 0, output+out_offset);
497        vstore4(out1, 1, output+out_offset);
498        vstore4(out2, 2, output+out_offset);
499    } else if (remain == 2) {
500        vstore4(out0, 0, output+out_offset);
501        vstore4(out1, 1, output+out_offset);
502    } else if (remain == 1) {
503        vstore4(out0, 0, output+out_offset);
504    }
505
506}
507
508
509
510__kernel
511void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,
512                          __global const FLOAT *input,
513                          __global const FLOAT *kernel_ptr,
514                          __global const FLOAT *bias_ptr,
515                          __global FLOAT *output,
516                          __private const int in_c_block,
517                          __private const int out_h,
518                          __private const int out_w,
519                          __private const int out_c_block) {
520
521    const int out_c_w_idx = get_global_id(0); //c/8 w/4
522    const int out_b_h_idx  = get_global_id(1); //b h
523
524    DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
525
526    const int out_c_idx = out_c_w_idx / out_w_blocks;
527    const int out_w_idx = out_c_w_idx % out_w_blocks;
528    const int out_b_idx = out_b_h_idx / out_h;//equal to in_b_idx
529    const int out_h_idx = out_b_h_idx % out_h;//equal to in_h_idx
530
531    const int out_w4_idx = mul24(out_w_idx, 4);
532    FLOAT4 out0 = vload4(out_c_idx<<1, (__global FLOAT *)bias_ptr);
533    FLOAT4 out1 = out0;
534    FLOAT4 out2 = out0;
535    FLOAT4 out3 = out0;
536
537    FLOAT4 out4 = vload4((out_c_idx<<1)+1, (__global FLOAT *)bias_ptr);
538    FLOAT4 out5 = out4;
539    FLOAT4 out6 = out4;
540    FLOAT4 out7 = out4;
541
542    const int intput_width_idx0 = out_w4_idx;
543
544    for (int in_channel_block_idx = 0; in_channel_block_idx < IN_C_BLOCK; ++in_channel_block_idx) {
545        int input_width_base  = mul24(in_channel_block_idx, out_w);
546
547        int offset = mad24(out_c_idx, in_c_block, in_channel_block_idx)*8;
548        const int inp_offset =
549        (((out_b_idx*in_c_block + in_channel_block_idx)*out_h + out_h_idx)* out_w + intput_width_idx0)*4;
550
551        FLOAT4 in0 = vload4(0, input+inp_offset);
552        FLOAT4 in1 = vload4(1, input+inp_offset);;
553        FLOAT4 in2 = vload4(2, input+inp_offset);;
554        FLOAT4 in3 = vload4(3, input+inp_offset);;
555
556        FLOAT4 weights0 = vload4(offset, (__global FLOAT *)kernel_ptr);
557        FLOAT4 weights1 = vload4(offset + 1, (__global FLOAT *)kernel_ptr);
558        FLOAT4 weights2 = vload4(offset + 2, (__global FLOAT *)kernel_ptr);
559        FLOAT4 weights3 = vload4(offset + 3, (__global FLOAT *)kernel_ptr);
560        FLOAT4 weights4 = vload4(offset + 4, (__global FLOAT *)kernel_ptr);
561        FLOAT4 weights5 = vload4(offset + 5, (__global FLOAT *)kernel_ptr);
562        FLOAT4 weights6 = vload4(offset + 6, (__global FLOAT *)kernel_ptr);
563        FLOAT4 weights7 = vload4(offset + 7, (__global FLOAT *)kernel_ptr);
564
565        out0.x += dot(weights0, in0);
566        out0.y += dot(weights1, in0);
567        out0.z += dot(weights2, in0);
568        out0.w += dot(weights3, in0);
569
570        out1.x += dot(weights0, in1);
571        out1.y += dot(weights1, in1);
572        out1.z += dot(weights2, in1);
573        out1.w += dot(weights3, in1);
574
575        out2.x += dot(weights0, in2);
576        out2.y += dot(weights1, in2);
577        out2.z += dot(weights2, in2);
578        out2.w += dot(weights3, in2);
579
580        out3.x += dot(weights0, in3);
581        out3.y += dot(weights1, in3);
582        out3.z += dot(weights2, in3);
583        out3.w += dot(weights3, in3);
584
585        out4.x += dot(weights4, in0);
586        out4.y += dot(weights5, in0);
587        out4.z += dot(weights6, in0);
588        out4.w += dot(weights7, in0);
589
590        out5.x += dot(weights4, in1);
591        out5.y += dot(weights5, in1);
592        out5.z += dot(weights6, in1);
593        out5.w += dot(weights7, in1);
594
595        out6.x += dot(weights4, in2);
596        out6.y += dot(weights5, in2);
597        out6.z += dot(weights6, in2);
598        out6.w += dot(weights7, in2);
599
600        out7.x += dot(weights4, in3);
601        out7.y += dot(weights5, in3);
602        out7.z += dot(weights6, in3);
603        out7.w += dot(weights7, in3);
604
605    }
606
607#ifdef RELU
608    out0 = fmax(out0, (FLOAT4)0);
609    out1 = fmax(out1, (FLOAT4)0);
610    out2 = fmax(out2, (FLOAT4)0);
611    out3 = fmax(out3, (FLOAT4)0);
612
613    out4 = fmax(out4, (FLOAT4)0);
614    out5 = fmax(out5, (FLOAT4)0);
615    out6 = fmax(out6, (FLOAT4)0);
616    out7 = fmax(out7, (FLOAT4)0);
617#endif
618
619#ifdef RELU6
620    out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6);
621    out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6);
622    out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6);
623    out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6);
624
625    out4 = clamp(out4, (FLOAT4)0, (FLOAT4)6);
626    out5 = clamp(out5, (FLOAT4)0, (FLOAT4)6);
627    out6 = clamp(out6, (FLOAT4)0, (FLOAT4)6);
628    out7 = clamp(out7, (FLOAT4)0, (FLOAT4)6);
629#endif
630
631    const int out_offset = (((out_b_idx*out_c_block + out_c_idx*2)*out_h + out_h_idx)* out_w + out_w4_idx)*4;
632
633    const int remain = out_w - out_w4_idx;
634
635    __global FLOAT* _tempoutput = output + out_offset;
636    __global FLOAT* _tempoutput1 = _tempoutput + 4*out_h*out_w;
637
638    if (remain >= 4) {
639        vstore4(out0, 0, _tempoutput);
640        vstore4(out1, 1, _tempoutput);
641        vstore4(out2, 2, _tempoutput);
642        vstore4(out3, 3, _tempoutput);
643    } else if (remain == 3) {
644        vstore4(out0, 0, _tempoutput);
645        vstore4(out1, 1, _tempoutput);
646        vstore4(out2, 2, _tempoutput);
647    } else if (remain == 2) {
648        vstore4(out0, 0, _tempoutput);
649        vstore4(out1, 1, _tempoutput);
650    } else if (remain == 1) {
651        vstore4(out0, 0, _tempoutput);
652    }
653    if(out_c_idx*2+1 >= out_c_block) {
654        return;
655    }
656    if (remain >= 4) {
657        vstore4(out4, 0, _tempoutput1);
658        vstore4(out5, 1, _tempoutput1);
659        vstore4(out6, 2, _tempoutput1);
660        vstore4(out7, 3, _tempoutput1);
661    } else if (remain == 3) {
662        vstore4(out4, 0, _tempoutput1);
663        vstore4(out5, 1, _tempoutput1);
664        vstore4(out6, 2, _tempoutput1);
665    } else if (remain == 2) {
666        vstore4(out4, 0, _tempoutput1);
667        vstore4(out5, 1, _tempoutput1);
668    } else if (remain == 1) {
669        vstore4(out4, 0, _tempoutput1);
670    }
671}
672
673
674__kernel
675void conv_2d_1x1_c8h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,
676                          __global const FLOAT *input,
677                          __global const FLOAT *kernel_ptr,
678                          __global const FLOAT *bias_ptr,
679                          __global FLOAT *output,
680                          __private const int in_c_block,
681                          __private const int out_h,
682                          __private const int out_w,
683                          __private const int out_c_block) {
684
685    const int out_c_w_idx = get_global_id(0); //c/8 w/4
686    const int out_b_h_idx  = get_global_id(1); //b h
687
688    DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
689
690    const int out_c_idx = out_c_w_idx / out_w_blocks;
691    const int out_w_idx = out_c_w_idx % out_w_blocks;
692    const int out_b_idx = out_b_h_idx / out_h;//equal to in_b_idx
693    const int out_h_idx = out_b_h_idx % out_h;//equal to in_h_idx
694
695    const int out_w2_idx = mul24(out_w_idx, 2);
696    FLOAT4 out0 = vload4(out_c_idx<<1, (__global FLOAT *)bias_ptr);
697    FLOAT4 out1 = out0;
698
699    FLOAT4 out4 = vload4((out_c_idx<<1)+1, (__global FLOAT *)bias_ptr);
700    FLOAT4 out5 = out4;
701
702    const int intput_width_idx0 = out_w2_idx;
703
704    for (int in_channel_block_idx = 0; in_channel_block_idx < IN_C_BLOCK; ++in_channel_block_idx) {
705        int input_width_base  = mul24(in_channel_block_idx, out_w);
706
707        int offset = mad24(out_c_idx, in_c_block, in_channel_block_idx)*8;
708        const int inp_offset =
709        (((out_b_idx*in_c_block + in_channel_block_idx)*out_h + out_h_idx)* out_w + intput_width_idx0)*4;
710
711        FLOAT4 in0 = vload4(0, input+inp_offset);
712        FLOAT4 in1 = vload4(1, input+inp_offset);;
713
714        FLOAT4 weights0 = vload4(offset, (__global FLOAT *)kernel_ptr);
715        FLOAT4 weights1 = vload4(offset + 1, (__global FLOAT *)kernel_ptr);
716        FLOAT4 weights2 = vload4(offset + 2, (__global FLOAT *)kernel_ptr);
717        FLOAT4 weights3 = vload4(offset + 3, (__global FLOAT *)kernel_ptr);
718        FLOAT4 weights4 = vload4(offset + 4, (__global FLOAT *)kernel_ptr);
719        FLOAT4 weights5 = vload4(offset + 5, (__global FLOAT *)kernel_ptr);
720        FLOAT4 weights6 = vload4(offset + 6, (__global FLOAT *)kernel_ptr);
721        FLOAT4 weights7 = vload4(offset + 7, (__global FLOAT *)kernel_ptr);
722
723        out0.x += dot(weights0, in0);
724        out0.y += dot(weights1, in0);
725        out0.z += dot(weights2, in0);
726        out0.w += dot(weights3, in0);
727
728        out1.x += dot(weights0, in1);
729        out1.y += dot(weights1, in1);
730        out1.z += dot(weights2, in1);
731        out1.w += dot(weights3, in1);
732
733        out4.x += dot(weights4, in0);
734        out4.y += dot(weights5, in0);
735        out4.z += dot(weights6, in0);
736        out4.w += dot(weights7, in0);
737
738        out5.x += dot(weights4, in1);
739        out5.y += dot(weights5, in1);
740        out5.z += dot(weights6, in1);
741        out5.w += dot(weights7, in1);
742    }
743
744#ifdef RELU
745    out0 = fmax(out0, (FLOAT4)0);
746    out1 = fmax(out1, (FLOAT4)0);
747
748    out4 = fmax(out4, (FLOAT4)0);
749    out5 = fmax(out5, (FLOAT4)0);
750#endif
751
752#ifdef RELU6
753    out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6);
754    out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6);
755
756    out4 = clamp(out4, (FLOAT4)0, (FLOAT4)6);
757    out5 = clamp(out5, (FLOAT4)0, (FLOAT4)6);
758#endif
759
760    const int out_offset = (((out_b_idx*out_c_block + out_c_idx*2)*out_h + out_h_idx)* out_w + out_w2_idx)*4;
761
762    const int remain = out_w - out_w2_idx;
763
764    __global FLOAT* _tempoutput = output + out_offset;
765    __global FLOAT* _tempoutput1 = _tempoutput + 4*out_h*out_w;
766
767    if (remain >= 2) {
768        vstore4(out0, 0, _tempoutput);
769        vstore4(out1, 1, _tempoutput);
770    } else if (remain == 1) {
771        vstore4(out0, 0, _tempoutput);
772    }
773    if(out_c_idx*2+1 >= out_c_block) {
774        return;
775    }
776    if (remain >= 2) {
777        vstore4(out4, 0, _tempoutput1);
778        vstore4(out5, 1, _tempoutput1);
779    } else if (remain == 1) {
780        vstore4(out4, 0, _tempoutput1);
781    }
782}
783
784__kernel
785void conv_2d_1x1_c4h1w1(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,
786                          __global const FLOAT *input,
787                          __global const FLOAT *kernel_ptr,
788                          __global const FLOAT *bias_ptr,
789                          __global FLOAT *output,
790                          __private const int in_c_block,
791                          __private const int out_h,
792                          __private const int out_w,
793                          __private const int out_c_block) {
794
795    const int out_c_w_idx = get_global_id(0); //c/4 w
796    const int out_b_h_idx  = get_global_id(1); //b h
797
798    DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
799
800    const int out_c_idx = out_c_w_idx / out_w;
801    const int out_w_idx = out_c_w_idx % out_w;
802    const int out_b_idx = out_b_h_idx / out_h;//equal to in_b_idx
803    const int out_h_idx = out_b_h_idx % out_h;//equal to in_h_idx
804
805    FLOAT4 out0 = vload4(out_c_idx, (__global FLOAT *)bias_ptr);
806    const int intput_width_idx0 = out_w_idx;
807
808    for (int in_channel_block_idx = 0; in_channel_block_idx < in_c_block; ++in_channel_block_idx) {
809        int input_width_base  = mul24(in_channel_block_idx, out_w);
810
811        int offset = mad24(out_c_idx, in_c_block, in_channel_block_idx)*4;
812        const int inp_offset =
813        (((out_b_idx*in_c_block + in_channel_block_idx)*out_h + out_h_idx)* out_w + intput_width_idx0)*4;
814
815        FLOAT4 in0 = vload4(0, input+inp_offset);
816
817        FLOAT4 weights0 = vload4(offset, (__global FLOAT *)kernel_ptr);
818        FLOAT4 weights1 = vload4(offset + 1, (__global FLOAT *)kernel_ptr);
819        FLOAT4 weights2 = vload4(offset + 2, (__global FLOAT *)kernel_ptr);
820        FLOAT4 weights3 = vload4(offset + 3, (__global FLOAT *)kernel_ptr);
821
822        out0.x += dot(weights0, in0);
823        out0.y += dot(weights1, in0);
824        out0.z += dot(weights2, in0);
825        out0.w += dot(weights3, in0);
826    }
827
828#ifdef RELU
829    out0 = fmax(out0, (FLOAT4)0);
830#endif
831
832#ifdef RELU6
833    out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6);
834#endif
835
836    const int out_offset = (((out_b_idx*out_c_block + out_c_idx)*out_h + out_h_idx)* out_w + out_w_idx)*4;
837
838    vstore4(out0, 0, output+out_offset);
839}
840
841
842__kernel
843void conv_2d_1x1_c4h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,
844                          __global const FLOAT *input,
845                          __global const FLOAT *kernel_ptr,
846                          __global const FLOAT *bias_ptr,
847                          __global FLOAT *output,
848                          __private const int in_c_block,
849                          __private const int out_h,
850                          __private const int out_w,
851                          __private const int out_c_block) {
852
853    const int out_c_w_idx = get_global_id(0); //c/4 w
854    const int out_b_h_idx  = get_global_id(1); //b h
855
856    DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
857
858    const int out_c_idx = out_c_w_idx / out_w_blocks;
859    const int out_w_idx = out_c_w_idx % out_w_blocks;
860    const int out_b_idx = out_b_h_idx / out_h;//equal to in_b_idx
861    const int out_h_idx = out_b_h_idx % out_h;//equal to in_h_idx
862
863    const int out_w2_idx = mul24(out_w_idx, 2);
864
865    FLOAT4 out0 = vload4(out_c_idx, (__global FLOAT *)bias_ptr);
866    FLOAT4 out1 = out0;
867
868    const int intput_width_idx0 = out_w2_idx;
869
870    for (int in_channel_block_idx = 0; in_channel_block_idx < in_c_block; ++in_channel_block_idx) {
871        int input_width_base  = mul24(in_channel_block_idx, out_w);
872
873        int offset = mad24(out_c_idx, in_c_block, in_channel_block_idx)*4;
874        const int inp_offset =
875        (((out_b_idx*in_c_block + in_channel_block_idx)*out_h + out_h_idx)* out_w + intput_width_idx0)*4;
876
877        FLOAT4 in0 = vload4(0, input+inp_offset);
878        FLOAT4 in1 = vload4(1, input+inp_offset);;
879
880        FLOAT4 weights0 = vload4(offset, (__global FLOAT *)kernel_ptr);
881        FLOAT4 weights1 = vload4(offset + 1, (__global FLOAT *)kernel_ptr);
882        FLOAT4 weights2 = vload4(offset + 2, (__global FLOAT *)kernel_ptr);
883        FLOAT4 weights3 = vload4(offset + 3, (__global FLOAT *)kernel_ptr);
884
885        out0.x += dot(weights0, in0);
886        out0.y += dot(weights1, in0);
887        out0.z += dot(weights2, in0);
888        out0.w += dot(weights3, in0);
889
890        out1.x += dot(weights0, in1);
891        out1.y += dot(weights1, in1);
892        out1.z += dot(weights2, in1);
893        out1.w += dot(weights3, in1);
894    }
895
896#ifdef RELU
897    out0 = fmax(out0, (FLOAT4)0);
898    out1 = fmax(out1, (FLOAT4)0);
899#endif
900
901#ifdef RELU6
902    out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6);
903    out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6);
904#endif
905
906    const int out_offset = (((out_b_idx*out_c_block + out_c_idx)*out_h + out_h_idx)* out_w + out_w2_idx)*4;
907
908    const int remain = out_w - out_w2_idx;
909
910    if (remain >= 2) {
911        vstore4(out0, 0, output+out_offset);
912        vstore4(out1, 1, output+out_offset);
913    } else if (remain == 1) {
914        vstore4(out0, 0, output+out_offset);
915    }
916
917}
918