1 //
2 //  CommonOptFunction.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/08/25.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include <float.h>
10 #include <string.h>
11 #include <algorithm>
12 #include <limits>
13 #include <vector>
14 #include "FunctionSummary.hpp"
15 #include "core/Macro.h"
16 #include "backend/cpu/CPUPool.hpp"
17 #include "backend/cpu/BinaryUtils.hpp"
18 #include "Vec8.hpp"
19 
_AVX_MNNCopyC4WithStride(const float * source,float * dest,size_t srcStride,size_t dstStride,size_t count)20 void _AVX_MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) {
21     for (int i = 0; i < count; ++i) {
22         auto s = source + i * srcStride;
23         auto d = dest + i * dstStride;
24         _mm256_storeu_ps(d, _mm256_loadu_ps(s));
25     }
26 }
_AVX_MNNAddC4WithStride(const float * source,float * dest,size_t srcStride,size_t dstStride,size_t count)27 void _AVX_MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) {
28     for (int i = 0; i < count; ++i) {
29         auto s = source + i * srcStride;
30         auto d = dest + i * dstStride;
31         _mm256_storeu_ps(d, _mm256_add_ps(_mm256_loadu_ps(s), _mm256_loadu_ps(d)));
32     }
33 }
34 
35 #define PACK_UNIT 8
_AVX_MNNPackCUnit(float * dst,const float * src,size_t area,size_t depth)36 void _AVX_MNNPackCUnit(float* dst, const float* src, size_t area, size_t depth) {
37     auto areaC4  = area / PACK_UNIT;
38     auto depthC4 = depth / PACK_UNIT;
39     __m256 t0, t1, t2, t3, t4, t5, t6, t7;
40     for (int z = 0; z < depthC4; ++z) {
41         auto dstPlane = dst + z * area * PACK_UNIT;
42         auto srcPlane = src + z * area * PACK_UNIT;
43         for (int x = 0; x < areaC4; ++x) {
44             auto s  = srcPlane + PACK_UNIT * x;
45             auto d  = dstPlane + PACK_UNIT * PACK_UNIT * x;
46             auto r0 = _mm256_loadu_ps(s + 0 * area);
47             auto r1 = _mm256_loadu_ps(s + 1 * area);
48             auto r2 = _mm256_loadu_ps(s + 2 * area);
49             auto r3 = _mm256_loadu_ps(s + 3 * area);
50             auto r4 = _mm256_loadu_ps(s + 4 * area);
51             auto r5 = _mm256_loadu_ps(s + 5 * area);
52             auto r6 = _mm256_loadu_ps(s + 6 * area);
53             auto r7 = _mm256_loadu_ps(s + 7 * area);
54 
55             TRANSPOSE_8x8;
56 
57             _mm256_storeu_ps(d + PACK_UNIT * 0, t0);
58             _mm256_storeu_ps(d + PACK_UNIT * 1, t1);
59             _mm256_storeu_ps(d + PACK_UNIT * 2, t2);
60             _mm256_storeu_ps(d + PACK_UNIT * 3, t3);
61             _mm256_storeu_ps(d + PACK_UNIT * 4, t4);
62             _mm256_storeu_ps(d + PACK_UNIT * 5, t5);
63             _mm256_storeu_ps(d + PACK_UNIT * 6, t6);
64             _mm256_storeu_ps(d + PACK_UNIT * 7, t7);
65         }
66     }
67     auto areaRemain  = areaC4 * PACK_UNIT;
68     auto depthRemain = depthC4 * PACK_UNIT;
69     // Down
70     int remain = depth - depthRemain;
71     if (remain > 0) {
72         float* dstPlane       = depthC4 * area * PACK_UNIT + dst;
73         const float* srcPlane = src + depthC4 * area * PACK_UNIT;
74         {
75             for (int x = 0; x < areaC4; ++x) {
76                 auto s  = srcPlane + PACK_UNIT * x;
77                 auto d  = dstPlane + PACK_UNIT * PACK_UNIT * x;
78                 auto r0 = _mm256_loadu_ps(s + 0 * area);
79                 auto r1 = _mm256_setzero_ps();
80                 auto r2 = _mm256_setzero_ps();
81                 auto r3 = _mm256_setzero_ps();
82                 auto r4 = _mm256_setzero_ps();
83                 auto r5 = _mm256_setzero_ps();
84                 auto r6 = _mm256_setzero_ps();
85                 auto r7 = _mm256_setzero_ps();
86                 switch (remain) {
87                     case 7:
88                         r6 = _mm256_loadu_ps(s + 6 * area);
89                     case 6:
90                         r5 = _mm256_loadu_ps(s + 5 * area);
91                     case 5:
92                         r4 = _mm256_loadu_ps(s + 4 * area);
93                     case 4:
94                         r3 = _mm256_loadu_ps(s + 3 * area);
95                     case 3:
96                         r2 = _mm256_loadu_ps(s + 2 * area);
97                     case 2:
98                         r1 = _mm256_loadu_ps(s + 1 * area);
99                     default:
100                         break;
101                 }
102 
103                 TRANSPOSE_8x8;
104 
105                 _mm256_storeu_ps(d + PACK_UNIT * 7, t7);
106                 _mm256_storeu_ps(d + PACK_UNIT * 6, t6);
107                 _mm256_storeu_ps(d + PACK_UNIT * 5, t5);
108                 _mm256_storeu_ps(d + PACK_UNIT * 4, t4);
109                 _mm256_storeu_ps(d + PACK_UNIT * 3, t3);
110                 _mm256_storeu_ps(d + PACK_UNIT * 2, t2);
111                 _mm256_storeu_ps(d + PACK_UNIT * 1, t1);
112                 _mm256_storeu_ps(d + PACK_UNIT * 0, t0);
113             }
114         }
115         for (int x = areaRemain; x < area; ++x) {
116             for (int y = 0; y < remain; y++) {
117                 dstPlane[PACK_UNIT * x + y] = srcPlane[y * area + x];
118             }
119             for (int y = remain; y < PACK_UNIT; y++) {
120                 dstPlane[PACK_UNIT * x + y] = 0;
121             }
122         }
123     }
124     // Right
125     for (int z = 0; z < depthC4; ++z) {
126         float* dstPlane       = z * area * PACK_UNIT + dst;
127         const float* srcPlane = src + z * area * PACK_UNIT;
128         for (int x = areaRemain; x < area; ++x) {
129             float s0 = srcPlane[x];
130             float s1 = srcPlane[x + area];
131             float s2 = srcPlane[x + area * 2];
132             float s3 = srcPlane[x + area * 3];
133             float s4 = srcPlane[x + area * 4];
134             float s5 = srcPlane[x + area * 5];
135             float s6 = srcPlane[x + area * 6];
136             float s7 = srcPlane[x + area * 7];
137             _mm256_storeu_ps(dstPlane + PACK_UNIT * x, _mm256_set_ps(s7, s6, s5, s4, s3, s2, s1, s0));
138         }
139     }
140 }
_AVX_MNNUnpackCUnit(float * dst,const float * src,size_t area,size_t depth)141 void _AVX_MNNUnpackCUnit(float* dst, const float* src, size_t area, size_t depth) {
142     auto areaC4  = area / PACK_UNIT;
143     auto depthC4 = depth / PACK_UNIT;
144     __m256 t0, t1, t2, t3, t4, t5, t6, t7;
145     for (int z = 0; z < depthC4; ++z) {
146         auto dstPlane = dst + z * area * PACK_UNIT;
147         auto srcPlane = src + z * area * PACK_UNIT;
148         for (int x = 0; x < areaC4; ++x) {
149             auto s  = srcPlane + PACK_UNIT * PACK_UNIT * x;
150             auto d  = dstPlane + PACK_UNIT * x;
151             auto r0 = _mm256_loadu_ps(s + 0 * PACK_UNIT);
152             auto r1 = _mm256_loadu_ps(s + 1 * PACK_UNIT);
153             auto r2 = _mm256_loadu_ps(s + 2 * PACK_UNIT);
154             auto r3 = _mm256_loadu_ps(s + 3 * PACK_UNIT);
155             auto r4 = _mm256_loadu_ps(s + 4 * PACK_UNIT);
156             auto r5 = _mm256_loadu_ps(s + 5 * PACK_UNIT);
157             auto r6 = _mm256_loadu_ps(s + 6 * PACK_UNIT);
158             auto r7 = _mm256_loadu_ps(s + 7 * PACK_UNIT);
159 
160             TRANSPOSE_8x8;
161 
162             _mm256_storeu_ps(d + 0 * area, t0);
163             _mm256_storeu_ps(d + 1 * area, t1);
164             _mm256_storeu_ps(d + 2 * area, t2);
165             _mm256_storeu_ps(d + 3 * area, t3);
166             _mm256_storeu_ps(d + 4 * area, t4);
167             _mm256_storeu_ps(d + 5 * area, t5);
168             _mm256_storeu_ps(d + 6 * area, t6);
169             _mm256_storeu_ps(d + 7 * area, t7);
170         }
171     }
172     auto areaRemain  = areaC4 * PACK_UNIT;
173     auto depthRemain = depthC4 * PACK_UNIT;
174     // Down
175     int remain = depth - depthRemain;
176     if (remain > 0) {
177         float* dstPlane       = depthC4 * area * PACK_UNIT + dst;
178         const float* srcPlane = src + depthC4 * area * PACK_UNIT;
179         for (int x = 0; x < areaC4; ++x) {
180             auto s  = srcPlane + PACK_UNIT * PACK_UNIT * x;
181             auto d  = dstPlane + PACK_UNIT * x;
182             auto r0 = _mm256_loadu_ps(s + 0 * PACK_UNIT);
183             auto r1 = _mm256_loadu_ps(s + 1 * PACK_UNIT);
184             auto r2 = _mm256_loadu_ps(s + 2 * PACK_UNIT);
185             auto r3 = _mm256_loadu_ps(s + 3 * PACK_UNIT);
186             auto r4 = _mm256_loadu_ps(s + 4 * PACK_UNIT);
187             auto r5 = _mm256_loadu_ps(s + 5 * PACK_UNIT);
188             auto r6 = _mm256_loadu_ps(s + 6 * PACK_UNIT);
189             auto r7 = _mm256_loadu_ps(s + 7 * PACK_UNIT);
190 
191             TRANSPOSE_8x8;
192 
193             switch (remain) {
194                 case 7:
195                     _mm256_storeu_ps(d + 6 * area, t6);
196                 case 6:
197                     _mm256_storeu_ps(d + 5 * area, t5);
198                 case 5:
199                     _mm256_storeu_ps(d + 4 * area, t4);
200                 case 4:
201                     _mm256_storeu_ps(d + 3 * area, t3);
202                 case 3:
203                     _mm256_storeu_ps(d + 2 * area, t2);
204                 case 2:
205                     _mm256_storeu_ps(d + 1 * area, t1);
206                 case 1:
207                     _mm256_storeu_ps(d + 0 * area, t0);
208                 default:
209                     break;
210             }
211         }
212         for (int x = areaRemain; x < area; ++x) {
213             for (int y = 0; y < remain; y++) {
214                 dstPlane[y * area + x] = srcPlane[PACK_UNIT * x + y];
215             }
216         }
217     }
218     // Right
219     for (int z = 0; z < depthC4; ++z) {
220         const float* srcPlane = z * area * PACK_UNIT + src;
221         float* dstPlane       = dst + z * area * PACK_UNIT;
222         for (int x = areaRemain; x < area; ++x) {
223             for (int y = 0; y < PACK_UNIT; y++) {
224                 dstPlane[y * area + x] = srcPlane[PACK_UNIT * x + y];
225             }
226         }
227     }
228 }
_AVX_MNNPackCUnitTranspose(float * dst,const float * src,size_t area,size_t depth)229 void _AVX_MNNPackCUnitTranspose(float* dst, const float* src, size_t area, size_t depth) {
230     int c      = (int)depth;
231     int cDiv4  = c / PACK_UNIT;
232     int cAlign = cDiv4 * PACK_UNIT;
233     for (int hi = 0; hi < area; ++hi) {
234         const float* srcHeight = src + hi * c;
235         float* dstHeight       = dst + hi * PACK_UNIT;
236         for (int ci = 0; ci < cDiv4; ++ci) {
237             _mm256_storeu_ps(dstHeight + PACK_UNIT * ci * area, _mm256_loadu_ps(srcHeight + PACK_UNIT * ci));
238         }
239     }
240 
241     if (cAlign == c) {
242         return;
243     }
244 
245     int cReamin   = c - cAlign;
246     auto srcAlign = src + cAlign;
247     auto dstAlign = dst + area * cAlign;
248 
249     for (int hi = 0; hi < area; ++hi) {
250         const float* srcHeight = srcAlign + hi * c;
251         float* dstHeight       = dstAlign + hi * PACK_UNIT;
252         for (int i = 0; i < PACK_UNIT; ++i) {
253             dstHeight[i] = 0;
254         }
255         for (int ci = 0; ci < cReamin; ++ci) {
256             dstHeight[ci] = srcHeight[ci];
257         }
258     }
259 
260 }
_AVX_MNNUnpackCUnitTranspose(float * dst,const float * src,size_t area,size_t depth)261 void _AVX_MNNUnpackCUnitTranspose(float* dst, const float* src, size_t area, size_t depth) {
262     int c      = (int)depth;
263     int cDiv4  = c / PACK_UNIT;
264     int cAlign = cDiv4 * PACK_UNIT;
265     for (int hi = 0; hi < area; ++hi) {
266         const float* srcHeight = src + hi * PACK_UNIT;
267         float* dstHeight       = dst + hi * c;
268         for (int ci = 0; ci < cDiv4; ++ci) {
269             _mm256_storeu_ps(dstHeight + PACK_UNIT * ci, _mm256_loadu_ps(srcHeight + PACK_UNIT * ci * area));
270         }
271     }
272 
273     if (cAlign == c) {
274         return;
275     }
276 
277     int cReamin   = c - cAlign;
278     auto srcAlign = src + area * cAlign;
279     auto dstAlign = dst + cAlign;
280 
281     for (int hi = 0; hi < area; ++hi) {
282         const float* srcHeight = srcAlign + hi * PACK_UNIT;
283         float* dstHeight       = dstAlign + hi * c;
284 
285         for (int ci = 0; ci < cReamin; ++ci) {
286             dstHeight[ci] = srcHeight[ci];
287         }
288     }
289 }
290 
_AVX_MNNReluWithSlopeChannel(float * dst,const float * src,const float * slope,size_t sizeQuad,size_t depthQuad)291 void _AVX_MNNReluWithSlopeChannel(float* dst, const float* src, const float* slope, size_t sizeQuad, size_t depthQuad) {
292     auto zero = _mm_set1_ps(0.0f);
293     auto zero2 = _mm256_set1_ps(0.0f);
294     int sizeC8 = sizeQuad;
295     for (int j = 0; j < depthQuad; j++) {
296         auto slopeZ       = _mm256_loadu_ps(slope + 8 * j);
297         const float* srcZ = src + 8 * j * sizeQuad;
298         float* dstZ       = dst + 8 * j * sizeQuad;
299         for (int i = 0; i < sizeC8; i++) {
300             auto src   = _mm256_loadu_ps(srcZ);
301             auto mask0 = _mm256_cmp_ps(src, zero2, 0x01);
302             auto mask1 = _mm256_cmp_ps(src, zero2, 0x0D);
303             auto other = _mm256_mul_ps(src, slopeZ);
304             _mm256_storeu_ps(dstZ, _mm256_add_ps(_mm256_and_ps(other, mask0), _mm256_and_ps(src, mask1)));
305             srcZ += 8;
306             dstZ += 8;
307         }
308     }
309 }
310 
_AVX_MNNGelu(float * dst,const float * src,size_t size)311 void _AVX_MNNGelu(float *dst, const float *src, size_t size) {
312     auto var1 = _mm256_set1_ps(0.044715f);
313     auto var2 = _mm256_set1_ps(0.79788458f);
314     auto var3 = _mm256_set1_ps(378.f);
315     auto var4 = _mm256_set1_ps(17325.f);
316     auto var5 = _mm256_set1_ps(135135.f);
317     auto var6 = _mm256_set1_ps(28.f);
318     auto var7 = _mm256_set1_ps(3150.f);
319     auto var8 = _mm256_set1_ps(62370.f);
320     auto var9 = _mm256_set1_ps(135135.f);
321     auto var10 = _mm256_set1_ps(0.5);
322     auto varOne = _mm256_set1_ps(1.f);
323     auto varNegOne = _mm256_set1_ps(-1.f);
324     for (int i = 0; i < size; i++) {
325         auto x = _mm256_load_ps(src + i * 8);
326         auto y = _mm256_mul_ps(x, x);
327         y = _mm256_mul_ps(y, x);
328         y = _mm256_mul_ps(y, var1);
329         y = _mm256_add_ps(y, x);
330         y = _mm256_mul_ps(y, var2);
331         // y = tanh(y)
332         {
333             auto y2 = _mm256_mul_ps(y, y);
334             auto w = _mm256_add_ps(y2, var3);
335             w = _mm256_mul_ps(w, y2);
336             w = _mm256_add_ps(w, var4);
337             w = _mm256_mul_ps(w, y2);
338             w = _mm256_add_ps(w, var5);
339             w = _mm256_mul_ps(w, y);
340             auto z = _mm256_mul_ps(y2, var6);
341             z = _mm256_add_ps(z, var7);
342             z = _mm256_mul_ps(z, y2);
343             z = _mm256_add_ps(z, var8);
344             z = _mm256_mul_ps(z, y2);
345             z = _mm256_add_ps(z, var9);
346             z = _mm256_div_ps(w, z);
347             z = _mm256_max_ps(z, varNegOne);
348             y = _mm256_min_ps(z, varOne);
349         }
350         y = _mm256_add_ps(y, varOne);
351         y = _mm256_mul_ps(y, x);
352         y = _mm256_mul_ps(y, var10);
353         _mm256_storeu_ps(dst + i * 8, y);
354     }
355 }
356 
_AVX_MNNAxByClampBroadcastUnit(float * C,const float * A,const float * B,size_t width,size_t cStride,size_t aStride,size_t height,const float * parameters)357 void _AVX_MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters) {
358     auto minF = _mm256_broadcast_ss(parameters + 2);
359     auto maxF = _mm256_broadcast_ss(parameters + 3);
360     for (int y = 0; y < height; ++y) {
361         auto a = A + aStride * y;
362         auto b = B + 8 * y;
363         auto bv = _mm256_loadu_ps(b);
364         auto c = C + cStride * y;
365         for (int x = 0; x < width; ++x) {
366             auto av = _mm256_loadu_ps(a);
367             auto cv = _mm256_add_ps(av, bv);
368             cv = _mm256_min_ps(cv, maxF);
369             cv = _mm256_max_ps(cv, minF);
370             _mm256_storeu_ps(c, cv);
371             a += 8;
372             c += 8;
373         }
374     }
375 }
376 
_AVX_MNNExpC8(float * dest,const float * source,const float * parameters,size_t countC8)377 void _AVX_MNNExpC8(float* dest, const float* source, const float* parameters, size_t countC8) {
378     auto count = countC8;
379     auto p0    = _mm256_set1_ps(parameters[0]);
380     auto p1    = _mm256_set1_ps(parameters[1]);
381     auto p2    = _mm256_set1_ps(parameters[2]);
382     auto p3    = _mm256_set1_ps(parameters[3]);
383     auto p4    = _mm256_set1_ps(parameters[4]);
384     auto p5    = _mm256_set1_ps(parameters[5]);
385     auto p6    = _mm256_set1_ps(parameters[6]);
386     auto p7    = _mm256_set1_ps(parameters[7]);
387     auto xMax  = _mm256_set1_ps(87);
388     auto xMin  = _mm256_set1_ps(-87);
389     auto basic = _mm256_set1_epi32(1 << 23);
390     auto temp127 = _mm256_set1_epi32(127);
391     auto negZero = _mm256_set1_ps(-0.f);
392     for (int i = 0; i < count; ++i) {
393         auto x            = _mm256_xor_ps(_mm256_loadu_ps(source + i * 8), negZero);
394         x                 = _mm256_max_ps(x, xMin);
395         x                 = _mm256_min_ps(x, xMax);
396         auto div          = _mm256_mul_ps(x, p1);
397         auto divInt       = _mm256_cvtps_epi32(div);
398         div               = _mm256_cvtepi32_ps(divInt);
399         auto div2         = _mm256_add_epi32(divInt, temp127);
400         div2 = _mm256_mullo_epi32(div2, basic);
401         auto expBasic  = _mm256_castsi256_ps(div2);
402         auto xReamin   = _mm256_sub_ps(x, _mm256_mul_ps(div, p0));
403         auto t         = xReamin;
404         auto c0        = _mm256_mul_ps(p7, t);
405         auto c1        = _mm256_add_ps(c0, p6);
406         auto c2        = _mm256_mul_ps(c1, t);
407         auto c3        = _mm256_add_ps(c2, p5);
408         auto c4        = _mm256_mul_ps(c3, t);
409         auto c5        = _mm256_add_ps(c4, p4);
410         auto c6        = _mm256_mul_ps(c5, t);
411         auto c7        = _mm256_add_ps(c6, p3);
412         auto c8        = _mm256_mul_ps(c7, t);
413         auto c9        = _mm256_add_ps(c8, p2);
414         auto expRemain = c9;
415         _mm256_storeu_ps(dest + 8 * i, _mm256_mul_ps(expBasic, expRemain));
416     }
417 }
418 
_AVX_MNNConvRunForUnitDepthWise(float * dst,const float * src,const float * weight,size_t fw,size_t fh,size_t weight_y_step,size_t dilateX_step,size_t dilateY_step)419 void _AVX_MNNConvRunForUnitDepthWise(float* dst, const float* src, const float* weight, size_t fw, size_t fh,
420                                   size_t weight_y_step, size_t dilateX_step, size_t dilateY_step) {
421     int fx, fy;
422     __m256 dstValue = _mm256_setzero_ps();
423     const float* src_z    = src;
424     const float* weight_z = weight;
425     for (fy = 0; fy < fh; ++fy) {
426         const float* src_y    = src_z + fy * dilateY_step;
427         const float* weight_y = weight_z + fy * weight_y_step;
428         for (fx = 0; fx < fw; ++fx) {
429             const float* weight_x = weight_y + 8 * fx;
430             const float* src_x    = src_y + fx * dilateX_step;
431             dstValue = _mm256_add_ps(dstValue, _mm256_mul_ps(_mm256_loadu_ps(src_x), _mm256_loadu_ps(weight_x)));
432         }
433     }
434     _mm256_storeu_ps(dst, dstValue);
435 }
436 
_AVX_MNNConvRunForLineDepthwise(float * dst,const float * src,const float * weight,size_t width,size_t src_w_setup,size_t fw,size_t fh,size_t dilateX_step,size_t dilateY_step,size_t height,size_t srcHStep,size_t dstHStep)437 void _AVX_MNNConvRunForLineDepthwise(float* dst, const float* src, const float* weight, size_t width, size_t src_w_setup,
438                                 size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, size_t height,
439                                      size_t srcHStep, size_t dstHStep) {
440     int dx, fx, fy;
441     const int unit = 4;
442     int widthUnit = width / unit;
443     int widthRemain = width - widthUnit * unit;
444     const float* weight_z = weight;
445     if (src_w_setup == 8) {
446         for (int y = 0; y < height; ++y) {
447             auto srcY = src + y * srcHStep;
448             auto dstY = dst + y * dstHStep;
449             for (dx = 0; dx < widthUnit; ++dx) {
450                 auto dstValue0 = _mm256_setzero_ps();
451                 auto dstValue1 = _mm256_setzero_ps();
452                 auto dstValue2 = _mm256_setzero_ps();
453                 auto dstValue3 = _mm256_setzero_ps();
454                 for (fy = 0; fy < fh; ++fy) {
455                     const float* src_y    = srcY + fy * dilateY_step;
456                     const float* weight_y = weight_z + fy * fw * 8;
457                     for (fx = 0; fx < fw; ++fx) {
458                         const float* src_x    = src_y + fx * dilateX_step;
459                         const float* weight_x = weight_y + 8 * fx;
460                         auto weightValue = _mm256_loadu_ps(weight_x);
461                         dstValue0 = _mm256_add_ps(dstValue0, _mm256_mul_ps(_mm256_loadu_ps(src_x + 0 * 8), weightValue));
462                         dstValue1 = _mm256_add_ps(dstValue1, _mm256_mul_ps(_mm256_loadu_ps(src_x + 1 * 8), weightValue));
463                         dstValue2 = _mm256_add_ps(dstValue2, _mm256_mul_ps(_mm256_loadu_ps(src_x + 2 * 8), weightValue));
464                         dstValue3 = _mm256_add_ps(dstValue3, _mm256_mul_ps(_mm256_loadu_ps(src_x + 3 * 8), weightValue));
465                     }
466                 }
467                 _mm256_storeu_ps(dstY + 8 * 0, dstValue0);
468                 _mm256_storeu_ps(dstY + 8 * 1, dstValue1);
469                 _mm256_storeu_ps(dstY + 8 * 2, dstValue2);
470                 _mm256_storeu_ps(dstY + 8 * 3, dstValue3);
471                 dstY += 8 * unit;
472                 srcY += unit * src_w_setup;
473             }
474             for (dx = 0; dx < widthRemain; ++dx) {
475                 float* dst_x          = dstY + dx * 8;
476                 auto dstValue = _mm256_setzero_ps();
477                 const float* src_z    = srcY + src_w_setup * dx;
478                 const float* weight_z = weight;
479                 for (fy = 0; fy < fh; ++fy) {
480                     const float* src_y    = src_z + fy * dilateY_step;
481                     const float* weight_y = weight_z + fy * fw * 8;
482                     for (fx = 0; fx < fw; ++fx) {
483                         const float* weight_x = weight_y + 8 * fx;
484                         const float* src_x    = src_y + fx * dilateX_step;
485                         dstValue = _mm256_add_ps(dstValue, _mm256_mul_ps(_mm256_loadu_ps(src_x), _mm256_loadu_ps(weight_x)));
486                     }
487                 }
488                 _mm256_storeu_ps(dst_x, dstValue);
489             }
490         }
491         return;
492     }
493     for (int y = 0; y < height; ++y) {
494         auto srcY = src + y * srcHStep;
495         auto dstY = dst + y * dstHStep;
496         for (dx = 0; dx < widthUnit; ++dx) {
497             auto dstValue0 = _mm256_setzero_ps();
498             auto dstValue1 = _mm256_setzero_ps();
499             auto dstValue2 = _mm256_setzero_ps();
500             auto dstValue3 = _mm256_setzero_ps();
501             for (fy = 0; fy < fh; ++fy) {
502                 const float* src_y    = srcY + fy * dilateY_step;
503                 const float* weight_y = weight_z + fy * fw * 8;
504                 for (fx = 0; fx < fw; ++fx) {
505                     const float* src_x    = src_y + fx * dilateX_step;
506                     const float* weight_x = weight_y + 8 * fx;
507                     auto weightValue = _mm256_loadu_ps(weight_x);
508                     dstValue0 = _mm256_add_ps(dstValue0, _mm256_mul_ps(_mm256_loadu_ps(src_x + 0 * src_w_setup), weightValue));
509                     dstValue1 = _mm256_add_ps(dstValue1, _mm256_mul_ps(_mm256_loadu_ps(src_x + 1 * src_w_setup), weightValue));
510                     dstValue2 = _mm256_add_ps(dstValue2, _mm256_mul_ps(_mm256_loadu_ps(src_x + 2 * src_w_setup), weightValue));
511                     dstValue3 = _mm256_add_ps(dstValue3, _mm256_mul_ps(_mm256_loadu_ps(src_x + 3 * src_w_setup), weightValue));
512                 }
513             }
514             _mm256_storeu_ps(dstY + 8 * 0, dstValue0);
515             _mm256_storeu_ps(dstY + 8 * 1, dstValue1);
516             _mm256_storeu_ps(dstY + 8 * 2, dstValue2);
517             _mm256_storeu_ps(dstY + 8 * 3, dstValue3);
518             dstY += 8 * unit;
519             srcY += unit * src_w_setup;
520         }
521         for (dx = 0; dx < widthRemain; ++dx) {
522             float* dst_x          = dstY + dx * 8;
523             auto dstValue = _mm256_setzero_ps();
524             const float* src_z    = srcY + src_w_setup * dx;
525             const float* weight_z = weight;
526             for (fy = 0; fy < fh; ++fy) {
527                 const float* src_y    = src_z + fy * dilateY_step;
528                 const float* weight_y = weight_z + fy * fw * 8;
529                 for (fx = 0; fx < fw; ++fx) {
530                     const float* weight_x = weight_y + 8 * fx;
531                     const float* src_x    = src_y + fx * dilateX_step;
532                     dstValue = _mm256_add_ps(dstValue, _mm256_mul_ps(_mm256_loadu_ps(src_x), _mm256_loadu_ps(weight_x)));
533                 }
534             }
535             _mm256_storeu_ps(dst_x, dstValue);
536         }
537     }
538 }
539 
_AVX_MNNMultiAndDestTransformCommon23(float ** cacheLine,const float * weigth,float * dest,int cacheLineSize,int ow,const float * bias,const float * parameter)540 void _AVX_MNNMultiAndDestTransformCommon23(float **cacheLine, const float *weigth, float *dest, int cacheLineSize, int ow, const float* bias, const float* parameter) {
541     int unit = ow / 2;
542     MNN_ASSERT(cacheLineSize >= 1);
543     auto biasF = Vec8::load(bias);
544     auto minF = Vec8(parameter[2]);
545     auto maxF = Vec8(parameter[3]);
546     for (int x = 0; x < unit; ++x) {
547         auto offset = 4 * 8 * x;
548         int i = 0;
549         Vec8 m0     = Vec8::load(weigth + i * 32 + 8 * 0) * Vec8::load(cacheLine[i] + offset + 8 * 0);
550         Vec8 m1     = Vec8::load(weigth + i * 32 + 8 * 1) * Vec8::load(cacheLine[i] + offset + 8 * 1);
551         Vec8 m2     = Vec8::load(weigth + i * 32 + 8 * 2) * Vec8::load(cacheLine[i] + offset + 8 * 2);
552         Vec8 m3     = Vec8::load(weigth + i * 32 + 8 * 3) * Vec8::load(cacheLine[i] + offset + 8 * 3);
553 
554         for (i = 1; i < cacheLineSize; ++i) {
555             m0 = m0 + Vec8::load(weigth + i * 32 + 8 * 0) * Vec8::load(cacheLine[i] + offset + 8 * 0);
556             m1 = m1 + Vec8::load(weigth + i * 32 + 8 * 1) * Vec8::load(cacheLine[i] + offset + 8 * 1);
557             m2 = m2 + Vec8::load(weigth + i * 32 + 8 * 2) * Vec8::load(cacheLine[i] + offset + 8 * 2);
558             m3 = m3 + Vec8::load(weigth + i * 32 + 8 * 3) * Vec8::load(cacheLine[i] + offset + 8 * 3);
559         }
560         auto o0 = m0 + m1 + m2 + biasF;
561         auto o1 = m1 - m2 + m3 + biasF;
562         o0 = Vec8::min(maxF, o0);
563         o1 = Vec8::min(maxF, o1);
564         o0 = Vec8::max(minF, o0);
565         o1 = Vec8::max(minF, o1);
566 
567         Vec8::save(dest + 16 * x + 0 * 8, o0);
568         Vec8::save(dest + 16 * x + 1 * 8, o1);
569     }
570     if (unit * 2 < ow) {
571         auto offset = 8 * 4 * unit;
572         int i = 0;
573         Vec8 m0     = Vec8::load(weigth + i * 32 + 8 * 0) * Vec8::load(cacheLine[i] + offset + 8 * 0);
574         Vec8 m1     = Vec8::load(weigth + i * 32 + 8 * 1) * Vec8::load(cacheLine[i] + offset + 8 * 1);
575         Vec8 m2     = Vec8::load(weigth + i * 32 + 8 * 2) * Vec8::load(cacheLine[i] + offset + 8 * 2);
576 
577         for (i = 1; i < cacheLineSize; ++i) {
578             m0 = m0 + Vec8::load(weigth + i * 32 + 8 * 0) * Vec8::load(cacheLine[i] + offset + 8 * 0);
579             m1 = m1 + Vec8::load(weigth + i * 32 + 8 * 1) * Vec8::load(cacheLine[i] + offset + 8 * 1);
580             m2 = m2 + Vec8::load(weigth + i * 32 + 8 * 2) * Vec8::load(cacheLine[i] + offset + 8 * 2);
581         }
582         auto o0 = m0 + m1 + m2 + biasF;
583         o0 = Vec8::min(maxF, o0);
584         o0 = Vec8::max(minF, o0);
585         Vec8::save(dest + 16 * unit + 0 * 8, o0);
586     }
587 }
_AVX_MNNConvDwF23SourceTransUnit(const float * source,float * dest,size_t unit)588 static void _AVX_MNNConvDwF23SourceTransUnit(const float *source, float *dest, size_t unit) {
589     if (unit <= 0) {
590         return;
591     }
592     Vec8 v0 = Vec8::load(source + 8 * 0);
593     Vec8 v1 = Vec8::load(source + 8 * 1);
594     Vec8 v2;
595     Vec8 v3;
596     source += 16;
597 
598     for (int x = 0; x < unit; ++x) {
599         v2 = Vec8::load(source + 0 * 8);
600         v3 = Vec8::load(source + 1 * 8);
601         auto m0 = v0 - v2;
602         auto m1 = v1 + v2;
603         auto m2 = v2 - v1;
604         auto m3 = v3 - v1;
605 
606         Vec8::save(dest + 8 * 0, m0);
607         Vec8::save(dest + 8 * 1, m1);
608         Vec8::save(dest + 8 * 2, m2);
609         Vec8::save(dest + 8 * 3, m3);
610 
611         source += 16;
612         dest += 32;
613 
614         v0 = v2;
615         v1 = v3;
616     }
617 }
618 
_AVX_MNNSourceTransformCommonF23(const float * source,float * dest,int unit,int iw,int pad,int su,int eu)619 void _AVX_MNNSourceTransformCommonF23(const float *source, float *dest, int unit, int iw, int pad, int su, int eu) {
620     for (int x = 0; x < su; ++x) {
621         auto dstX = dest + 4 * 8 * x;
622         auto sx   = x * 2 - (int)pad;
623         auto ex   = sx + 4;
624 
625         auto clampSx = std::max(sx, 0);
626         auto clampEx = std::min(ex, (int)iw);
627 
628         Vec8 v[4] = {0.0f, 0.0f, 0.0f, 0.0f};
629         for (int i = clampSx; i < clampEx; ++i) {
630             v[i - sx] = Vec8::load(source + 8 * i);
631         }
632         auto m0 = v[0] - v[2];
633         auto m1 = v[1] + v[2];
634         auto m2 = v[2] - v[1];
635         auto m3 = v[3] - v[1];
636 
637         Vec8::save(dstX + 8 * 0, m0);
638         Vec8::save(dstX + 8 * 1, m1);
639         Vec8::save(dstX + 8 * 2, m2);
640         Vec8::save(dstX + 8 * 3, m3);
641     }
642     _AVX_MNNConvDwF23SourceTransUnit(source + 8 * (su * 2 - pad), dest + 8 * 4 * su, eu - su);
643 
644     for (int x = eu; x < unit; ++x) {
645         auto dstX = dest + 8 * 4 * x;
646         auto sx   = x * 2 - (int)pad;
647         auto ex   = sx + 4;
648 
649         auto clampSx = std::max(sx, 0);
650         auto clampEx = std::min(ex, (int)iw);
651 
652         Vec8 v[4] = {0.0f, 0.0f, 0.0f, 0.0f};
653         for (int i = clampSx; i < clampEx; ++i) {
654             v[i - sx] = Vec8::load(source + 8 * i);
655         }
656         auto m0 = v[0] - v[2];
657         auto m1 = v[1] + v[2];
658         auto m2 = v[2] - v[1];
659         auto m3 = v[3] - v[1];
660 
661         Vec8::save(dstX + 8 * 0, m0);
662         Vec8::save(dstX + 8 * 1, m1);
663         Vec8::save(dstX + 8 * 2, m2);
664         Vec8::save(dstX + 8 * 3, m3);
665     }
666 }
667 
_AVX_MNNConvDwF23MulTransUnit(float ** cacheLine,const float * weigth,float * dest,size_t ow,const float * bias,const float * parameter)668 void _AVX_MNNConvDwF23MulTransUnit(float **cacheLine, const float *weigth, float *dest, size_t ow, const float* bias, const float* parameter) {
669     int unit = ow / 2;
670     auto w00 = Vec8::load(weigth + 0 * 32 + 8 * 0);
671     auto w01 = Vec8::load(weigth + 0 * 32 + 8 * 1);
672     auto w02 = Vec8::load(weigth + 0 * 32 + 8 * 2);
673     auto w03 = Vec8::load(weigth + 0 * 32 + 8 * 3);
674     auto w10 = Vec8::load(weigth + 1 * 32 + 8 * 0);
675     auto w11 = Vec8::load(weigth + 1 * 32 + 8 * 1);
676     auto w12 = Vec8::load(weigth + 1 * 32 + 8 * 2);
677     auto w13 = Vec8::load(weigth + 1 * 32 + 8 * 3);
678     auto w20 = Vec8::load(weigth + 2 * 32 + 8 * 0);
679     auto w21 = Vec8::load(weigth + 2 * 32 + 8 * 1);
680     auto w22 = Vec8::load(weigth + 2 * 32 + 8 * 2);
681     auto w23 = Vec8::load(weigth + 2 * 32 + 8 * 3);
682     auto biasF = Vec8::load(bias);
683     auto minF = Vec8(parameter[2]);
684     auto maxF = Vec8(parameter[3]);
685 
686     for (int x = 0; x < unit; ++x) {
687         auto offset = 8 * 4 * x;
688         int i = 0;
689         Vec8 m0     = w00 * Vec8::load(cacheLine[0] + offset + 8 * 0);
690         Vec8 m1     = w01 * Vec8::load(cacheLine[0] + offset + 8 * 1);
691         Vec8 m2     = w02 * Vec8::load(cacheLine[0] + offset + 8 * 2);
692         Vec8 m3     = w03 * Vec8::load(cacheLine[0] + offset + 8 * 3);
693 
694         m0 = m0 + w10 * Vec8::load(cacheLine[1] + offset + 8 * 0);
695         m1 = m1 + w11 * Vec8::load(cacheLine[1] + offset + 8 * 1);
696         m2 = m2 + w12 * Vec8::load(cacheLine[1] + offset + 8 * 2);
697         m3 = m3 + w13 * Vec8::load(cacheLine[1] + offset + 8 * 3);
698 
699         m0 = m0 + w20 * Vec8::load(cacheLine[2] + offset + 8 * 0);
700         m1 = m1 + w21 * Vec8::load(cacheLine[2] + offset + 8 * 1);
701         m2 = m2 + w22 * Vec8::load(cacheLine[2] + offset + 8 * 2);
702         m3 = m3 + w23 * Vec8::load(cacheLine[2] + offset + 8 * 3);
703 
704         auto o0 = m0 + m1 + m2 + biasF;
705         auto o1 = m1 - m2 + m3 + biasF;
706         o0 = Vec8::min(maxF, o0);
707         o1 = Vec8::min(maxF, o1);
708         o0 = Vec8::max(minF, o0);
709         o1 = Vec8::max(minF, o1);
710         Vec8::save(dest + 16 * x + 0 * 8, o0);
711         Vec8::save(dest + 16 * x + 1 * 8, o1);
712     }
713     if (unit * 2 < ow) {
714         auto offset = 8 * 4 * unit;
715         Vec8 m0     = w00 * Vec8::load(cacheLine[0] + offset + 8 * 0);
716         Vec8 m1     = w01 * Vec8::load(cacheLine[0] + offset + 8 * 1);
717         Vec8 m2     = w02 * Vec8::load(cacheLine[0] + offset + 8 * 2);
718 
719         m0 = m0 + w10 * Vec8::load(cacheLine[1] + offset + 8 * 0);
720         m1 = m1 + w11 * Vec8::load(cacheLine[1] + offset + 8 * 1);
721         m2 = m2 + w12 * Vec8::load(cacheLine[1] + offset + 8 * 2);
722 
723         m0 = m0 + w20 * Vec8::load(cacheLine[2] + offset + 8 * 0);
724         m1 = m1 + w21 * Vec8::load(cacheLine[2] + offset + 8 * 1);
725         m2 = m2 + w22 * Vec8::load(cacheLine[2] + offset + 8 * 2);
726         auto o0 = m0 + m1 + m2 + biasF;
727         o0 = Vec8::min(maxF, o0);
728         o0 = Vec8::max(minF, o0);
729         Vec8::save(dest + 16 * unit + 0 * 8, o0);
730     }
731 }
732 
_AVX2_MNNSelectBinaryFunctionForFloat(int opType)733 static MNNBinaryExecute _AVX2_MNNSelectBinaryFunctionForFloat(int opType) {
734     auto vecF = MNN::selectVector<Vec8, 8>(opType);
735     if (nullptr != vecF) {
736         return vecF;
737     }
738     return MNN::MNNGetCoreFunctions()->MNNSelectBinaryFunctionForFloat(opType);
739 }
740 
741 
_AVX_ExtraInit(void * functions)742 void _AVX_ExtraInit(void* functions) {
743     auto coreFunction = static_cast<MNN::CoreFunctions*>(functions);
744     coreFunction->MNNPoolingAvg = (decltype(coreFunction->MNNPoolingAvg))(MNN::poolingAvg<float, Vec8, 8>);
745     // Set min value as 1 << 24
746     coreFunction->MNNPoolingMax = (decltype(coreFunction->MNNPoolingMax))(MNN::poolingMax<float, Vec8, 8, -16777216>);
747     coreFunction->MNNSelectBinaryFunctionForFloat = _AVX2_MNNSelectBinaryFunctionForFloat;
748 }
749 
_AVX_MNNScaleAndAddBias(float * dst,const float * src,const float * bias,const float * alpha,size_t planeNumber,size_t biasNumber)750 void _AVX_MNNScaleAndAddBias(float* dst, const float* src, const float* bias, const float* alpha, size_t planeNumber,
751                         size_t biasNumber) {
752     for (int z = 0; z < biasNumber; ++z) {
753         float* dstZ         = dst + planeNumber * 8 * z;
754         const float* srcZ   = src + planeNumber * 8 * z;
755         auto biasZ = Vec8::load(bias + 8 * z);
756         auto alphaZ = Vec8::load(alpha + 8 * z);
757         for (int p = 0; p < planeNumber; ++p) {
758             float* dstX       = dstZ + 8 * p;
759             const float* srcX = srcZ + 8 * p;
760             Vec8::save(dstX, (Vec8::load(srcX) * alphaZ) + biasZ);
761         }
762     }
763 }
764 
_AVX_MNNDeconvRunForUnitDepthWise(const float * dst,float * src,const float * weight,size_t fw,size_t fh,size_t weight_y_step,size_t dilateX_step,size_t dilateY_step)765 void _AVX_MNNDeconvRunForUnitDepthWise(const float* dst, float* src, const float* weight, size_t fw, size_t fh,
766                                   size_t weight_y_step, size_t dilateX_step, size_t dilateY_step) {
767     int fx, fy;
768     float* src_z          = src;
769     const float* weight_z = weight;
770     Vec8 dstV             = Vec8::load(dst);
771     for (fy = 0; fy < fh; ++fy) {
772         float* src_y          = src_z + fy * dilateY_step;
773         const float* weight_y = weight_z + fy * weight_y_step;
774         for (fx = 0; fx < fw; ++fx) {
775             Vec8 weight_x = Vec8::load(weight_y + 8 * fx);
776             Vec8 src_x    = Vec8::load(src_y + fx * dilateX_step);
777             Vec8::save(src_y + fx * dilateX_step, src_x + weight_x * dstV);
778         }
779     }
780 }
_AVX_MNNDeconvRunForLineDepthwise(const float * dst,float * src,const float * weight,size_t width,size_t src_w_setup,size_t fw,size_t fh,size_t dilateX_step,size_t dilateY_step)781 void _AVX_MNNDeconvRunForLineDepthwise(const float* dst, float* src, const float* weight, size_t width, size_t src_w_setup,
782                                   size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step) {
783     int dx;
784     for (dx = 0; dx < width; ++dx) {
785         const float* dst_x = dst + dx * 8;
786         float* src_dx      = src + src_w_setup * dx;
787         _AVX_MNNDeconvRunForUnitDepthWise(dst_x, src_dx, weight, fw, fh, fw * 8, dilateX_step, dilateY_step);
788     }
789 }
790