1 //
2 // ConvWinograd.cpp
3 // MNN
4 //
5 // Created by MNN on 2019/01/08.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "backend/opencl/execution/image/ConvWinograd.hpp"
10 #include <string.h>
11 #include "core/Backend.hpp"
12 #include "core/ConvolutionCommon.hpp"
13 #include "math/WingoradGenerater.hpp"
14 #include "backend/opencl/core/OpenCLRunningUtils.hpp"
15 #define UNIT 2
16 #define INTERP 1
17 namespace MNN {
18 namespace OpenCL {
valid(const Convolution2DCommon * common,const Tensor * input,int limit)19 bool ConvWinograd::valid(const Convolution2DCommon* common, const Tensor* input, int limit) {
20 if (common->strideX() != 1 || common->strideY() != 1) {
21 return false;
22 }
23 if (common->dilateX() != 1 || common->dilateY() != 1) {
24 return false;
25 }
26 if (input->channel() < 8 || common->outputCount() < 8) {
27 return false;
28 }
29
30 return (common->kernelX() == 3 && common->kernelY() == 3) || (common->kernelX() == 5 && common->kernelY() == 5);
31 }
32
33
ConvWinograd(const MNN::Convolution2D * op,Backend * backend)34 ConvWinograd::ConvWinograd(const MNN::Convolution2D* op, Backend* backend) : Execution(backend) {
35 mOpenCLBackend = static_cast<OpenCLBackend*>(backend);
36 mCommon = op->common();
37 MNN_ASSERT((3 == mCommon->kernelY() && 3 == mCommon->kernelX()) ||
38 (5 == mCommon->kernelX() && 5 == mCommon->kernelY()));
39 MNN_ASSERT(1 == mCommon->strideX() && 1 == mCommon->strideY());
40 MNN_ASSERT(1 == mCommon->dilateX() && 1 == mCommon->dilateY());
41 auto runTime = mOpenCLBackend->getOpenCLRuntime();
42 int ky = mCommon->kernelY();
43 int kx = mCommon->kernelX();
44
45 int weightSize = 0;
46 const float* filterDataPtr = nullptr;
47
48 std::shared_ptr<MNN::ConvolutionCommon::Int8Common> quanCommon;
49 if (nullptr != op->quanParameter()) {
50 quanCommon = ConvolutionCommon::load(op->quanParameter(), true);
51 if (nullptr == quanCommon) {
52 MNN_ERROR("Memory not Enough, can't extract IDST Convolution \n");
53 }
54 if (quanCommon->weightFloat.get() == nullptr) {
55 MNN_PRINT("quanCommon->weightFloat.get() == nullptr \n");
56 }
57 // Back to float
58 filterDataPtr = quanCommon->weightFloat.get();
59 weightSize = quanCommon->weightFloat.size();
60 }
61
62 if (nullptr == filterDataPtr) {
63 weightSize = op->weight()->size();
64 filterDataPtr = op->weight()->data();
65 }
66
67 int co = mCommon->outputCount();
68 int ci = weightSize / co / mCommon->kernelX() / mCommon->kernelY();
69 auto coC4 = UP_DIV(co, 4);
70 auto ciC4 = UP_DIV(ci, 4);
71 auto queue = runTime->commandQueue();
72
73 auto imageChannelType = CL_HALF_FLOAT;
74 if (mOpenCLBackend->getPrecision() == BackendConfig::Precision_High) {
75 imageChannelType = CL_FLOAT;
76 }
77 // Create Image
78 {
79 mBias.reset(new cl::Image2D(runTime->context(), CL_MEM_READ_WRITE, cl::ImageFormat(CL_RGBA, imageChannelType),
80 UP_DIV(co, 4), 1, 0, nullptr, nullptr));
81
82 int buffer_size = ALIGN_UP4(co);
83 if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf()) {
84 buffer_size *= sizeof(half_float::half);
85 } else {
86 buffer_size *= sizeof(float);
87 }
88 std::shared_ptr<cl::Buffer> biasBuffer(
89 new cl::Buffer(runTime->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
90
91 cl_int error;
92 auto biasC = queue.enqueueMapBuffer(*biasBuffer, CL_TRUE, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error);
93 if(biasC != nullptr && error == CL_SUCCESS){
94 if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf()){
95 for(int i=0; i<co; i++) {
96 ((half_float::half*)biasC)[i] = (half_float::half)(op->bias()->data()[i]);
97 }
98 for(int i=co; i<ALIGN_UP4(co); i++) {
99 ((half_float::half*)biasC)[i] = (half_float::half)(0.0f);
100 }
101 }else{
102 ::memset(biasC, 0, buffer_size);
103 ::memcpy(biasC, op->bias()->data(), co * sizeof(float));
104 }
105 }else{
106 MNN_ERROR("Map error biasC == nullptr \n");
107 }
108 queue.enqueueUnmapMemObject(*biasBuffer, biasC);
109 copyBufferToImage(runTime, *biasBuffer, *mBias, coC4, 1);
110
111 std::shared_ptr<Tensor> sourceWeight(
112 Tensor::create<float>(std::vector<int>{co, ci, ky, kx}, (void*)(filterDataPtr), Tensor::CAFFE));
113
114 int unit = UNIT;
115 int kernelSize = kx;
116 Math::WinogradGenerater generator(unit, kernelSize, INTERP);
117 int alpha = unit + kernelSize - 1;
118 auto weightDest = generator.allocTransformWeight(sourceWeight.get());
119 generator.transformWeight(weightDest.get(), sourceWeight.get());
120 auto weightDestSize = weightDest->size();
121
122 buffer_size = weightDest->elementSize();
123 if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf()) {
124 buffer_size *= sizeof(half_float::half);
125 } else {
126 buffer_size *= sizeof(float);
127 }
128 cl::Buffer weightBuffer(runTime->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size);
129 {
130 cl_int error;
131 auto weightPtr = queue.enqueueMapBuffer(weightBuffer, CL_TRUE, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error);
132 if(weightPtr != nullptr && error == CL_SUCCESS){
133 if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf()){
134 for(int i=0; i<weightDest->elementSize(); i++) {
135 ((half_float::half*)weightPtr)[i] = (half_float::half)(weightDest->host<float>()[i]);
136 }
137 }else{
138 ::memcpy(weightPtr, weightDest->host<float>(), buffer_size);
139 }
140 } else{
141 MNN_ERROR("Map error weightPtr == nullptr \n");
142 }
143
144 queue.enqueueUnmapMemObject(weightBuffer, weightPtr);
145 }
146 mWeight.reset(new cl::Image2D(runTime->context(), CL_MEM_READ_WRITE, cl::ImageFormat(CL_RGBA, imageChannelType),
147 ciC4 * 4, coC4 * alpha * alpha, 0, nullptr, nullptr));
148 copyBufferToImage(runTime, weightBuffer, *mWeight, ciC4 * 4, coC4 * alpha * alpha);
149 }
150 }
151
onResize(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)152 ErrorCode ConvWinograd::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
153 auto input = inputs[0];
154 auto output = outputs[0];
155 mKernelX = mCommon->kernelX();
156 mKernelY = mCommon->kernelY();
157 mPadX = mCommon->padX();
158 mPadY = mCommon->padY();
159 mStrideX = mCommon->strideX();
160 mStrideY = mCommon->strideY();
161 mPadMode = mCommon->padMode();
162
163 int alpha = mCommon->kernelX() + UNIT - 1;
164 auto wUnit = UP_DIV(output->width(), UNIT);
165 auto hUnit = UP_DIV(output->height(), UNIT);
166 int padX = mPadX;
167 int padY = mPadY;
168 if (mPadMode == PadMode_SAME) {
169 int kernelWidthSize = (mKernelX - 1) * mCommon->dilateX() + 1;
170 int kernelHeightSize = (mKernelY - 1) * mCommon->dilateY() + 1;
171 int padNeededWidth = (output->width() - 1) * mStrideX + kernelWidthSize - input->width();
172 int padNeededHeight = (output->height() - 1) * mStrideY + kernelHeightSize - input->height();
173 padX = padNeededWidth / 2;
174 padY = padNeededHeight / 2;
175 }
176
177 auto runTime = mOpenCLBackend->getOpenCLRuntime();
178
179 int maxWidth = runTime->getMaxImage2DSize()[0];
180 int maxHeight = runTime->getMaxImage2DSize()[1];
181
182 int sourceWidth = UP_DIV(input->channel(), 4) * 4;
183 int sourceHeight = alpha * alpha * UP_DIV(wUnit * hUnit, 4);
184
185 int sliceNumber = 1;
186 const int maxSlice = 100;
187
188 if (maxWidth < sourceWidth || maxHeight < sourceHeight) {
189 for (int i = 2; i < maxSlice; ++i) {
190 int realWidth = (size_t)UP_DIV(input->channel(), 4) * 4;
191 int readHeight = (size_t)alpha * alpha * UP_DIV(UP_DIV(wUnit, i) * UP_DIV(hUnit, i), 4);
192
193 if (realWidth < maxWidth && readHeight < maxHeight) {
194 sliceNumber = i;
195 break;
196 }
197 }
198 }
199
200 mSliceNumber = sliceNumber;
201
202 int wPiece = UP_DIV(wUnit, sliceNumber);
203 int hPiece = UP_DIV(hUnit, sliceNumber);
204
205 auto bn = backend();
206 mSource.reset(Tensor::createDevice<float>(
207 std::vector<int>{alpha * alpha, input->channel(), UP_DIV(wPiece * hPiece, 4), 4}, Tensor::CAFFE_C4));
208 mDest.reset(Tensor::createDevice<float>(
209 std::vector<int>{4, wPiece * hPiece, UP_DIV(output->channel(), 4), alpha * alpha}, Tensor::CAFFE_C4));
210
211 bn->onAcquireBuffer(mSource.get(), Backend::DYNAMIC);
212 bn->onAcquireBuffer(mDest.get(), Backend::DYNAMIC);
213 bn->onReleaseBuffer(mSource.get(), Backend::DYNAMIC);
214 bn->onReleaseBuffer(mDest.get(), Backend::DYNAMIC);
215
216 auto icC4 = UP_DIV(input->channel(), 4);
217 auto ocC4 = UP_DIV(output->channel(), 4);
218
219 uint32_t total_num = input->batch()*mSliceNumber*mSliceNumber;
220 mSourceTransform.resize(total_num);
221 mMatMul.resize(total_num);
222 mDestTransform.resize(total_num);
223 mMaxWGS_S.resize(total_num);
224 mMaxWGS_D.resize(total_num);
225 mMaxWGS_M.resize(total_num);
226
227 std::set<std::string> basic;
228 /*Create Kernel*/
229 for(int i = 0; i < input->batch()*mSliceNumber*mSliceNumber; i++) {
230 char format[20];
231 ::memset(format, 0, sizeof(format));
232 sprintf(format, "%d_%d_%d", UNIT, mKernelX, INTERP);
233 auto formatStr = std::string(format);
234 mSourceTransform[i] =
235 runTime->buildKernel("winogradTransformSource" + formatStr,
236 "winogradTransformSource", basic);
237 mMaxWGS_S[i] = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(mSourceTransform[i]));
238 {
239 std::set<std::string> buildOptions = basic;
240 if (mCommon->relu()) {
241 buildOptions.emplace("-DRELU");
242 }
243 if (mCommon->relu6()) {
244 buildOptions.emplace("-DRELU6");
245 }
246 mDestTransform[i] =
247 runTime->buildKernel("winogradTransformDest" + formatStr,
248 "winogradTransformDest", buildOptions);
249 mMaxWGS_D[i] = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(mDestTransform[i]));
250 }
251 mMatMul[i] = runTime->buildKernel("gemm", "gemm", basic);
252 mMaxWGS_M[i] = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(mMatMul[i]));
253 }
254
255 mGWS_S.resize(total_num);
256 mGWS_D.resize(total_num);
257 mGWS_M.resize(total_num);
258 mLWS_S.resize(total_num);
259 mLWS_D.resize(total_num);
260 mLWS_M.resize(total_num);
261
262 for (int b = 0; b < input->batch(); ++b) {
263 std::vector<int> offsetData;
264 offsetData.push_back(0);
265 offsetData.push_back(0);
266
267 for (int y = 0; y < mSliceNumber; ++y) {
268 int hCount = hPiece;
269 if (y == mSliceNumber - 1) {
270 hCount = hUnit - (mSliceNumber - 1) * hPiece;
271 }
272 offsetData[1] = y * hPiece;
273
274 for (int x = 0; x < mSliceNumber; ++x) {
275 int wCount = wPiece;
276 if (x == mSliceNumber - 1) {
277 wCount = wUnit - (mSliceNumber - 1) * wPiece;
278 }
279 offsetData[0] = x * wPiece;
280
281 auto dest = mDest.get();
282 int index = b*mSliceNumber*mSliceNumber + y*mSliceNumber + x;
283
284 mSourceTransform[index].setArg(0, openCLImage(input));
285 mSourceTransform[index].setArg(1, openCLImage(mSource.get()));
286 mSourceTransform[index].setArg(4, padX);
287 mSourceTransform[index].setArg(5, padY);
288 mSourceTransform[index].setArg(6, input->width());
289 mSourceTransform[index].setArg(7, input->height());
290 mSourceTransform[index].setArg(8, icC4);
291
292 mMatMul[index].setArg(0, openCLImage(mSource.get()));
293 mMatMul[index].setArg(1, *mWeight);
294 mMatMul[index].setArg(4, ocC4);
295 mMatMul[index].setArg(5, icC4);
296 mMatMul[index].setArg(6, alpha*alpha);
297
298 mDestTransform[index].setArg(1, *mBias);
299 mDestTransform[index].setArg(2, openCLImage(output));
300 mDestTransform[index].setArg(5, output->width());
301 mDestTransform[index].setArg(6, output->height());
302 mDestTransform[index].setArg(7, ocC4);
303
304
305 mSourceTransform[index].setArg(2, wCount);
306 mSourceTransform[index].setArg(3, hCount);
307 mSourceTransform[index].setArg(9, offsetData[0]);
308 mSourceTransform[index].setArg(10, offsetData[1]);
309 mSourceTransform[index].setArg(11, b);
310
311 auto gemmWidth = UP_DIV(wCount * hCount, 4);
312 mMatMul[index].setArg(2, openCLImage(dest));
313 mMatMul[index].setArg(3, gemmWidth);
314
315 mDestTransform[index].setArg(0, openCLImage(dest));
316 mDestTransform[index].setArg(3, wCount);
317 mDestTransform[index].setArg(4, hCount);
318 mDestTransform[index].setArg(8, offsetData[0]);
319 mDestTransform[index].setArg(9, offsetData[1]);
320 mDestTransform[index].setArg(10, b);
321
322 /*Source Transform*/
323 {
324 mGWS_S[index] = {static_cast<uint32_t>(wCount * hCount), static_cast<uint32_t>(icC4)};
325 std::string kernelName = "winogradTransformSource";
326 mLWS_S[index] = localWS2DDefault(mGWS_S[index], mMaxWGS_S[index], mOpenCLBackend->getOpenCLRuntime(), kernelName, mSourceTransform[index]).first;
327 }
328
329 /*MatMul*/
330 {
331 auto gemmHeight = ocC4;
332 mGWS_M[index] = {static_cast<uint32_t>(gemmWidth*gemmHeight), static_cast<uint32_t>(alpha * alpha)};
333 std::string kernelName = "gemm";
334 mLWS_M[index] = localWS2DDefault(mGWS_M[index], mMaxWGS_M[index], mOpenCLBackend->getOpenCLRuntime(), kernelName, mMatMul[index]).first;
335 }
336
337 // Dest Transform
338 {
339 mGWS_D[index] = {static_cast<uint32_t>(wCount*hCount), static_cast<uint32_t>(ocC4)};
340 std::string kernelName = "winogradTransformDest";
341 mLWS_D[index] = localWS2DDefault(mGWS_D[index], mMaxWGS_D[index], mOpenCLBackend->getOpenCLRuntime(), kernelName, mDestTransform[index]).first;
342 }
343
344 }
345 }
346 }
347
348 return NO_ERROR;
349 }
350
onExecute(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)351 ErrorCode ConvWinograd::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
352 auto input = inputs[0];
353 auto output = outputs[0];
354
355 #ifdef ENABLE_OPENCL_TIME_PROFILER
356 int costTime = 0;
357 #endif
358 for (int b = 0; b < input->batch(); ++b) {
359 for (int y = 0; y < mSliceNumber; ++y) {
360 for (int x = 0; x < mSliceNumber; ++x) {
361 int index = b*mSliceNumber*mSliceNumber + y*mSliceNumber + x;
362
363 /*Source Transform*/
364 {
365 #ifdef ENABLE_OPENCL_TIME_PROFILER
366 cl::Event event;
367 runKernel2D(mSourceTransform[index], mGWS_S[index], mLWS_S[index],
368 mOpenCLBackend->getOpenCLRuntime(), &event);
369
370 int costTime0 = (int)mOpenCLBackend->getOpenCLRuntime()->getCostTime(&event);
371 costTime += costTime0;
372 MNN_PRINT("kernel cost:%d us ConvWino0\n",costTime0);
373 #else
374 runKernel2D(mSourceTransform[index], mGWS_S[index], mLWS_S[index],
375 mOpenCLBackend->getOpenCLRuntime());
376 #endif
377 }
378
379 /*MatMul*/
380 {
381 #ifdef ENABLE_OPENCL_TIME_PROFILER
382 cl::Event event;
383 runKernel2D(mMatMul[index], mGWS_M[index], mLWS_M[index],
384 mOpenCLBackend->getOpenCLRuntime(), &event);
385
386 int costTime1 = (int)mOpenCLBackend->getOpenCLRuntime()->getCostTime(&event);
387 costTime += costTime1;
388 MNN_PRINT("kernel cost:%d us ConvWino1\n",costTime1);
389 #else
390 runKernel2D(mMatMul[index], mGWS_M[index], mLWS_M[index],
391 mOpenCLBackend->getOpenCLRuntime());
392 #endif
393 }
394
395 // Dest Transform
396 {
397 #ifdef ENABLE_OPENCL_TIME_PROFILER
398 cl::Event event;
399 runKernel2D(mDestTransform[index], mGWS_D[index], mLWS_D[index],
400 mOpenCLBackend->getOpenCLRuntime(), &event);
401
402 int costTime2 = (int)mOpenCLBackend->getOpenCLRuntime()->getCostTime(&event);
403 costTime += costTime2;
404 MNN_PRINT("kernel cost:%d us ConvWino2\n",costTime2);
405 #else
406 runKernel2D(mDestTransform[index], mGWS_D[index], mLWS_D[index],
407 mOpenCLBackend->getOpenCLRuntime());
408 #endif
409 }
410 }
411 }
412 }
413 #ifdef ENABLE_OPENCL_TIME_PROFILER
414 MNN_PRINT("kernel cost:%d us ConvWino total\n",costTime);
415 #endif
416
417 return NO_ERROR;
418 }
419
420 } // namespace OpenCL
421 } // namespace MNN
422