1 //
2 // MNNMatrixAdd.cpp
3 // MNN
4 //
5 // Created by MNN on 2019/08/25.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "FunctionSummary.hpp"
_AVX_MNNMatrixAdd(float * C,const float * A,const float * B,size_t widthC4,size_t cStride,size_t aStride,size_t bStride,size_t height)10 void _AVX_MNNMatrixAdd(float* C, const float* A, const float* B, size_t widthC4, size_t cStride, size_t aStride,
11 size_t bStride, size_t height) {
12 for (int y = 0; y < height; ++y) {
13 auto a = A + aStride * y;
14 auto b = B + bStride * y;
15 auto c = C + cStride * y;
16 for (int x = 0; x < widthC4; ++x) {
17 _mm256_storeu_ps(c + 8 * x, _mm256_add_ps(_mm256_loadu_ps(b + 8 * x), _mm256_loadu_ps(a + 8 * x)));
18 }
19 }
20 }
21
_AVX_MNNStrassenMergeCFunction(float * c11,float * c12,float * c21,float * c22,float * xAddr,size_t cStride,size_t eSub,size_t hSub)22 void _AVX_MNNStrassenMergeCFunction(float* c11, float* c12, float* c21, float* c22, float* xAddr, size_t cStride, size_t eSub, size_t hSub) {
23 const int unit = 8;
24 for (int y=0; y<hSub; ++y) {
25 auto c11Y = c11 + y * cStride;
26 auto c12Y = c12 + y * cStride;
27 auto c22Y = c22 + y * cStride;
28 auto c21Y = c21 + y * cStride;
29 auto xY = xAddr + y * eSub * unit;
30 for (int x=0; x<eSub; ++x) {
31 auto xv = _mm256_loadu_ps(xY + unit*x);
32 auto c21v = _mm256_loadu_ps(c21Y + unit*x);
33 auto c11v = _mm256_loadu_ps(c11Y + unit*x);
34 auto c22v = _mm256_loadu_ps(c22Y + unit*x);
35 auto c12v = _mm256_loadu_ps(c12Y + unit*x);
36 c12v = _mm256_add_ps(c12v, xv);
37 c21v = _mm256_add_ps(c12v, c21v);
38 c12v = _mm256_add_ps(c22v, c12v);
39 c22v = _mm256_add_ps(c22v, c21v);
40 c12v = _mm256_add_ps(c11v, c12v);
41 _mm256_storeu_ps(c12Y + unit*x, c12v);
42 _mm256_storeu_ps(c22Y + unit*x, c22v);
43 _mm256_storeu_ps(c21Y + unit*x, c21v);
44 }
45 }
46 }
47