1 //
2 //  Convolution1x1Strassen.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/02/12.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "Convolution1x1Strassen.hpp"
10 #include <string.h>
11 #include "core/BufferAllocator.hpp"
12 #include "backend/cpu/CPUBackend.hpp"
13 #include "core/Concurrency.h"
14 #include "ConvOpt.h"
15 #include "core/Macro.h"
16 #include "CommonOptFunction.h"
17 
18 namespace MNN {
Convolution1x1Strassen(const Convolution2DCommon * common,Backend * b,const float * originWeight,size_t originWeightSize,const float * bias,size_t biasSize)19 Convolution1x1Strassen::Convolution1x1Strassen(const Convolution2DCommon *common, Backend *b, const float *originWeight,
20                                                size_t originWeightSize, const float *bias, size_t biasSize)
21     : CPUConvolution(common, b) {
22     auto outputCount = (int)biasSize;
23     auto mSrcCount   = (int)originWeightSize / outputCount;
24     mResource.reset(new CPUConvolution::Resource);
25     mResource->backend = b;
26     if (!mResource->copyBiasAlign(bias, biasSize)) {
27         MNN_ERROR("Not Enough Memory\n");
28         mValid = false;
29         return;
30     }
31     auto core = static_cast<CPUBackend*>(b)->functions();
32     int ePack, lPack, hPack;
33     core->MNNGetMatMulPackMode(&ePack, &lPack, &hPack);
34     mResource->mWeight.reset(Tensor::createDevice<float>(std::vector<int>{UP_DIV(outputCount, hPack), UP_DIV(mSrcCount, lPack) * lPack, hPack}));
35     mValid = b->onAcquireBuffer(mResource->mWeight.get(), Backend::STATIC);
36     if (!mValid) {
37         MNN_ERROR("Not Enough Memory\n");
38         return;
39     }
40     if (core->bytes < 4) {
41         AutoRelease<Tensor> tempTensor(Tensor::createDevice<float>({outputCount * mSrcCount}));
42         mValid = b->onAcquireBuffer(tempTensor.get(), Backend::STATIC);
43         if (!mValid) {
44             MNN_ERROR("Not Enough Memory\n");
45             return;
46         }
47         core->MNNFp32ToLowp(originWeight, tempTensor->host<int16_t>(), outputCount * mSrcCount);
48         core->MNNPackForMatMul_B(mResource->mWeight->host<float>(), tempTensor->host<float>(), outputCount, mSrcCount, true);
49         b->onReleaseBuffer(tempTensor.get(), Backend::STATIC);
50     } else {
51         core->MNNPackForMatMul_B(mResource->mWeight->host<float>(), originWeight, outputCount, mSrcCount, true);
52     }
53 }
Convolution1x1Strassen(std::shared_ptr<CPUConvolution::Resource> resource,const Convolution2DCommon * common,Backend * b)54 Convolution1x1Strassen::Convolution1x1Strassen(std::shared_ptr<CPUConvolution::Resource> resource, const Convolution2DCommon *common, Backend* b) : CPUConvolution(common, b) {
55     mResource = resource;
56 }
57 
~Convolution1x1Strassen()58 Convolution1x1Strassen::~Convolution1x1Strassen() {
59     // Do nothing
60 }
61 
onClone(Backend * bn,const Op * op,Execution ** dst)62 bool Convolution1x1Strassen::onClone(Backend* bn, const Op* op, Execution** dst) {
63     if (!mValid) {
64         return false;
65     }
66     if (nullptr == dst) {
67         return true;
68     }
69     *dst = new Convolution1x1Strassen(mResource, op->main_as_Convolution2D()->common(), bn);
70     return true;
71 }
72 
onResize(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)73 ErrorCode Convolution1x1Strassen::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
74     CPUConvolution::onResize(inputs, outputs);
75     auto core = static_cast<CPUBackend*>(backend())->functions();
76     int ePack, lPack, hPack;
77     core->MNNGetMatMulPackMode(&ePack, &lPack, &hPack);
78     int bytes = core->bytes;
79     auto CONVOLUTION_TILED_NUMBER = ePack;
80     auto input       = inputs[0];
81     auto output      = outputs[0];
82     int numberThread = ((CPUBackend *)backend())->threadNumber();
83     auto ic = input->channel();
84     auto oc = output->channel();
85     auto icC4        = UP_DIV(ic, core->pack);
86     auto ocC4        = UP_DIV(oc, core->pack);
87     auto batch       = input->batch();
88     auto matrixSizeE = output->height() * output->width() * input->batch();
89     auto outputPlane = output->height() * output->width();
90     mUnits.clear();
91     auto inputPtr  = input->host<uint8_t>();
92     auto outputPtr = output->host<uint8_t>();
93     mTempOutputBatch.reset();
94     mTempInputBatch.reset();
95     std::shared_ptr<char> __autoFunction;
96     auto padY     = mPadY;
97     auto padX     = mPadX;
98     auto strideX  = mCommon->strideX();
99     auto strideY  = mCommon->strideY();
100     mNeedPretreat = input->batch() > 1 || (!(padX == 0 && padY == 0 && strideY == 1 && strideX == 1));
101     auto postParameters = getPostParameters();
102     if (mNeedPretreat) {
103         mTempInputBatch.reset(Tensor::createDevice<float>(std::vector<int>{icC4, matrixSizeE, core->pack}));
104         mTempOutputBatch.reset(Tensor::createDevice<float>(std::vector<int>{ocC4, matrixSizeE, core->pack}));
105         bool success = backend()->onAcquireBuffer(mTempOutputBatch.get(), Backend::DYNAMIC);
106         success      = success && backend()->onAcquireBuffer(mTempInputBatch.get(), Backend::DYNAMIC);
107         if (!success) {
108             return OUT_OF_MEMORY;
109         }
110         inputPtr       = mTempInputBatch->host<uint8_t>();
111         outputPtr      = mTempOutputBatch->host<uint8_t>();
112         __autoFunction = std::shared_ptr<char>(nullptr, [this](void *ptr) {
113             backend()->onReleaseBuffer(mTempOutputBatch.get(), Backend::DYNAMIC);
114             backend()->onReleaseBuffer(mTempInputBatch.get(), Backend::DYNAMIC);
115         });
116         auto ow        = output->width();
117         auto oh        = output->height();
118         auto iw        = input->width();
119         auto ih        = input->height();
120         if (padX == 0 && padY == 0 && strideY == 1 && strideX == 1) {
121             mPretreatFunction = [outputPlane, icC4, batch, numberThread, this, core](const uint8_t *srcBatch, uint8_t *dstBatch) {
122                 MNN_CONCURRENCY_BEGIN(y, icC4) {
123                     auto srcY = srcBatch + outputPlane * y * core->pack * core->bytes;
124                     auto dstY = dstBatch + y * outputPlane * batch * core->pack * core->bytes;
125                     for (int x = 0; x < batch; ++x) {
126                         auto srcX = srcY + x * outputPlane * icC4 * core->pack * core->bytes;
127                         auto dstX = dstY + x * outputPlane * core->pack * core->bytes;
128                         ::memcpy(dstX, srcX, outputPlane * core->pack * core->bytes);
129                     }
130                 }
131                 MNN_CONCURRENCY_END();
132             };
133         } else if (strideY == 1 && strideX == 1) {
134             mPretreatFunction = [outputPlane, padY, padX, ow, oh, iw, ih, icC4, batch, this, core](const uint8_t *srcOrigin,
135                                                                                                     uint8_t *dstOrigin) {
136                 auto unitBytes = core->bytes * core->pack;
137                 ::memset(dstOrigin, 0, outputPlane * batch * unitBytes * icC4);
138                 MNN_CONCURRENCY_BEGIN(z, icC4) {
139                     auto srcZ = srcOrigin + z * iw * ih * unitBytes;
140                     auto dstZ = dstOrigin + z * ow * oh * batch * unitBytes;
141                     for (int b = 0; b < batch; ++b) {
142                         auto srcBatch = srcZ + b * iw * ih * icC4 * unitBytes;
143                         auto dstBatch = dstZ + b * ow * oh * unitBytes;
144                         for (int y = 0; y < ih; ++y) {
145                             auto src = srcBatch + iw * y * unitBytes;
146                             auto dst = dstBatch + (ow * (y + padY) + padX) * unitBytes;
147                             ::memcpy(dst, src, iw * unitBytes);
148                         }
149                     }
150                 }
151                 MNN_CONCURRENCY_END();
152             };
153         } else {
154             int oyStart, oyEnd, oxStart, oxEnd;
155             for (oyStart = 0; oyStart * strideY - padY < 0; ++oyStart) {
156                 // do nothing
157             }
158             for (oyEnd = oh - 1; oyEnd * strideY - padY >= ih; --oyEnd) {
159                 // do nothing
160             }
161             for (oxStart = 0; oxStart * strideX - padX < 0; ++oxStart) {
162                 // do nothing
163             }
164             for (oxEnd = ow - 1; oxEnd * strideX - padX >= iw; --oxEnd) {
165                 // do nothing
166             }
167             int oyCount       = oyEnd - oyStart + 1;
168             int oxCount       = oxEnd - oxStart + 1;
169             mPretreatFunction = [outputPlane, padY, padX, strideX, strideY, ow, oh, iw, ih, icC4, oxStart, oyStart,
170                                  oxCount, oyCount, batch, this, core](const uint8_t *srcOrigin, uint8_t *dstOrigin) {
171                 ::memset(dstOrigin, 0, outputPlane * batch * core->bytes * core->pack * icC4);
172                 auto srcStride = strideX;
173                 auto dstStride = 1;
174                 int syStart    = oyStart * strideY - padY;
175                 int sxStart    = oxStart * strideX - padX;
176                 MNN_CONCURRENCY_BEGIN(z, icC4) {
177                     auto srcZ = srcOrigin + (z * iw * ih + syStart * iw + sxStart) * core->bytes * core->pack;
178                     auto dstZ = dstOrigin + (z * ow * oh * batch + oyStart * ow + oxStart) * core->bytes * core->pack;
179                     for (int b = 0; b < batch; ++b) {
180                         auto srcBatch = srcZ + b * iw * ih * icC4 * core->bytes * core->pack;
181                         auto dstBatch = dstZ + b * ow * oh * core->bytes * core->pack;
182                         for (int y = 0; y < oyCount; ++y) {
183                             auto dstY = dstBatch + y * ow * core->bytes * core->pack;
184                             auto srcY = srcBatch + y * strideY * iw * core->bytes * core->pack;
185                             core->MNNCopyC4WithStride((const float*)(srcY), (float*)(dstY), strideX * core->pack, core->pack, oxCount);
186                         }
187                     }
188                 }
189                 MNN_CONCURRENCY_END();
190             };
191         }
192     }
193     auto memoryPool = ((CPUBackend *)backend())->getBufferAllocator();
194     memoryPool->barrierBegin();
195     std::shared_ptr<void> __a(nullptr, [memoryPool](void *) { memoryPool->barrierEnd(); });
196     int maxDepth = 5;
197     auto icAlign = UP_DIV(ic, lPack) * lPack;
198     auto weightTensor = mResource->mWeight.get();
199     AutoRelease<Tensor> tempWeight;
200     if (icAlign != ic) {
201         tempWeight.reset(Tensor::create<float>(std::vector<int>{oc, ic, hPack}, mResource->mWeight->host<uint8_t>()));
202         weightTensor = tempWeight.get();
203     }
204     if (matrixSizeE > CONVOLUTION_TILED_NUMBER * 8 * numberThread && matrixSizeE > ocC4) {
205         // Divide in plane, in this case the divide equal numberThread
206         int divideStep = UP_DIV(matrixSizeE, numberThread);
207         mUnits.resize(numberThread);
208         for (int i = 0; i < numberThread; ++i) {
209             int planeStart = i * divideStep;
210             int planeEnd   = std::min(planeStart + divideStep, matrixSizeE);
211             int planeSize  = planeEnd - planeStart;
212             Unit &unit     = mUnits[i];
213             if (planeSize <= 0) {
214                 unit.mValid = false;
215                 continue;
216             }
217             unit.mStracssenComputor.reset(new StrassenMatrixComputor(backend(), false, maxDepth));
218             AutoRelease<Tensor> mTempInput(
219                 Tensor::create<float>(std::vector<int>{icC4, planeSize, core->pack}, inputPtr + core->pack * planeStart * bytes));
220             mTempInput->setStride(0, matrixSizeE * core->pack);
221             AutoRelease<Tensor> mTempOutput(
222                 Tensor::create<float>(std::vector<int>{ocC4, planeSize, core->pack}, outputPtr + core->pack * planeStart * bytes));
223             mTempOutput->setStride(0, matrixSizeE * core->pack);
224             unit.mTempInputVector  = std::vector<Tensor *>{mTempInput.get(), weightTensor, mResource->mBias.get()};
225             unit.mTempOutputVector = std::vector<Tensor *>{mTempOutput.get()};
226             memoryPool->beginGroup();
227             auto code = unit.mStracssenComputor->onEncode(unit.mTempInputVector, unit.mTempOutputVector, postParameters, ic, oc);
228             if (NO_ERROR != code) {
229                 memoryPool->endGroup();
230                 return code;
231             }
232             memoryPool->endGroup();
233         }
234     } else {
235         // Divide in ocC4
236         auto hDiv = 1;
237         if (hPack > core->pack) {
238             hDiv = hPack / core->pack;
239         }
240         auto ocDiv = UP_DIV(ocC4, hDiv);
241         numberThread   = std::min(numberThread, ocDiv);
242         int divideStep = (ocDiv / numberThread) * hDiv;
243         mUnits.resize(numberThread);
244         for (int i = 0; i < numberThread; ++i) {
245             int ocStart = i * divideStep;
246             int ocSize  = divideStep;
247             if (i == numberThread - 1) {
248                 ocSize = ocC4 - i * divideStep;
249             }
250             Unit &unit  = mUnits[i];
251             if (ocSize <= 0) {
252                 unit.mValid = false;
253                 continue;
254             }
255             auto ocStartWeight = (ocStart * core->pack) / hPack;
256             auto ocWeightSize = std::min(UP_DIV((ocSize * core->pack), hPack), mResource->mWeight->length(0) - ocStartWeight);
257             unit.mStracssenComputor.reset(new StrassenMatrixComputor(backend(), false, maxDepth));
258             AutoRelease<Tensor> mTempInput(Tensor::create<float>(std::vector<int>{icC4, matrixSizeE, core->pack}, inputPtr));
259             AutoRelease<Tensor> mTempBias(Tensor::create<float>({ocSize, 1, core->pack}, mResource->mBias->host<uint8_t>() + core->pack * ocStart * bytes));
260             AutoRelease<Tensor> mTempOutput(
261                 Tensor::create<float>(std::vector<int>{ocSize, matrixSizeE, core->pack}, outputPtr + core->pack * matrixSizeE * ocStart * bytes));
262             AutoRelease<Tensor> mTempWeight(Tensor::create<float>(std::vector<int>{ocWeightSize, ic, hPack},
263                                                          mResource->mWeight->host<uint8_t>() + hPack * icAlign * ocStartWeight * bytes));
264             unit.mTempInputVector  = std::vector<Tensor *>{mTempInput.get(), mTempWeight.get(), mTempBias.get()};
265             unit.mTempOutputVector = std::vector<Tensor *>{mTempOutput.get()};
266             memoryPool->beginGroup();
267             auto code = unit.mStracssenComputor->onEncode(unit.mTempInputVector, unit.mTempOutputVector, postParameters, ic);
268             if (NO_ERROR != code) {
269                 memoryPool->endGroup();
270                 return code;
271             }
272             memoryPool->endGroup();
273         }
274     }
275     return NO_ERROR;
276 }
277 
onExecute(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)278 ErrorCode Convolution1x1Strassen::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
279     auto size   = mUnits.size();
280     auto input  = inputs[0];
281     auto output = outputs[0];
282     auto core = static_cast<CPUBackend*>(backend())->functions();
283 
284     if (!mNeedPretreat) {
285         MNN_CONCURRENCY_BEGIN(tId, size) {
286             auto &unit = mUnits[tId];
287             if (unit.mValid) {
288                 unit.mStracssenComputor->onExecute();
289             }
290         }
291         MNN_CONCURRENCY_END();
292         return NO_ERROR;
293     }
294     int bytes = core->bytes;
295     mPretreatFunction(input->host<uint8_t>(), mTempInputBatch->host<uint8_t>());
296     MNN_CONCURRENCY_BEGIN(tId, size) {
297         auto &unit = mUnits[tId];
298         if (unit.mValid) {
299             unit.mStracssenComputor->onExecute();
300         }
301     }
302     MNN_CONCURRENCY_END();
303 
304     auto batch       = input->batch();
305     auto outputPlane = output->height() * output->width();
306     auto ocC4        = UP_DIV(output->channel(), core->pack);
307     MNN_CONCURRENCY_BEGIN(y, ocC4) {
308         auto srcY = mTempOutputBatch->host<uint8_t>() + outputPlane * y * core->pack * batch * bytes;
309         auto dstY = output->host<uint8_t>() + y * outputPlane * core->pack * bytes;
310         for (int x = 0; x < batch; ++x) {
311             auto srcX = srcY + x * outputPlane * core->pack * bytes;
312             auto dstX = dstY + x * outputPlane * ocC4 * core->pack * bytes;
313             ::memcpy(dstX, srcX, outputPlane * core->pack * bytes);
314         }
315     }
316     MNN_CONCURRENCY_END();
317     return NO_ERROR;
318 }
319 } // namespace MNN
320