1 //
2 // StrassenMatmulComputor.cpp
3 // MNN
4 //
5 // Created by MNN on 2019/02/11.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "StrassenMatmulComputor.hpp"
10 #include "CommonOptFunction.h"
11 #include "backend/cpu/CPUBackend.hpp"
12 #include <string.h>
13 #include <limits.h>
14 #include "core/AutoStorage.h"
15 #include "core/Macro.h"
16 #include "core/Concurrency.h"
17 //#define MNN_OPEN_TIME_TRACE
18 #include <MNN/AutoTime.hpp>
19 #include "math/Vec.hpp"
20 #include "math/Matrix.hpp"
21 #include "core/BufferAllocator.hpp"
22
23 namespace MNN {
24 class AutoMemory {
25 public:
AutoMemory(int size,BufferAllocator * allocator)26 AutoMemory(int size, BufferAllocator* allocator) {
27 mContent = allocator->alloc(size);
28 mAllocator = allocator;
29 }
~AutoMemory()30 ~ AutoMemory() {
31 if (nullptr != mContent.first) {
32 mAllocator->free(mContent);
33 }
34 }
get() const35 const std::pair<void*, int>& get() const {
36 return mContent;
37 }
38 private:
39 std::pair<void*, int> mContent;
40 BufferAllocator* mAllocator;
41 };
42
StrassenMatrixComputor(Backend * bn,bool multithread,int maxDepth)43 StrassenMatrixComputor::StrassenMatrixComputor(Backend* bn, bool multithread, int maxDepth) : mBackend(bn) {
44 mMaxDepth = maxDepth;
45 mSupportMultiThread = multithread;
46 };
~StrassenMatrixComputor()47 StrassenMatrixComputor::~StrassenMatrixComputor() {
48 // Do nothing
49 }
50
_generateTrivalMatMul(int e,int l,int h,const MatrixInfo & AT,const MatrixInfo & BT,const MatrixInfo & CT,const MatrixInfo & COT,const std::vector<float> & active)51 ErrorCode StrassenMatrixComputor::_generateTrivalMatMul(int e, int l, int h, const MatrixInfo& AT, const MatrixInfo& BT, const MatrixInfo& CT, const MatrixInfo& COT, const std::vector<float>& active) {
52 // Generate Trival Matrix Multiply
53 MNN_ASSERT(e > 0);
54 auto core = static_cast<CPUBackend*>(backend())->functions();
55 int bytes = core->bytes;
56 auto aStride = AT.lineStrideBytes;
57 auto bStride = BT.lineStrideBytes;
58 auto cStride = CT.lineStrideBytes;
59 int eP, lP, hP;
60 core->MNNGetMatMulPackMode(&eP, &lP, &hP);
61 auto numberThread = mSupportMultiThread ? ((CPUBackend*)backend())->threadNumber() : 1;
62 auto bExtraStride = bStride - UP_DIV(l, lP)*lP*hP * core->bytes;
63 MNN_ASSERT(bExtraStride >= 0);
64 auto tileBufferBasic = static_cast<CPUBackend*>(backend())->getBufferAllocator()->alloc(numberThread * UP_DIV(l, lP) * eP * lP * bytes);
65 if (nullptr == tileBufferBasic.first) {
66 return OUT_OF_MEMORY;
67 }
68 auto tileHostOrigin = (uint8_t*)tileBufferBasic.first + tileBufferBasic.second;
69 int unitNumber = e / eP;
70 int xCount = e - unitNumber * eP;
71 auto eReal = aStride / core->bytes / core->pack;
72 mFunctions.emplace_back(
73 std::make_pair([cStride, l, h, xCount, AT, BT, CT, COT, tileHostOrigin, unitNumber, bExtraStride, numberThread, eReal, eP, active, this](int tId) {
74 auto core = static_cast<CPUBackend*>(backend())->functions();
75 size_t parameters[6];
76 parameters[0] = xCount * core->bytes;
77 parameters[1] = l;
78 parameters[2] = h;
79 parameters[3] = cStride;
80 parameters[4] = 0;
81 parameters[5] = bExtraStride;
82 auto tileHost = tileHostOrigin + eP * parameters[1] * tId * core->bytes;
83 const float* postParametersPtr = nullptr;
84 if (!active.empty()) {
85 postParametersPtr = active.data();
86 }
87 auto aHost = mStack[AT.stackIndex] + AT.offsetBytes;
88 auto bHost = mStack[BT.stackIndex] + BT.offsetBytes;
89 auto cHost = mStack[CT.stackIndex] + CT.offsetBytes;
90 const uint8_t* biasPtr = nullptr;
91 if (-1 != COT.stackIndex) {
92 biasPtr = mStack[COT.stackIndex] + COT.offsetBytes;
93 }
94 auto packUnit = core->bytes * core->pack;
95 int32_t info[4];
96 int32_t stride[4];
97 stride[0] = eP;
98 stride[1] = parameters[1];
99 stride[2] = 0;
100 stride[3] = 0;
101 info[0] = 1;
102 info[1] = eReal;
103 info[2] = eP;
104 info[3] = 1;
105 for (int i = tId; i < unitNumber; i+=numberThread) {
106 int xStart = i * eP;
107 auto aStart = aHost + xStart * packUnit;
108 core->MNNPackC4ForMatMul_A((float*)(tileHost), (const float**)(&aStart), info, stride);
109 core->MNNPackedMatMul((float*)(cHost + xStart * packUnit), (float*)tileHost, (float*)bHost, parameters, postParametersPtr, (const float*)biasPtr);
110 }
111 if (tId != numberThread -1) {
112 return;
113 }
114 if (xCount > 0) {
115 stride[0] = xCount;
116 stride[1] = parameters[1];
117 info[2] = xCount;
118
119 int xStart = unitNumber * eP;
120 auto aStart = aHost + xStart * packUnit;
121 // Copy
122 core->MNNPackC4ForMatMul_A((float*)(tileHost), (const float**)(&aStart), info, stride);
123 core->MNNPackedMatMulRemain((float*)(cHost + xStart * packUnit), (float*)tileHost, (float*)bHost, xCount, parameters, postParametersPtr, (const float*)biasPtr);
124 }
125 }, numberThread));
126 static_cast<CPUBackend*>(backend())->getBufferAllocator()->free(tileBufferBasic);
127 return NO_ERROR;
128 }
129
130 #define MNNMATRIX_SUB_MULTITHREAD(c_, a_, b_, widthC4, cStride, aStride, bStride, lSub, core) \
131 {\
132 auto c = c_;\
133 auto b = b_;\
134 auto a = a_;\
135 for (int y = tId; y < lSub; y+=numberThread) {\
136 core->MNNMatrixSub((float*)(c + y * cStride), (float*)(a + y * aStride), (float*)(b + y * bStride), widthC4, 0, 0, 0, 1);\
137 }\
138 }
139
140 #define MNNMATRIX_ADD_MULTITHREAD(c_, a_, b_, widthC4, cStride, aStride, bStride, lSub, core) \
141 {\
142 auto c = c_;\
143 auto b = b_;\
144 auto a = a_;\
145 for (int y = tId; y < lSub; y+=numberThread) {\
146 core->MNNMatrixAdd((float*)(c + y * cStride), (float*)(a + y * aStride), (float*)(b + y * bStride), widthC4, 0, 0, 0, 1);\
147 }\
148 }
149
_generateBasicMatMul(int e,int l,int h,const MatrixInfo & AT,const MatrixInfo & BT,const MatrixInfo & CT,const MatrixInfo & COT,const std::vector<float> & postParameters)150 ErrorCode StrassenMatrixComputor::_generateBasicMatMul(int e, int l, int h, const MatrixInfo& AT, const MatrixInfo& BT, const MatrixInfo& CT, const MatrixInfo& COT, const std::vector<float>& postParameters) {
151 auto core = static_cast<CPUBackend*>(backend())->functions();
152 int eP, lP, hP;
153 core->MNNGetMatMulPackMode(&eP, &lP, &hP);
154 int lLimit = 32768 / (std::min(eP, e) + hP);
155 if (l <= lLimit) {
156 return _generateTrivalMatMul(e, l, h, AT, BT, CT, COT, postParameters);
157 }
158 {
159 auto lUnit = std::max(lP, core->pack);
160 lLimit = lLimit / lUnit * lUnit;
161 }
162 int unit = UP_DIV(l, lLimit);
163 auto allocator = static_cast<CPUBackend*>(backend())->getBufferAllocator();
164 AutoMemory CAddr(e * UP_DIV(h, core->pack) * core->pack * core->bytes, allocator);
165 MatrixInfo CTemp;
166 CTemp.stackIndex = (int)mStack.size();
167 CTemp.offsetBytes = 0;
168 CTemp.lineStrideBytes = e * core->bytes * core->pack;
169 mStack.emplace_back((uint8_t*)CAddr.get().first + CAddr.get().second);
170
171 MatrixInfo Empty;
172 Empty.stackIndex = -1;
173 auto numberThread = mSupportMultiThread ? ((CPUBackend*)backend())->threadNumber() : 1;
174 auto cHeight = UP_DIV(h, core->pack);
175
176 for (int i=0; i<unit; ++i) {
177 int lS = i * lLimit;
178 int lE = lS + lLimit;
179 if (lE > l) {
180 lE = l;
181 }
182 if (0 == i) {
183 // First write to output
184 auto code = _generateTrivalMatMul(e, lE-lS, h, AT, BT, CT, Empty, {});
185 if (NO_ERROR != code) {
186 return code;
187 }
188 continue;
189 }
190 MatrixInfo tempA = AT;
191 MatrixInfo tempB = BT;
192 tempA.offsetBytes = AT.offsetBytes + lS / core->pack * AT.lineStrideBytes;
193 tempB.offsetBytes = BT.offsetBytes + lS * hP * core->bytes;
194 auto code = _generateTrivalMatMul(e, lE-lS, h, tempA, tempB, CTemp, Empty, {});
195 if (NO_ERROR != code) {
196 return code;
197 }
198 // Add CTemp to C
199 auto f1 = [CT, CTemp, e, cHeight, numberThread, core, this](int tId) {
200 auto c11Ptr = mStack[CT.stackIndex] + CT.offsetBytes;
201 auto xAddr = mStack[CTemp.stackIndex] + CTemp.offsetBytes;
202 MNNMATRIX_ADD_MULTITHREAD(c11Ptr, c11Ptr, xAddr, e, CT.lineStrideBytes, CT.lineStrideBytes, CTemp.lineStrideBytes, cHeight, core);
203 };
204 mFunctions.emplace_back(std::make_pair(f1, numberThread));
205 }
206 if (!postParameters.empty() && COT.stackIndex >= 0) {
207 if (1 == numberThread) {
208 auto postFunction = [CT, COT, e, cHeight, numberThread, postParameters, core, this](int tId) {
209 auto biasPtr = (const float*)(mStack[COT.stackIndex] + COT.offsetBytes);
210 auto width = e;
211 auto height = cHeight;
212 auto c11Ptr = mStack[CT.stackIndex] + CT.offsetBytes;
213 core->MNNAxByClampBroadcastUnit((float*)c11Ptr, (float*)c11Ptr, biasPtr, width, CT.lineStrideBytes / core->bytes, CT.lineStrideBytes / core->bytes, height, postParameters.data());
214 };
215 mFunctions.emplace_back(std::make_pair(postFunction, 1));
216 } else {
217 auto postFunction = [CT, COT, e, cHeight, numberThread, postParameters, core, this](int tId) {
218 auto width = e;
219 auto height = cHeight;
220 auto c11Ptr = mStack[CT.stackIndex] + CT.offsetBytes;
221 auto biasPtr = mStack[COT.stackIndex] + COT.offsetBytes;
222 for (int y = tId; y < height; y+=numberThread) {
223 core->MNNAxByClampBroadcastUnit((float*)(c11Ptr + y * CT.lineStrideBytes), (float*)(c11Ptr + y * CT.lineStrideBytes), (const float*)(biasPtr + y * core->bytes * core->pack), width, 0, 0, 1, postParameters.data());
224 }
225 };
226 mFunctions.emplace_back(std::make_pair(postFunction, numberThread));
227 }
228 }
229 return NO_ERROR;
230 }
231
_generateMatMul(int e,int l,int h,const MatrixInfo & AT,const MatrixInfo & BT,const MatrixInfo & CT,const MatrixInfo & COT,int currentDepth,const std::vector<float> & postParameters)232 ErrorCode StrassenMatrixComputor::_generateMatMul(int e, int l, int h, const MatrixInfo& AT, const MatrixInfo& BT, const MatrixInfo& CT, const MatrixInfo& COT, int currentDepth, const std::vector<float>& postParameters) {
233 auto core = static_cast<CPUBackend*>(backend())->functions();
234 auto aUnit = core->pack;
235
236 auto numberThread = mSupportMultiThread ? ((CPUBackend*)backend())->threadNumber() : 1;
237 int eP, lP, hP;
238 core->MNNGetMatMulPackMode(&eP, &lP, &hP);
239 MNN_ASSERT(hP % core->pack == 0 || core->pack % hP == 0);
240 auto eSub = (e / eP) / 2 * eP;
241 auto lMinDiv = std::max(core->pack * 2, 2 * lP);
242 auto hSub = (h / std::max(hP, core->pack)) / 2 * std::max(hP, core->pack);
243 auto remainH = h - hSub * 2;
244 auto remainE = e - eSub * 2;
245 int packHUnit = 1;
246 if (core->pack > hP) {
247 packHUnit = core->pack / hP;
248 }
249 if (currentDepth >= mMaxDepth || eSub == 0 || hSub == 0 || l % (2 * core->pack) != 0 || l % (2 * lP) || l % (2 * packHUnit) != 0) {
250 return _generateBasicMatMul(e, l, h, AT, BT, CT, COT, postParameters);
251 }
252 auto lSub = l / 2;
253 auto lSubUnit = lSub / core->pack;
254
255 auto bWidth = lSub * hP / core->pack;
256 auto aHeight = lSub / core->pack;
257 auto cHeight = hSub / core->pack;
258 auto bHeight = hSub / hP;
259 /*
260 Compute the memory read / write cost for expand
261 */
262 auto bHSub = bHeight;
263 float AComputeCost = 4 * ((float)eSub * lSub);
264 float BComputeCost = 4 * (float)lSub * bHSub * hP;
265 float CComputeCost = 7 * (float)eSub * hSub;
266 float saveMatMulCost = (e / eP) * (aUnit * eP * hSub / core->pack + lSubUnit * eP * aUnit + lSub * bHSub * hP);
267
268 const float penalty = core->penalty;//FIXME: Find beter way to set it
269 float saveCost = saveMatMulCost - (AComputeCost + BComputeCost + CComputeCost) * penalty;
270 if (saveCost <= 0.0f) {
271 return _generateBasicMatMul(e, l, h, AT, BT, CT, COT, postParameters);
272 }
273
274 // Strassen Construct
275 auto bn = backend();
276 auto allocator = static_cast<CPUBackend*>(backend())->getBufferAllocator();
277 currentDepth += 1;
278 auto maxlH = std::max(lSub, hSub);
279 AutoMemory YAddr(hSub * lSub * core->bytes, allocator);
280 AutoMemory XAddr(maxlH * eSub * core->bytes, allocator);
281 if (nullptr == XAddr.get().first || nullptr == YAddr.get().first) {
282 return OUT_OF_MEMORY;
283 }
284 MatrixInfo Y;
285 Y.stackIndex = (int)mStack.size();
286 mStack.emplace_back((uint8_t*)YAddr.get().first + YAddr.get().second);
287 Y.offsetBytes = 0;
288 Y.lineStrideBytes = lSub * core->bytes * hP;
289 MatrixInfo X;
290 X.stackIndex = (int)mStack.size();
291 X.offsetBytes = 0;
292 X.lineStrideBytes = eSub * core->bytes * core->pack;
293 mStack.emplace_back((uint8_t*)XAddr.get().first + XAddr.get().second);
294
295 MatrixInfo CX;
296 CX.stackIndex = X.stackIndex;
297 CX.offsetBytes = 0;
298 CX.lineStrideBytes = eSub * core->bytes * core->pack;
299
300 MatrixInfo a11 = AT;
301 MatrixInfo a12 = AT;
302 a12.offsetBytes = AT.offsetBytes + AT.lineStrideBytes * lSubUnit;
303 MatrixInfo a21 = AT;
304 a21.offsetBytes = AT.offsetBytes + eSub * core->pack * core->bytes;
305 MatrixInfo a22 = AT;
306 a22.offsetBytes = AT.offsetBytes + eSub * core->pack * core->bytes + AT.lineStrideBytes * lSubUnit;
307
308 MatrixInfo b11 = BT;
309 MatrixInfo b12 = BT;
310 b12.offsetBytes = BT.offsetBytes + BT.lineStrideBytes * (hSub / hP);
311 MatrixInfo b21 = BT;
312 b21.offsetBytes = BT.offsetBytes + lSub * hP * core->bytes;
313 MatrixInfo b22 = BT;
314 b22.offsetBytes = BT.offsetBytes + BT.lineStrideBytes * (hSub / hP) + lSub * hP * core->bytes;
315
316 MatrixInfo c11 = CT;
317 MatrixInfo c12 = CT;
318 c12.offsetBytes = CT.offsetBytes + CT.lineStrideBytes * (hSub / core->pack);
319 MatrixInfo c21 = CT;
320 c21.offsetBytes = CT.offsetBytes + eSub * core->pack * core->bytes;
321 MatrixInfo c22 = CT;
322 c22.offsetBytes = CT.offsetBytes + eSub * core->pack * core->bytes + CT.lineStrideBytes * (hSub / core->pack);
323
324 MatrixInfo Empty;
325 Empty.stackIndex = -1;
326
327 {
328 // S3=A11-A21, T3=B22-B12, P7=S3*T3
329 auto f = [a11, a21, b22, b12, X, Y, eSub, lSub, hSub, numberThread, core, hP, this, bWidth, aHeight, bHeight](int tId) {
330 auto xAddr = mStack[X.stackIndex] + X.offsetBytes;
331 auto yAddr = mStack[Y.stackIndex] + Y.offsetBytes;
332 auto a11Ptr = mStack[a11.stackIndex] + a11.offsetBytes;
333 auto a21Ptr = mStack[a21.stackIndex] + a21.offsetBytes;
334 MNNMATRIX_SUB_MULTITHREAD(xAddr, a11Ptr, a21Ptr, eSub, X.lineStrideBytes, a11.lineStrideBytes, a21.lineStrideBytes, aHeight, core);
335 MNNMATRIX_SUB_MULTITHREAD(yAddr, mStack[b22.stackIndex] + b22.offsetBytes, mStack[b12.stackIndex] + b12.offsetBytes, bWidth, Y.lineStrideBytes, b22.lineStrideBytes, b12.lineStrideBytes, bHeight, core);
336 };
337 mFunctions.emplace_back(std::make_pair(f, numberThread));
338 auto code = _generateMatMul(eSub, lSub, hSub, X, Y, c21, Empty, currentDepth, {});
339 if (code != NO_ERROR) {
340 return code;
341 }
342 }
343 {
344 // S1=A21+A22, T1=B12-B11, P5=S1T1
345 auto f = [a22, a21, b11, b12, X, Y, eSub, lSub, hSub, numberThread, hP, core, this, bWidth, aHeight, bHeight](int tId) {
346 MNNMATRIX_ADD_MULTITHREAD(mStack[X.stackIndex] + X.offsetBytes, mStack[a21.stackIndex] + a21.offsetBytes, mStack[a22.stackIndex] + a22.offsetBytes , eSub, X.lineStrideBytes, a21.lineStrideBytes, a22.lineStrideBytes, aHeight, core);
347 MNNMATRIX_SUB_MULTITHREAD(mStack[Y.stackIndex] + Y.offsetBytes, mStack[b12.stackIndex] + b12.offsetBytes, mStack[b11.stackIndex] + b11.offsetBytes, bWidth, Y.lineStrideBytes, b12.lineStrideBytes, b11.lineStrideBytes, bHeight, core);
348 };
349 mFunctions.emplace_back(std::make_pair(f, numberThread));
350 auto code = _generateMatMul(eSub, lSub, hSub, X, Y, c22, Empty, currentDepth, {});
351 if (code != NO_ERROR) {
352 return code;
353 }
354 }
355 {
356 // S2=S1-A11, T2=B22-T1, P6=S2T2
357 auto f = [a11, b22, X, Y, eSub, lSub, hSub, numberThread, hP, core, this, bWidth, aHeight, bHeight](int tId) {
358 auto xAddr = mStack[X.stackIndex] + X.offsetBytes;
359 auto yAddr = mStack[Y.stackIndex] + Y.offsetBytes;
360 MNNMATRIX_SUB_MULTITHREAD(xAddr, xAddr, mStack[a11.stackIndex] + a11.offsetBytes, eSub, X.lineStrideBytes, X.lineStrideBytes, a11.lineStrideBytes, aHeight, core);
361 MNNMATRIX_SUB_MULTITHREAD(yAddr, mStack[b22.stackIndex] + b22.offsetBytes, yAddr, bWidth, Y.lineStrideBytes, b22.lineStrideBytes, Y.lineStrideBytes, bHeight, core);
362 };
363 mFunctions.emplace_back(std::make_pair(f, numberThread));
364 auto code = _generateMatMul(eSub, lSub, hSub, X, Y, c12, Empty, currentDepth, {});
365 if (code != NO_ERROR) {
366 return code;
367 }
368 }
369 {
370 // S4=A12-S2, P3=S4*B22, P1=A11*B11
371 auto f = [a12, X, eSub, aHeight, numberThread, core, this](int tId) {
372 auto xAddr = mStack[X.stackIndex] + X.offsetBytes;
373 MNNMATRIX_SUB_MULTITHREAD(xAddr, mStack[a12.stackIndex] + a12.offsetBytes, xAddr, eSub, X.lineStrideBytes, a12.lineStrideBytes, X.lineStrideBytes, aHeight, core);
374 };
375 mFunctions.emplace_back(std::make_pair(f, numberThread));
376 auto code = _generateMatMul(eSub, lSub, hSub, X, b22, c11, Empty, currentDepth, {});
377 if (code != NO_ERROR) {
378 return code;
379 }
380 code = _generateMatMul(eSub, lSub, hSub, a11, b11, CX, Empty, currentDepth, {});
381 if (code != NO_ERROR) {
382 return code;
383 }
384 }
385 {
386 // U2=P1+P6, U3=U2+P7, U4=U2+P5, U7=U3+P5
387 // U5=U4+P3, T4=T2-B21, P4=A22*T4
388 auto f = [c11, c12, c21, c22, b21, X, Y, eSub, bWidth, cHeight, bHeight, numberThread, core, this](int tId) {
389 for (int y = tId; y < cHeight; y+=numberThread) {
390 core->MNNStrassenMergeCFunction((float*)(mStack[c11.stackIndex] + c11.offsetBytes + y * c11.lineStrideBytes), (float*)(mStack[c12.stackIndex] + c12.offsetBytes + y * c12.lineStrideBytes), (float*)(mStack[c21.stackIndex] + c21.offsetBytes + y * c21.lineStrideBytes), (float*)(mStack[c22.stackIndex] + c22.offsetBytes + y * c22.lineStrideBytes), (float*)(mStack[X.stackIndex] + X.offsetBytes + y * X.lineStrideBytes), 0, eSub, 1);
391 }
392 auto yAddr = mStack[Y.stackIndex] + Y.offsetBytes;
393 MNNMATRIX_SUB_MULTITHREAD(yAddr, yAddr, mStack[b21.stackIndex] + b21.offsetBytes, bWidth, Y.lineStrideBytes, Y.lineStrideBytes, b21.lineStrideBytes, bHeight, core);
394 };
395 mFunctions.emplace_back(std::make_pair(f, numberThread));
396 auto code = _generateMatMul(eSub, lSub, hSub, a22, Y, c11, Empty, currentDepth, {});
397 if (code != NO_ERROR) {
398 return code;
399 }
400 }
401 {
402 // U6=U3-P4, P2=A12*B21, U1=P1+P2
403 auto f0 = [c11, c21, eSub, cHeight, numberThread, core, this](int tId) {
404 auto cw = eSub;
405 auto c21Addr = mStack[c21.stackIndex] + c21.offsetBytes;
406 MNNMATRIX_SUB_MULTITHREAD(c21Addr, c21Addr, mStack[c11.stackIndex] + c11.offsetBytes, cw, c21.lineStrideBytes, c21.lineStrideBytes, c11.lineStrideBytes, cHeight, core);
407 };
408 mFunctions.emplace_back(std::make_pair(f0, numberThread));
409 auto code = _generateMatMul(eSub, lSub, hSub, a12, b21, c11, Empty, currentDepth, {});
410 if (code != NO_ERROR) {
411 return code;
412 }
413 auto f1 = [c11, X, eSub, cHeight, numberThread, core, this](int tId) {
414 auto cw = eSub;
415 auto c11Ptr = mStack[c11.stackIndex] + c11.offsetBytes;
416 auto xAddr = mStack[X.stackIndex] + X.offsetBytes;
417 MNNMATRIX_ADD_MULTITHREAD(c11Ptr, c11Ptr, xAddr, cw, c11.lineStrideBytes, c11.lineStrideBytes, X.lineStrideBytes, cHeight, core);
418 };
419 mFunctions.emplace_back(std::make_pair(f1, numberThread));
420 if (!postParameters.empty() && COT.stackIndex >= 0) {
421 if (1 == numberThread) {
422 auto postFunction = [c11, COT, eSub, cHeight, numberThread, postParameters, core, this](int tId) {
423 auto biasPtr = (const float*)(mStack[COT.stackIndex] + COT.offsetBytes);
424 auto width = eSub * 2;
425 auto height = cHeight * 2;
426 auto c11Ptr = mStack[c11.stackIndex] + c11.offsetBytes;
427 core->MNNAxByClampBroadcastUnit((float*)c11Ptr, (float*)c11Ptr, biasPtr, width, c11.lineStrideBytes / core->bytes, c11.lineStrideBytes / core->bytes, height, postParameters.data());
428 };
429 mFunctions.emplace_back(std::make_pair(postFunction, numberThread));
430 } else {
431 auto postFunction = [c11, COT, eSub, cHeight, numberThread, postParameters, core, this](int tId) {
432 auto width = eSub * 2;
433 auto height = cHeight * 2;
434 auto c11Ptr = mStack[c11.stackIndex] + c11.offsetBytes;
435 auto biasPtr = mStack[COT.stackIndex] + COT.offsetBytes;
436 for (int y = tId; y < height; y+=numberThread) {
437 core->MNNAxByClampBroadcastUnit((float*)(c11Ptr + y * c11.lineStrideBytes), (float*)(c11Ptr + y * c11.lineStrideBytes), (const float*)(biasPtr + y * core->bytes * core->pack), width, 0, 0, 1, postParameters.data());
438 }
439 };
440 mFunctions.emplace_back(std::make_pair(postFunction, numberThread));
441 }
442 }
443 }
444 if (remainH > 0) {
445 auto lastH = hSub * 2;
446 MatrixInfo CLast = CT;
447 CLast.offsetBytes = CT.offsetBytes + CT.lineStrideBytes * (lastH / core->pack);
448 MatrixInfo BLast = BT;
449 BLast.offsetBytes = BT.offsetBytes + BT.lineStrideBytes * (lastH / hP);
450 MatrixInfo Bias = COT;
451 if (Bias.stackIndex >= 0) {
452 Bias.offsetBytes = COT.offsetBytes + core->bytes * lastH;
453 }
454 auto code = _generateBasicMatMul(eSub * 2, l, remainH, AT, BLast, CLast, Bias, postParameters);
455 if (NO_ERROR != code) {
456 return code;
457 }
458 }
459 if (remainE > 0) {
460 MatrixInfo CLast = CT;
461 CLast.offsetBytes = CT.offsetBytes + eSub * 2 * core->pack * core->bytes;
462 MatrixInfo ALast = AT;
463 ALast.offsetBytes = AT.offsetBytes + eSub * 2 * core->pack * core->bytes;
464
465 auto code = _generateBasicMatMul(remainE, l, h, ALast, BT, CLast, COT, postParameters);
466 if (NO_ERROR != code) {
467 return code;
468 }
469 }
470 return NO_ERROR;
471 }
472
onReset()473 void StrassenMatrixComputor::onReset() {
474 mStack.clear();
475 mFunctions.clear();
476 }
477
onEncode(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,const std::vector<float> & postParameters,int inputL,int inputH)478 ErrorCode StrassenMatrixComputor::onEncode(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, const std::vector<float>& postParameters, int inputL, int inputH) {
479 mFunctions.clear();
480 auto core = static_cast<CPUBackend*>(backend())->functions();
481 MNN_ASSERT(inputs.size() == 2 || inputs.size() == 3);
482 MNN_ASSERT(outputs.size() == 1);
483 auto A = inputs[0];
484 auto B = inputs[1];
485 auto C = outputs[0];
486 Tensor* CO = nullptr;
487 auto l = B->length(1);
488 if (inputL != 0) {
489 l = inputL;
490 }
491 auto e = A->length(1);
492 auto h = std::min(C->length(0) * core->pack, B->length(0) * B->length(2));
493 if (inputH != 0) {
494 h = inputH;
495 }
496 mStack = {A->host<uint8_t>(), B->host<uint8_t>(), C->host<uint8_t>()};
497 MatrixInfo a,b,c,bias;
498 bias.stackIndex = -1;
499 if (inputs.size() > 2) {
500 CO = inputs[2];
501 bias.stackIndex = 3;
502 bias.offsetBytes = 0;
503 mStack.emplace_back(CO->host<uint8_t>());
504 }
505 a.stackIndex = 0;
506 a.lineStrideBytes = A->stride(0) * core->bytes;
507 a.offsetBytes = 0;
508 int eP, lP, hP;
509 core->MNNGetMatMulPackMode(&eP, &lP, &hP);
510
511 b.stackIndex = 1;
512 b.lineStrideBytes = UP_DIV(l, lP) * lP * hP * core->bytes;
513 b.offsetBytes = 0;
514
515 c.stackIndex = 2;
516 c.lineStrideBytes = C->stride(0) * core->bytes;
517 c.offsetBytes = 0;
518
519 return _generateMatMul(e, l, h, a, b, c, bias, 0, postParameters);
520 }
521
onExecute(const uint8_t * AT,const uint8_t * BT,const uint8_t * COT,uint8_t * CT)522 void StrassenMatrixComputor::onExecute(const uint8_t* AT, const uint8_t* BT, const uint8_t* COT, uint8_t* CT) {
523 if (nullptr != AT) {
524 mStack[0] = (uint8_t*)AT;
525 }
526 if (nullptr != BT) {
527 mStack[1] = (uint8_t*)BT;
528 }
529 if (nullptr != CT) {
530 mStack[2] = (uint8_t*)CT;
531 }
532 if (nullptr != COT) {
533 mStack[3] = (uint8_t*)COT;
534 }
535
536 // All is done in onResize, just execute it
537 for (auto& f : mFunctions) {
538 MNN_CONCURRENCY_BEGIN(tId, f.second) {
539 f.first(tId);
540 }
541 MNN_CONCURRENCY_END();
542 }
543 }
544 } // namespace MNN
545