1 //
2 //  MNNMatrixAdd.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/08/25.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "FunctionSummary.hpp"
_AVX_MNNMatrixAdd(float * C,const float * A,const float * B,size_t widthC4,size_t cStride,size_t aStride,size_t bStride,size_t height)10 void _AVX_MNNMatrixAdd(float* C, const float* A, const float* B, size_t widthC4, size_t cStride, size_t aStride,
11                        size_t bStride, size_t height) {
12     for (int y = 0; y < height; ++y) {
13         auto a = A + aStride * y;
14         auto b = B + bStride * y;
15         auto c = C + cStride * y;
16         for (int x = 0; x < widthC4; ++x) {
17             _mm256_storeu_ps(c + 8 * x, _mm256_add_ps(_mm256_loadu_ps(b + 8 * x), _mm256_loadu_ps(a + 8 * x)));
18         }
19     }
20 }
21 
_AVX_MNNStrassenMergeCFunction(float * c11,float * c12,float * c21,float * c22,float * xAddr,size_t cStride,size_t eSub,size_t hSub)22 void _AVX_MNNStrassenMergeCFunction(float* c11, float* c12, float* c21, float* c22, float* xAddr, size_t cStride, size_t eSub, size_t hSub) {
23     const int unit = 8;
24     for (int y=0; y<hSub; ++y) {
25         auto c11Y = c11 + y * cStride;
26         auto c12Y = c12 + y * cStride;
27         auto c22Y = c22 + y * cStride;
28         auto c21Y = c21 + y * cStride;
29         auto xY = xAddr + y * eSub * unit;
30         for (int x=0; x<eSub; ++x) {
31             auto xv = _mm256_loadu_ps(xY + unit*x);
32             auto c21v = _mm256_loadu_ps(c21Y + unit*x);
33             auto c11v = _mm256_loadu_ps(c11Y + unit*x);
34             auto c22v = _mm256_loadu_ps(c22Y + unit*x);
35             auto c12v = _mm256_loadu_ps(c12Y + unit*x);
36             c12v = _mm256_add_ps(c12v, xv);
37             c21v = _mm256_add_ps(c12v, c21v);
38             c12v = _mm256_add_ps(c22v, c12v);
39             c22v = _mm256_add_ps(c22v, c21v);
40             c12v = _mm256_add_ps(c11v, c12v);
41             _mm256_storeu_ps(c12Y + unit*x, c12v);
42             _mm256_storeu_ps(c22Y + unit*x, c22v);
43             _mm256_storeu_ps(c21Y + unit*x, c21v);
44         }
45     }
46 }
47