1#define READ_INPUT_IMAGE(i, base)                                                                         \
2    int in_width_value##i = in_width##i + base;                                                           \
3    in_width_value##i =                                                                                   \
4        select(in_idx + in_width_value##i, -1, (in_width_value##i < 0 || in_width_value##i >= input_shape.y)); \
5    in##i = read_imagef(input, SAMPLER, (int2)(in_width_value##i, in_hb_value));
6
7#define CALCULATE_OUTPUT(i)                  \
8    out##i = mad(in##i.x, weights0, out##i); \
9    out##i = mad(in##i.y, weights1, out##i); \
10    out##i = mad(in##i.z, weights2, out##i); \
11    out##i = mad(in##i.w, weights3, out##i);
12#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3)                                             \
13    if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { \
14        return;                                                                                   \
15    }
16#define GLOBAL_SIZE_3_DIMS \
17    __private const int global_size_dim0, __private const int global_size_dim1, __private const int global_size_dim2,
18
19__constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
20
21
22__kernel void depthwise_deconv2d(GLOBAL_SIZE_3_DIMS __read_only image2d_t input,
23                                 __read_only image2d_t weights,
24                                 #ifndef NO_BIAS
25                                 __read_only image2d_t bias,
26                                 #endif
27                                 __write_only image2d_t output,
28                                 __private const int2 input_shape,
29                                 __private const int2 output_shape,
30                                 __private const int2 stride_shape,
31                                 __private const int2 align_shape,
32                                 __private const int2 padding_shape,
33                                 __private const int2 kernel_shape,
34                                 __private const int kernel_size, __private const int out_channel_blocks) {
35    const int out_channel_blocks_idx = get_global_id(0);
36    const int out_width_idx          = get_global_id(1);
37    const int out_batch_height_idx   = get_global_id(2);
38
39    DEAL_NON_UNIFORM_DIM3(out_channel_blocks_idx, out_width_idx, out_batch_height_idx);
40    #ifndef NO_BIAS
41    float4 out0 = read_imagef(bias, SAMPLER, (int2)(out_channel_blocks_idx, 0));
42    #else
43    float4 out0 = (float4)(0.0);
44    #endif
45
46    const int out_batch_idx  = out_batch_height_idx / output_shape.x;
47    const int out_height_idx = out_batch_height_idx % output_shape.x;
48
49    const int out_width_fill_idx  = out_width_idx - (stride_shape.y - 1);
50    const int out_height_fill_idx = out_height_idx - (stride_shape.x - 1);
51
52    int kernel_start_x = (out_width_fill_idx + align_shape.y) / stride_shape.y;
53    int kernel_start_y = (out_height_fill_idx + align_shape.x) / stride_shape.x;
54
55    int deal_kernel_width  = kernel_shape.y - mad24(kernel_start_x, stride_shape.y, padding_shape.y) + out_width_fill_idx - 1;
56    int deal_kernel_height = kernel_shape.x - mad24(kernel_start_y, stride_shape.x, padding_shape.x) + out_height_fill_idx - 1;
57
58    int kernel_image_x;
59    float4 in0;
60    float4 weight;
61    int in_width0;
62    int in_idx, in_idy;
63    for (int k_y = deal_kernel_height, idx_h = kernel_start_y; k_y >= 0; k_y -= stride_shape.x, idx_h++) {
64        in_idy          = mad24(out_batch_idx, input_shape.x, idx_h);
65        int in_hb_value = select(in_idy, -1, idx_h < 0 || idx_h >= input_shape.x);
66        for (int k_x = deal_kernel_width, in_width_idx = kernel_start_x; k_x >= 0; k_x -= stride_shape.y, in_width_idx++) {
67            in_width0 = in_width_idx;
68
69            in_idx = mul24(out_channel_blocks_idx, input_shape.y);
70            READ_INPUT_IMAGE(0, 0);
71
72            kernel_image_x = mad24(k_y, kernel_shape.y, k_x);
73            weight         = read_imagef(weights, SAMPLER, (int2)(kernel_image_x, out_channel_blocks_idx));
74            out0           = mad(in0, weight, out0);
75        }
76
77#ifdef RELU
78        out0 = fmax(out0, (float4)0);
79#endif
80
81#ifdef RELU6
82        out0 = clamp(out0, (float4)0, (float4)6);
83#endif
84
85        const int output_image_x = mad24(out_channel_blocks_idx, output_shape.y, out_width_idx);
86        write_imagef(output, (int2)(output_image_x, out_batch_height_idx), out0);
87    }
88}
89