1#ifdef MNN_SUPPORT_FP16 2#pragma OPENCL EXTENSION cl_khr_fp16 : enable 3#endif 4#define READ_INPUT_IMAGE(i, base) \ 5 int inOffset##i = inWidthOffset##i + base; \ 6 inOffset##i = \ 7 select(inCurIdx + inOffset##i, -1, (inOffset##i < 0 || inOffset##i >= inputShape.y)); \ 8 inValue##i = RI_F(input, SAMPLER, (int2)(inOffset##i, inHeightIdx)); 9 10#define CALCULATE_OUTPUT(i) \ 11 outValue##i = mad(inValue##i.x, weights0, outValue##i); \ 12 outValue##i = mad(inValue##i.y, weights1, outValue##i); \ 13 outValue##i = mad(inValue##i.z, weights2, outValue##i); \ 14 outValue##i = mad(inValue##i.w, weights3, outValue##i); 15 16#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0, __private const int global_size_dim1, 17 18__constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; 19 20#define DEAL_NON_UNIFORM_DIM2(input1, input2) \ 21 if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { \ 22 return; \ 23 } 24 25__kernel 26#if SET_ATTRIBUTE 27__attribute__((work_group_size_hint(16, 16, 1))) 28#endif 29void depthwise_conv2d_s1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only image2d_t filter, 30 #ifndef NO_BIAS 31 __read_only image2d_t bias, 32 #endif 33 __write_only image2d_t output, 34 __private const int2 inputShape, 35 __private const int inChannelBlocks, 36 __private const int2 outputShape, 37 __private const int2 filterShape, 38 __private const int2 paddingShape) { 39 40 const int outChannelWidthIdx = get_global_id(0); 41 const int outHeightBlockIdx = get_global_id(1); 42 DEAL_NON_UNIFORM_DIM2(outChannelWidthIdx, outHeightBlockIdx); 43 int ow4 = (outputShape.y + 3) / 4; 44 const int outChannelBlockIdx = outChannelWidthIdx / ow4; 45 const int outWidthBlockidx = outChannelWidthIdx % ow4; 46 47 const int inChannelBlockIdx = outChannelBlockIdx; 48 49 #ifndef NO_BIAS 50 FLOAT4 outValue0 = RI_F(bias, SAMPLER, (int2)(outChannelBlockIdx, 0)); 51 #else 52 FLOAT4 outValue0 = (FLOAT4)(0.0f); 53 #endif 54 FLOAT4 outValue1 = outValue0; 55 FLOAT4 outValue2 = outValue0; 56 FLOAT4 outValue3 = outValue0; 57 58 const int outWidthBlockidx4 = outWidthBlockidx << 2; 59 const int inWidthOffset0 = outWidthBlockidx4 - paddingShape.y; 60 const int inWidthOffset1 = inWidthOffset0 + 1; 61 const int inWidthOffset2 = inWidthOffset0 + 2; 62 const int inWidthOffset3 = inWidthOffset0 + 3; 63 64 int heightIdx = outHeightBlockIdx % outputShape.x - paddingShape.x; 65 const int outBatchIdx = mul24((outHeightBlockIdx / outputShape.x), inputShape.x); 66 const int inCurIdx = mul24(inChannelBlockIdx, inputShape.y); 67 68 const int inWidthIdx0 = select(inCurIdx + inWidthOffset0, -1, (inWidthOffset0 < 0 || inWidthOffset0 >= inputShape.y)); 69 const int inWidthIdx1 = select(inCurIdx + inWidthOffset1, -1, (inWidthOffset1 < 0 || inWidthOffset1 >= inputShape.y)); 70 const int inWidthIdx2 = select(inCurIdx + inWidthOffset2, -1, (inWidthOffset2 < 0 || inWidthOffset2 >= inputShape.y)); 71 72 FLOAT4 inValue0, inValue1, inValue2, inValue3; 73 for (int kh = 0; kh < filterShape.x; kh++) { 74 int inHeightIdx = select(heightIdx + outBatchIdx, -1, (heightIdx < 0 || heightIdx >= inputShape.x)); 75 heightIdx++; 76 inValue1 = RI_F(input, SAMPLER, (int2)(inWidthIdx0, inHeightIdx)); 77 inValue2 = RI_F(input, SAMPLER, (int2)(inWidthIdx1, inHeightIdx)); 78 inValue3 = RI_F(input, SAMPLER, (int2)(inWidthIdx2, inHeightIdx)); 79 for (int kw = 0; kw < filterShape.y; kw++) { 80 81 int filterIdx = mad24(kh, filterShape.y, kw); 82 inValue0 = inValue1; 83 inValue1 = inValue2; 84 inValue2 = inValue3; 85 86 int inWidthIdx = inWidthOffset3 + kw; 87 inWidthIdx = select(inCurIdx + inWidthIdx, -1, (inWidthIdx < 0 || inWidthIdx >= inputShape.y)); 88 inValue3 = RI_F(input, SAMPLER, (int2)(inWidthIdx, inHeightIdx)); 89 90 FLOAT4 weights = RI_F(filter, SAMPLER, (int2)(filterIdx, inChannelBlockIdx)); 91 92 outValue0 = mad(inValue0, weights, outValue0); 93 outValue1 = mad(inValue1, weights, outValue1); 94 outValue2 = mad(inValue2, weights, outValue2); 95 outValue3 = mad(inValue3, weights, outValue3); 96 } 97 } 98 99#ifdef RELU 100 outValue0 = fmax(outValue0, (FLOAT4)0); 101 outValue1 = fmax(outValue1, (FLOAT4)0); 102 outValue2 = fmax(outValue2, (FLOAT4)0); 103 outValue3 = fmax(outValue3, (FLOAT4)0); 104#endif 105 106#ifdef RELU6 107 outValue0 = clamp(outValue0, (FLOAT4)0, (FLOAT4)6); 108 outValue1 = clamp(outValue1, (FLOAT4)0, (FLOAT4)6); 109 outValue2 = clamp(outValue2, (FLOAT4)0, (FLOAT4)6); 110 outValue3 = clamp(outValue3, (FLOAT4)0, (FLOAT4)6); 111#endif 112 113 const int remain = outputShape.y - outWidthBlockidx4; 114 int outWidthIdx = mul24(outChannelBlockIdx, outputShape.y) + outWidthBlockidx4; 115 if (remain >= 4) { 116 WI_F(output, (int2)(outWidthIdx, outHeightBlockIdx), outValue0); 117 WI_F(output, (int2)(outWidthIdx + 1, outHeightBlockIdx), outValue1); 118 WI_F(output, (int2)(outWidthIdx + 2, outHeightBlockIdx), outValue2); 119 WI_F(output, (int2)(outWidthIdx + 3, outHeightBlockIdx), outValue3); 120 } else if (remain == 3) { 121 WI_F(output, (int2)(outWidthIdx, outHeightBlockIdx), outValue0); 122 WI_F(output, (int2)(outWidthIdx + 1, outHeightBlockIdx), outValue1); 123 WI_F(output, (int2)(outWidthIdx + 2, outHeightBlockIdx), outValue2); 124 } else if (remain == 2) { 125 WI_F(output, (int2)(outWidthIdx, outHeightBlockIdx), outValue0); 126 WI_F(output, (int2)(outWidthIdx + 1, outHeightBlockIdx), outValue1); 127 } else if (remain == 1) { 128 WI_F(output, (int2)(outWidthIdx, outHeightBlockIdx), outValue0); 129 } 130} 131 132__kernel 133#if SET_ATTRIBUTE 134__attribute__((work_group_size_hint(16, 16, 1))) 135#endif 136void depthwise_conv2d(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only image2d_t filter, 137 #ifndef NO_BIAS 138 __read_only image2d_t bias, 139 #endif 140 __write_only image2d_t output, 141 __private const int2 inputShape, 142 __private const int inChannelBlocks, __private const int2 outputShape, 143 __private const int2 filterShape, 144 __private const int2 paddingShape, 145 __private const int2 dilationShape, 146 __private const int2 strideShape) { 147 148 const int outChannelWidthIdx = get_global_id(0); 149 const int outHeightIdx = get_global_id(1); 150 DEAL_NON_UNIFORM_DIM2(outChannelWidthIdx, outHeightIdx); 151 152 int ow4 = (outputShape.y + 3) / 4; 153 const int outChannelBlockIdx = outChannelWidthIdx / ow4; 154 const int outWidthBlockidx = outChannelWidthIdx % ow4; 155 156 const int inChannelBlockIdx = outChannelBlockIdx; 157 158 #ifndef NO_BIAS 159 FLOAT4 outValue0 = RI_F(bias, SAMPLER, (int2)(outChannelBlockIdx, 0)); 160 #else 161 FLOAT4 outValue0 = (FLOAT4)(0.0f); 162 #endif 163 FLOAT4 outValue1 = outValue0; 164 FLOAT4 outValue2 = outValue0; 165 FLOAT4 outValue3 = outValue0; 166 167 const int inWidthOffset0 = mad24(outWidthBlockidx, strideShape.y << 2, -paddingShape.y); 168 const int inWidthOffset1 = inWidthOffset0 + strideShape.y; 169 const int inWidthOffset2 = inWidthOffset1 + strideShape.y; 170 const int inWidthOffset3 = inWidthOffset2 + strideShape.y; 171 int heightIdx = mad24(outHeightIdx % outputShape.x, strideShape.x, -paddingShape.x); 172 173 const int outBatchIdx = mul24((outHeightIdx / outputShape.x), inputShape.x); 174 175 const int inCurIdx = mul24(inChannelBlockIdx, inputShape.y); 176 for (int kh = 0; kh < filterShape.x; kh++) { 177 int inHeightIdx = select(heightIdx + outBatchIdx, -1, (heightIdx < 0 || heightIdx >= inputShape.x)); 178 heightIdx += dilationShape.x; 179 for (int kw = 0; kw < filterShape.y; kw++) { 180 int filterIdx = mad24(kh, filterShape.y, kw); 181 FLOAT4 inValue0, inValue1, inValue2, inValue3; 182 int inWidthIdx = mul24(kw, dilationShape.y); 183 184 READ_INPUT_IMAGE(0, inWidthIdx); 185 READ_INPUT_IMAGE(1, inWidthIdx); 186 READ_INPUT_IMAGE(2, inWidthIdx); 187 READ_INPUT_IMAGE(3, inWidthIdx); 188 189 FLOAT4 weights = RI_F(filter, SAMPLER, (int2)(filterIdx, inChannelBlockIdx)); 190 191 outValue0 = mad(inValue0, weights, outValue0); 192 outValue1 = mad(inValue1, weights, outValue1); 193 outValue2 = mad(inValue2, weights, outValue2); 194 outValue3 = mad(inValue3, weights, outValue3); 195 } 196 } 197 198#ifdef RELU 199 outValue0 = fmax(outValue0, (FLOAT4)0); 200 outValue1 = fmax(outValue1, (FLOAT4)0); 201 outValue2 = fmax(outValue2, (FLOAT4)0); 202 outValue3 = fmax(outValue3, (FLOAT4)0); 203#endif 204 205#ifdef RELU6 206 outValue0 = clamp(outValue0, (FLOAT4)0, (FLOAT4)6); 207 outValue1 = clamp(outValue1, (FLOAT4)0, (FLOAT4)6); 208 outValue2 = clamp(outValue2, (FLOAT4)0, (FLOAT4)6); 209 outValue3 = clamp(outValue3, (FLOAT4)0, (FLOAT4)6); 210#endif 211 212 const int outWidthBlockidx4 = outWidthBlockidx << 2; 213 const int remain = outputShape.y - outWidthBlockidx4; 214 int outWidthIdx = mul24(outChannelBlockIdx, outputShape.y) + outWidthBlockidx4; 215 if (remain >= 4) { 216 WI_F(output, (int2)(outWidthIdx, outHeightIdx), outValue0); 217 WI_F(output, (int2)(outWidthIdx + 1, outHeightIdx), outValue1); 218 WI_F(output, (int2)(outWidthIdx + 2, outHeightIdx), outValue2); 219 WI_F(output, (int2)(outWidthIdx + 3, outHeightIdx), outValue3); 220 } else if (remain == 3) { 221 WI_F(output, (int2)(outWidthIdx, outHeightIdx), outValue0); 222 WI_F(output, (int2)(outWidthIdx + 1, outHeightIdx), outValue1); 223 WI_F(output, (int2)(outWidthIdx + 2, outHeightIdx), outValue2); 224 } else if (remain == 2) { 225 WI_F(output, (int2)(outWidthIdx, outHeightIdx), outValue0); 226 WI_F(output, (int2)(outWidthIdx + 1, outHeightIdx), outValue1); 227 } else if (remain == 1) { 228 WI_F(output, (int2)(outWidthIdx, outHeightIdx), outValue0); 229 } 230} 231