1 //
2 // GemmFunction.hpp
3 // MNN
4 //
5 // Created by MNN on 2020/09/22.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #define MNN_UNIT_E 24
10 #define TRANPOSE_SAVE(u, v, z0, z3, z6, z9) \
11 { \
12 auto m0 = _mm256_extractf128_ps(z0, u); \
13 auto m1 = _mm256_extractf128_ps(z3, u); \
14 auto m2 = _mm256_extractf128_ps(z6, u); \
15 auto m3 = _mm256_extractf128_ps(z9, u); \
16 _MM_TRANSPOSE4_PS(m0, m1, m2, m3); \
17 STORE_4(dst + 8 * (0 + 4 * u + 8 * v), m0); \
18 STORE_4(dst + 8 * (1 + 4 * u + 8 * v), m1); \
19 STORE_4(dst + 8 * (2 + 4 * u + 8 * v), m2); \
20 STORE_4(dst + 8 * (3 + 4 * u + 8 * v), m3); \
21 }
22
23 namespace {
mm_loadu_si128(const void * addr)24 static inline __m128i mm_loadu_si128(const void* addr) {
25 return _mm_castps_si128(LOAD4((const float*)addr));
26 }
27
mm256_broadcastsi128_si256(const void * addr)28 static inline __m256i mm256_broadcastsi128_si256(const void* addr) {
29 return _mm256_broadcastsi128_si256(mm_loadu_si128(addr));
30 }
31 } // namespace
32 //
33 #define INIT_MAIN_24_4 \
34 auto s0 = LOAD8(A + 0 * 24); \
35 auto s1 = LOAD8(A + 0 * 24 + 8); \
36 auto s2 = LOAD8(A + 0 * 24 + 16); \
37 auto w0 = BROAD_LOAD(weight + 0 * 4 + 0); \
38 auto z0 = _mm256_mul_ps(s0, w0); \
39 auto z1 = _mm256_mul_ps(s1, w0); \
40 auto z2 = _mm256_mul_ps(s2, w0); \
41 auto w1 = BROAD_LOAD(weight + 0 * 4 + 1); \
42 auto z3 = _mm256_mul_ps(s0, w1); \
43 auto z4 = _mm256_mul_ps(s1, w1); \
44 auto z5 = _mm256_mul_ps(s2, w1); \
45 auto w2 = BROAD_LOAD(weight + 0 * 4 + 2); \
46 auto z6 = _mm256_mul_ps(s0, w2); \
47 auto z7 = _mm256_mul_ps(s1, w2); \
48 auto z8 = _mm256_mul_ps(s2, w2); \
49 auto w3 = BROAD_LOAD(weight + 0 * 4 + 3); \
50 auto z9 = _mm256_mul_ps(s0, w3); \
51 auto z10 = _mm256_mul_ps(s1, w3); \
52 auto z11 = _mm256_mul_ps(s2, w3);
53
54 #define COMPUTE_24_4 \
55 s0 = LOAD8(A + sy * 24); \
56 s1 = LOAD8(A + sy * 24 + 8); \
57 s2 = LOAD8(A + sy * 24 + 16); \
58 w0 = BROAD_LOAD(weight + sy * 4 + 0); \
59 z0 = MNNAVXFMA(s0, w0, z0); \
60 z1 = MNNAVXFMA(s1, w0, z1); \
61 z2 = MNNAVXFMA(s2, w0, z2); \
62 w1 = BROAD_LOAD(weight + sy * 4 + 1); \
63 z3 = MNNAVXFMA(s0, w1, z3); \
64 z4 = MNNAVXFMA(s1, w1, z4); \
65 z5 = MNNAVXFMA(s2, w1, z5); \
66 w2 = BROAD_LOAD(weight + sy * 4 + 2); \
67 z6 = MNNAVXFMA(s0, w2, z6); \
68 z7 = MNNAVXFMA(s1, w2, z7); \
69 z8 = MNNAVXFMA(s2, w2, z8); \
70 w3 = BROAD_LOAD(weight + sy * 4 + 3); \
71 z9 = MNNAVXFMA(s0, w3, z9); \
72 z10 = MNNAVXFMA(s1, w3, z10); \
73 z11 = MNNAVXFMA(s2, w3, z11);
74
75 template <typename TYPE>
_AVX_MNNPackedMatMul_Main(TYPE * C,const TYPE * A,const TYPE * B,const size_t * parameter)76 static void _AVX_MNNPackedMatMul_Main(TYPE* C, const TYPE* A, const TYPE* B, const size_t* parameter) {
77 auto h = parameter[2];
78 auto l = parameter[1];
79 auto cStride = parameter[3] / sizeof(TYPE);
80 auto bExtraStride = parameter[5] / sizeof(TYPE);
81 auto bStride = bExtraStride + l * 4;
82 auto hC4 = UP_DIV(h, 4);
83 for (int y = 0; y < hC4; ++y) {
84 auto weight = B + y * bStride;
85 auto dst = C + (y / 2) * cStride + 4 * (y % 2);
86 INIT_MAIN_24_4;
87
88 for (int sy = 1; sy < l; ++sy) {
89 COMPUTE_24_4;
90 }
91 TRANPOSE_SAVE(0, 0, z0, z3, z6, z9);
92 TRANPOSE_SAVE(1, 0, z0, z3, z6, z9);
93 TRANPOSE_SAVE(0, 1, z1, z4, z7, z10);
94 TRANPOSE_SAVE(1, 1, z1, z4, z7, z10);
95 TRANPOSE_SAVE(0, 2, z2, z5, z8, z11);
96 TRANPOSE_SAVE(1, 2, z2, z5, z8, z11);
97 }
98 }
99
100
101 #define EXPAND_128(x) _mm256_castsi256_ps(_mm256_broadcastsi128_si256(_mm_castps_si128((x))))
102 //
103 #define INIT_MAIN_20_4 \
104 auto s0 = LOAD8(A + 0 * aStride); \
105 auto s1 = LOAD8(A + 0 * aStride + 8); \
106 auto s2 = EXPAND_128(LOAD4(A + 0 * aStride + 16)); \
107 auto w0 = BROAD_LOAD(weight + 0 * 4 + 0); \
108 auto z0 = _mm256_mul_ps(s0, w0); \
109 auto z1 = _mm256_mul_ps(s1, w0); \
110 auto z2 = _mm256_mul_ps(s2, w0); \
111 auto w1 = BROAD_LOAD(weight + 0 * 4 + 1); \
112 auto z3 = _mm256_mul_ps(s0, w1); \
113 auto z4 = _mm256_mul_ps(s1, w1); \
114 auto z5 = _mm256_mul_ps(s2, w1); \
115 auto w2 = BROAD_LOAD(weight + 0 * 4 + 2); \
116 auto z6 = _mm256_mul_ps(s0, w2); \
117 auto z7 = _mm256_mul_ps(s1, w2); \
118 auto z8 = _mm256_mul_ps(s2, w2); \
119 auto w3 = BROAD_LOAD(weight + 0 * 4 + 3); \
120 auto z9 = _mm256_mul_ps(s0, w3); \
121 auto z10 = _mm256_mul_ps(s1, w3); \
122 auto z11 = _mm256_mul_ps(s2, w3);
123
124 #define COMPUTE_20_4 \
125 s0 = LOAD8(A + sy * aStride); \
126 s1 = LOAD8(A + sy * aStride + 8); \
127 s2 = EXPAND_128(LOAD4(A + sy * aStride + 16)); \
128 w0 = BROAD_LOAD(weight + sy * 4 + 0); \
129 z0 = MNNAVXFMA(s0, w0, z0); \
130 z1 = MNNAVXFMA(s1, w0, z1); \
131 z2 = MNNAVXFMA(s2, w0, z2); \
132 w1 = BROAD_LOAD(weight + sy * 4 + 1); \
133 z3 = MNNAVXFMA(s0, w1, z3); \
134 z4 = MNNAVXFMA(s1, w1, z4); \
135 z5 = MNNAVXFMA(s2, w1, z5); \
136 w2 = BROAD_LOAD(weight + sy * 4 + 2); \
137 z6 = MNNAVXFMA(s0, w2, z6); \
138 z7 = MNNAVXFMA(s1, w2, z7); \
139 z8 = MNNAVXFMA(s2, w2, z8); \
140 w3 = BROAD_LOAD(weight + sy * 4 + 3); \
141 z9 = MNNAVXFMA(s0, w3, z9); \
142 z10 = MNNAVXFMA(s1, w3, z10); \
143 z11 = MNNAVXFMA(s2, w3, z11);
144
145
146 template <typename TYPE>
_AVX_MNNPackedMatMul_20(TYPE * C,const TYPE * A,const TYPE * B,const size_t * parameter)147 static void _AVX_MNNPackedMatMul_20(TYPE* C, const TYPE* A, const TYPE* B, const size_t* parameter) {
148 auto aStride = parameter[0] / sizeof(TYPE);
149 auto h = parameter[2];
150 auto l = parameter[1];
151 auto cStride = parameter[3] / sizeof(TYPE);
152 auto bExtraStride = parameter[5] / sizeof(TYPE);
153 auto bStride = bExtraStride + l * 4;
154 auto hC4 = UP_DIV(h, 4);
155 for (int y = 0; y < hC4; ++y) {
156 auto weight = B + y * bStride;
157 auto dst = C + (y / 2) * cStride + 4 * (y % 2);
158 INIT_MAIN_20_4;
159
160 for (int sy = 1; sy < l; ++sy) {
161 COMPUTE_20_4;
162 }
163 TRANPOSE_SAVE(0, 0, z0, z3, z6, z9);
164 TRANPOSE_SAVE(1, 0, z0, z3, z6, z9);
165 TRANPOSE_SAVE(0, 1, z1, z4, z7, z10);
166 TRANPOSE_SAVE(1, 1, z1, z4, z7, z10);
167 TRANPOSE_SAVE(0, 2, z2, z5, z8, z11);
168 }
169 }
170
171 #define INIT_MAIN_16_4 \
172 auto s0 = LOAD8(A + 0 * aStride); \
173 auto s1 = LOAD8(A + 0 * aStride + 8); \
174 auto w0 = BROAD_LOAD(weight + 0 * 4 + 0); \
175 auto z0 = _mm256_mul_ps(s0, w0); \
176 auto z1 = _mm256_mul_ps(s1, w0); \
177 auto w1 = BROAD_LOAD(weight + 0 * 4 + 1); \
178 auto z3 = _mm256_mul_ps(s0, w1); \
179 auto z4 = _mm256_mul_ps(s1, w1); \
180 auto w2 = BROAD_LOAD(weight + 0 * 4 + 2); \
181 auto z6 = _mm256_mul_ps(s0, w2); \
182 auto z7 = _mm256_mul_ps(s1, w2); \
183 auto w3 = BROAD_LOAD(weight + 0 * 4 + 3); \
184 auto z9 = _mm256_mul_ps(s0, w3); \
185 auto z10 = _mm256_mul_ps(s1, w3);
186
187 #define COMPUTE_16_4 \
188 s0 = LOAD8(A + sy * aStride); \
189 s1 = LOAD8(A + sy * aStride + 8); \
190 w0 = BROAD_LOAD(weight + sy * 4 + 0); \
191 z0 = MNNAVXFMA(s0, w0, z0); \
192 z1 = MNNAVXFMA(s1, w0, z1); \
193 w1 = BROAD_LOAD(weight + sy * 4 + 1); \
194 z3 = MNNAVXFMA(s0, w1, z3); \
195 z4 = MNNAVXFMA(s1, w1, z4); \
196 w2 = BROAD_LOAD(weight + sy * 4 + 2); \
197 z6 = MNNAVXFMA(s0, w2, z6); \
198 z7 = MNNAVXFMA(s1, w2, z7); \
199 w3 = BROAD_LOAD(weight + sy * 4 + 3); \
200 z9 = MNNAVXFMA(s0, w3, z9); \
201 z10 = MNNAVXFMA(s1, w3, z10);
202
203 template <typename TYPE>
_AVX_MNNPackedMatMul_16(TYPE * C,const TYPE * A,const TYPE * B,const size_t * parameter)204 static void _AVX_MNNPackedMatMul_16(TYPE* C, const TYPE* A, const TYPE* B, const size_t* parameter) {
205 auto aStride = parameter[0] / sizeof(TYPE);
206 auto h = parameter[2];
207 auto l = parameter[1];
208 auto cStride = parameter[3] / sizeof(TYPE);
209 auto bExtraStride = parameter[5] / sizeof(TYPE);
210 auto bStride = bExtraStride + l * 4;
211 auto hC4 = UP_DIV(h, 4);
212 for (int y = 0; y < hC4; ++y) {
213 auto weight = B + y * bStride;
214 auto dst = C + (y / 2) * cStride + 4 * (y % 2);
215 INIT_MAIN_16_4;
216
217 for (int sy = 1; sy < l; ++sy) {
218 COMPUTE_16_4;
219 }
220 TRANPOSE_SAVE(0, 0, z0, z3, z6, z9);
221 TRANPOSE_SAVE(1, 0, z0, z3, z6, z9);
222 TRANPOSE_SAVE(0, 1, z1, z4, z7, z10);
223 TRANPOSE_SAVE(1, 1, z1, z4, z7, z10);
224 }
225 }
226
227 #define DST_ADDR_UNPACK4(x)\
228 auto dst0 = C + (hC4Unit * y / 2 + 0) * cStride + x * 8;\
229 auto dst1 = C + (hC4Unit * y / 2 + 0) * cStride + x * 8 + 4;\
230 auto dst2 = C + (hC4Unit * y / 2 + 1) * cStride + x * 8;\
231 auto dst3 = C + (hC4Unit * y / 2 + 1) * cStride + x * 8 + 4;\
232
233 template <typename TYPE>
_AVX_MNNPackedMatMul_5(TYPE * C,const TYPE * A,const TYPE * B,const size_t * parameter)234 static void _AVX_MNNPackedMatMul_5(TYPE* C, const TYPE* A, const TYPE* B, const size_t* parameter) {
235 auto aStride = parameter[0] / sizeof(TYPE);
236 auto h = parameter[2];
237 auto l = parameter[1];
238 auto cStride = parameter[3] / sizeof(TYPE);
239 auto bExtraStride = parameter[5] / sizeof(TYPE);
240 auto bStride = bExtraStride + l * 4;
241 auto hC4 = UP_DIV(h, 4);
242 int lC4 = l / 4;
243 int lR = lC4 * 4;
244 const int hC4Unit = 4;
245 int hC16 = hC4 / hC4Unit;
246 int hR = hC16 * hC4Unit;
247 auto src = A;
248 for (int y = 0; y < hC16; ++y) {
249 auto weight0 = B + (hC4Unit * y + 0) * bStride;
250 auto weight1 = B + (hC4Unit * y + 1) * bStride;
251 auto weight2 = B + (hC4Unit * y + 2) * bStride;
252 auto weight3 = B + (hC4Unit * y + 3) * bStride;
253 DST_ADDR_UNPACK4(0);
254 auto sumAvx00 = _mm256_setzero_ps();
255 auto sumAvx01 = _mm256_setzero_ps();
256
257 auto sumAvx10 = _mm256_setzero_ps();
258 auto sumAvx11 = _mm256_setzero_ps();
259
260 auto sumAvx20 = _mm256_setzero_ps();
261 auto sumAvx21 = _mm256_setzero_ps();
262
263 auto sumAvx30 = _mm256_setzero_ps();
264 auto sumAvx31 = _mm256_setzero_ps();
265
266 auto sumAvx40 = _mm256_setzero_ps();
267 auto sumAvx41 = _mm256_setzero_ps();
268
269 auto srcUse = src;
270 for (int sy = 0; sy < l; ++sy) {
271 auto S0 = BROAD_LOAD(srcUse + 0);
272 auto S1 = BROAD_LOAD(srcUse + 1);
273 auto S2 = BROAD_LOAD(srcUse + 2);
274 auto S3 = BROAD_LOAD(srcUse + 3);
275 auto S4 = BROAD_LOAD(srcUse + 4);
276 auto W0 = _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight0), mm_loadu_si128(weight1), 1));
277 auto W1 = _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight2), mm_loadu_si128(weight3), 1));
278
279 sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00);
280 sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01);
281
282 sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10);
283 sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11);
284
285 sumAvx20 = MNNAVXFMA(S2, W0, sumAvx20);
286 sumAvx21 = MNNAVXFMA(S2, W1, sumAvx21);
287
288 sumAvx30 = MNNAVXFMA(S3, W0, sumAvx30);
289 sumAvx31 = MNNAVXFMA(S3, W1, sumAvx31);
290
291 sumAvx40 = MNNAVXFMA(S4, W0, sumAvx40);
292 sumAvx41 = MNNAVXFMA(S4, W1, sumAvx41);
293
294 srcUse += aStride;
295 weight0 += 4;
296 weight1 += 4;
297 weight2 += 4;
298 weight3 += 4;
299 }
300 STORE_8(dst0, sumAvx00);
301 STORE_8(dst0 + 8, sumAvx10);
302 STORE_8(dst0 + 16, sumAvx20);
303 STORE_8(dst0 + 24, sumAvx30);
304 STORE_8(dst0 + 32, sumAvx40);
305
306 STORE_8(dst2, sumAvx01);
307 STORE_8(dst2 + 8, sumAvx11);
308 STORE_8(dst2 + 16, sumAvx21);
309 STORE_8(dst2 + 24, sumAvx31);
310 STORE_8(dst2 + 32, sumAvx41);
311 }
312 for (int y = hR; y < hC4; ++y) {
313 auto weight = B + y * bStride;
314 auto dst = C + (y / 2) * cStride + 4 * (y % 2);
315 auto s0 = BROAD_LOAD_4(A + 0 * aStride + 0);
316 auto s1 = BROAD_LOAD_4(A + 0 * aStride + 1);
317 auto s2 = BROAD_LOAD_4(A + 0 * aStride + 2);
318 auto s3 = BROAD_LOAD_4(A + 0 * aStride + 3);
319 auto s4 = BROAD_LOAD_4(A + 0 * aStride + 4);
320 auto w0 = LOAD4(weight + 0 * 4);
321 auto z0 = _mm_mul_ps(s0, w0);
322 auto z1 = _mm_mul_ps(s1, w0);
323 auto z2 = _mm_mul_ps(s2, w0);
324 auto z3 = _mm_mul_ps(s3, w0);
325 auto z4 = _mm_mul_ps(s4, w0);
326
327 for (int sy = 1; sy < l; ++sy) {
328 s0 = BROAD_LOAD_4(A + sy * aStride + 0);
329 s1 = BROAD_LOAD_4(A + sy * aStride + 1);
330 s2 = BROAD_LOAD_4(A + sy * aStride + 2);
331 s3 = BROAD_LOAD_4(A + sy * aStride + 3);
332 s4 = BROAD_LOAD_4(A + sy * aStride + 4);
333 w0 = LOAD4(weight + sy * 4);
334 z0 = MNNSSEFMA(s0, w0, z0);
335 z1 = MNNSSEFMA(s1, w0, z1);
336 z2 = MNNSSEFMA(s2, w0, z2);
337 z3 = MNNSSEFMA(s3, w0, z3);
338 z4 = MNNSSEFMA(s4, w0, z4);
339 }
340 STORE_4(dst + 8 * 0, z0);
341 STORE_4(dst + 8 * 1, z1);
342 STORE_4(dst + 8 * 2, z2);
343 STORE_4(dst + 8 * 3, z3);
344 STORE_4(dst + 8 * 4, z4);
345 }
346 }
347
348 template <typename TYPE>
_AVX_MNNPackedMatMul_3(TYPE * C,const TYPE * A,const TYPE * B,const size_t * parameter)349 static void _AVX_MNNPackedMatMul_3(TYPE* C, const TYPE* A, const TYPE* B, const size_t* parameter) {
350 auto aStride = parameter[0] / sizeof(TYPE);
351 auto h = parameter[2];
352 auto l = parameter[1];
353 auto cStride = parameter[3] / sizeof(TYPE);
354 auto bExtraStride = parameter[5] / sizeof(TYPE);
355 auto bStride = bExtraStride + l * 4;
356 auto hC4 = UP_DIV(h, 4);
357 int lC4 = l / 4;
358 int lR = lC4 * 4;
359 const int hC4Unit = 4;
360 int hC16 = hC4 / hC4Unit;
361 int hR = hC16 * hC4Unit;
362 auto src = A;
363 for (int y = 0; y < hC16; ++y) {
364 auto weight0 = B + (hC4Unit * y + 0) * bStride;
365 auto weight1 = B + (hC4Unit * y + 1) * bStride;
366 auto weight2 = B + (hC4Unit * y + 2) * bStride;
367 auto weight3 = B + (hC4Unit * y + 3) * bStride;
368 auto sumAvx00 = _mm256_setzero_ps();
369 auto sumAvx01 = _mm256_setzero_ps();
370
371 auto sumAvx10 = _mm256_setzero_ps();
372 auto sumAvx11 = _mm256_setzero_ps();
373
374 auto sumAvx20 = _mm256_setzero_ps();
375 auto sumAvx21 = _mm256_setzero_ps();
376
377 DST_ADDR_UNPACK4(0);
378
379 auto srcUse = src;
380 for (int sy = 0; sy < l; ++sy) {
381 auto S0 = BROAD_LOAD(srcUse + 0);
382 auto S1 = BROAD_LOAD(srcUse + 1);
383 auto S2 = BROAD_LOAD(srcUse + 2);
384 auto W0 = _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight0), mm_loadu_si128(weight1), 1));
385 auto W1 = _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight2), mm_loadu_si128(weight3), 1));
386
387 sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00);
388 sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01);
389
390 sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10);
391 sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11);
392
393 sumAvx20 = MNNAVXFMA(S2, W0, sumAvx20);
394 sumAvx21 = MNNAVXFMA(S2, W1, sumAvx21);
395
396 srcUse += aStride;
397 weight0 += 4;
398 weight1 += 4;
399 weight2 += 4;
400 weight3 += 4;
401 }
402 STORE_4(dst0 + 0, _mm256_extractf128_ps(sumAvx00, 0));
403 STORE_4(dst0 + 8, _mm256_extractf128_ps(sumAvx10, 0));
404 STORE_4(dst0 + 16, _mm256_extractf128_ps(sumAvx20, 0));
405
406 STORE_4(dst1 + 0, _mm256_extractf128_ps(sumAvx00, 1));
407 STORE_4(dst1 + 8, _mm256_extractf128_ps(sumAvx10, 1));
408 STORE_4(dst1 + 16, _mm256_extractf128_ps(sumAvx20, 1));
409
410 STORE_4(dst2 + 0, _mm256_extractf128_ps(sumAvx01, 0));
411 STORE_4(dst2 + 8, _mm256_extractf128_ps(sumAvx11, 0));
412 STORE_4(dst2 + 16, _mm256_extractf128_ps(sumAvx21, 0));
413
414 STORE_4(dst3 + 0, _mm256_extractf128_ps(sumAvx01, 1));
415 STORE_4(dst3 + 8, _mm256_extractf128_ps(sumAvx11, 1));
416 STORE_4(dst3 + 16, _mm256_extractf128_ps(sumAvx21, 1));
417
418 }
419 for (int y = hR; y < hC4; ++y) {
420 auto weight = B + y * bStride;
421 auto dst = C + (y / 2) * cStride + 4 * (y % 2);
422 auto s0 = BROAD_LOAD_4(A + 0 * aStride + 0);
423 auto s1 = BROAD_LOAD_4(A + 0 * aStride + 1);
424 auto s2 = BROAD_LOAD_4(A + 0 * aStride + 2);
425 auto w0 = LOAD4(weight + 0 * 4);
426 auto z0 = _mm_mul_ps(s0, w0);
427 auto z1 = _mm_mul_ps(s1, w0);
428 auto z2 = _mm_mul_ps(s2, w0);
429
430 for (int sy = 1; sy < l; ++sy) {
431 s0 = BROAD_LOAD_4(A + sy * aStride + 0);
432 s1 = BROAD_LOAD_4(A + sy * aStride + 1);
433 s2 = BROAD_LOAD_4(A + sy * aStride + 2);
434 w0 = LOAD4(weight + sy * 4);
435 z0 = MNNSSEFMA(s0, w0, z0);
436 z1 = MNNSSEFMA(s1, w0, z1);
437 z2 = MNNSSEFMA(s2, w0, z2);
438 }
439 STORE_4(dst + 8 * 0, z0);
440 STORE_4(dst + 8 * 1, z1);
441 STORE_4(dst + 8 * 2, z2);
442 }
443 }
444 template <typename TYPE>
_AVX_MNNPackedMatMul_2(TYPE * C,const TYPE * A,const TYPE * B,const size_t * parameter)445 static void _AVX_MNNPackedMatMul_2(TYPE* C, const TYPE* A, const TYPE* B, const size_t* parameter) {
446 auto aStride = parameter[0] / sizeof(TYPE);
447 auto h = parameter[2];
448 auto l = parameter[1];
449 auto cStride = parameter[3] / sizeof(TYPE);
450 auto bExtraStride = parameter[5] / sizeof(TYPE);
451 auto bStride = bExtraStride + l * 4;
452 auto hC4 = UP_DIV(h, 4);
453 int lC4 = l / 4;
454 int lR = lC4 * 4;
455 const int hC4Unit = 4;
456 int hC16 = hC4 / hC4Unit;
457 int hR = hC16 * hC4Unit;
458 auto src = A;
459 for (int y = 0; y < hC16; ++y) {
460 auto weight0 = B + (hC4Unit * y + 0) * bStride;
461 auto weight1 = B + (hC4Unit * y + 1) * bStride;
462 auto weight2 = B + (hC4Unit * y + 2) * bStride;
463 auto weight3 = B + (hC4Unit * y + 3) * bStride;
464 auto sumAvx00 = _mm256_setzero_ps();
465 auto sumAvx01 = _mm256_setzero_ps();
466 DST_ADDR_UNPACK4(0);
467
468 auto sumAvx10 = _mm256_setzero_ps();
469 auto sumAvx11 = _mm256_setzero_ps();
470
471 auto srcUse = src;
472 for (int sy = 0; sy < l; ++sy) {
473 auto S0 = BROAD_LOAD(srcUse + 0);
474 auto S1 = BROAD_LOAD(srcUse + 1);
475 auto W0 = _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight0), mm_loadu_si128(weight1), 1));
476 auto W1 = _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight2), mm_loadu_si128(weight3), 1));
477
478 sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00);
479 sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01);
480
481 sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10);
482 sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11);
483
484 srcUse += aStride;
485 weight0 += 4;
486 weight1 += 4;
487 weight2 += 4;
488 weight3 += 4;
489 }
490 STORE_4(dst0 + 0, _mm256_extractf128_ps(sumAvx00, 0));
491 STORE_4(dst0 + 8, _mm256_extractf128_ps(sumAvx10, 0));
492
493 STORE_4(dst1 + 0, _mm256_extractf128_ps(sumAvx00, 1));
494 STORE_4(dst1 + 8, _mm256_extractf128_ps(sumAvx10, 1));
495
496 STORE_4(dst2 + 0, _mm256_extractf128_ps(sumAvx01, 0));
497 STORE_4(dst2 + 8, _mm256_extractf128_ps(sumAvx11, 0));
498
499 STORE_4(dst3 + 0, _mm256_extractf128_ps(sumAvx01, 1));
500 STORE_4(dst3 + 8, _mm256_extractf128_ps(sumAvx11, 1));
501
502 }
503 for (int y = hR; y < hC4; ++y) {
504 auto weight = B + y * bStride;
505 auto dst = C + (y / 2) * cStride + 4 * (y % 2);
506 auto s0 = BROAD_LOAD_4(A + 0 * aStride + 0);
507 auto s1 = BROAD_LOAD_4(A + 0 * aStride + 1);
508 auto w0 = LOAD4(weight + 0 * 4);
509 auto z0 = _mm_mul_ps(s0, w0);
510 auto z1 = _mm_mul_ps(s1, w0);
511
512 for (int sy = 1; sy < l; ++sy) {
513 s0 = BROAD_LOAD_4(A + sy * aStride + 0);
514 s1 = BROAD_LOAD_4(A + sy * aStride + 1);
515 w0 = LOAD4(weight + sy * 4);
516 z0 = MNNSSEFMA(s0, w0, z0);
517 z1 = MNNSSEFMA(s1, w0, z1);
518 }
519 STORE_4(dst + 8 * 0, z0);
520 STORE_4(dst + 8 * 1, z1);
521 }
522 }
523
524 template <typename TYPE>
_AVX_MNNPackedMatMul_4(TYPE * C,const TYPE * A,const TYPE * B,const size_t * parameter)525 static void _AVX_MNNPackedMatMul_4(TYPE* C, const TYPE* A, const TYPE* B, const size_t* parameter) {
526 auto aStride = parameter[0] / sizeof(TYPE);
527 auto h = parameter[2];
528 auto l = parameter[1];
529 auto cStride = parameter[3] / sizeof(TYPE);
530 auto bExtraStride = parameter[5] / sizeof(TYPE);
531 auto bStride = bExtraStride + l * 4;
532 auto hC4 = UP_DIV(h, 4);
533 int lC4 = l / 4;
534 int lR = lC4 * 4;
535 const int hC4Unit = 4;
536 int hC16 = hC4 / hC4Unit;
537 int hR = hC16 * hC4Unit;
538 auto src = A;
539 for (int y = 0; y < hC16; ++y) {
540 auto weight0 = B + (hC4Unit * y + 0) * bStride;
541 auto weight1 = B + (hC4Unit * y + 1) * bStride;
542 auto weight2 = B + (hC4Unit * y + 2) * bStride;
543 auto weight3 = B + (hC4Unit * y + 3) * bStride;
544 DST_ADDR_UNPACK4(0);
545
546 auto sumAvx00 = _mm256_setzero_ps();
547 auto sumAvx01 = _mm256_setzero_ps();
548
549 auto sumAvx10 = _mm256_setzero_ps();
550 auto sumAvx11 = _mm256_setzero_ps();
551
552 auto sumAvx20 = _mm256_setzero_ps();
553 auto sumAvx21 = _mm256_setzero_ps();
554
555 auto sumAvx30 = _mm256_setzero_ps();
556 auto sumAvx31 = _mm256_setzero_ps();
557
558 auto srcUse = src;
559 for (int sy = 0; sy < l; ++sy) {
560 auto S0 = BROAD_LOAD(srcUse + 0);
561 auto S1 = BROAD_LOAD(srcUse + 1);
562 auto S2 = BROAD_LOAD(srcUse + 2);
563 auto S3 = BROAD_LOAD(srcUse + 3);
564 auto W0 = _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight0), mm_loadu_si128(weight1), 1));
565 auto W1 = _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight2), mm_loadu_si128(weight3), 1));
566
567 sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00);
568 sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01);
569
570 sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10);
571 sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11);
572
573 sumAvx20 = MNNAVXFMA(S2, W0, sumAvx20);
574 sumAvx21 = MNNAVXFMA(S2, W1, sumAvx21);
575
576 sumAvx30 = MNNAVXFMA(S3, W0, sumAvx30);
577 sumAvx31 = MNNAVXFMA(S3, W1, sumAvx31);
578
579 srcUse += aStride;
580 weight0 += 4;
581 weight1 += 4;
582 weight2 += 4;
583 weight3 += 4;
584 }
585 STORE_8(dst0, sumAvx00);
586 STORE_8(dst0 + 8, sumAvx10);
587 STORE_8(dst0 + 16, sumAvx20);
588 STORE_8(dst0 + 24, sumAvx30);
589
590 STORE_8(dst2, sumAvx01);
591 STORE_8(dst2 + 8, sumAvx11);
592 STORE_8(dst2 + 16, sumAvx21);
593 STORE_8(dst2 + 24, sumAvx31);
594 }
595 for (int y = hR; y < hC4; ++y) {
596 auto weight = B + y * bStride;
597 auto dst = C + (y / 2) * cStride + 4 * (y % 2);
598 auto s0 = LOAD4(A + 0 * aStride);
599 auto w0 = BROAD_LOAD_4(weight + 0 * 4 + 0);
600 auto w1 = BROAD_LOAD_4(weight + 0 * 4 + 1);
601 auto w2 = BROAD_LOAD_4(weight + 0 * 4 + 2);
602 auto w3 = BROAD_LOAD_4(weight + 0 * 4 + 3);
603 auto z0 = _mm_mul_ps(s0, w0);
604 auto z3 = _mm_mul_ps(s0, w1);
605 auto z6 = _mm_mul_ps(s0, w2);
606 auto z9 = _mm_mul_ps(s0, w3);
607
608 for (int sy = 1; sy < l; ++sy) {
609 s0 = LOAD4(A + sy * aStride);
610 w0 = BROAD_LOAD_4(weight + sy * 4 + 0);
611 w1 = BROAD_LOAD_4(weight + sy * 4 + 1);
612 w2 = BROAD_LOAD_4(weight + sy * 4 + 2);
613 w3 = BROAD_LOAD_4(weight + sy * 4 + 3);
614 z0 = MNNSSEFMA(s0, w0, z0);
615 z3 = MNNSSEFMA(s0, w1, z3);
616 z6 = MNNSSEFMA(s0, w2, z6);
617 z9 = MNNSSEFMA(s0, w3, z9);
618 }
619 _MM_TRANSPOSE4_PS(z0, z3, z6, z9);
620 STORE_4(dst + 8 * 0, z0);
621 STORE_4(dst + 8 * 1, z3);
622 STORE_4(dst + 8 * 2, z6);
623 STORE_4(dst + 8 * 3, z9);
624 }
625 }
626 template <typename TYPE>
_AVX_MNNPackednMatMulRemainCommon(TYPE * C,const TYPE * A,const TYPE * B,size_t eSize,const size_t * parameter)627 static void _AVX_MNNPackednMatMulRemainCommon(TYPE* C, const TYPE* A, const TYPE* B, size_t eSize,
628 const size_t* parameter) {
629 auto h = parameter[2];
630 auto l = parameter[1];
631 auto cStride = parameter[3] / sizeof(TYPE);
632 auto bExtraStride = parameter[5] / sizeof(TYPE);
633 auto bStride = bExtraStride + l * 4;
634 auto hC4 = UP_DIV(h, 4);
635 auto es = eSize;
636 auto oC = C;
637 auto aStride = parameter[0] / sizeof(TYPE);
638 if (eSize >= 20) {
639 _AVX_MNNPackedMatMul_20<TYPE>(C, A, B, parameter);
640 eSize -= 20;
641 C += 20 * 8;
642 A += 20;
643 }
644 if (eSize >= 16) {
645 _AVX_MNNPackedMatMul_16<TYPE>(C, A, B, parameter);
646 eSize -= 16;
647 C += 16 * 8;
648 A += 16;
649 }
650 while (eSize >= 5) {
651 _AVX_MNNPackedMatMul_5<TYPE>(C, A, B, parameter);
652 eSize -= 5;
653 C += 5 * 8;
654 A += 5;
655 }
656 if (eSize == 4) {
657 _AVX_MNNPackedMatMul_4<TYPE>(C, A, B, parameter);
658 return;
659 }
660 if (eSize == 3) {
661 _AVX_MNNPackedMatMul_3<TYPE>(C, A, B, parameter);
662 return;
663 }
664 if (eSize == 2) {
665 _AVX_MNNPackedMatMul_2<TYPE>(C, A, B, parameter);
666 return;
667 }
668 if (eSize == 0) {
669 return;
670 }
671 int lC4 = l / 4;
672 int lR = lC4 * 4;
673 const int hC4Unit = 4;
674 int hC16 = hC4 / hC4Unit;
675 int hR = hC16 * hC4Unit;
676 auto src = A;
677 int x = 0;
678 for (int y = 0; y < hC16; ++y) {
679 auto weight0 = B + (hC4Unit * y + 0) * bStride;
680 auto weight1 = B + (hC4Unit * y + 1) * bStride;
681 auto weight2 = B + (hC4Unit * y + 2) * bStride;
682 auto weight3 = B + (hC4Unit * y + 3) * bStride;
683 auto dst0 = C + (hC4Unit * y / 2 + 0) * cStride + x * 8;
684 auto dst1 = C + (hC4Unit * y / 2 + 0) * cStride + x * 8 + 4;
685 auto dst2 = C + (hC4Unit * y / 2 + 1) * cStride + x * 8;
686 auto dst3 = C + (hC4Unit * y / 2 + 1) * cStride + x * 8 + 4;
687 auto sumAvx00 = _mm256_setzero_ps();
688 auto sumAvx01 = _mm256_setzero_ps();
689
690 auto sumAvx10 = _mm256_setzero_ps();
691 auto sumAvx11 = _mm256_setzero_ps();
692
693 auto sumAvx20 = _mm256_setzero_ps();
694 auto sumAvx21 = _mm256_setzero_ps();
695
696 auto sumAvx30 = _mm256_setzero_ps();
697 auto sumAvx31 = _mm256_setzero_ps();
698
699 auto srcUse = src;
700 for (int sy = 0; sy < lC4; ++sy) {
701 auto s0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (0) * aStride));
702 auto s1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (1) * aStride));
703 auto S0 = _mm256_castsi256_ps(_mm256_insertf128_si256(s0, s1, 1));
704 auto d0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (2) * aStride));
705 auto d1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (3) * aStride));
706 auto S1 = _mm256_castsi256_ps(_mm256_insertf128_si256(d0, d1, 1));
707 auto W00 = LOAD8(weight0 + 16 * sy + 0);
708 auto W01 = LOAD8(weight0 + 16 * sy + 8);
709 auto W10 = LOAD8(weight1 + 16 * sy + 0);
710 auto W11 = LOAD8(weight1 + 16 * sy + 8);
711
712 auto W20 = LOAD8(weight2 + 16 * sy + 0);
713 auto W21 = LOAD8(weight2 + 16 * sy + 8);
714 auto W30 = LOAD8(weight3 + 16 * sy + 0);
715 auto W31 = LOAD8(weight3 + 16 * sy + 8);
716
717 sumAvx00 = MNNAVXFMA(S0, W00, sumAvx00);
718 sumAvx01 = MNNAVXFMA(S1, W01, sumAvx01);
719
720 sumAvx10 = MNNAVXFMA(S0, W10, sumAvx10);
721 sumAvx11 = MNNAVXFMA(S1, W11, sumAvx11);
722
723 sumAvx20 = MNNAVXFMA(S0, W20, sumAvx20);
724 sumAvx21 = MNNAVXFMA(S1, W21, sumAvx21);
725
726 sumAvx30 = MNNAVXFMA(S0, W30, sumAvx30);
727 sumAvx31 = MNNAVXFMA(S1, W31, sumAvx31);
728 srcUse += 4 * aStride;
729 }
730 sumAvx00 = _mm256_add_ps(sumAvx00, sumAvx01);
731 sumAvx10 = _mm256_add_ps(sumAvx10, sumAvx11);
732 sumAvx20 = _mm256_add_ps(sumAvx20, sumAvx21);
733 sumAvx30 = _mm256_add_ps(sumAvx30, sumAvx31);
734 auto sum00 = _mm256_extractf128_ps(sumAvx00, 0);
735 auto sum01 = _mm256_extractf128_ps(sumAvx00, 1);
736 auto sum0 = _mm_add_ps(sum00, sum01);
737 auto sum10 = _mm256_extractf128_ps(sumAvx10, 0);
738 auto sum11 = _mm256_extractf128_ps(sumAvx10, 1);
739 auto sum1 = _mm_add_ps(sum10, sum11);
740
741 auto sum20 = _mm256_extractf128_ps(sumAvx20, 0);
742 auto sum21 = _mm256_extractf128_ps(sumAvx20, 1);
743 auto sum2 = _mm_add_ps(sum20, sum21);
744 auto sum30 = _mm256_extractf128_ps(sumAvx30, 0);
745 auto sum31 = _mm256_extractf128_ps(sumAvx30, 1);
746 auto sum3 = _mm_add_ps(sum30, sum31);
747 for (int sy = lR; sy < l; ++sy) {
748 auto s = BROAD_LOAD_4(srcUse);
749 auto w0 = LOAD4(weight0 + 4 * sy);
750 auto w1 = LOAD4(weight1 + 4 * sy);
751 auto w2 = LOAD4(weight2 + 4 * sy);
752 auto w3 = LOAD4(weight3 + 4 * sy);
753 sum0 = MNNSSEFMA(s, w0, sum0);
754 sum1 = MNNSSEFMA(s, w1, sum1);
755 sum2 = MNNSSEFMA(s, w2, sum2);
756 sum3 = MNNSSEFMA(s, w3, sum3);
757 srcUse += aStride;
758 }
759 STORE_4(dst0, sum0);
760 STORE_4(dst1, sum1);
761 STORE_4(dst2, sum2);
762 STORE_4(dst3, sum3);
763 }
764 for (int y = hR; y < hC4; ++y) {
765 auto weight = B + y * bStride;
766 auto dst = C + (y / 2) * cStride + x * 8 + 4 * (y % 2);
767 auto sumAvx0 = _mm256_setzero_ps();
768 auto sumAvx1 = _mm256_setzero_ps();
769 auto srcUse = src;
770 for (int sy = 0; sy < lC4; ++sy) {
771 auto s0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (0) * aStride));
772 auto s1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (1) * aStride));
773 auto S0 = _mm256_castsi256_ps(_mm256_insertf128_si256(s0, s1, 1));
774 auto d0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (2) * aStride));
775 auto d1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (3) * aStride));
776 auto S1 = _mm256_castsi256_ps(_mm256_insertf128_si256(d0, d1, 1));
777 auto W0 = LOAD8(weight + 16 * sy + 0);
778 auto W1 = LOAD8(weight + 16 * sy + 8);
779 sumAvx0 = MNNAVXFMA(S0, W0, sumAvx0);
780 sumAvx1 = MNNAVXFMA(S1, W1, sumAvx1);
781 srcUse += 4 * aStride;
782 }
783 sumAvx0 = _mm256_add_ps(sumAvx0, sumAvx1);
784 auto sum0 = _mm256_extractf128_ps(sumAvx0, 0);
785 auto sum1 = _mm256_extractf128_ps(sumAvx0, 1);
786 auto sum = _mm_add_ps(sum0, sum1);
787 for (int sy = lR; sy < l; ++sy) {
788 auto s = BROAD_LOAD_4(srcUse);
789 auto w = LOAD4(weight + 4 * sy);
790 sum = MNNSSEFMA(s, w, sum);
791 srcUse += aStride;
792 }
793 STORE_4(dst, sum);
794 }
795 }
796