1 //
2 //  StrassenMatmulComputor.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/02/11.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "StrassenMatmulComputor.hpp"
10 #include "CommonOptFunction.h"
11 #include "backend/cpu/CPUBackend.hpp"
12 #include <string.h>
13 #include <limits.h>
14 #include "core/AutoStorage.h"
15 #include "core/Macro.h"
16 #include "core/Concurrency.h"
17 //#define MNN_OPEN_TIME_TRACE
18 #include <MNN/AutoTime.hpp>
19 #include "math/Vec.hpp"
20 #include "math/Matrix.hpp"
21 #include "core/BufferAllocator.hpp"
22 
23 namespace MNN {
24 class AutoMemory {
25 public:
AutoMemory(int size,BufferAllocator * allocator)26     AutoMemory(int size, BufferAllocator* allocator) {
27         mContent = allocator->alloc(size);
28         mAllocator = allocator;
29     }
~AutoMemory()30     ~ AutoMemory() {
31         if (nullptr != mContent.first) {
32             mAllocator->free(mContent);
33         }
34     }
get() const35     const std::pair<void*, int>& get() const {
36         return mContent;
37     }
38 private:
39     std::pair<void*, int> mContent;
40     BufferAllocator* mAllocator;
41 };
42 
StrassenMatrixComputor(Backend * bn,bool multithread,int maxDepth)43 StrassenMatrixComputor::StrassenMatrixComputor(Backend* bn, bool multithread, int maxDepth) : mBackend(bn) {
44     mMaxDepth = maxDepth;
45     mSupportMultiThread = multithread;
46 };
~StrassenMatrixComputor()47 StrassenMatrixComputor::~StrassenMatrixComputor() {
48     // Do nothing
49 }
50 
_generateTrivalMatMul(int e,int l,int h,const MatrixInfo & AT,const MatrixInfo & BT,const MatrixInfo & CT,const MatrixInfo & COT,const std::vector<float> & active)51 ErrorCode StrassenMatrixComputor::_generateTrivalMatMul(int e, int l, int h, const MatrixInfo& AT, const MatrixInfo& BT, const MatrixInfo& CT, const MatrixInfo& COT, const std::vector<float>& active) {
52     // Generate Trival Matrix Multiply
53     MNN_ASSERT(e > 0);
54     auto core = static_cast<CPUBackend*>(backend())->functions();
55     int bytes    = core->bytes;
56     auto aStride = AT.lineStrideBytes;
57     auto bStride = BT.lineStrideBytes;
58     auto cStride = CT.lineStrideBytes;
59     int eP, lP, hP;
60     core->MNNGetMatMulPackMode(&eP, &lP, &hP);
61     auto numberThread = mSupportMultiThread ? ((CPUBackend*)backend())->threadNumber() : 1;
62     auto bExtraStride = bStride - UP_DIV(l, lP)*lP*hP * core->bytes;
63     MNN_ASSERT(bExtraStride >= 0);
64     auto tileBufferBasic = static_cast<CPUBackend*>(backend())->getBufferAllocator()->alloc(numberThread * UP_DIV(l, lP) * eP * lP * bytes);
65     if (nullptr == tileBufferBasic.first) {
66         return OUT_OF_MEMORY;
67     }
68     auto tileHostOrigin  = (uint8_t*)tileBufferBasic.first + tileBufferBasic.second;
69     int unitNumber = e / eP;
70     int xCount     = e - unitNumber * eP;
71     auto eReal = aStride / core->bytes / core->pack;
72     mFunctions.emplace_back(
73         std::make_pair([cStride, l, h, xCount, AT, BT, CT, COT, tileHostOrigin, unitNumber, bExtraStride, numberThread, eReal, eP, active, this](int tId) {
74             auto core = static_cast<CPUBackend*>(backend())->functions();
75             size_t parameters[6];
76             parameters[0] = xCount * core->bytes;
77             parameters[1] = l;
78             parameters[2] = h;
79             parameters[3] = cStride;
80             parameters[4] = 0;
81             parameters[5] = bExtraStride;
82             auto tileHost = tileHostOrigin + eP * parameters[1] * tId * core->bytes;
83             const float* postParametersPtr = nullptr;
84             if (!active.empty()) {
85                 postParametersPtr = active.data();
86             }
87             auto aHost = mStack[AT.stackIndex] + AT.offsetBytes;
88             auto bHost = mStack[BT.stackIndex] + BT.offsetBytes;
89             auto cHost = mStack[CT.stackIndex] + CT.offsetBytes;
90             const uint8_t* biasPtr = nullptr;
91             if (-1 != COT.stackIndex) {
92                 biasPtr = mStack[COT.stackIndex] + COT.offsetBytes;
93             }
94             auto packUnit = core->bytes * core->pack;
95             int32_t info[4];
96             int32_t stride[4];
97             stride[0] = eP;
98             stride[1] = parameters[1];
99             stride[2] = 0;
100             stride[3] = 0;
101             info[0] = 1;
102             info[1] = eReal;
103             info[2] = eP;
104             info[3] = 1;
105             for (int i = tId; i < unitNumber; i+=numberThread) {
106                 int xStart    = i * eP;
107                 auto aStart   = aHost + xStart * packUnit;
108                 core->MNNPackC4ForMatMul_A((float*)(tileHost), (const float**)(&aStart), info, stride);
109                 core->MNNPackedMatMul((float*)(cHost + xStart * packUnit), (float*)tileHost, (float*)bHost, parameters, postParametersPtr, (const float*)biasPtr);
110             }
111             if (tId != numberThread -1) {
112                 return;
113             }
114             if (xCount > 0) {
115                 stride[0] = xCount;
116                 stride[1] = parameters[1];
117                 info[2] = xCount;
118 
119                 int xStart    = unitNumber * eP;
120                 auto aStart   = aHost + xStart * packUnit;
121                 // Copy
122                 core->MNNPackC4ForMatMul_A((float*)(tileHost), (const float**)(&aStart), info, stride);
123                 core->MNNPackedMatMulRemain((float*)(cHost + xStart * packUnit), (float*)tileHost, (float*)bHost, xCount, parameters, postParametersPtr, (const float*)biasPtr);
124             }
125         }, numberThread));
126     static_cast<CPUBackend*>(backend())->getBufferAllocator()->free(tileBufferBasic);
127     return NO_ERROR;
128 }
129 
130 #define MNNMATRIX_SUB_MULTITHREAD(c_, a_, b_, widthC4, cStride, aStride, bStride, lSub, core) \
131 {\
132 auto c = c_;\
133 auto b = b_;\
134 auto a = a_;\
135 for (int y = tId; y < lSub; y+=numberThread) {\
136 core->MNNMatrixSub((float*)(c + y * cStride), (float*)(a + y * aStride), (float*)(b + y * bStride), widthC4, 0, 0, 0, 1);\
137 }\
138 }
139 
140 #define MNNMATRIX_ADD_MULTITHREAD(c_, a_, b_, widthC4, cStride, aStride, bStride, lSub, core) \
141 {\
142 auto c = c_;\
143 auto b = b_;\
144 auto a = a_;\
145 for (int y = tId; y < lSub; y+=numberThread) {\
146 core->MNNMatrixAdd((float*)(c + y * cStride), (float*)(a + y * aStride), (float*)(b + y * bStride), widthC4, 0, 0, 0, 1);\
147 }\
148 }
149 
_generateBasicMatMul(int e,int l,int h,const MatrixInfo & AT,const MatrixInfo & BT,const MatrixInfo & CT,const MatrixInfo & COT,const std::vector<float> & postParameters)150 ErrorCode StrassenMatrixComputor::_generateBasicMatMul(int e, int l, int h, const MatrixInfo& AT, const MatrixInfo& BT, const MatrixInfo& CT, const MatrixInfo& COT, const std::vector<float>& postParameters) {
151     auto core = static_cast<CPUBackend*>(backend())->functions();
152     int eP, lP, hP;
153     core->MNNGetMatMulPackMode(&eP, &lP, &hP);
154     int lLimit = 32768 / (std::min(eP, e) + hP);
155     if (l <= lLimit) {
156         return _generateTrivalMatMul(e, l, h, AT, BT, CT, COT, postParameters);
157     }
158     {
159         auto lUnit = std::max(lP, core->pack);
160         lLimit = lLimit / lUnit * lUnit;
161     }
162     int unit = UP_DIV(l, lLimit);
163     auto allocator = static_cast<CPUBackend*>(backend())->getBufferAllocator();
164     AutoMemory CAddr(e * UP_DIV(h, core->pack) * core->pack * core->bytes, allocator);
165     MatrixInfo CTemp;
166     CTemp.stackIndex = (int)mStack.size();
167     CTemp.offsetBytes = 0;
168     CTemp.lineStrideBytes = e * core->bytes * core->pack;
169     mStack.emplace_back((uint8_t*)CAddr.get().first + CAddr.get().second);
170 
171     MatrixInfo Empty;
172     Empty.stackIndex = -1;
173     auto numberThread = mSupportMultiThread ? ((CPUBackend*)backend())->threadNumber() : 1;
174     auto cHeight = UP_DIV(h, core->pack);
175 
176     for (int i=0; i<unit; ++i) {
177         int lS = i * lLimit;
178         int lE = lS + lLimit;
179         if (lE > l) {
180             lE = l;
181         }
182         if (0 == i) {
183             // First write to output
184             auto code = _generateTrivalMatMul(e, lE-lS, h, AT, BT, CT, Empty, {});
185             if (NO_ERROR != code) {
186                 return code;
187             }
188             continue;
189         }
190         MatrixInfo tempA = AT;
191         MatrixInfo tempB = BT;
192         tempA.offsetBytes = AT.offsetBytes + lS / core->pack * AT.lineStrideBytes;
193         tempB.offsetBytes = BT.offsetBytes + lS * hP * core->bytes;
194         auto code = _generateTrivalMatMul(e, lE-lS, h, tempA, tempB, CTemp, Empty, {});
195         if (NO_ERROR != code) {
196             return code;
197         }
198         // Add CTemp to C
199         auto f1 = [CT, CTemp, e, cHeight, numberThread, core, this](int tId) {
200             auto c11Ptr = mStack[CT.stackIndex] + CT.offsetBytes;
201             auto xAddr = mStack[CTemp.stackIndex] + CTemp.offsetBytes;
202             MNNMATRIX_ADD_MULTITHREAD(c11Ptr, c11Ptr, xAddr, e, CT.lineStrideBytes, CT.lineStrideBytes, CTemp.lineStrideBytes, cHeight, core);
203         };
204         mFunctions.emplace_back(std::make_pair(f1, numberThread));
205     }
206     if (!postParameters.empty() && COT.stackIndex >= 0) {
207         if (1 == numberThread) {
208             auto postFunction = [CT, COT, e, cHeight, numberThread, postParameters, core, this](int tId) {
209                 auto biasPtr = (const float*)(mStack[COT.stackIndex] + COT.offsetBytes);
210                 auto width = e;
211                 auto height = cHeight;
212                 auto c11Ptr = mStack[CT.stackIndex] + CT.offsetBytes;
213                 core->MNNAxByClampBroadcastUnit((float*)c11Ptr, (float*)c11Ptr, biasPtr, width, CT.lineStrideBytes / core->bytes, CT.lineStrideBytes / core->bytes, height, postParameters.data());
214             };
215             mFunctions.emplace_back(std::make_pair(postFunction, 1));
216         } else {
217             auto postFunction = [CT, COT, e, cHeight, numberThread, postParameters, core, this](int tId) {
218                 auto width = e;
219                 auto height = cHeight;
220                 auto c11Ptr = mStack[CT.stackIndex] + CT.offsetBytes;
221                 auto biasPtr = mStack[COT.stackIndex] + COT.offsetBytes;
222                 for (int y = tId; y < height; y+=numberThread) {
223                     core->MNNAxByClampBroadcastUnit((float*)(c11Ptr + y * CT.lineStrideBytes), (float*)(c11Ptr + y * CT.lineStrideBytes), (const float*)(biasPtr + y * core->bytes * core->pack), width, 0, 0, 1, postParameters.data());
224                 }
225             };
226             mFunctions.emplace_back(std::make_pair(postFunction, numberThread));
227         }
228     }
229     return NO_ERROR;
230 }
231 
_generateMatMul(int e,int l,int h,const MatrixInfo & AT,const MatrixInfo & BT,const MatrixInfo & CT,const MatrixInfo & COT,int currentDepth,const std::vector<float> & postParameters)232 ErrorCode StrassenMatrixComputor::_generateMatMul(int e, int l, int h, const MatrixInfo& AT, const MatrixInfo& BT, const MatrixInfo& CT, const MatrixInfo& COT, int currentDepth, const std::vector<float>& postParameters) {
233     auto core = static_cast<CPUBackend*>(backend())->functions();
234     auto aUnit = core->pack;
235 
236     auto numberThread = mSupportMultiThread ? ((CPUBackend*)backend())->threadNumber() : 1;
237     int eP, lP, hP;
238     core->MNNGetMatMulPackMode(&eP, &lP, &hP);
239     MNN_ASSERT(hP % core->pack == 0 || core->pack % hP == 0);
240     auto eSub = (e / eP) / 2 * eP;
241     auto lMinDiv = std::max(core->pack * 2, 2 * lP);
242     auto hSub = (h / std::max(hP, core->pack)) / 2 * std::max(hP, core->pack);
243     auto remainH = h - hSub * 2;
244     auto remainE = e - eSub * 2;
245     int packHUnit = 1;
246     if (core->pack > hP) {
247         packHUnit = core->pack / hP;
248     }
249     if (currentDepth >= mMaxDepth || eSub == 0 || hSub == 0 || l % (2 * core->pack) != 0 || l % (2 * lP) || l % (2 * packHUnit) != 0) {
250         return _generateBasicMatMul(e, l, h, AT, BT, CT, COT, postParameters);
251     }
252     auto lSub = l / 2;
253     auto lSubUnit = lSub / core->pack;
254 
255     auto bWidth = lSub * hP / core->pack;
256     auto aHeight = lSub / core->pack;
257     auto cHeight = hSub / core->pack;
258     auto bHeight = hSub / hP;
259     /*
260      Compute the memory read / write cost for expand
261      */
262     auto bHSub = bHeight;
263     float AComputeCost = 4 * ((float)eSub * lSub);
264     float BComputeCost = 4 * (float)lSub * bHSub * hP;
265     float CComputeCost = 7 * (float)eSub * hSub;
266     float saveMatMulCost = (e / eP) * (aUnit * eP * hSub / core->pack + lSubUnit * eP * aUnit + lSub * bHSub * hP);
267 
268     const float penalty = core->penalty;//FIXME: Find beter way to set it
269     float saveCost = saveMatMulCost - (AComputeCost + BComputeCost + CComputeCost) * penalty;
270     if (saveCost <= 0.0f) {
271         return _generateBasicMatMul(e, l, h, AT, BT, CT, COT, postParameters);
272     }
273 
274     // Strassen Construct
275     auto bn = backend();
276     auto allocator = static_cast<CPUBackend*>(backend())->getBufferAllocator();
277     currentDepth += 1;
278     auto maxlH = std::max(lSub, hSub);
279     AutoMemory YAddr(hSub * lSub * core->bytes, allocator);
280     AutoMemory XAddr(maxlH * eSub * core->bytes, allocator);
281     if (nullptr == XAddr.get().first || nullptr == YAddr.get().first) {
282         return OUT_OF_MEMORY;
283     }
284     MatrixInfo Y;
285     Y.stackIndex = (int)mStack.size();
286     mStack.emplace_back((uint8_t*)YAddr.get().first + YAddr.get().second);
287     Y.offsetBytes = 0;
288     Y.lineStrideBytes = lSub * core->bytes * hP;
289     MatrixInfo X;
290     X.stackIndex = (int)mStack.size();
291     X.offsetBytes = 0;
292     X.lineStrideBytes = eSub * core->bytes * core->pack;
293     mStack.emplace_back((uint8_t*)XAddr.get().first + XAddr.get().second);
294 
295     MatrixInfo CX;
296     CX.stackIndex = X.stackIndex;
297     CX.offsetBytes = 0;
298     CX.lineStrideBytes = eSub * core->bytes * core->pack;
299 
300     MatrixInfo a11 = AT;
301     MatrixInfo a12 = AT;
302     a12.offsetBytes = AT.offsetBytes + AT.lineStrideBytes * lSubUnit;
303     MatrixInfo a21 = AT;
304     a21.offsetBytes = AT.offsetBytes + eSub * core->pack * core->bytes;
305     MatrixInfo a22 = AT;
306     a22.offsetBytes = AT.offsetBytes + eSub * core->pack * core->bytes + AT.lineStrideBytes * lSubUnit;
307 
308     MatrixInfo b11 = BT;
309     MatrixInfo b12 = BT;
310     b12.offsetBytes = BT.offsetBytes + BT.lineStrideBytes * (hSub / hP);
311     MatrixInfo b21 = BT;
312     b21.offsetBytes = BT.offsetBytes + lSub * hP * core->bytes;
313     MatrixInfo b22 = BT;
314     b22.offsetBytes = BT.offsetBytes + BT.lineStrideBytes * (hSub / hP) + lSub * hP * core->bytes;
315 
316     MatrixInfo c11 = CT;
317     MatrixInfo c12 = CT;
318     c12.offsetBytes = CT.offsetBytes + CT.lineStrideBytes * (hSub / core->pack);
319     MatrixInfo c21 = CT;
320     c21.offsetBytes = CT.offsetBytes + eSub * core->pack * core->bytes;
321     MatrixInfo c22 = CT;
322     c22.offsetBytes = CT.offsetBytes + eSub * core->pack * core->bytes + CT.lineStrideBytes * (hSub / core->pack);
323 
324     MatrixInfo Empty;
325     Empty.stackIndex = -1;
326 
327     {
328         // S3=A11-A21, T3=B22-B12, P7=S3*T3
329         auto f = [a11, a21, b22, b12, X, Y, eSub, lSub, hSub, numberThread, core, hP, this, bWidth, aHeight, bHeight](int tId) {
330             auto xAddr = mStack[X.stackIndex] + X.offsetBytes;
331             auto yAddr = mStack[Y.stackIndex] + Y.offsetBytes;
332             auto a11Ptr = mStack[a11.stackIndex] + a11.offsetBytes;
333             auto a21Ptr = mStack[a21.stackIndex] + a21.offsetBytes;
334             MNNMATRIX_SUB_MULTITHREAD(xAddr, a11Ptr, a21Ptr, eSub, X.lineStrideBytes, a11.lineStrideBytes, a21.lineStrideBytes, aHeight, core);
335             MNNMATRIX_SUB_MULTITHREAD(yAddr, mStack[b22.stackIndex] + b22.offsetBytes, mStack[b12.stackIndex] + b12.offsetBytes, bWidth, Y.lineStrideBytes, b22.lineStrideBytes, b12.lineStrideBytes, bHeight, core);
336         };
337         mFunctions.emplace_back(std::make_pair(f, numberThread));
338         auto code = _generateMatMul(eSub, lSub, hSub, X, Y, c21, Empty, currentDepth, {});
339         if (code != NO_ERROR) {
340             return code;
341         }
342     }
343     {
344         // S1=A21+A22, T1=B12-B11, P5=S1T1
345         auto f = [a22, a21, b11, b12, X, Y, eSub, lSub, hSub, numberThread, hP, core, this, bWidth, aHeight, bHeight](int tId) {
346             MNNMATRIX_ADD_MULTITHREAD(mStack[X.stackIndex] + X.offsetBytes, mStack[a21.stackIndex] + a21.offsetBytes, mStack[a22.stackIndex] + a22.offsetBytes , eSub, X.lineStrideBytes, a21.lineStrideBytes, a22.lineStrideBytes, aHeight, core);
347             MNNMATRIX_SUB_MULTITHREAD(mStack[Y.stackIndex] + Y.offsetBytes, mStack[b12.stackIndex] + b12.offsetBytes, mStack[b11.stackIndex] + b11.offsetBytes, bWidth, Y.lineStrideBytes, b12.lineStrideBytes, b11.lineStrideBytes, bHeight, core);
348         };
349         mFunctions.emplace_back(std::make_pair(f, numberThread));
350         auto code = _generateMatMul(eSub, lSub, hSub, X, Y, c22, Empty, currentDepth, {});
351         if (code != NO_ERROR) {
352             return code;
353         }
354     }
355     {
356         // S2=S1-A11, T2=B22-T1, P6=S2T2
357         auto f = [a11, b22, X, Y, eSub, lSub, hSub, numberThread, hP, core, this, bWidth, aHeight, bHeight](int tId) {
358             auto xAddr = mStack[X.stackIndex] + X.offsetBytes;
359             auto yAddr = mStack[Y.stackIndex] + Y.offsetBytes;
360             MNNMATRIX_SUB_MULTITHREAD(xAddr, xAddr, mStack[a11.stackIndex] + a11.offsetBytes, eSub, X.lineStrideBytes, X.lineStrideBytes, a11.lineStrideBytes, aHeight, core);
361             MNNMATRIX_SUB_MULTITHREAD(yAddr, mStack[b22.stackIndex] + b22.offsetBytes, yAddr, bWidth, Y.lineStrideBytes, b22.lineStrideBytes, Y.lineStrideBytes, bHeight, core);
362         };
363         mFunctions.emplace_back(std::make_pair(f, numberThread));
364         auto code = _generateMatMul(eSub, lSub, hSub, X, Y, c12, Empty, currentDepth, {});
365         if (code != NO_ERROR) {
366             return code;
367         }
368     }
369     {
370         // S4=A12-S2, P3=S4*B22, P1=A11*B11
371         auto f = [a12, X, eSub, aHeight, numberThread, core, this](int tId) {
372             auto xAddr = mStack[X.stackIndex] + X.offsetBytes;
373             MNNMATRIX_SUB_MULTITHREAD(xAddr, mStack[a12.stackIndex] + a12.offsetBytes, xAddr, eSub, X.lineStrideBytes, a12.lineStrideBytes, X.lineStrideBytes, aHeight, core);
374         };
375         mFunctions.emplace_back(std::make_pair(f, numberThread));
376         auto code = _generateMatMul(eSub, lSub, hSub, X, b22, c11, Empty, currentDepth, {});
377         if (code != NO_ERROR) {
378             return code;
379         }
380         code = _generateMatMul(eSub, lSub, hSub, a11, b11, CX, Empty, currentDepth, {});
381         if (code != NO_ERROR) {
382             return code;
383         }
384     }
385     {
386         // U2=P1+P6, U3=U2+P7, U4=U2+P5, U7=U3+P5
387         // U5=U4+P3, T4=T2-B21, P4=A22*T4
388         auto f = [c11, c12, c21, c22, b21, X, Y, eSub, bWidth, cHeight, bHeight, numberThread, core, this](int tId) {
389             for (int y = tId; y < cHeight; y+=numberThread) {
390                 core->MNNStrassenMergeCFunction((float*)(mStack[c11.stackIndex] + c11.offsetBytes + y * c11.lineStrideBytes), (float*)(mStack[c12.stackIndex] + c12.offsetBytes + y * c12.lineStrideBytes), (float*)(mStack[c21.stackIndex] + c21.offsetBytes + y * c21.lineStrideBytes), (float*)(mStack[c22.stackIndex] + c22.offsetBytes + y * c22.lineStrideBytes), (float*)(mStack[X.stackIndex] + X.offsetBytes + y * X.lineStrideBytes), 0, eSub, 1);
391             }
392             auto yAddr = mStack[Y.stackIndex] + Y.offsetBytes;
393             MNNMATRIX_SUB_MULTITHREAD(yAddr, yAddr, mStack[b21.stackIndex] + b21.offsetBytes, bWidth, Y.lineStrideBytes, Y.lineStrideBytes, b21.lineStrideBytes, bHeight, core);
394         };
395         mFunctions.emplace_back(std::make_pair(f, numberThread));
396         auto code = _generateMatMul(eSub, lSub, hSub, a22, Y, c11, Empty, currentDepth, {});
397         if (code != NO_ERROR) {
398             return code;
399         }
400     }
401     {
402         // U6=U3-P4, P2=A12*B21, U1=P1+P2
403         auto f0 = [c11, c21, eSub, cHeight, numberThread, core, this](int tId) {
404             auto cw = eSub;
405             auto c21Addr = mStack[c21.stackIndex] + c21.offsetBytes;
406             MNNMATRIX_SUB_MULTITHREAD(c21Addr, c21Addr, mStack[c11.stackIndex] + c11.offsetBytes, cw, c21.lineStrideBytes, c21.lineStrideBytes, c11.lineStrideBytes, cHeight, core);
407         };
408         mFunctions.emplace_back(std::make_pair(f0, numberThread));
409         auto code = _generateMatMul(eSub, lSub, hSub, a12, b21, c11, Empty, currentDepth, {});
410         if (code != NO_ERROR) {
411             return code;
412         }
413         auto f1 = [c11, X, eSub, cHeight, numberThread, core, this](int tId) {
414             auto cw = eSub;
415             auto c11Ptr = mStack[c11.stackIndex] + c11.offsetBytes;
416             auto xAddr = mStack[X.stackIndex] + X.offsetBytes;
417             MNNMATRIX_ADD_MULTITHREAD(c11Ptr, c11Ptr, xAddr, cw, c11.lineStrideBytes, c11.lineStrideBytes, X.lineStrideBytes, cHeight, core);
418         };
419         mFunctions.emplace_back(std::make_pair(f1, numberThread));
420         if (!postParameters.empty() && COT.stackIndex >= 0) {
421             if (1 == numberThread) {
422                 auto postFunction = [c11, COT, eSub, cHeight, numberThread, postParameters, core, this](int tId) {
423                     auto biasPtr = (const float*)(mStack[COT.stackIndex] + COT.offsetBytes);
424                     auto width = eSub * 2;
425                     auto height = cHeight * 2;
426                     auto c11Ptr = mStack[c11.stackIndex] + c11.offsetBytes;
427                     core->MNNAxByClampBroadcastUnit((float*)c11Ptr, (float*)c11Ptr, biasPtr, width, c11.lineStrideBytes / core->bytes, c11.lineStrideBytes / core->bytes, height, postParameters.data());
428                 };
429                 mFunctions.emplace_back(std::make_pair(postFunction, numberThread));
430             } else {
431                 auto postFunction = [c11, COT, eSub, cHeight, numberThread, postParameters, core, this](int tId) {
432                     auto width = eSub * 2;
433                     auto height = cHeight * 2;
434                     auto c11Ptr = mStack[c11.stackIndex] + c11.offsetBytes;
435                     auto biasPtr = mStack[COT.stackIndex] + COT.offsetBytes;
436                     for (int y = tId; y < height; y+=numberThread) {
437                         core->MNNAxByClampBroadcastUnit((float*)(c11Ptr + y * c11.lineStrideBytes), (float*)(c11Ptr + y * c11.lineStrideBytes), (const float*)(biasPtr + y * core->bytes * core->pack), width, 0, 0, 1, postParameters.data());
438                     }
439                 };
440                 mFunctions.emplace_back(std::make_pair(postFunction, numberThread));
441             }
442         }
443     }
444     if (remainH > 0) {
445         auto lastH = hSub * 2;
446         MatrixInfo CLast = CT;
447         CLast.offsetBytes = CT.offsetBytes + CT.lineStrideBytes * (lastH / core->pack);
448         MatrixInfo BLast = BT;
449         BLast.offsetBytes = BT.offsetBytes + BT.lineStrideBytes * (lastH / hP);
450         MatrixInfo Bias = COT;
451         if (Bias.stackIndex >= 0) {
452             Bias.offsetBytes = COT.offsetBytes + core->bytes * lastH;
453         }
454         auto code = _generateBasicMatMul(eSub * 2, l, remainH, AT, BLast, CLast, Bias, postParameters);
455         if (NO_ERROR != code) {
456             return code;
457         }
458     }
459     if (remainE > 0) {
460         MatrixInfo CLast = CT;
461         CLast.offsetBytes = CT.offsetBytes + eSub * 2 * core->pack * core->bytes;
462         MatrixInfo ALast = AT;
463         ALast.offsetBytes = AT.offsetBytes + eSub * 2 * core->pack * core->bytes;
464 
465         auto code = _generateBasicMatMul(remainE, l, h, ALast, BT, CLast, COT, postParameters);
466         if (NO_ERROR != code) {
467             return code;
468         }
469     }
470     return NO_ERROR;
471 }
472 
onReset()473 void StrassenMatrixComputor::onReset() {
474     mStack.clear();
475     mFunctions.clear();
476 }
477 
onEncode(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,const std::vector<float> & postParameters,int inputL,int inputH)478 ErrorCode StrassenMatrixComputor::onEncode(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, const std::vector<float>& postParameters, int inputL, int inputH) {
479     mFunctions.clear();
480     auto core = static_cast<CPUBackend*>(backend())->functions();
481     MNN_ASSERT(inputs.size() == 2 || inputs.size() == 3);
482     MNN_ASSERT(outputs.size() == 1);
483     auto A  = inputs[0];
484     auto B  = inputs[1];
485     auto C  = outputs[0];
486     Tensor* CO = nullptr;
487     auto l = B->length(1);
488     if (inputL != 0) {
489         l = inputL;
490     }
491     auto e = A->length(1);
492     auto h = std::min(C->length(0) * core->pack, B->length(0) * B->length(2));
493     if (inputH != 0) {
494         h = inputH;
495     }
496     mStack = {A->host<uint8_t>(), B->host<uint8_t>(), C->host<uint8_t>()};
497     MatrixInfo a,b,c,bias;
498     bias.stackIndex = -1;
499     if (inputs.size() > 2) {
500         CO = inputs[2];
501         bias.stackIndex = 3;
502         bias.offsetBytes = 0;
503         mStack.emplace_back(CO->host<uint8_t>());
504     }
505     a.stackIndex = 0;
506     a.lineStrideBytes = A->stride(0) * core->bytes;
507     a.offsetBytes = 0;
508     int eP, lP, hP;
509     core->MNNGetMatMulPackMode(&eP, &lP, &hP);
510 
511     b.stackIndex = 1;
512     b.lineStrideBytes = UP_DIV(l, lP) * lP * hP * core->bytes;
513     b.offsetBytes = 0;
514 
515     c.stackIndex = 2;
516     c.lineStrideBytes = C->stride(0) * core->bytes;
517     c.offsetBytes = 0;
518 
519     return _generateMatMul(e, l, h, a, b, c, bias, 0, postParameters);
520 }
521 
onExecute(const uint8_t * AT,const uint8_t * BT,const uint8_t * COT,uint8_t * CT)522 void StrassenMatrixComputor::onExecute(const uint8_t* AT, const uint8_t* BT, const uint8_t* COT, uint8_t* CT) {
523     if (nullptr != AT) {
524         mStack[0] = (uint8_t*)AT;
525     }
526     if (nullptr != BT) {
527         mStack[1] = (uint8_t*)BT;
528     }
529     if (nullptr != CT) {
530         mStack[2] = (uint8_t*)CT;
531     }
532     if (nullptr != COT) {
533         mStack[3] = (uint8_t*)COT;
534     }
535 
536     // All is done in onResize, just execute it
537     for (auto& f : mFunctions) {
538         MNN_CONCURRENCY_BEGIN(tId, f.second) {
539             f.first(tId);
540         }
541         MNN_CONCURRENCY_END();
542     }
543 }
544 } // namespace MNN
545