1 //
2 //  ConvWinograd.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/01/08.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "backend/opencl/execution/image/ConvWinograd.hpp"
10 #include <string.h>
11 #include "core/Backend.hpp"
12 #include "core/ConvolutionCommon.hpp"
13 #include "math/WingoradGenerater.hpp"
14 #include "backend/opencl/core/OpenCLRunningUtils.hpp"
15 #define UNIT 2
16 #define INTERP 1
17 namespace MNN {
18 namespace OpenCL {
valid(const Convolution2DCommon * common,const Tensor * input,int limit)19 bool ConvWinograd::valid(const Convolution2DCommon* common, const Tensor* input, int limit) {
20     if (common->strideX() != 1 || common->strideY() != 1) {
21         return false;
22     }
23     if (common->dilateX() != 1 || common->dilateY() != 1) {
24         return false;
25     }
26     if (input->channel() < 8 || common->outputCount() < 8) {
27         return false;
28     }
29 
30     return (common->kernelX() == 3 && common->kernelY() == 3) || (common->kernelX() == 5 && common->kernelY() == 5);
31 }
32 
33 
ConvWinograd(const MNN::Convolution2D * op,Backend * backend)34 ConvWinograd::ConvWinograd(const MNN::Convolution2D* op, Backend* backend) : Execution(backend) {
35     mOpenCLBackend = static_cast<OpenCLBackend*>(backend);
36     mCommon        = op->common();
37     MNN_ASSERT((3 == mCommon->kernelY() && 3 == mCommon->kernelX()) ||
38                (5 == mCommon->kernelX() && 5 == mCommon->kernelY()));
39     MNN_ASSERT(1 == mCommon->strideX() && 1 == mCommon->strideY());
40     MNN_ASSERT(1 == mCommon->dilateX() && 1 == mCommon->dilateY());
41     auto runTime = mOpenCLBackend->getOpenCLRuntime();
42     int ky       = mCommon->kernelY();
43     int kx       = mCommon->kernelX();
44 
45     int weightSize             = 0;
46     const float* filterDataPtr = nullptr;
47 
48     std::shared_ptr<MNN::ConvolutionCommon::Int8Common> quanCommon;
49     if (nullptr != op->quanParameter()) {
50         quanCommon = ConvolutionCommon::load(op->quanParameter(), true);
51         if (nullptr == quanCommon) {
52             MNN_ERROR("Memory not Enough, can't extract IDST Convolution \n");
53         }
54         if (quanCommon->weightFloat.get() == nullptr) {
55             MNN_PRINT("quanCommon->weightFloat.get() == nullptr \n");
56         }
57         // Back to float
58         filterDataPtr = quanCommon->weightFloat.get();
59         weightSize    = quanCommon->weightFloat.size();
60     }
61 
62     if (nullptr == filterDataPtr) {
63         weightSize    = op->weight()->size();
64         filterDataPtr = op->weight()->data();
65     }
66 
67     int co     = mCommon->outputCount();
68     int ci     = weightSize / co / mCommon->kernelX() / mCommon->kernelY();
69     auto coC4  = UP_DIV(co, 4);
70     auto ciC4  = UP_DIV(ci, 4);
71     auto queue = runTime->commandQueue();
72 
73     auto imageChannelType = CL_HALF_FLOAT;
74     if (mOpenCLBackend->getPrecision() == BackendConfig::Precision_High) {
75         imageChannelType = CL_FLOAT;
76     }
77     // Create Image
78     {
79         mBias.reset(new cl::Image2D(runTime->context(), CL_MEM_READ_WRITE, cl::ImageFormat(CL_RGBA, imageChannelType),
80                                     UP_DIV(co, 4), 1, 0, nullptr, nullptr));
81 
82         int buffer_size = ALIGN_UP4(co);
83         if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf()) {
84             buffer_size *= sizeof(half_float::half);
85         } else {
86             buffer_size *= sizeof(float);
87         }
88         std::shared_ptr<cl::Buffer> biasBuffer(
89             new cl::Buffer(runTime->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
90 
91         cl_int error;
92         auto biasC = queue.enqueueMapBuffer(*biasBuffer, CL_TRUE, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error);
93         if(biasC != nullptr && error == CL_SUCCESS){
94             if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf()){
95                 for(int i=0; i<co; i++) {
96                     ((half_float::half*)biasC)[i] = (half_float::half)(op->bias()->data()[i]);
97                 }
98                 for(int i=co; i<ALIGN_UP4(co); i++) {
99                     ((half_float::half*)biasC)[i] = (half_float::half)(0.0f);
100                 }
101             }else{
102                 ::memset(biasC, 0, buffer_size);
103                 ::memcpy(biasC, op->bias()->data(), co * sizeof(float));
104             }
105         }else{
106             MNN_ERROR("Map error biasC == nullptr \n");
107         }
108         queue.enqueueUnmapMemObject(*biasBuffer, biasC);
109         copyBufferToImage(runTime, *biasBuffer, *mBias, coC4, 1);
110 
111         std::shared_ptr<Tensor> sourceWeight(
112             Tensor::create<float>(std::vector<int>{co, ci, ky, kx}, (void*)(filterDataPtr), Tensor::CAFFE));
113 
114         int unit       = UNIT;
115         int kernelSize = kx;
116         Math::WinogradGenerater generator(unit, kernelSize, INTERP);
117         int alpha       = unit + kernelSize - 1;
118         auto weightDest = generator.allocTransformWeight(sourceWeight.get());
119         generator.transformWeight(weightDest.get(), sourceWeight.get());
120         auto weightDestSize = weightDest->size();
121 
122         buffer_size = weightDest->elementSize();
123         if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf()) {
124             buffer_size *= sizeof(half_float::half);
125         } else {
126             buffer_size *= sizeof(float);
127         }
128         cl::Buffer weightBuffer(runTime->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size);
129         {
130             cl_int error;
131             auto weightPtr = queue.enqueueMapBuffer(weightBuffer, CL_TRUE, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error);
132             if(weightPtr != nullptr && error == CL_SUCCESS){
133                 if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf()){
134                     for(int i=0; i<weightDest->elementSize(); i++) {
135                         ((half_float::half*)weightPtr)[i] = (half_float::half)(weightDest->host<float>()[i]);
136                     }
137                 }else{
138                     ::memcpy(weightPtr, weightDest->host<float>(), buffer_size);
139                 }
140             } else{
141                 MNN_ERROR("Map error weightPtr == nullptr \n");
142             }
143 
144             queue.enqueueUnmapMemObject(weightBuffer, weightPtr);
145         }
146         mWeight.reset(new cl::Image2D(runTime->context(), CL_MEM_READ_WRITE, cl::ImageFormat(CL_RGBA, imageChannelType),
147                                       ciC4 * 4, coC4 * alpha * alpha, 0, nullptr, nullptr));
148         copyBufferToImage(runTime, weightBuffer, *mWeight, ciC4 * 4, coC4 * alpha * alpha);
149     }
150 }
151 
onResize(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)152 ErrorCode ConvWinograd::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
153     auto input  = inputs[0];
154     auto output = outputs[0];
155     mKernelX    = mCommon->kernelX();
156     mKernelY    = mCommon->kernelY();
157     mPadX       = mCommon->padX();
158     mPadY       = mCommon->padY();
159     mStrideX    = mCommon->strideX();
160     mStrideY    = mCommon->strideY();
161     mPadMode    = mCommon->padMode();
162 
163     int alpha  = mCommon->kernelX() + UNIT - 1;
164     auto wUnit = UP_DIV(output->width(), UNIT);
165     auto hUnit = UP_DIV(output->height(), UNIT);
166     int padX   = mPadX;
167     int padY   = mPadY;
168     if (mPadMode == PadMode_SAME) {
169         int kernelWidthSize  = (mKernelX - 1) * mCommon->dilateX() + 1;
170         int kernelHeightSize = (mKernelY - 1) * mCommon->dilateY() + 1;
171         int padNeededWidth   = (output->width() - 1) * mStrideX + kernelWidthSize - input->width();
172         int padNeededHeight  = (output->height() - 1) * mStrideY + kernelHeightSize - input->height();
173         padX                 = padNeededWidth / 2;
174         padY                 = padNeededHeight / 2;
175     }
176 
177     auto runTime = mOpenCLBackend->getOpenCLRuntime();
178 
179     int maxWidth  = runTime->getMaxImage2DSize()[0];
180     int maxHeight = runTime->getMaxImage2DSize()[1];
181 
182     int sourceWidth  = UP_DIV(input->channel(), 4) * 4;
183     int sourceHeight = alpha * alpha * UP_DIV(wUnit * hUnit, 4);
184 
185     int sliceNumber    = 1;
186     const int maxSlice = 100;
187 
188     if (maxWidth < sourceWidth || maxHeight < sourceHeight) {
189         for (int i = 2; i < maxSlice; ++i) {
190             int realWidth  = (size_t)UP_DIV(input->channel(), 4) * 4;
191             int readHeight = (size_t)alpha * alpha * UP_DIV(UP_DIV(wUnit, i) * UP_DIV(hUnit, i), 4);
192 
193             if (realWidth < maxWidth && readHeight < maxHeight) {
194                 sliceNumber = i;
195                 break;
196             }
197         }
198     }
199 
200     mSliceNumber = sliceNumber;
201 
202     int wPiece = UP_DIV(wUnit, sliceNumber);
203     int hPiece = UP_DIV(hUnit, sliceNumber);
204 
205     auto bn = backend();
206     mSource.reset(Tensor::createDevice<float>(
207         std::vector<int>{alpha * alpha, input->channel(), UP_DIV(wPiece * hPiece, 4), 4}, Tensor::CAFFE_C4));
208     mDest.reset(Tensor::createDevice<float>(
209         std::vector<int>{4, wPiece * hPiece, UP_DIV(output->channel(), 4), alpha * alpha}, Tensor::CAFFE_C4));
210 
211     bn->onAcquireBuffer(mSource.get(), Backend::DYNAMIC);
212     bn->onAcquireBuffer(mDest.get(), Backend::DYNAMIC);
213     bn->onReleaseBuffer(mSource.get(), Backend::DYNAMIC);
214     bn->onReleaseBuffer(mDest.get(), Backend::DYNAMIC);
215 
216     auto icC4 = UP_DIV(input->channel(), 4);
217     auto ocC4 = UP_DIV(output->channel(), 4);
218 
219     uint32_t total_num = input->batch()*mSliceNumber*mSliceNumber;
220     mSourceTransform.resize(total_num);
221     mMatMul.resize(total_num);
222     mDestTransform.resize(total_num);
223     mMaxWGS_S.resize(total_num);
224     mMaxWGS_D.resize(total_num);
225     mMaxWGS_M.resize(total_num);
226 
227     std::set<std::string> basic;
228     /*Create Kernel*/
229     for(int i = 0; i < input->batch()*mSliceNumber*mSliceNumber; i++) {
230         char format[20];
231         ::memset(format, 0, sizeof(format));
232         sprintf(format, "%d_%d_%d", UNIT, mKernelX, INTERP);
233         auto formatStr = std::string(format);
234         mSourceTransform[i] =
235             runTime->buildKernel("winogradTransformSource" + formatStr,
236                                  "winogradTransformSource", basic);
237         mMaxWGS_S[i] = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(mSourceTransform[i]));
238         {
239             std::set<std::string> buildOptions = basic;
240             if (mCommon->relu()) {
241                 buildOptions.emplace("-DRELU");
242             }
243             if (mCommon->relu6()) {
244                 buildOptions.emplace("-DRELU6");
245             }
246             mDestTransform[i] =
247                 runTime->buildKernel("winogradTransformDest" + formatStr,
248                                      "winogradTransformDest", buildOptions);
249             mMaxWGS_D[i] = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(mDestTransform[i]));
250         }
251         mMatMul[i] = runTime->buildKernel("gemm", "gemm", basic);
252         mMaxWGS_M[i] = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(mMatMul[i]));
253     }
254 
255     mGWS_S.resize(total_num);
256     mGWS_D.resize(total_num);
257     mGWS_M.resize(total_num);
258     mLWS_S.resize(total_num);
259     mLWS_D.resize(total_num);
260     mLWS_M.resize(total_num);
261 
262     for (int b = 0; b < input->batch(); ++b) {
263         std::vector<int> offsetData;
264         offsetData.push_back(0);
265         offsetData.push_back(0);
266 
267         for (int y = 0; y < mSliceNumber; ++y) {
268             int hCount = hPiece;
269             if (y == mSliceNumber - 1) {
270                 hCount = hUnit - (mSliceNumber - 1) * hPiece;
271             }
272             offsetData[1] = y * hPiece;
273 
274             for (int x = 0; x < mSliceNumber; ++x) {
275                 int wCount = wPiece;
276                 if (x == mSliceNumber - 1) {
277                     wCount = wUnit - (mSliceNumber - 1) * wPiece;
278                 }
279                 offsetData[0] = x * wPiece;
280 
281                 auto dest = mDest.get();
282                 int index = b*mSliceNumber*mSliceNumber + y*mSliceNumber + x;
283 
284                 mSourceTransform[index].setArg(0, openCLImage(input));
285                 mSourceTransform[index].setArg(1, openCLImage(mSource.get()));
286                 mSourceTransform[index].setArg(4, padX);
287                 mSourceTransform[index].setArg(5, padY);
288                 mSourceTransform[index].setArg(6, input->width());
289                 mSourceTransform[index].setArg(7, input->height());
290                 mSourceTransform[index].setArg(8, icC4);
291 
292                 mMatMul[index].setArg(0, openCLImage(mSource.get()));
293                 mMatMul[index].setArg(1, *mWeight);
294                 mMatMul[index].setArg(4, ocC4);
295                 mMatMul[index].setArg(5, icC4);
296                 mMatMul[index].setArg(6, alpha*alpha);
297 
298                 mDestTransform[index].setArg(1, *mBias);
299                 mDestTransform[index].setArg(2, openCLImage(output));
300                 mDestTransform[index].setArg(5, output->width());
301                 mDestTransform[index].setArg(6, output->height());
302                 mDestTransform[index].setArg(7, ocC4);
303 
304 
305                 mSourceTransform[index].setArg(2, wCount);
306                 mSourceTransform[index].setArg(3, hCount);
307                 mSourceTransform[index].setArg(9, offsetData[0]);
308                 mSourceTransform[index].setArg(10, offsetData[1]);
309                 mSourceTransform[index].setArg(11, b);
310 
311                 auto gemmWidth = UP_DIV(wCount * hCount, 4);
312                 mMatMul[index].setArg(2, openCLImage(dest));
313                 mMatMul[index].setArg(3, gemmWidth);
314 
315                 mDestTransform[index].setArg(0, openCLImage(dest));
316                 mDestTransform[index].setArg(3, wCount);
317                 mDestTransform[index].setArg(4, hCount);
318                 mDestTransform[index].setArg(8, offsetData[0]);
319                 mDestTransform[index].setArg(9, offsetData[1]);
320                 mDestTransform[index].setArg(10, b);
321 
322                 /*Source Transform*/
323                 {
324                     mGWS_S[index] = {static_cast<uint32_t>(wCount * hCount), static_cast<uint32_t>(icC4)};
325                     std::string kernelName = "winogradTransformSource";
326                     mLWS_S[index] = localWS2DDefault(mGWS_S[index], mMaxWGS_S[index], mOpenCLBackend->getOpenCLRuntime(), kernelName, mSourceTransform[index]).first;
327                 }
328 
329                 /*MatMul*/
330                 {
331                     auto gemmHeight = ocC4;
332                     mGWS_M[index] = {static_cast<uint32_t>(gemmWidth*gemmHeight), static_cast<uint32_t>(alpha * alpha)};
333                     std::string kernelName = "gemm";
334                     mLWS_M[index] = localWS2DDefault(mGWS_M[index], mMaxWGS_M[index], mOpenCLBackend->getOpenCLRuntime(), kernelName, mMatMul[index]).first;
335                 }
336 
337                 // Dest Transform
338                 {
339                     mGWS_D[index] = {static_cast<uint32_t>(wCount*hCount), static_cast<uint32_t>(ocC4)};
340                     std::string kernelName = "winogradTransformDest";
341                     mLWS_D[index] = localWS2DDefault(mGWS_D[index], mMaxWGS_D[index], mOpenCLBackend->getOpenCLRuntime(), kernelName, mDestTransform[index]).first;
342                 }
343 
344             }
345         }
346     }
347 
348     return NO_ERROR;
349 }
350 
onExecute(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)351 ErrorCode ConvWinograd::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
352     auto input  = inputs[0];
353     auto output = outputs[0];
354 
355     #ifdef ENABLE_OPENCL_TIME_PROFILER
356     int costTime = 0;
357     #endif
358     for (int b = 0; b < input->batch(); ++b) {
359         for (int y = 0; y < mSliceNumber; ++y) {
360             for (int x = 0; x < mSliceNumber; ++x) {
361                 int index = b*mSliceNumber*mSliceNumber + y*mSliceNumber + x;
362 
363                 /*Source Transform*/
364                 {
365                 #ifdef ENABLE_OPENCL_TIME_PROFILER
366                     cl::Event event;
367                     runKernel2D(mSourceTransform[index], mGWS_S[index], mLWS_S[index],
368                                 mOpenCLBackend->getOpenCLRuntime(), &event);
369 
370                     int costTime0 = (int)mOpenCLBackend->getOpenCLRuntime()->getCostTime(&event);
371                     costTime += costTime0;
372                     MNN_PRINT("kernel cost:%d    us ConvWino0\n",costTime0);
373                 #else
374                     runKernel2D(mSourceTransform[index], mGWS_S[index], mLWS_S[index],
375                                 mOpenCLBackend->getOpenCLRuntime());
376                 #endif
377                 }
378 
379                 /*MatMul*/
380                 {
381                 #ifdef ENABLE_OPENCL_TIME_PROFILER
382                     cl::Event event;
383                     runKernel2D(mMatMul[index], mGWS_M[index], mLWS_M[index],
384                                 mOpenCLBackend->getOpenCLRuntime(), &event);
385 
386                     int costTime1 = (int)mOpenCLBackend->getOpenCLRuntime()->getCostTime(&event);
387                     costTime += costTime1;
388                     MNN_PRINT("kernel cost:%d    us ConvWino1\n",costTime1);
389                 #else
390                     runKernel2D(mMatMul[index], mGWS_M[index], mLWS_M[index],
391                                 mOpenCLBackend->getOpenCLRuntime());
392                 #endif
393                 }
394 
395                 // Dest Transform
396                 {
397                 #ifdef ENABLE_OPENCL_TIME_PROFILER
398                     cl::Event event;
399                     runKernel2D(mDestTransform[index], mGWS_D[index], mLWS_D[index],
400                                 mOpenCLBackend->getOpenCLRuntime(), &event);
401 
402                     int costTime2 = (int)mOpenCLBackend->getOpenCLRuntime()->getCostTime(&event);
403                     costTime += costTime2;
404                     MNN_PRINT("kernel cost:%d    us ConvWino2\n",costTime2);
405                 #else
406                     runKernel2D(mDestTransform[index], mGWS_D[index], mLWS_D[index],
407                                 mOpenCLBackend->getOpenCLRuntime());
408                 #endif
409                 }
410             }
411         }
412     }
413     #ifdef ENABLE_OPENCL_TIME_PROFILER
414     MNN_PRINT("kernel cost:%d    us ConvWino total\n",costTime);
415     #endif
416 
417     return NO_ERROR;
418 }
419 
420 } // namespace OpenCL
421 } // namespace MNN
422