1#ifdef MNN_SUPPORT_FP16
2#pragma OPENCL EXTENSION cl_khr_fp16 : enable
3#endif
4
5#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0, __private const int global_size_dim1,
6#define DEAL_NON_UNIFORM_DIM2(input1, input2)                       \
7    if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { \
8        return;                                                     \
9    }
10
11// convert data from buffer(nhwc) to buffer(nc4hw4)
12__kernel void nhwc_buffer_to_nc4hw4_buffer(GLOBAL_SIZE_2_DIMS
13                                   #ifdef BUFFER_FORMAT_INP_TRANS
14                                   __global const float *input_ptr,
15                                   #else
16                                   __global const FLOAT *input_ptr,
17                                   #endif
18                                   __private const int height,
19                                   __private const int width, __private const int channels,
20                                   __global FLOAT *output) {
21    int image_width_idx  = get_global_id(0);
22    int image_height_idx = get_global_id(1);
23
24    DEAL_NON_UNIFORM_DIM2(image_width_idx, image_height_idx);
25
26    const int batch_idx     = image_height_idx / height;
27    const int height_idx    = image_height_idx % height;
28    const int width_idx     = image_width_idx % width;
29    const int channel_4_idx = (image_width_idx / width) << 2;
30    const int buffer_offset = ((batch_idx * height + height_idx) * width + width_idx) * channels + channel_4_idx;
31
32    const int remain_channel                = channels - channel_4_idx;
33    FLOAT4 values                           = 0;
34
35    #ifdef BUFFER_FORMAT_INP_TRANS
36    __global const float *input_current_ptr = input_ptr + buffer_offset;
37    values                                  = CONVERT_FLOAT4(vload4(0, input_current_ptr));
38    #else
39    __global const FLOAT *input_current_ptr = input_ptr + buffer_offset;
40    values                                  = vload4(0, input_current_ptr);
41    #endif
42
43    if (remain_channel == 3) {
44        values.w = 0;
45    } else if (remain_channel == 2) {
46        values.z = 0;
47        values.w = 0;
48    } else if (remain_channel == 1) {
49        values.y = 0;
50        values.z = 0;
51        values.w = 0;
52    }
53    const int out_offset = (((batch_idx * ((channels+3)/4) + channel_4_idx/4) * height + height_idx) * width + width_idx)*4;
54    vstore4(values, 0, output+out_offset);
55}
56
57// convert data from buffer(nchw) to buffer(nc4hw4)
58__kernel void nchw_buffer_to_nc4hw4_buffer(GLOBAL_SIZE_2_DIMS
59                                   #ifdef BUFFER_FORMAT_INP_TRANS
60                                   __global const float *input_ptr,
61                                   #else
62                                   __global const FLOAT *input_ptr,
63                                   #endif
64                                   __private const int height, __private const int width, __private const int channels,
65                                   __global FLOAT *output) {
66    int image_width_idx  = get_global_id(0);
67    int image_height_idx = get_global_id(1);
68
69    DEAL_NON_UNIFORM_DIM2(image_width_idx, image_height_idx);
70
71    const int batch_idx     = image_height_idx / height;
72    const int height_idx    = image_height_idx % height;
73    const int width_idx     = image_width_idx % width;
74    const int channel_4_idx = image_width_idx / width << 2;
75    const int buffer_offset = ((batch_idx * channels + channel_4_idx) * height + height_idx) * width + width_idx;
76
77    const int remain_channel    = channels - channel_4_idx;
78    const int height_width_size = height * width;
79    FLOAT4 output_values    = 0;
80
81    if (remain_channel >= 4) {
82        int offset      = buffer_offset;
83        output_values.x = (FLOAT)*(input_ptr + offset);
84        offset += height_width_size;
85        output_values.y = (FLOAT)*(input_ptr + offset);
86        offset += height_width_size;
87        output_values.z = (FLOAT)*(input_ptr + offset);
88        offset += height_width_size;
89        output_values.w = (FLOAT)*(input_ptr + offset);
90    } else if (remain_channel == 3) {
91        int offset      = buffer_offset;
92        output_values.x = (FLOAT)*(input_ptr + offset);
93        offset += height_width_size;
94        output_values.y = (FLOAT)*(input_ptr + offset);
95        offset += height_width_size;
96        output_values.z = (FLOAT)*(input_ptr + offset);
97    } else if (remain_channel == 2) {
98        int offset      = buffer_offset;
99        output_values.x = (FLOAT)*(input_ptr + offset);
100        offset += height_width_size;
101        output_values.y = (FLOAT)*(input_ptr + offset);
102    } else if (remain_channel == 1) {
103        int offset      = buffer_offset;
104        output_values.x = (FLOAT)*(input_ptr + offset);
105    }
106
107    const int out_offset = (((batch_idx * ((channels+3)/4) + channel_4_idx/4) * height + height_idx) * width + width_idx)*4;
108    vstore4(output_values, 0, output+out_offset);
109}
110
111// convert data from image(b h, ic/4 w ic4) to buffer(nhwc)
112__kernel void nc4hw4_buffer_to_nhwc_buffer(GLOBAL_SIZE_2_DIMS
113                                    #ifdef BUFFER_FORMAT_OUT_TRANS
114                                    __global float *output,
115                                    #else
116                                    __global FLOAT *output,
117                                    #endif
118                                    __private const int height, __private const int width,
119                                    __private const int channels,
120                                    __global FLOAT *input_ptr) {
121    int image_width_idx  = get_global_id(0);
122    int image_height_idx = get_global_id(1);
123
124    DEAL_NON_UNIFORM_DIM2(image_width_idx, image_height_idx);
125
126    const int batch_idx     = image_height_idx / height;
127    const int height_idx    = image_height_idx % height;
128    const int width_idx     = image_width_idx % width;
129    const int channel_4_idx = (image_width_idx / width) << 2;
130    const int buffer_offset = ((batch_idx * height + height_idx) * width + width_idx) * channels + channel_4_idx;
131
132    const int in_offset = (((batch_idx * ((channels+3)/4) + channel_4_idx/4) * height + height_idx) * width + width_idx)*4;
133
134    #ifdef BUFFER_FORMAT_OUT_TRANS
135    float4 values        = convert_float4(vload4(0, input_ptr+in_offset));
136    #else
137    FLOAT4 values        = vload4(0, input_ptr+in_offset);
138    #endif
139    const int remain_channel = channels - channel_4_idx;
140    if (remain_channel >= 4) {
141        vstore4(values, 0, output + buffer_offset);
142    } else if (remain_channel == 3) {
143        int offset     = buffer_offset;
144        output[offset] = values.x;
145        offset++;
146        output[offset] = values.y;
147        offset++;
148        output[offset] = values.z;
149    } else if (remain_channel == 2) {
150        int offset     = buffer_offset;
151        output[offset] = values.x;
152        offset++;
153        output[offset] = values.y;
154    } else if (remain_channel == 1) {
155        int offset     = buffer_offset;
156        output[offset] = values.x;
157    }
158}
159
160// convert data from buffer(nc4hw4) to buffer(nchw)
161__kernel void nc4hw4_buffer_to_nchw_buffer(GLOBAL_SIZE_2_DIMS
162                                    #ifdef BUFFER_FORMAT_OUT_TRANS
163                                    __global float *output,
164                                    #else
165                                    __global FLOAT *output,
166                                    #endif
167                                    __private const int height, __private const int width,
168                                    __private const int channels,
169                                    __global FLOAT *input_ptr) {
170    int image_width_idx  = get_global_id(0);
171    int image_height_idx = get_global_id(1);
172
173    DEAL_NON_UNIFORM_DIM2(image_width_idx, image_height_idx);
174
175    const int batch_idx  = image_height_idx / height;
176    const int height_idx = image_height_idx % height;
177    const int width_idx  = image_width_idx % width;
178    int channel_4_idx    = (image_width_idx / width) * 4;
179    int buffer_offset    = ((batch_idx * channels + channel_4_idx) * height + height_idx) * width + width_idx;
180
181    const int in_offset = (((batch_idx * ((channels+3)/4) + channel_4_idx/4) * height + height_idx) * width + width_idx)*4;
182    #ifdef BUFFER_FORMAT_OUT_TRANS
183    float4 values    = convert_float4(vload4(0, input_ptr+in_offset));
184    #else
185    FLOAT4 values    = vload4(0, input_ptr+in_offset);
186    #endif
187
188    const int height_width_size = height * width;
189
190    const int remain_channel = channels - channel_4_idx;
191
192    if (remain_channel >= 4) {
193        int offset     = buffer_offset;
194        output[offset] = values.x;
195        offset += height_width_size;
196        output[offset] = values.y;
197        offset += height_width_size;
198        output[offset] = values.z;
199        offset += height_width_size;
200        output[offset] = values.w;
201    } else if (remain_channel == 3) {
202        int offset     = buffer_offset;
203        output[offset] = values.x;
204        offset += height_width_size;
205        output[offset] = values.y;
206        offset += height_width_size;
207        output[offset] = values.z;
208    } else if (remain_channel == 2) {
209        int offset     = buffer_offset;
210        output[offset] = values.x;
211        offset += height_width_size;
212        output[offset] = values.y;
213    } else if (remain_channel == 1) {
214        int offset     = buffer_offset;
215        output[offset] = values.x;
216    }
217}
218
219__kernel void nc4hw4_buffer_to_nc4hw4_buffer(GLOBAL_SIZE_2_DIMS
220                                    #ifdef BUFFER_FORMAT_INP_TRANS
221                                    __global const float *input_ptr,
222                                    #else
223                                    __global const FLOAT *input_ptr,
224                                    #endif
225                                    __private const int2 output_shape,
226                                    __private const int channel_4,
227                                    #ifdef BUFFER_FORMAT_OUT_TRANS
228                                    __global float *output
229                                    #else
230                                    __global FLOAT *output
231                                    #endif
232) {
233
234    int image_width_idx  = get_global_id(0);
235    int image_height_idx = get_global_id(1);
236
237    DEAL_NON_UNIFORM_DIM2(image_width_idx, image_height_idx);
238
239    const int batch_idx         = image_height_idx / output_shape.x;
240    const int height_idx        = image_height_idx % output_shape.x;
241    const int width_idx         = image_width_idx % output_shape.y;
242    const int channel_block_idx = image_width_idx / output_shape.y;
243    int buffer_offset =
244        (((batch_idx * channel_4 + channel_block_idx) * output_shape.x + height_idx) * output_shape.y + width_idx) * 4;
245    #ifdef BUFFER_FORMAT_INP_TRANS
246    FLOAT4 values = CONVERT_FLOAT4(vload4(0, input_ptr + buffer_offset));
247    #else
248    FLOAT4 values = vload4(0, input_ptr + buffer_offset);
249    #endif
250
251    #ifdef BUFFER_FORMAT_OUT_TRANS
252    vstore4(convert_float4(values), 0, output+buffer_offset);
253    #else
254    vstore4(values, 0, output+buffer_offset);
255    #endif
256}
257
258// convert kernel : from buffer(oihw) to image(oc/4 h w , ic oc4)
259__kernel void conv2d_filter_buffer_to_nc4hw4_buffer(GLOBAL_SIZE_2_DIMS
260                                            #ifdef BUFFER_FORMAT_INP_TRANS
261                                            __global const float *input_ptr,
262                                            #else
263                                            __global const FLOAT *input_ptr,
264                                            #endif
265                                            __private const int output_channel,
266                                            __private const int2 kernel_shape,
267                                            __private const int ic_h_w_size,
268                                            __private const int height_width_size,
269                                            __global FLOAT *output) {
270    int image_width_idx  = get_global_id(0); // ic
271    int image_height_idx = get_global_id(1); // oc/4 h w
272
273    DEAL_NON_UNIFORM_DIM2(image_width_idx, image_height_idx);
274
275    const int input_channel_4_idx  = image_width_idx;
276    const int output_channel_4_idx = (image_height_idx / height_width_size) * 4;
277    const int height_width_idx     = image_height_idx % height_width_size;
278    const int buffer_height_idx    = height_width_idx / kernel_shape.y;
279    const int buffer_width_idx     = height_width_idx % kernel_shape.y;
280
281    const int buffer_offset = output_channel_4_idx * ic_h_w_size + input_channel_4_idx * height_width_size +
282                              buffer_height_idx * kernel_shape.y + buffer_width_idx;
283
284    FLOAT4 output_values = 0;
285    if (output_channel_4_idx < output_channel) {
286        const int remain_channel = output_channel - output_channel_4_idx;
287        if (remain_channel >= 4) {
288            int offset      = buffer_offset;
289            output_values.x = (FLOAT)(*(input_ptr + offset));
290            offset          = mad24(1, ic_h_w_size, offset);
291            output_values.y = (FLOAT)(*(input_ptr + offset));
292            offset += ic_h_w_size;
293            output_values.z = (FLOAT)(*(input_ptr + offset));
294            offset += ic_h_w_size;
295            output_values.w = (FLOAT)(*(input_ptr + offset));
296        } else if (remain_channel == 3) {
297            int offset      = buffer_offset;
298            output_values.x = (FLOAT)(*(input_ptr + offset));
299            offset          = mad24(1, ic_h_w_size, offset);
300            output_values.y = (FLOAT)(*(input_ptr + offset));
301            offset += ic_h_w_size;
302            output_values.z = (FLOAT)(*(input_ptr + offset));
303
304        } else if (remain_channel == 2) {
305            int offset      = buffer_offset;
306            output_values.x = (FLOAT)(*(input_ptr + offset));
307            offset          = mad24(1, ic_h_w_size, offset);
308            output_values.y = (FLOAT)(*(input_ptr + offset));
309        } else if (remain_channel == 1) {
310            int offset      = buffer_offset;
311            output_values.x = (FLOAT)(*(input_ptr + offset));
312        }
313    }
314    const int out_offset = (image_width_idx*height_width_size*((output_channel+3)/4)+image_height_idx)*4;
315    vstore4(output_values, 0, output+out_offset);
316}
317
318// convert kernel from buffer(mihw) to image(ic/4, ic4 h w m)
319// but now dw only support m == 1
320__kernel void dw_filter_buffer_to_nc4hw4_buffer(GLOBAL_SIZE_2_DIMS
321                                        #ifdef BUFFER_FORMAT_INP_TRANS
322                                        __global const float *input_ptr,
323                                        #else
324                                        __global const FLOAT *input_ptr,
325                                        #endif
326                                        __private const int4 kernel_shape,//[1, Cout, fh, fw]
327                                        __private const int height_width_size,
328                                        __global FLOAT *output) {
329    const int image_width_idx  = get_global_id(0);//fh*fw
330    const int image_height_idx = get_global_id(1);//UP_DIV(Cout, 4)
331
332    DEAL_NON_UNIFORM_DIM2(image_width_idx, image_height_idx);
333
334    FLOAT4 output_values = 0;
335    if (kernel_shape.x == 1) {
336        const int input_channel_4_idx = image_height_idx * 4;
337        const int buffer_height_idx   = image_width_idx / kernel_shape.w;
338        const int buffer_width_idx    = image_width_idx % kernel_shape.w;
339
340        const int buffer_offset =
341            mad24(mad24(input_channel_4_idx, kernel_shape.z, buffer_height_idx), kernel_shape.w, buffer_width_idx);
342
343        //input [1, Cout,                fh,                 fw]
344        //index:[0, input_channel_4_idx, buffer_height_idx,  buffer_width_idx]
345        const int remain_channel = kernel_shape.y - input_channel_4_idx;
346        if (input_channel_4_idx < kernel_shape.y) {
347            if (remain_channel >= 4) {
348                int offset      = buffer_offset;
349                output_values.x = (FLOAT)(*(input_ptr + offset));
350                offset += height_width_size;
351                output_values.y = (FLOAT)(*(input_ptr + offset));
352                offset += height_width_size;
353                output_values.z = (FLOAT)(*(input_ptr + offset));
354                offset += height_width_size;
355                output_values.w = (FLOAT)(*(input_ptr + offset));
356            } else if (remain_channel == 3) {
357                int offset      = buffer_offset;
358                output_values.x = (FLOAT)(*(input_ptr + offset));
359                offset += height_width_size;
360                output_values.y = (FLOAT)(*(input_ptr + offset));
361                offset += height_width_size;
362                output_values.z = (FLOAT)(*(input_ptr + offset));
363
364            } else if (remain_channel == 2) {
365                int offset      = buffer_offset;
366                output_values.x = (FLOAT)(*(input_ptr + offset));
367                offset += height_width_size;
368                output_values.y = (FLOAT)(*(input_ptr + offset));
369            } else if (remain_channel == 1) {
370                int offset      = buffer_offset;
371                output_values.x = (FLOAT)(*(input_ptr + offset));
372            }
373        }
374    }
375
376    //output NC4HW4 [1, fw*fh,            1, Cout/4]x oc4
377    //index:        [0, image_width_idx,  0, image_height_idx]
378    const int out_offset = (image_width_idx*((kernel_shape.y+3)/4)+image_height_idx)*4;
379    vstore4(output_values, 0, output+out_offset);
380}
381