1 //
2 //  GemmFunction.hpp
3 //  MNN
4 //
5 //  Created by MNN on 2020/09/22.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #define MNN_UNIT_E 24
10 #define TRANPOSE_SAVE(u, v, z0, z3, z6, z9)              \
11     {                                                    \
12         auto m0 = _mm256_extractf128_ps(z0, u);          \
13         auto m1 = _mm256_extractf128_ps(z3, u);          \
14         auto m2 = _mm256_extractf128_ps(z6, u);          \
15         auto m3 = _mm256_extractf128_ps(z9, u);          \
16         _MM_TRANSPOSE4_PS(m0, m1, m2, m3);               \
17         STORE_4(dst + 8 * (0 + 4 * u + 8 * v), m0); \
18         STORE_4(dst + 8 * (1 + 4 * u + 8 * v), m1); \
19         STORE_4(dst + 8 * (2 + 4 * u + 8 * v), m2); \
20         STORE_4(dst + 8 * (3 + 4 * u + 8 * v), m3); \
21     }
22 
23 namespace {
mm_loadu_si128(const void * addr)24 static inline __m128i mm_loadu_si128(const void* addr) {
25     return _mm_castps_si128(LOAD4((const float*)addr));
26 }
27 
mm256_broadcastsi128_si256(const void * addr)28 static inline __m256i mm256_broadcastsi128_si256(const void* addr) {
29     return _mm256_broadcastsi128_si256(mm_loadu_si128(addr));
30 }
31 }  // namespace
32 //
33 #define INIT_MAIN_24_4                                  \
34     auto s0  = LOAD8(A + 0 * 24);             \
35     auto s1  = LOAD8(A + 0 * 24 + 8);         \
36     auto s2  = LOAD8(A + 0 * 24 + 16);        \
37     auto w0  = BROAD_LOAD(weight + 0 * 4 + 0); \
38     auto z0  = _mm256_mul_ps(s0, w0);                   \
39     auto z1  = _mm256_mul_ps(s1, w0);                   \
40     auto z2  = _mm256_mul_ps(s2, w0);                   \
41     auto w1  = BROAD_LOAD(weight + 0 * 4 + 1); \
42     auto z3  = _mm256_mul_ps(s0, w1);                   \
43     auto z4  = _mm256_mul_ps(s1, w1);                   \
44     auto z5  = _mm256_mul_ps(s2, w1);                   \
45     auto w2  = BROAD_LOAD(weight + 0 * 4 + 2); \
46     auto z6  = _mm256_mul_ps(s0, w2);                   \
47     auto z7  = _mm256_mul_ps(s1, w2);                   \
48     auto z8  = _mm256_mul_ps(s2, w2);                   \
49     auto w3  = BROAD_LOAD(weight + 0 * 4 + 3); \
50     auto z9  = _mm256_mul_ps(s0, w3);                   \
51     auto z10 = _mm256_mul_ps(s1, w3);                   \
52     auto z11 = _mm256_mul_ps(s2, w3);
53 
54 #define COMPUTE_24_4                                \
55     s0  = LOAD8(A + sy * 24);             \
56     s1  = LOAD8(A + sy * 24 + 8);         \
57     s2  = LOAD8(A + sy * 24 + 16);        \
58     w0  = BROAD_LOAD(weight + sy * 4 + 0); \
59     z0  = MNNAVXFMA(s0, w0, z0);                    \
60     z1  = MNNAVXFMA(s1, w0, z1);                    \
61     z2  = MNNAVXFMA(s2, w0, z2);                    \
62     w1  = BROAD_LOAD(weight + sy * 4 + 1); \
63     z3  = MNNAVXFMA(s0, w1, z3);                    \
64     z4  = MNNAVXFMA(s1, w1, z4);                    \
65     z5  = MNNAVXFMA(s2, w1, z5);                    \
66     w2  = BROAD_LOAD(weight + sy * 4 + 2); \
67     z6  = MNNAVXFMA(s0, w2, z6);                    \
68     z7  = MNNAVXFMA(s1, w2, z7);                    \
69     z8  = MNNAVXFMA(s2, w2, z8);                    \
70     w3  = BROAD_LOAD(weight + sy * 4 + 3); \
71     z9  = MNNAVXFMA(s0, w3, z9);                    \
72     z10 = MNNAVXFMA(s1, w3, z10);                   \
73     z11 = MNNAVXFMA(s2, w3, z11);
74 
75 template <typename TYPE>
_AVX_MNNPackedMatMul_Main(TYPE * C,const TYPE * A,const TYPE * B,const size_t * parameter)76 static void _AVX_MNNPackedMatMul_Main(TYPE* C, const TYPE* A, const TYPE* B, const size_t* parameter) {
77     auto h            = parameter[2];
78     auto l            = parameter[1];
79     auto cStride      = parameter[3] / sizeof(TYPE);
80     auto bExtraStride = parameter[5] / sizeof(TYPE);
81     auto bStride      = bExtraStride + l * 4;
82     auto hC4          = UP_DIV(h, 4);
83     for (int y = 0; y < hC4; ++y) {
84         auto weight = B + y * bStride;
85         auto dst    = C + (y / 2) * cStride + 4 * (y % 2);
86         INIT_MAIN_24_4;
87 
88         for (int sy = 1; sy < l; ++sy) {
89             COMPUTE_24_4;
90         }
91         TRANPOSE_SAVE(0, 0, z0, z3, z6, z9);
92         TRANPOSE_SAVE(1, 0, z0, z3, z6, z9);
93         TRANPOSE_SAVE(0, 1, z1, z4, z7, z10);
94         TRANPOSE_SAVE(1, 1, z1, z4, z7, z10);
95         TRANPOSE_SAVE(0, 2, z2, z5, z8, z11);
96         TRANPOSE_SAVE(1, 2, z2, z5, z8, z11);
97     }
98 }
99 
100 
101 #define EXPAND_128(x) _mm256_castsi256_ps(_mm256_broadcastsi128_si256(_mm_castps_si128((x))))
102 //
103 #define INIT_MAIN_20_4                                  \
104     auto s0  = LOAD8(A + 0 * aStride);             \
105     auto s1  = LOAD8(A + 0 * aStride + 8);         \
106     auto s2  = EXPAND_128(LOAD4(A + 0 * aStride + 16));        \
107     auto w0  = BROAD_LOAD(weight + 0 * 4 + 0); \
108     auto z0  = _mm256_mul_ps(s0, w0);                   \
109     auto z1  = _mm256_mul_ps(s1, w0);                   \
110     auto z2  = _mm256_mul_ps(s2, w0);                   \
111     auto w1  = BROAD_LOAD(weight + 0 * 4 + 1); \
112     auto z3  = _mm256_mul_ps(s0, w1);                   \
113     auto z4  = _mm256_mul_ps(s1, w1);                   \
114     auto z5  = _mm256_mul_ps(s2, w1);                   \
115     auto w2  = BROAD_LOAD(weight + 0 * 4 + 2); \
116     auto z6  = _mm256_mul_ps(s0, w2);                   \
117     auto z7  = _mm256_mul_ps(s1, w2);                   \
118     auto z8  = _mm256_mul_ps(s2, w2);                   \
119     auto w3  = BROAD_LOAD(weight + 0 * 4 + 3); \
120     auto z9  = _mm256_mul_ps(s0, w3);                   \
121     auto z10 = _mm256_mul_ps(s1, w3);                   \
122     auto z11 = _mm256_mul_ps(s2, w3);
123 
124 #define COMPUTE_20_4                                \
125     s0  = LOAD8(A + sy * aStride);             \
126     s1  = LOAD8(A + sy * aStride + 8);         \
127     s2  = EXPAND_128(LOAD4(A + sy * aStride + 16)); \
128     w0  = BROAD_LOAD(weight + sy * 4 + 0); \
129     z0  = MNNAVXFMA(s0, w0, z0);                    \
130     z1  = MNNAVXFMA(s1, w0, z1);                    \
131     z2  = MNNAVXFMA(s2, w0, z2);                    \
132     w1  = BROAD_LOAD(weight + sy * 4 + 1); \
133     z3  = MNNAVXFMA(s0, w1, z3);                    \
134     z4  = MNNAVXFMA(s1, w1, z4);                    \
135     z5  = MNNAVXFMA(s2, w1, z5);                    \
136     w2  = BROAD_LOAD(weight + sy * 4 + 2); \
137     z6  = MNNAVXFMA(s0, w2, z6);                    \
138     z7  = MNNAVXFMA(s1, w2, z7);                    \
139     z8  = MNNAVXFMA(s2, w2, z8);                    \
140     w3  = BROAD_LOAD(weight + sy * 4 + 3); \
141     z9  = MNNAVXFMA(s0, w3, z9);                    \
142     z10 = MNNAVXFMA(s1, w3, z10);                   \
143     z11 = MNNAVXFMA(s2, w3, z11);
144 
145 
146 template <typename TYPE>
_AVX_MNNPackedMatMul_20(TYPE * C,const TYPE * A,const TYPE * B,const size_t * parameter)147 static void _AVX_MNNPackedMatMul_20(TYPE* C, const TYPE* A, const TYPE* B, const size_t* parameter) {
148     auto aStride      = parameter[0] / sizeof(TYPE);
149     auto h            = parameter[2];
150     auto l            = parameter[1];
151     auto cStride      = parameter[3] / sizeof(TYPE);
152     auto bExtraStride = parameter[5] / sizeof(TYPE);
153     auto bStride      = bExtraStride + l * 4;
154     auto hC4          = UP_DIV(h, 4);
155     for (int y = 0; y < hC4; ++y) {
156         auto weight = B + y * bStride;
157         auto dst    = C + (y / 2) * cStride + 4 * (y % 2);
158         INIT_MAIN_20_4;
159 
160         for (int sy = 1; sy < l; ++sy) {
161             COMPUTE_20_4;
162         }
163         TRANPOSE_SAVE(0, 0, z0, z3, z6, z9);
164         TRANPOSE_SAVE(1, 0, z0, z3, z6, z9);
165         TRANPOSE_SAVE(0, 1, z1, z4, z7, z10);
166         TRANPOSE_SAVE(1, 1, z1, z4, z7, z10);
167         TRANPOSE_SAVE(0, 2, z2, z5, z8, z11);
168     }
169 }
170 
171 #define INIT_MAIN_16_4                                  \
172     auto s0  = LOAD8(A + 0 * aStride);        \
173     auto s1  = LOAD8(A + 0 * aStride + 8);    \
174     auto w0  = BROAD_LOAD(weight + 0 * 4 + 0); \
175     auto z0  = _mm256_mul_ps(s0, w0);                   \
176     auto z1  = _mm256_mul_ps(s1, w0);                   \
177     auto w1  = BROAD_LOAD(weight + 0 * 4 + 1); \
178     auto z3  = _mm256_mul_ps(s0, w1);                   \
179     auto z4  = _mm256_mul_ps(s1, w1);                   \
180     auto w2  = BROAD_LOAD(weight + 0 * 4 + 2); \
181     auto z6  = _mm256_mul_ps(s0, w2);                   \
182     auto z7  = _mm256_mul_ps(s1, w2);                   \
183     auto w3  = BROAD_LOAD(weight + 0 * 4 + 3); \
184     auto z9  = _mm256_mul_ps(s0, w3);                   \
185     auto z10 = _mm256_mul_ps(s1, w3);
186 
187 #define COMPUTE_16_4                                \
188     s0  = LOAD8(A + sy * aStride);        \
189     s1  = LOAD8(A + sy * aStride + 8);    \
190     w0  = BROAD_LOAD(weight + sy * 4 + 0); \
191     z0  = MNNAVXFMA(s0, w0, z0);                    \
192     z1  = MNNAVXFMA(s1, w0, z1);                    \
193     w1  = BROAD_LOAD(weight + sy * 4 + 1); \
194     z3  = MNNAVXFMA(s0, w1, z3);                    \
195     z4  = MNNAVXFMA(s1, w1, z4);                    \
196     w2  = BROAD_LOAD(weight + sy * 4 + 2); \
197     z6  = MNNAVXFMA(s0, w2, z6);                    \
198     z7  = MNNAVXFMA(s1, w2, z7);                    \
199     w3  = BROAD_LOAD(weight + sy * 4 + 3); \
200     z9  = MNNAVXFMA(s0, w3, z9);                    \
201     z10 = MNNAVXFMA(s1, w3, z10);
202 
203 template <typename TYPE>
_AVX_MNNPackedMatMul_16(TYPE * C,const TYPE * A,const TYPE * B,const size_t * parameter)204 static void _AVX_MNNPackedMatMul_16(TYPE* C, const TYPE* A, const TYPE* B, const size_t* parameter) {
205     auto aStride      = parameter[0] / sizeof(TYPE);
206     auto h            = parameter[2];
207     auto l            = parameter[1];
208     auto cStride      = parameter[3] / sizeof(TYPE);
209     auto bExtraStride = parameter[5] / sizeof(TYPE);
210     auto bStride      = bExtraStride + l * 4;
211     auto hC4          = UP_DIV(h, 4);
212     for (int y = 0; y < hC4; ++y) {
213         auto weight = B + y * bStride;
214         auto dst    = C + (y / 2) * cStride + 4 * (y % 2);
215         INIT_MAIN_16_4;
216 
217         for (int sy = 1; sy < l; ++sy) {
218             COMPUTE_16_4;
219         }
220         TRANPOSE_SAVE(0, 0, z0, z3, z6, z9);
221         TRANPOSE_SAVE(1, 0, z0, z3, z6, z9);
222         TRANPOSE_SAVE(0, 1, z1, z4, z7, z10);
223         TRANPOSE_SAVE(1, 1, z1, z4, z7, z10);
224     }
225 }
226 
227 #define DST_ADDR_UNPACK4(x)\
228 auto dst0    = C + (hC4Unit * y / 2 + 0) * cStride + x * 8;\
229 auto dst1    = C + (hC4Unit * y / 2 + 0) * cStride + x * 8 + 4;\
230 auto dst2    = C + (hC4Unit * y / 2 + 1) * cStride + x * 8;\
231 auto dst3    = C + (hC4Unit * y / 2 + 1) * cStride + x * 8 + 4;\
232 
233 template <typename TYPE>
_AVX_MNNPackedMatMul_5(TYPE * C,const TYPE * A,const TYPE * B,const size_t * parameter)234 static void _AVX_MNNPackedMatMul_5(TYPE* C, const TYPE* A, const TYPE* B, const size_t* parameter) {
235     auto aStride      = parameter[0] / sizeof(TYPE);
236     auto h            = parameter[2];
237     auto l            = parameter[1];
238     auto cStride      = parameter[3] / sizeof(TYPE);
239     auto bExtraStride = parameter[5] / sizeof(TYPE);
240     auto bStride      = bExtraStride + l * 4;
241     auto hC4          = UP_DIV(h, 4);
242     int lC4 = l / 4;
243     int lR = lC4 * 4;
244     const int hC4Unit = 4;
245     int hC16 = hC4 / hC4Unit;
246     int hR = hC16 * hC4Unit;
247     auto src = A;
248     for (int y = 0; y < hC16; ++y) {
249         auto weight0 = B + (hC4Unit * y + 0) * bStride;
250         auto weight1 = B + (hC4Unit * y + 1) * bStride;
251         auto weight2 = B + (hC4Unit * y + 2) * bStride;
252         auto weight3 = B + (hC4Unit * y + 3) * bStride;
253         DST_ADDR_UNPACK4(0);
254         auto sumAvx00    = _mm256_setzero_ps();
255         auto sumAvx01    = _mm256_setzero_ps();
256 
257         auto sumAvx10    = _mm256_setzero_ps();
258         auto sumAvx11    = _mm256_setzero_ps();
259 
260         auto sumAvx20    = _mm256_setzero_ps();
261         auto sumAvx21    = _mm256_setzero_ps();
262 
263         auto sumAvx30    = _mm256_setzero_ps();
264         auto sumAvx31    = _mm256_setzero_ps();
265 
266         auto sumAvx40    = _mm256_setzero_ps();
267         auto sumAvx41    = _mm256_setzero_ps();
268 
269         auto srcUse = src;
270         for (int sy = 0; sy < l; ++sy) {
271             auto S0 = BROAD_LOAD(srcUse + 0);
272             auto S1 = BROAD_LOAD(srcUse + 1);
273             auto S2 = BROAD_LOAD(srcUse + 2);
274             auto S3 = BROAD_LOAD(srcUse + 3);
275             auto S4 = BROAD_LOAD(srcUse + 4);
276             auto W0 =  _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight0), mm_loadu_si128(weight1), 1));
277             auto W1 =  _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight2), mm_loadu_si128(weight3), 1));
278 
279             sumAvx00   = MNNAVXFMA(S0, W0, sumAvx00);
280             sumAvx01   = MNNAVXFMA(S0, W1, sumAvx01);
281 
282             sumAvx10   = MNNAVXFMA(S1, W0, sumAvx10);
283             sumAvx11   = MNNAVXFMA(S1, W1, sumAvx11);
284 
285             sumAvx20   = MNNAVXFMA(S2, W0, sumAvx20);
286             sumAvx21   = MNNAVXFMA(S2, W1, sumAvx21);
287 
288             sumAvx30   = MNNAVXFMA(S3, W0, sumAvx30);
289             sumAvx31   = MNNAVXFMA(S3, W1, sumAvx31);
290 
291             sumAvx40   = MNNAVXFMA(S4, W0, sumAvx40);
292             sumAvx41   = MNNAVXFMA(S4, W1, sumAvx41);
293 
294             srcUse += aStride;
295             weight0 += 4;
296             weight1 += 4;
297             weight2 += 4;
298             weight3 += 4;
299         }
300         STORE_8(dst0, sumAvx00);
301         STORE_8(dst0 + 8, sumAvx10);
302         STORE_8(dst0 + 16, sumAvx20);
303         STORE_8(dst0 + 24, sumAvx30);
304         STORE_8(dst0 + 32, sumAvx40);
305 
306         STORE_8(dst2, sumAvx01);
307         STORE_8(dst2 + 8, sumAvx11);
308         STORE_8(dst2 + 16, sumAvx21);
309         STORE_8(dst2 + 24, sumAvx31);
310         STORE_8(dst2 + 32, sumAvx41);
311     }
312     for (int y = hR; y < hC4; ++y) {
313         auto weight = B + y * bStride;
314         auto dst    = C + (y / 2) * cStride + 4 * (y % 2);
315         auto s0     = BROAD_LOAD_4(A + 0 * aStride + 0);
316         auto s1     = BROAD_LOAD_4(A + 0 * aStride + 1);
317         auto s2     = BROAD_LOAD_4(A + 0 * aStride + 2);
318         auto s3     = BROAD_LOAD_4(A + 0 * aStride + 3);
319         auto s4     = BROAD_LOAD_4(A + 0 * aStride + 4);
320         auto w0     = LOAD4(weight + 0 * 4);
321         auto z0     = _mm_mul_ps(s0, w0);
322         auto z1     = _mm_mul_ps(s1, w0);
323         auto z2     = _mm_mul_ps(s2, w0);
324         auto z3     = _mm_mul_ps(s3, w0);
325         auto z4     = _mm_mul_ps(s4, w0);
326 
327         for (int sy = 1; sy < l; ++sy) {
328             s0 = BROAD_LOAD_4(A + sy * aStride + 0);
329             s1 = BROAD_LOAD_4(A + sy * aStride + 1);
330             s2 = BROAD_LOAD_4(A + sy * aStride + 2);
331             s3 = BROAD_LOAD_4(A + sy * aStride + 3);
332             s4 = BROAD_LOAD_4(A + sy * aStride + 4);
333             w0 = LOAD4(weight + sy * 4);
334             z0 = MNNSSEFMA(s0, w0, z0);
335             z1 = MNNSSEFMA(s1, w0, z1);
336             z2 = MNNSSEFMA(s2, w0, z2);
337             z3 = MNNSSEFMA(s3, w0, z3);
338             z4 = MNNSSEFMA(s4, w0, z4);
339         }
340         STORE_4(dst + 8 * 0, z0);
341         STORE_4(dst + 8 * 1, z1);
342         STORE_4(dst + 8 * 2, z2);
343         STORE_4(dst + 8 * 3, z3);
344         STORE_4(dst + 8 * 4, z4);
345     }
346 }
347 
348 template <typename TYPE>
_AVX_MNNPackedMatMul_3(TYPE * C,const TYPE * A,const TYPE * B,const size_t * parameter)349 static void _AVX_MNNPackedMatMul_3(TYPE* C, const TYPE* A, const TYPE* B, const size_t* parameter) {
350     auto aStride      = parameter[0] / sizeof(TYPE);
351     auto h            = parameter[2];
352     auto l            = parameter[1];
353     auto cStride      = parameter[3] / sizeof(TYPE);
354     auto bExtraStride = parameter[5] / sizeof(TYPE);
355     auto bStride      = bExtraStride + l * 4;
356     auto hC4          = UP_DIV(h, 4);
357     int lC4 = l / 4;
358     int lR = lC4 * 4;
359     const int hC4Unit = 4;
360     int hC16 = hC4 / hC4Unit;
361     int hR = hC16 * hC4Unit;
362     auto src = A;
363     for (int y = 0; y < hC16; ++y) {
364         auto weight0 = B + (hC4Unit * y + 0) * bStride;
365         auto weight1 = B + (hC4Unit * y + 1) * bStride;
366         auto weight2 = B + (hC4Unit * y + 2) * bStride;
367         auto weight3 = B + (hC4Unit * y + 3) * bStride;
368         auto sumAvx00    = _mm256_setzero_ps();
369         auto sumAvx01    = _mm256_setzero_ps();
370 
371         auto sumAvx10    = _mm256_setzero_ps();
372         auto sumAvx11    = _mm256_setzero_ps();
373 
374         auto sumAvx20    = _mm256_setzero_ps();
375         auto sumAvx21    = _mm256_setzero_ps();
376 
377         DST_ADDR_UNPACK4(0);
378 
379         auto srcUse = src;
380         for (int sy = 0; sy < l; ++sy) {
381             auto S0 = BROAD_LOAD(srcUse + 0);
382             auto S1 = BROAD_LOAD(srcUse + 1);
383             auto S2 = BROAD_LOAD(srcUse + 2);
384             auto W0 =  _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight0), mm_loadu_si128(weight1), 1));
385             auto W1 =  _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight2), mm_loadu_si128(weight3), 1));
386 
387             sumAvx00   = MNNAVXFMA(S0, W0, sumAvx00);
388             sumAvx01   = MNNAVXFMA(S0, W1, sumAvx01);
389 
390             sumAvx10   = MNNAVXFMA(S1, W0, sumAvx10);
391             sumAvx11   = MNNAVXFMA(S1, W1, sumAvx11);
392 
393             sumAvx20   = MNNAVXFMA(S2, W0, sumAvx20);
394             sumAvx21   = MNNAVXFMA(S2, W1, sumAvx21);
395 
396             srcUse += aStride;
397             weight0 += 4;
398             weight1 += 4;
399             weight2 += 4;
400             weight3 += 4;
401         }
402         STORE_4(dst0 + 0, _mm256_extractf128_ps(sumAvx00, 0));
403         STORE_4(dst0 + 8, _mm256_extractf128_ps(sumAvx10, 0));
404         STORE_4(dst0 + 16, _mm256_extractf128_ps(sumAvx20, 0));
405 
406         STORE_4(dst1 + 0, _mm256_extractf128_ps(sumAvx00, 1));
407         STORE_4(dst1 + 8, _mm256_extractf128_ps(sumAvx10, 1));
408         STORE_4(dst1 + 16, _mm256_extractf128_ps(sumAvx20, 1));
409 
410         STORE_4(dst2 + 0, _mm256_extractf128_ps(sumAvx01, 0));
411         STORE_4(dst2 + 8, _mm256_extractf128_ps(sumAvx11, 0));
412         STORE_4(dst2 + 16, _mm256_extractf128_ps(sumAvx21, 0));
413 
414         STORE_4(dst3 + 0, _mm256_extractf128_ps(sumAvx01, 1));
415         STORE_4(dst3 + 8, _mm256_extractf128_ps(sumAvx11, 1));
416         STORE_4(dst3 + 16, _mm256_extractf128_ps(sumAvx21, 1));
417 
418     }
419     for (int y = hR; y < hC4; ++y) {
420         auto weight = B + y * bStride;
421         auto dst    = C + (y / 2) * cStride + 4 * (y % 2);
422         auto s0     = BROAD_LOAD_4(A + 0 * aStride + 0);
423         auto s1     = BROAD_LOAD_4(A + 0 * aStride + 1);
424         auto s2     = BROAD_LOAD_4(A + 0 * aStride + 2);
425         auto w0     = LOAD4(weight + 0 * 4);
426         auto z0     = _mm_mul_ps(s0, w0);
427         auto z1     = _mm_mul_ps(s1, w0);
428         auto z2     = _mm_mul_ps(s2, w0);
429 
430         for (int sy = 1; sy < l; ++sy) {
431             s0 = BROAD_LOAD_4(A + sy * aStride + 0);
432             s1 = BROAD_LOAD_4(A + sy * aStride + 1);
433             s2 = BROAD_LOAD_4(A + sy * aStride + 2);
434             w0 = LOAD4(weight + sy * 4);
435             z0 = MNNSSEFMA(s0, w0, z0);
436             z1 = MNNSSEFMA(s1, w0, z1);
437             z2 = MNNSSEFMA(s2, w0, z2);
438         }
439         STORE_4(dst + 8 * 0, z0);
440         STORE_4(dst + 8 * 1, z1);
441         STORE_4(dst + 8 * 2, z2);
442     }
443 }
444 template <typename TYPE>
_AVX_MNNPackedMatMul_2(TYPE * C,const TYPE * A,const TYPE * B,const size_t * parameter)445 static void _AVX_MNNPackedMatMul_2(TYPE* C, const TYPE* A, const TYPE* B, const size_t* parameter) {
446     auto aStride      = parameter[0] / sizeof(TYPE);
447     auto h            = parameter[2];
448     auto l            = parameter[1];
449     auto cStride      = parameter[3] / sizeof(TYPE);
450     auto bExtraStride = parameter[5] / sizeof(TYPE);
451     auto bStride      = bExtraStride + l * 4;
452     auto hC4          = UP_DIV(h, 4);
453     int lC4 = l / 4;
454     int lR = lC4 * 4;
455     const int hC4Unit = 4;
456     int hC16 = hC4 / hC4Unit;
457     int hR = hC16 * hC4Unit;
458     auto src = A;
459     for (int y = 0; y < hC16; ++y) {
460         auto weight0 = B + (hC4Unit * y + 0) * bStride;
461         auto weight1 = B + (hC4Unit * y + 1) * bStride;
462         auto weight2 = B + (hC4Unit * y + 2) * bStride;
463         auto weight3 = B + (hC4Unit * y + 3) * bStride;
464         auto sumAvx00    = _mm256_setzero_ps();
465         auto sumAvx01    = _mm256_setzero_ps();
466         DST_ADDR_UNPACK4(0);
467 
468         auto sumAvx10    = _mm256_setzero_ps();
469         auto sumAvx11    = _mm256_setzero_ps();
470 
471         auto srcUse = src;
472         for (int sy = 0; sy < l; ++sy) {
473             auto S0 = BROAD_LOAD(srcUse + 0);
474             auto S1 = BROAD_LOAD(srcUse + 1);
475             auto W0 =  _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight0), mm_loadu_si128(weight1), 1));
476             auto W1 =  _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight2), mm_loadu_si128(weight3), 1));
477 
478             sumAvx00   = MNNAVXFMA(S0, W0, sumAvx00);
479             sumAvx01   = MNNAVXFMA(S0, W1, sumAvx01);
480 
481             sumAvx10   = MNNAVXFMA(S1, W0, sumAvx10);
482             sumAvx11   = MNNAVXFMA(S1, W1, sumAvx11);
483 
484             srcUse += aStride;
485             weight0 += 4;
486             weight1 += 4;
487             weight2 += 4;
488             weight3 += 4;
489         }
490         STORE_4(dst0 + 0, _mm256_extractf128_ps(sumAvx00, 0));
491         STORE_4(dst0 + 8, _mm256_extractf128_ps(sumAvx10, 0));
492 
493         STORE_4(dst1 + 0, _mm256_extractf128_ps(sumAvx00, 1));
494         STORE_4(dst1 + 8, _mm256_extractf128_ps(sumAvx10, 1));
495 
496         STORE_4(dst2 + 0, _mm256_extractf128_ps(sumAvx01, 0));
497         STORE_4(dst2 + 8, _mm256_extractf128_ps(sumAvx11, 0));
498 
499         STORE_4(dst3 + 0, _mm256_extractf128_ps(sumAvx01, 1));
500         STORE_4(dst3 + 8, _mm256_extractf128_ps(sumAvx11, 1));
501 
502     }
503     for (int y = hR; y < hC4; ++y) {
504         auto weight = B + y * bStride;
505         auto dst    = C + (y / 2) * cStride + 4 * (y % 2);
506         auto s0     = BROAD_LOAD_4(A + 0 * aStride + 0);
507         auto s1     = BROAD_LOAD_4(A + 0 * aStride + 1);
508         auto w0     = LOAD4(weight + 0 * 4);
509         auto z0     = _mm_mul_ps(s0, w0);
510         auto z1     = _mm_mul_ps(s1, w0);
511 
512         for (int sy = 1; sy < l; ++sy) {
513             s0 = BROAD_LOAD_4(A + sy * aStride + 0);
514             s1 = BROAD_LOAD_4(A + sy * aStride + 1);
515             w0 = LOAD4(weight + sy * 4);
516             z0 = MNNSSEFMA(s0, w0, z0);
517             z1 = MNNSSEFMA(s1, w0, z1);
518         }
519         STORE_4(dst + 8 * 0, z0);
520         STORE_4(dst + 8 * 1, z1);
521     }
522 }
523 
524 template <typename TYPE>
_AVX_MNNPackedMatMul_4(TYPE * C,const TYPE * A,const TYPE * B,const size_t * parameter)525 static void _AVX_MNNPackedMatMul_4(TYPE* C, const TYPE* A, const TYPE* B, const size_t* parameter) {
526     auto aStride      = parameter[0] / sizeof(TYPE);
527     auto h            = parameter[2];
528     auto l            = parameter[1];
529     auto cStride      = parameter[3] / sizeof(TYPE);
530     auto bExtraStride = parameter[5] / sizeof(TYPE);
531     auto bStride      = bExtraStride + l * 4;
532     auto hC4          = UP_DIV(h, 4);
533     int lC4 = l / 4;
534     int lR = lC4 * 4;
535     const int hC4Unit = 4;
536     int hC16 = hC4 / hC4Unit;
537     int hR = hC16 * hC4Unit;
538     auto src = A;
539     for (int y = 0; y < hC16; ++y) {
540         auto weight0 = B + (hC4Unit * y + 0) * bStride;
541         auto weight1 = B + (hC4Unit * y + 1) * bStride;
542         auto weight2 = B + (hC4Unit * y + 2) * bStride;
543         auto weight3 = B + (hC4Unit * y + 3) * bStride;
544         DST_ADDR_UNPACK4(0);
545 
546         auto sumAvx00    = _mm256_setzero_ps();
547         auto sumAvx01    = _mm256_setzero_ps();
548 
549         auto sumAvx10    = _mm256_setzero_ps();
550         auto sumAvx11    = _mm256_setzero_ps();
551 
552         auto sumAvx20    = _mm256_setzero_ps();
553         auto sumAvx21    = _mm256_setzero_ps();
554 
555         auto sumAvx30    = _mm256_setzero_ps();
556         auto sumAvx31    = _mm256_setzero_ps();
557 
558         auto srcUse = src;
559         for (int sy = 0; sy < l; ++sy) {
560             auto S0 = BROAD_LOAD(srcUse + 0);
561             auto S1 = BROAD_LOAD(srcUse + 1);
562             auto S2 = BROAD_LOAD(srcUse + 2);
563             auto S3 = BROAD_LOAD(srcUse + 3);
564             auto W0 =  _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight0), mm_loadu_si128(weight1), 1));
565             auto W1 =  _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight2), mm_loadu_si128(weight3), 1));
566 
567             sumAvx00   = MNNAVXFMA(S0, W0, sumAvx00);
568             sumAvx01   = MNNAVXFMA(S0, W1, sumAvx01);
569 
570             sumAvx10   = MNNAVXFMA(S1, W0, sumAvx10);
571             sumAvx11   = MNNAVXFMA(S1, W1, sumAvx11);
572 
573             sumAvx20   = MNNAVXFMA(S2, W0, sumAvx20);
574             sumAvx21   = MNNAVXFMA(S2, W1, sumAvx21);
575 
576             sumAvx30   = MNNAVXFMA(S3, W0, sumAvx30);
577             sumAvx31   = MNNAVXFMA(S3, W1, sumAvx31);
578 
579             srcUse += aStride;
580             weight0 += 4;
581             weight1 += 4;
582             weight2 += 4;
583             weight3 += 4;
584         }
585         STORE_8(dst0, sumAvx00);
586         STORE_8(dst0 + 8, sumAvx10);
587         STORE_8(dst0 + 16, sumAvx20);
588         STORE_8(dst0 + 24, sumAvx30);
589 
590         STORE_8(dst2, sumAvx01);
591         STORE_8(dst2 + 8, sumAvx11);
592         STORE_8(dst2 + 16, sumAvx21);
593         STORE_8(dst2 + 24, sumAvx31);
594     }
595     for (int y = hR; y < hC4; ++y) {
596         auto weight = B + y * bStride;
597         auto dst    = C + (y / 2) * cStride + 4 * (y % 2);
598         auto s0     = LOAD4(A + 0 * aStride);
599         auto w0     = BROAD_LOAD_4(weight + 0 * 4 + 0);
600         auto w1     = BROAD_LOAD_4(weight + 0 * 4 + 1);
601         auto w2     = BROAD_LOAD_4(weight + 0 * 4 + 2);
602         auto w3     = BROAD_LOAD_4(weight + 0 * 4 + 3);
603         auto z0     = _mm_mul_ps(s0, w0);
604         auto z3     = _mm_mul_ps(s0, w1);
605         auto z6     = _mm_mul_ps(s0, w2);
606         auto z9     = _mm_mul_ps(s0, w3);
607 
608         for (int sy = 1; sy < l; ++sy) {
609             s0 = LOAD4(A + sy * aStride);
610             w0 = BROAD_LOAD_4(weight + sy * 4 + 0);
611             w1 = BROAD_LOAD_4(weight + sy * 4 + 1);
612             w2 = BROAD_LOAD_4(weight + sy * 4 + 2);
613             w3 = BROAD_LOAD_4(weight + sy * 4 + 3);
614             z0 = MNNSSEFMA(s0, w0, z0);
615             z3 = MNNSSEFMA(s0, w1, z3);
616             z6 = MNNSSEFMA(s0, w2, z6);
617             z9 = MNNSSEFMA(s0, w3, z9);
618         }
619         _MM_TRANSPOSE4_PS(z0, z3, z6, z9);
620         STORE_4(dst + 8 * 0, z0);
621         STORE_4(dst + 8 * 1, z3);
622         STORE_4(dst + 8 * 2, z6);
623         STORE_4(dst + 8 * 3, z9);
624     }
625 }
626 template <typename TYPE>
_AVX_MNNPackednMatMulRemainCommon(TYPE * C,const TYPE * A,const TYPE * B,size_t eSize,const size_t * parameter)627 static void _AVX_MNNPackednMatMulRemainCommon(TYPE* C, const TYPE* A, const TYPE* B, size_t eSize,
628                                               const size_t* parameter) {
629     auto h            = parameter[2];
630     auto l            = parameter[1];
631     auto cStride      = parameter[3] / sizeof(TYPE);
632     auto bExtraStride = parameter[5] / sizeof(TYPE);
633     auto bStride      = bExtraStride + l * 4;
634     auto hC4          = UP_DIV(h, 4);
635     auto es           = eSize;
636     auto oC           = C;
637     auto aStride      = parameter[0] / sizeof(TYPE);
638     if (eSize >= 20) {
639         _AVX_MNNPackedMatMul_20<TYPE>(C, A, B, parameter);
640         eSize -= 20;
641         C += 20 * 8;
642         A += 20;
643     }
644     if (eSize >= 16) {
645         _AVX_MNNPackedMatMul_16<TYPE>(C, A, B, parameter);
646         eSize -= 16;
647         C += 16 * 8;
648         A += 16;
649     }
650     while (eSize >= 5) {
651         _AVX_MNNPackedMatMul_5<TYPE>(C, A, B, parameter);
652         eSize -= 5;
653         C += 5 * 8;
654         A += 5;
655     }
656     if (eSize == 4) {
657         _AVX_MNNPackedMatMul_4<TYPE>(C, A, B, parameter);
658         return;
659     }
660     if (eSize == 3) {
661         _AVX_MNNPackedMatMul_3<TYPE>(C, A, B, parameter);
662         return;
663     }
664     if (eSize == 2) {
665         _AVX_MNNPackedMatMul_2<TYPE>(C, A, B, parameter);
666         return;
667     }
668     if (eSize == 0) {
669         return;
670     }
671     int lC4 = l / 4;
672     int lR = lC4 * 4;
673     const int hC4Unit = 4;
674     int hC16 = hC4 / hC4Unit;
675     int hR = hC16 * hC4Unit;
676     auto src = A;
677     int x = 0;
678     for (int y = 0; y < hC16; ++y) {
679         auto weight0 = B + (hC4Unit * y + 0) * bStride;
680         auto weight1 = B + (hC4Unit * y + 1) * bStride;
681         auto weight2 = B + (hC4Unit * y + 2) * bStride;
682         auto weight3 = B + (hC4Unit * y + 3) * bStride;
683         auto dst0    = C + (hC4Unit * y / 2 + 0) * cStride + x * 8;
684         auto dst1    = C + (hC4Unit * y / 2 + 0) * cStride + x * 8 + 4;
685         auto dst2    = C + (hC4Unit * y / 2 + 1) * cStride + x * 8;
686         auto dst3    = C + (hC4Unit * y / 2 + 1) * cStride + x * 8 + 4;
687         auto sumAvx00    = _mm256_setzero_ps();
688         auto sumAvx01    = _mm256_setzero_ps();
689 
690         auto sumAvx10    = _mm256_setzero_ps();
691         auto sumAvx11    = _mm256_setzero_ps();
692 
693         auto sumAvx20    = _mm256_setzero_ps();
694         auto sumAvx21    = _mm256_setzero_ps();
695 
696         auto sumAvx30    = _mm256_setzero_ps();
697         auto sumAvx31    = _mm256_setzero_ps();
698 
699         auto srcUse = src;
700         for (int sy = 0; sy < lC4; ++sy) {
701             auto s0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (0) * aStride));
702             auto s1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (1) * aStride));
703             auto S0 = _mm256_castsi256_ps(_mm256_insertf128_si256(s0, s1, 1));
704             auto d0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (2) * aStride));
705             auto d1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (3) * aStride));
706             auto S1 = _mm256_castsi256_ps(_mm256_insertf128_si256(d0, d1, 1));
707             auto W00 = LOAD8(weight0 + 16 * sy + 0);
708             auto W01 = LOAD8(weight0 + 16 * sy + 8);
709             auto W10 = LOAD8(weight1 + 16 * sy + 0);
710             auto W11 = LOAD8(weight1 + 16 * sy + 8);
711 
712             auto W20 = LOAD8(weight2 + 16 * sy + 0);
713             auto W21 = LOAD8(weight2 + 16 * sy + 8);
714             auto W30 = LOAD8(weight3 + 16 * sy + 0);
715             auto W31 = LOAD8(weight3 + 16 * sy + 8);
716 
717             sumAvx00   = MNNAVXFMA(S0, W00, sumAvx00);
718             sumAvx01   = MNNAVXFMA(S1, W01, sumAvx01);
719 
720             sumAvx10   = MNNAVXFMA(S0, W10, sumAvx10);
721             sumAvx11   = MNNAVXFMA(S1, W11, sumAvx11);
722 
723             sumAvx20   = MNNAVXFMA(S0, W20, sumAvx20);
724             sumAvx21   = MNNAVXFMA(S1, W21, sumAvx21);
725 
726             sumAvx30   = MNNAVXFMA(S0, W30, sumAvx30);
727             sumAvx31   = MNNAVXFMA(S1, W31, sumAvx31);
728             srcUse += 4 * aStride;
729         }
730         sumAvx00 = _mm256_add_ps(sumAvx00, sumAvx01);
731         sumAvx10 = _mm256_add_ps(sumAvx10, sumAvx11);
732         sumAvx20 = _mm256_add_ps(sumAvx20, sumAvx21);
733         sumAvx30 = _mm256_add_ps(sumAvx30, sumAvx31);
734         auto sum00 = _mm256_extractf128_ps(sumAvx00, 0);
735         auto sum01 = _mm256_extractf128_ps(sumAvx00, 1);
736         auto sum0 = _mm_add_ps(sum00, sum01);
737         auto sum10 = _mm256_extractf128_ps(sumAvx10, 0);
738         auto sum11 = _mm256_extractf128_ps(sumAvx10, 1);
739         auto sum1 = _mm_add_ps(sum10, sum11);
740 
741         auto sum20 = _mm256_extractf128_ps(sumAvx20, 0);
742         auto sum21 = _mm256_extractf128_ps(sumAvx20, 1);
743         auto sum2 = _mm_add_ps(sum20, sum21);
744         auto sum30 = _mm256_extractf128_ps(sumAvx30, 0);
745         auto sum31 = _mm256_extractf128_ps(sumAvx30, 1);
746         auto sum3 = _mm_add_ps(sum30, sum31);
747         for (int sy = lR; sy < l; ++sy) {
748             auto s = BROAD_LOAD_4(srcUse);
749             auto w0 = LOAD4(weight0 + 4 * sy);
750             auto w1 = LOAD4(weight1 + 4 * sy);
751             auto w2 = LOAD4(weight2 + 4 * sy);
752             auto w3 = LOAD4(weight3 + 4 * sy);
753             sum0    = MNNSSEFMA(s, w0, sum0);
754             sum1    = MNNSSEFMA(s, w1, sum1);
755             sum2    = MNNSSEFMA(s, w2, sum2);
756             sum3    = MNNSSEFMA(s, w3, sum3);
757             srcUse += aStride;
758         }
759         STORE_4(dst0, sum0);
760         STORE_4(dst1, sum1);
761         STORE_4(dst2, sum2);
762         STORE_4(dst3, sum3);
763     }
764     for (int y = hR; y < hC4; ++y) {
765         auto weight = B + y * bStride;
766         auto dst    = C + (y / 2) * cStride + x * 8 + 4 * (y % 2);
767         auto sumAvx0    = _mm256_setzero_ps();
768         auto sumAvx1    = _mm256_setzero_ps();
769         auto srcUse = src;
770         for (int sy = 0; sy < lC4; ++sy) {
771             auto s0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (0) * aStride));
772             auto s1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (1) * aStride));
773             auto S0 = _mm256_castsi256_ps(_mm256_insertf128_si256(s0, s1, 1));
774             auto d0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (2) * aStride));
775             auto d1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (3) * aStride));
776             auto S1 = _mm256_castsi256_ps(_mm256_insertf128_si256(d0, d1, 1));
777             auto W0 = LOAD8(weight + 16 * sy + 0);
778             auto W1 = LOAD8(weight + 16 * sy + 8);
779             sumAvx0   = MNNAVXFMA(S0, W0, sumAvx0);
780             sumAvx1   = MNNAVXFMA(S1, W1, sumAvx1);
781             srcUse += 4 * aStride;
782         }
783         sumAvx0 = _mm256_add_ps(sumAvx0, sumAvx1);
784         auto sum0 = _mm256_extractf128_ps(sumAvx0, 0);
785         auto sum1 = _mm256_extractf128_ps(sumAvx0, 1);
786         auto sum = _mm_add_ps(sum0, sum1);
787         for (int sy = lR; sy < l; ++sy) {
788             auto s = BROAD_LOAD_4(srcUse);
789             auto w = LOAD4(weight + 4 * sy);
790             sum    = MNNSSEFMA(s, w, sum);
791             srcUse += aStride;
792         }
793         STORE_4(dst, sum);
794     }
795 }
796