1#define GLOBAL_SIZE_3_DIMS \
2    __private const int global_size_dim0, __private const int global_size_dim1, __private const int global_size_dim2,
3#ifdef MNN_SUPPORT_FP16
4#pragma OPENCL EXTENSION cl_khr_fp16 : enable
5#endif
6#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3)                                             \
7    if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { \
8        return;                                                                                   \
9    }
10
11__constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
12
13__kernel void deconv_2d(GLOBAL_SIZE_3_DIMS
14                    #ifdef USE_BUFFER
15                        __global FLOAT* input,
16                        __global FLOAT* weights,
17                        #ifdef BIAS
18                        __global FLOAT* bias,
19                        #endif
20                        __global FLOAT* output,
21                    #else
22                        __read_only image2d_t input,
23                        __read_only image2d_t weights,
24                        #ifdef BIAS
25                        __read_only image2d_t bias,
26                        #endif
27                        __write_only image2d_t output,
28                    #endif
29                        __private const int2 input_shape,
30                        __private const int2 output_shape,
31                        __private const int2 stride_shape,
32                        __private const int2 align_shape,
33                        __private const int2 padding_shape,
34                        __private const int2 kernel_shape,
35                        __private const int kernel_size,
36                        __private const int in_channel_blocks, __private const int out_channel_blocks) {
37
38    const int out_channel_blocks_idx = get_global_id(0);
39    const int out_w_idx          = get_global_id(1);
40    const int out_batch_height_idx   = get_global_id(2);
41
42    DEAL_NON_UNIFORM_DIM3(out_channel_blocks_idx, out_w_idx, out_batch_height_idx);
43
44#ifdef BIAS
45    #ifdef USE_BUFFER
46    FLOAT4 out0 = vload4(out_channel_blocks_idx, bias);
47    #else
48    FLOAT4 out0 = RI_F(bias, SAMPLER, (int2)(out_channel_blocks_idx, 0));
49    #endif
50#else
51    FLOAT4 out0 = (FLOAT4)0;
52#endif
53
54    const int out_b_idx  = out_batch_height_idx / output_shape.x;
55    const int out_h_idx = out_batch_height_idx % output_shape.x;
56
57    int kernel_start_x = max(0, (out_w_idx + align_shape.y) / stride_shape.y);
58    int kernel_start_y = max(0, (out_h_idx + align_shape.x) / stride_shape.x);
59    int deal_kernel_width  = kernel_shape.y - mad24(kernel_start_x, stride_shape.y, padding_shape.y) + out_w_idx - 1;
60    int deal_kernel_height = kernel_shape.x - mad24(kernel_start_y, stride_shape.x, padding_shape.x) + out_h_idx - 1;
61
62
63    int kernel_x_0, kernel_x_1, kernel_x_2, kernel_x_3, kernel_y;
64    FLOAT4 in0;
65    FLOAT4 weights0, weights1, weights2, weights3;
66    for (int ic = 0; ic < in_channel_blocks; ic++) {
67        kernel_x_0 = ic << 2;
68        kernel_x_1 = kernel_x_0 + 1;
69        kernel_x_2 = kernel_x_0 + 2;
70        kernel_x_3 = kernel_x_0 + 3;
71        for (int k_y = deal_kernel_height, idx_h = kernel_start_y; k_y >= 0; k_y -= stride_shape.x, idx_h++) {
72            #ifdef USE_BUFFER
73            int in_width0   = kernel_start_x;
74            for (int k_x = deal_kernel_width; k_x >= 0; k_x -= stride_shape.y) {
75                kernel_y = mad24(k_y, kernel_shape.y, k_x);
76                kernel_y = mad24(out_channel_blocks_idx, kernel_size, kernel_y);
77                //weights  NC4HW4  [1,  4*icC4,  ocC4*kh*kw,  1] xic4
78                //index:   [0, kernel_x_0, kernel_y, 0]
79                weights0 = vload4(kernel_x_0*(out_channel_blocks*kernel_shape.x*kernel_shape.y)+kernel_y, weights);
80                weights1 = vload4(kernel_x_1*(out_channel_blocks*kernel_shape.x*kernel_shape.y)+kernel_y, weights);
81                weights2 = vload4(kernel_x_2*(out_channel_blocks*kernel_shape.x*kernel_shape.y)+kernel_y, weights);
82                weights3 = vload4(kernel_x_3*(out_channel_blocks*kernel_shape.x*kernel_shape.y)+kernel_y, weights);
83
84                bool outBoundry = (idx_h < 0 || idx_h >= input_shape.x || kernel_start_x < 0 || in_width0 >= input_shape.y);
85                int inp_offset = (((out_b_idx * in_channel_blocks + ic) * input_shape.x + idx_h) * input_shape.y + in_width0) * 4;
86                in0 = outBoundry ? (FLOAT4)0 : vload4(0, input+inp_offset);
87
88                out0 = mad(in0.x, weights0, out0);
89                out0 = mad(in0.y, weights1, out0);
90                out0 = mad(in0.z, weights2, out0);
91                out0 = mad(in0.w, weights3, out0);
92                in_width0++;
93            }
94            #else
95            int in_idy      = mad24(out_b_idx, input_shape.x, idx_h);
96            int in_hb_value = select(in_idy, -1, idx_h < 0 || idx_h >= input_shape.x);
97            int in_width0   = kernel_start_x;
98            for (int k_x = deal_kernel_width; k_x >= 0; k_x -= stride_shape.y) {
99                kernel_y = mad24(k_y, kernel_shape.y, k_x);
100                kernel_y = mad24(out_channel_blocks_idx, kernel_size, kernel_y);
101                weights0 = RI_F(weights, SAMPLER, (int2)(kernel_x_0, kernel_y));
102                weights1 = RI_F(weights, SAMPLER, (int2)(kernel_x_1, kernel_y));
103                weights2 = RI_F(weights, SAMPLER, (int2)(kernel_x_2, kernel_y));
104                weights3 = RI_F(weights, SAMPLER, (int2)(kernel_x_3, kernel_y));
105
106                int in_idx = mul24(ic, input_shape.y);
107                int in_width_value0 = in_width0;                                                           \
108                in_width_value0 =                                                                                   \
109                    select(in_idx + in_width_value0, -1, (in_width_value0 < 0 || in_width_value0 >= input_shape.y)); \
110                in0 = RI_F(input, SAMPLER, (int2)(in_width_value0, in_hb_value));
111
112                out0 = mad(in0.x, weights0, out0);
113                out0 = mad(in0.y, weights1, out0);
114                out0 = mad(in0.z, weights2, out0);
115                out0 = mad(in0.w, weights3, out0);
116                in_width0++;
117            }
118            #endif
119        }
120    }
121#ifdef RELU
122    out0 = fmax(out0, (FLOAT4)0);
123#endif
124
125#ifdef RELU6
126    out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6);
127#endif
128
129#ifdef USE_BUFFER
130    const int out_offset = (((out_b_idx*out_channel_blocks + out_channel_blocks_idx)*output_shape.x + out_h_idx)*output_shape.y + out_w_idx)*4;
131    vstore4(out0, 0, output+out_offset);
132#else
133    int out_image_width_idx = mad24(out_channel_blocks_idx, output_shape.y, out_w_idx);
134    WI_F(output, (int2)(out_image_width_idx, out_batch_height_idx), out0);
135#endif
136}
137
138__kernel void iohw2oihw(__global const float* input_ptr, __global float* output_ptr, int plane_number, int input_channel, int output_channel) {
139    const int ic_index = get_global_id(0), oc_index = get_global_id(1);
140    if (ic_index >= input_channel || oc_index >= output_channel) {
141        return;
142    }
143    const int input_offset = (ic_index * output_channel + oc_index) * plane_number;
144    const int output_offset = (oc_index * input_channel + ic_index) * plane_number;
145    for (int i = 0; i < plane_number; ++i) {
146        output_ptr[output_offset + i] = input_ptr[input_offset + i];
147    }
148}
149