1#ifdef MNN_SUPPORT_FP16
2#pragma OPENCL EXTENSION cl_khr_fp16 : enable
3#endif
4#define READ_INPUT_IMAGE(i, base)                                                                         \
5    int inOffset##i = inWidthOffset##i + base;                                                           \
6    inOffset##i =                                                                                   \
7        select(inCurIdx + inOffset##i, -1, (inOffset##i < 0 || inOffset##i >= inputShape.y)); \
8    inValue##i = RI_F(input, SAMPLER, (int2)(inOffset##i, inHeightIdx));
9
10#define CALCULATE_OUTPUT(i)                  \
11    outValue##i = mad(inValue##i.x, weights0, outValue##i); \
12    outValue##i = mad(inValue##i.y, weights1, outValue##i); \
13    outValue##i = mad(inValue##i.z, weights2, outValue##i); \
14    outValue##i = mad(inValue##i.w, weights3, outValue##i);
15
16#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0, __private const int global_size_dim1,
17
18__constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
19
20#define DEAL_NON_UNIFORM_DIM2(input1, input2)                       \
21    if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { \
22        return;                                                     \
23    }
24
25__kernel
26#if SET_ATTRIBUTE
27__attribute__((work_group_size_hint(16, 16, 1)))
28#endif
29void depthwise_conv2d_s1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only image2d_t filter,
30                                  #ifndef NO_BIAS
31                                  __read_only image2d_t bias,
32                                  #endif
33                                  __write_only image2d_t output,
34                                  __private const int2 inputShape,
35                                  __private const int inChannelBlocks,
36                                  __private const int2 outputShape,
37                                  __private const int2 filterShape,
38                                  __private const int2 paddingShape) {
39
40    const int outChannelWidthIdx = get_global_id(0);
41    const int outHeightBlockIdx     = get_global_id(1);
42    DEAL_NON_UNIFORM_DIM2(outChannelWidthIdx, outHeightBlockIdx);
43    int ow4              = (outputShape.y + 3) / 4;
44    const int outChannelBlockIdx = outChannelWidthIdx / ow4;
45    const int outWidthBlockidx   = outChannelWidthIdx % ow4;
46
47    const int inChannelBlockIdx = outChannelBlockIdx;
48
49    #ifndef NO_BIAS
50    FLOAT4 outValue0 = RI_F(bias, SAMPLER, (int2)(outChannelBlockIdx, 0));
51    #else
52    FLOAT4 outValue0 = (FLOAT4)(0.0f);
53    #endif
54    FLOAT4 outValue1 = outValue0;
55    FLOAT4 outValue2 = outValue0;
56    FLOAT4 outValue3 = outValue0;
57
58    const int outWidthBlockidx4 = outWidthBlockidx << 2;
59    const int inWidthOffset0             = outWidthBlockidx4 - paddingShape.y;
60    const int inWidthOffset1             = inWidthOffset0 + 1;
61    const int inWidthOffset2             = inWidthOffset0 + 2;
62    const int inWidthOffset3             = inWidthOffset0 + 3;
63
64    int heightIdx            = outHeightBlockIdx % outputShape.x - paddingShape.x;
65    const int outBatchIdx = mul24((outHeightBlockIdx / outputShape.x), inputShape.x);
66    const int inCurIdx = mul24(inChannelBlockIdx, inputShape.y);
67
68    const int inWidthIdx0 = select(inCurIdx + inWidthOffset0, -1, (inWidthOffset0 < 0 || inWidthOffset0 >= inputShape.y));
69    const int inWidthIdx1 = select(inCurIdx + inWidthOffset1, -1, (inWidthOffset1 < 0 || inWidthOffset1 >= inputShape.y));
70    const int inWidthIdx2 = select(inCurIdx + inWidthOffset2, -1, (inWidthOffset2 < 0 || inWidthOffset2 >= inputShape.y));
71
72    FLOAT4 inValue0, inValue1, inValue2, inValue3;
73    for (int kh = 0; kh < filterShape.x; kh++) {
74        int inHeightIdx = select(heightIdx + outBatchIdx, -1, (heightIdx < 0 || heightIdx >= inputShape.x));
75        heightIdx++;
76        inValue1       = RI_F(input, SAMPLER, (int2)(inWidthIdx0, inHeightIdx));
77        inValue2       = RI_F(input, SAMPLER, (int2)(inWidthIdx1, inHeightIdx));
78        inValue3       = RI_F(input, SAMPLER, (int2)(inWidthIdx2, inHeightIdx));
79        for (int kw = 0; kw < filterShape.y; kw++) {
80
81            int filterIdx   = mad24(kh, filterShape.y, kw);
82            inValue0 = inValue1;
83            inValue1 = inValue2;
84            inValue2 = inValue3;
85
86            int inWidthIdx = inWidthOffset3 + kw;
87            inWidthIdx = select(inCurIdx + inWidthIdx, -1, (inWidthIdx < 0 || inWidthIdx >= inputShape.y));
88            inValue3  = RI_F(input, SAMPLER, (int2)(inWidthIdx, inHeightIdx));
89
90            FLOAT4 weights = RI_F(filter, SAMPLER, (int2)(filterIdx, inChannelBlockIdx));
91
92            outValue0 = mad(inValue0, weights, outValue0);
93            outValue1 = mad(inValue1, weights, outValue1);
94            outValue2 = mad(inValue2, weights, outValue2);
95            outValue3 = mad(inValue3, weights, outValue3);
96        }
97    }
98
99#ifdef RELU
100    outValue0 = fmax(outValue0, (FLOAT4)0);
101    outValue1 = fmax(outValue1, (FLOAT4)0);
102    outValue2 = fmax(outValue2, (FLOAT4)0);
103    outValue3 = fmax(outValue3, (FLOAT4)0);
104#endif
105
106#ifdef RELU6
107    outValue0 = clamp(outValue0, (FLOAT4)0, (FLOAT4)6);
108    outValue1 = clamp(outValue1, (FLOAT4)0, (FLOAT4)6);
109    outValue2 = clamp(outValue2, (FLOAT4)0, (FLOAT4)6);
110    outValue3 = clamp(outValue3, (FLOAT4)0, (FLOAT4)6);
111#endif
112
113    const int remain     = outputShape.y - outWidthBlockidx4;
114    int outWidthIdx       = mul24(outChannelBlockIdx, outputShape.y) + outWidthBlockidx4;
115    if (remain >= 4) {
116        WI_F(output, (int2)(outWidthIdx, outHeightBlockIdx), outValue0);
117        WI_F(output, (int2)(outWidthIdx + 1, outHeightBlockIdx), outValue1);
118        WI_F(output, (int2)(outWidthIdx + 2, outHeightBlockIdx), outValue2);
119        WI_F(output, (int2)(outWidthIdx + 3, outHeightBlockIdx), outValue3);
120    } else if (remain == 3) {
121        WI_F(output, (int2)(outWidthIdx, outHeightBlockIdx), outValue0);
122        WI_F(output, (int2)(outWidthIdx + 1, outHeightBlockIdx), outValue1);
123        WI_F(output, (int2)(outWidthIdx + 2, outHeightBlockIdx), outValue2);
124    } else if (remain == 2) {
125        WI_F(output, (int2)(outWidthIdx, outHeightBlockIdx), outValue0);
126        WI_F(output, (int2)(outWidthIdx + 1, outHeightBlockIdx), outValue1);
127    } else if (remain == 1) {
128        WI_F(output, (int2)(outWidthIdx, outHeightBlockIdx), outValue0);
129    }
130}
131
132__kernel
133#if SET_ATTRIBUTE
134__attribute__((work_group_size_hint(16, 16, 1)))
135#endif
136void depthwise_conv2d(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only image2d_t filter,
137                               #ifndef NO_BIAS
138                               __read_only image2d_t bias,
139                               #endif
140                               __write_only image2d_t output,
141                               __private const int2 inputShape,
142                               __private const int inChannelBlocks, __private const int2 outputShape,
143                               __private const int2 filterShape,
144                               __private const int2 paddingShape,
145                               __private const int2 dilationShape,
146                               __private const int2 strideShape) {
147
148    const int outChannelWidthIdx = get_global_id(0);
149    const int outHeightIdx     = get_global_id(1);
150    DEAL_NON_UNIFORM_DIM2(outChannelWidthIdx, outHeightIdx);
151
152    int ow4              = (outputShape.y + 3) / 4;
153    const int outChannelBlockIdx = outChannelWidthIdx / ow4;
154    const int outWidthBlockidx   = outChannelWidthIdx % ow4;
155
156    const int inChannelBlockIdx = outChannelBlockIdx;
157
158    #ifndef NO_BIAS
159    FLOAT4 outValue0 = RI_F(bias, SAMPLER, (int2)(outChannelBlockIdx, 0));
160    #else
161    FLOAT4 outValue0 = (FLOAT4)(0.0f);
162    #endif
163    FLOAT4 outValue1 = outValue0;
164    FLOAT4 outValue2 = outValue0;
165    FLOAT4 outValue3 = outValue0;
166
167    const int inWidthOffset0  = mad24(outWidthBlockidx, strideShape.y << 2, -paddingShape.y);
168    const int inWidthOffset1  = inWidthOffset0 + strideShape.y;
169    const int inWidthOffset2  = inWidthOffset1 + strideShape.y;
170    const int inWidthOffset3  = inWidthOffset2 + strideShape.y;
171    int heightIdx = mad24(outHeightIdx % outputShape.x, strideShape.x, -paddingShape.x);
172
173    const int outBatchIdx = mul24((outHeightIdx / outputShape.x), inputShape.x);
174
175    const int inCurIdx = mul24(inChannelBlockIdx, inputShape.y);
176    for (int kh = 0; kh < filterShape.x; kh++) {
177        int inHeightIdx = select(heightIdx + outBatchIdx, -1, (heightIdx < 0 || heightIdx >= inputShape.x));
178        heightIdx += dilationShape.x;
179        for (int kw = 0; kw < filterShape.y; kw++) {
180            int filterIdx = mad24(kh, filterShape.y, kw);
181            FLOAT4 inValue0, inValue1, inValue2, inValue3;
182            int inWidthIdx = mul24(kw, dilationShape.y);
183
184            READ_INPUT_IMAGE(0, inWidthIdx);
185            READ_INPUT_IMAGE(1, inWidthIdx);
186            READ_INPUT_IMAGE(2, inWidthIdx);
187            READ_INPUT_IMAGE(3, inWidthIdx);
188
189            FLOAT4 weights = RI_F(filter, SAMPLER, (int2)(filterIdx, inChannelBlockIdx));
190
191            outValue0 = mad(inValue0, weights, outValue0);
192            outValue1 = mad(inValue1, weights, outValue1);
193            outValue2 = mad(inValue2, weights, outValue2);
194            outValue3 = mad(inValue3, weights, outValue3);
195        }
196    }
197
198#ifdef RELU
199    outValue0 = fmax(outValue0, (FLOAT4)0);
200    outValue1 = fmax(outValue1, (FLOAT4)0);
201    outValue2 = fmax(outValue2, (FLOAT4)0);
202    outValue3 = fmax(outValue3, (FLOAT4)0);
203#endif
204
205#ifdef RELU6
206    outValue0 = clamp(outValue0, (FLOAT4)0, (FLOAT4)6);
207    outValue1 = clamp(outValue1, (FLOAT4)0, (FLOAT4)6);
208    outValue2 = clamp(outValue2, (FLOAT4)0, (FLOAT4)6);
209    outValue3 = clamp(outValue3, (FLOAT4)0, (FLOAT4)6);
210#endif
211
212    const int outWidthBlockidx4        = outWidthBlockidx << 2;
213    const int remain = outputShape.y - outWidthBlockidx4;
214    int outWidthIdx   = mul24(outChannelBlockIdx, outputShape.y) + outWidthBlockidx4;
215    if (remain >= 4) {
216        WI_F(output, (int2)(outWidthIdx, outHeightIdx), outValue0);
217        WI_F(output, (int2)(outWidthIdx + 1, outHeightIdx), outValue1);
218        WI_F(output, (int2)(outWidthIdx + 2, outHeightIdx), outValue2);
219        WI_F(output, (int2)(outWidthIdx + 3, outHeightIdx), outValue3);
220    } else if (remain == 3) {
221        WI_F(output, (int2)(outWidthIdx, outHeightIdx), outValue0);
222        WI_F(output, (int2)(outWidthIdx + 1, outHeightIdx), outValue1);
223        WI_F(output, (int2)(outWidthIdx + 2, outHeightIdx), outValue2);
224    } else if (remain == 2) {
225        WI_F(output, (int2)(outWidthIdx, outHeightIdx), outValue0);
226        WI_F(output, (int2)(outWidthIdx + 1, outHeightIdx), outValue1);
227    } else if (remain == 1) {
228        WI_F(output, (int2)(outWidthIdx, outHeightIdx), outValue0);
229    }
230}
231