1 //
2 // CommonOptFunction.cpp
3 // MNN
4 //
5 // Created by MNN on 2019/08/25.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include <float.h>
10 #include <string.h>
11 #include <algorithm>
12 #include <limits>
13 #include <vector>
14 #include "FunctionSummary.hpp"
15 #include "core/Macro.h"
16 #include "backend/cpu/CPUPool.hpp"
17 #include "backend/cpu/BinaryUtils.hpp"
18 #include "Vec8.hpp"
19
_AVX_MNNCopyC4WithStride(const float * source,float * dest,size_t srcStride,size_t dstStride,size_t count)20 void _AVX_MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) {
21 for (int i = 0; i < count; ++i) {
22 auto s = source + i * srcStride;
23 auto d = dest + i * dstStride;
24 _mm256_storeu_ps(d, _mm256_loadu_ps(s));
25 }
26 }
_AVX_MNNAddC4WithStride(const float * source,float * dest,size_t srcStride,size_t dstStride,size_t count)27 void _AVX_MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) {
28 for (int i = 0; i < count; ++i) {
29 auto s = source + i * srcStride;
30 auto d = dest + i * dstStride;
31 _mm256_storeu_ps(d, _mm256_add_ps(_mm256_loadu_ps(s), _mm256_loadu_ps(d)));
32 }
33 }
34
35 #define PACK_UNIT 8
_AVX_MNNPackCUnit(float * dst,const float * src,size_t area,size_t depth)36 void _AVX_MNNPackCUnit(float* dst, const float* src, size_t area, size_t depth) {
37 auto areaC4 = area / PACK_UNIT;
38 auto depthC4 = depth / PACK_UNIT;
39 __m256 t0, t1, t2, t3, t4, t5, t6, t7;
40 for (int z = 0; z < depthC4; ++z) {
41 auto dstPlane = dst + z * area * PACK_UNIT;
42 auto srcPlane = src + z * area * PACK_UNIT;
43 for (int x = 0; x < areaC4; ++x) {
44 auto s = srcPlane + PACK_UNIT * x;
45 auto d = dstPlane + PACK_UNIT * PACK_UNIT * x;
46 auto r0 = _mm256_loadu_ps(s + 0 * area);
47 auto r1 = _mm256_loadu_ps(s + 1 * area);
48 auto r2 = _mm256_loadu_ps(s + 2 * area);
49 auto r3 = _mm256_loadu_ps(s + 3 * area);
50 auto r4 = _mm256_loadu_ps(s + 4 * area);
51 auto r5 = _mm256_loadu_ps(s + 5 * area);
52 auto r6 = _mm256_loadu_ps(s + 6 * area);
53 auto r7 = _mm256_loadu_ps(s + 7 * area);
54
55 TRANSPOSE_8x8;
56
57 _mm256_storeu_ps(d + PACK_UNIT * 0, t0);
58 _mm256_storeu_ps(d + PACK_UNIT * 1, t1);
59 _mm256_storeu_ps(d + PACK_UNIT * 2, t2);
60 _mm256_storeu_ps(d + PACK_UNIT * 3, t3);
61 _mm256_storeu_ps(d + PACK_UNIT * 4, t4);
62 _mm256_storeu_ps(d + PACK_UNIT * 5, t5);
63 _mm256_storeu_ps(d + PACK_UNIT * 6, t6);
64 _mm256_storeu_ps(d + PACK_UNIT * 7, t7);
65 }
66 }
67 auto areaRemain = areaC4 * PACK_UNIT;
68 auto depthRemain = depthC4 * PACK_UNIT;
69 // Down
70 int remain = depth - depthRemain;
71 if (remain > 0) {
72 float* dstPlane = depthC4 * area * PACK_UNIT + dst;
73 const float* srcPlane = src + depthC4 * area * PACK_UNIT;
74 {
75 for (int x = 0; x < areaC4; ++x) {
76 auto s = srcPlane + PACK_UNIT * x;
77 auto d = dstPlane + PACK_UNIT * PACK_UNIT * x;
78 auto r0 = _mm256_loadu_ps(s + 0 * area);
79 auto r1 = _mm256_setzero_ps();
80 auto r2 = _mm256_setzero_ps();
81 auto r3 = _mm256_setzero_ps();
82 auto r4 = _mm256_setzero_ps();
83 auto r5 = _mm256_setzero_ps();
84 auto r6 = _mm256_setzero_ps();
85 auto r7 = _mm256_setzero_ps();
86 switch (remain) {
87 case 7:
88 r6 = _mm256_loadu_ps(s + 6 * area);
89 case 6:
90 r5 = _mm256_loadu_ps(s + 5 * area);
91 case 5:
92 r4 = _mm256_loadu_ps(s + 4 * area);
93 case 4:
94 r3 = _mm256_loadu_ps(s + 3 * area);
95 case 3:
96 r2 = _mm256_loadu_ps(s + 2 * area);
97 case 2:
98 r1 = _mm256_loadu_ps(s + 1 * area);
99 default:
100 break;
101 }
102
103 TRANSPOSE_8x8;
104
105 _mm256_storeu_ps(d + PACK_UNIT * 7, t7);
106 _mm256_storeu_ps(d + PACK_UNIT * 6, t6);
107 _mm256_storeu_ps(d + PACK_UNIT * 5, t5);
108 _mm256_storeu_ps(d + PACK_UNIT * 4, t4);
109 _mm256_storeu_ps(d + PACK_UNIT * 3, t3);
110 _mm256_storeu_ps(d + PACK_UNIT * 2, t2);
111 _mm256_storeu_ps(d + PACK_UNIT * 1, t1);
112 _mm256_storeu_ps(d + PACK_UNIT * 0, t0);
113 }
114 }
115 for (int x = areaRemain; x < area; ++x) {
116 for (int y = 0; y < remain; y++) {
117 dstPlane[PACK_UNIT * x + y] = srcPlane[y * area + x];
118 }
119 for (int y = remain; y < PACK_UNIT; y++) {
120 dstPlane[PACK_UNIT * x + y] = 0;
121 }
122 }
123 }
124 // Right
125 for (int z = 0; z < depthC4; ++z) {
126 float* dstPlane = z * area * PACK_UNIT + dst;
127 const float* srcPlane = src + z * area * PACK_UNIT;
128 for (int x = areaRemain; x < area; ++x) {
129 float s0 = srcPlane[x];
130 float s1 = srcPlane[x + area];
131 float s2 = srcPlane[x + area * 2];
132 float s3 = srcPlane[x + area * 3];
133 float s4 = srcPlane[x + area * 4];
134 float s5 = srcPlane[x + area * 5];
135 float s6 = srcPlane[x + area * 6];
136 float s7 = srcPlane[x + area * 7];
137 _mm256_storeu_ps(dstPlane + PACK_UNIT * x, _mm256_set_ps(s7, s6, s5, s4, s3, s2, s1, s0));
138 }
139 }
140 }
_AVX_MNNUnpackCUnit(float * dst,const float * src,size_t area,size_t depth)141 void _AVX_MNNUnpackCUnit(float* dst, const float* src, size_t area, size_t depth) {
142 auto areaC4 = area / PACK_UNIT;
143 auto depthC4 = depth / PACK_UNIT;
144 __m256 t0, t1, t2, t3, t4, t5, t6, t7;
145 for (int z = 0; z < depthC4; ++z) {
146 auto dstPlane = dst + z * area * PACK_UNIT;
147 auto srcPlane = src + z * area * PACK_UNIT;
148 for (int x = 0; x < areaC4; ++x) {
149 auto s = srcPlane + PACK_UNIT * PACK_UNIT * x;
150 auto d = dstPlane + PACK_UNIT * x;
151 auto r0 = _mm256_loadu_ps(s + 0 * PACK_UNIT);
152 auto r1 = _mm256_loadu_ps(s + 1 * PACK_UNIT);
153 auto r2 = _mm256_loadu_ps(s + 2 * PACK_UNIT);
154 auto r3 = _mm256_loadu_ps(s + 3 * PACK_UNIT);
155 auto r4 = _mm256_loadu_ps(s + 4 * PACK_UNIT);
156 auto r5 = _mm256_loadu_ps(s + 5 * PACK_UNIT);
157 auto r6 = _mm256_loadu_ps(s + 6 * PACK_UNIT);
158 auto r7 = _mm256_loadu_ps(s + 7 * PACK_UNIT);
159
160 TRANSPOSE_8x8;
161
162 _mm256_storeu_ps(d + 0 * area, t0);
163 _mm256_storeu_ps(d + 1 * area, t1);
164 _mm256_storeu_ps(d + 2 * area, t2);
165 _mm256_storeu_ps(d + 3 * area, t3);
166 _mm256_storeu_ps(d + 4 * area, t4);
167 _mm256_storeu_ps(d + 5 * area, t5);
168 _mm256_storeu_ps(d + 6 * area, t6);
169 _mm256_storeu_ps(d + 7 * area, t7);
170 }
171 }
172 auto areaRemain = areaC4 * PACK_UNIT;
173 auto depthRemain = depthC4 * PACK_UNIT;
174 // Down
175 int remain = depth - depthRemain;
176 if (remain > 0) {
177 float* dstPlane = depthC4 * area * PACK_UNIT + dst;
178 const float* srcPlane = src + depthC4 * area * PACK_UNIT;
179 for (int x = 0; x < areaC4; ++x) {
180 auto s = srcPlane + PACK_UNIT * PACK_UNIT * x;
181 auto d = dstPlane + PACK_UNIT * x;
182 auto r0 = _mm256_loadu_ps(s + 0 * PACK_UNIT);
183 auto r1 = _mm256_loadu_ps(s + 1 * PACK_UNIT);
184 auto r2 = _mm256_loadu_ps(s + 2 * PACK_UNIT);
185 auto r3 = _mm256_loadu_ps(s + 3 * PACK_UNIT);
186 auto r4 = _mm256_loadu_ps(s + 4 * PACK_UNIT);
187 auto r5 = _mm256_loadu_ps(s + 5 * PACK_UNIT);
188 auto r6 = _mm256_loadu_ps(s + 6 * PACK_UNIT);
189 auto r7 = _mm256_loadu_ps(s + 7 * PACK_UNIT);
190
191 TRANSPOSE_8x8;
192
193 switch (remain) {
194 case 7:
195 _mm256_storeu_ps(d + 6 * area, t6);
196 case 6:
197 _mm256_storeu_ps(d + 5 * area, t5);
198 case 5:
199 _mm256_storeu_ps(d + 4 * area, t4);
200 case 4:
201 _mm256_storeu_ps(d + 3 * area, t3);
202 case 3:
203 _mm256_storeu_ps(d + 2 * area, t2);
204 case 2:
205 _mm256_storeu_ps(d + 1 * area, t1);
206 case 1:
207 _mm256_storeu_ps(d + 0 * area, t0);
208 default:
209 break;
210 }
211 }
212 for (int x = areaRemain; x < area; ++x) {
213 for (int y = 0; y < remain; y++) {
214 dstPlane[y * area + x] = srcPlane[PACK_UNIT * x + y];
215 }
216 }
217 }
218 // Right
219 for (int z = 0; z < depthC4; ++z) {
220 const float* srcPlane = z * area * PACK_UNIT + src;
221 float* dstPlane = dst + z * area * PACK_UNIT;
222 for (int x = areaRemain; x < area; ++x) {
223 for (int y = 0; y < PACK_UNIT; y++) {
224 dstPlane[y * area + x] = srcPlane[PACK_UNIT * x + y];
225 }
226 }
227 }
228 }
_AVX_MNNPackCUnitTranspose(float * dst,const float * src,size_t area,size_t depth)229 void _AVX_MNNPackCUnitTranspose(float* dst, const float* src, size_t area, size_t depth) {
230 int c = (int)depth;
231 int cDiv4 = c / PACK_UNIT;
232 int cAlign = cDiv4 * PACK_UNIT;
233 for (int hi = 0; hi < area; ++hi) {
234 const float* srcHeight = src + hi * c;
235 float* dstHeight = dst + hi * PACK_UNIT;
236 for (int ci = 0; ci < cDiv4; ++ci) {
237 _mm256_storeu_ps(dstHeight + PACK_UNIT * ci * area, _mm256_loadu_ps(srcHeight + PACK_UNIT * ci));
238 }
239 }
240
241 if (cAlign == c) {
242 return;
243 }
244
245 int cReamin = c - cAlign;
246 auto srcAlign = src + cAlign;
247 auto dstAlign = dst + area * cAlign;
248
249 for (int hi = 0; hi < area; ++hi) {
250 const float* srcHeight = srcAlign + hi * c;
251 float* dstHeight = dstAlign + hi * PACK_UNIT;
252 for (int i = 0; i < PACK_UNIT; ++i) {
253 dstHeight[i] = 0;
254 }
255 for (int ci = 0; ci < cReamin; ++ci) {
256 dstHeight[ci] = srcHeight[ci];
257 }
258 }
259
260 }
_AVX_MNNUnpackCUnitTranspose(float * dst,const float * src,size_t area,size_t depth)261 void _AVX_MNNUnpackCUnitTranspose(float* dst, const float* src, size_t area, size_t depth) {
262 int c = (int)depth;
263 int cDiv4 = c / PACK_UNIT;
264 int cAlign = cDiv4 * PACK_UNIT;
265 for (int hi = 0; hi < area; ++hi) {
266 const float* srcHeight = src + hi * PACK_UNIT;
267 float* dstHeight = dst + hi * c;
268 for (int ci = 0; ci < cDiv4; ++ci) {
269 _mm256_storeu_ps(dstHeight + PACK_UNIT * ci, _mm256_loadu_ps(srcHeight + PACK_UNIT * ci * area));
270 }
271 }
272
273 if (cAlign == c) {
274 return;
275 }
276
277 int cReamin = c - cAlign;
278 auto srcAlign = src + area * cAlign;
279 auto dstAlign = dst + cAlign;
280
281 for (int hi = 0; hi < area; ++hi) {
282 const float* srcHeight = srcAlign + hi * PACK_UNIT;
283 float* dstHeight = dstAlign + hi * c;
284
285 for (int ci = 0; ci < cReamin; ++ci) {
286 dstHeight[ci] = srcHeight[ci];
287 }
288 }
289 }
290
_AVX_MNNReluWithSlopeChannel(float * dst,const float * src,const float * slope,size_t sizeQuad,size_t depthQuad)291 void _AVX_MNNReluWithSlopeChannel(float* dst, const float* src, const float* slope, size_t sizeQuad, size_t depthQuad) {
292 auto zero = _mm_set1_ps(0.0f);
293 auto zero2 = _mm256_set1_ps(0.0f);
294 int sizeC8 = sizeQuad;
295 for (int j = 0; j < depthQuad; j++) {
296 auto slopeZ = _mm256_loadu_ps(slope + 8 * j);
297 const float* srcZ = src + 8 * j * sizeQuad;
298 float* dstZ = dst + 8 * j * sizeQuad;
299 for (int i = 0; i < sizeC8; i++) {
300 auto src = _mm256_loadu_ps(srcZ);
301 auto mask0 = _mm256_cmp_ps(src, zero2, 0x01);
302 auto mask1 = _mm256_cmp_ps(src, zero2, 0x0D);
303 auto other = _mm256_mul_ps(src, slopeZ);
304 _mm256_storeu_ps(dstZ, _mm256_add_ps(_mm256_and_ps(other, mask0), _mm256_and_ps(src, mask1)));
305 srcZ += 8;
306 dstZ += 8;
307 }
308 }
309 }
310
_AVX_MNNGelu(float * dst,const float * src,size_t size)311 void _AVX_MNNGelu(float *dst, const float *src, size_t size) {
312 auto var1 = _mm256_set1_ps(0.044715f);
313 auto var2 = _mm256_set1_ps(0.79788458f);
314 auto var3 = _mm256_set1_ps(378.f);
315 auto var4 = _mm256_set1_ps(17325.f);
316 auto var5 = _mm256_set1_ps(135135.f);
317 auto var6 = _mm256_set1_ps(28.f);
318 auto var7 = _mm256_set1_ps(3150.f);
319 auto var8 = _mm256_set1_ps(62370.f);
320 auto var9 = _mm256_set1_ps(135135.f);
321 auto var10 = _mm256_set1_ps(0.5);
322 auto varOne = _mm256_set1_ps(1.f);
323 auto varNegOne = _mm256_set1_ps(-1.f);
324 for (int i = 0; i < size; i++) {
325 auto x = _mm256_load_ps(src + i * 8);
326 auto y = _mm256_mul_ps(x, x);
327 y = _mm256_mul_ps(y, x);
328 y = _mm256_mul_ps(y, var1);
329 y = _mm256_add_ps(y, x);
330 y = _mm256_mul_ps(y, var2);
331 // y = tanh(y)
332 {
333 auto y2 = _mm256_mul_ps(y, y);
334 auto w = _mm256_add_ps(y2, var3);
335 w = _mm256_mul_ps(w, y2);
336 w = _mm256_add_ps(w, var4);
337 w = _mm256_mul_ps(w, y2);
338 w = _mm256_add_ps(w, var5);
339 w = _mm256_mul_ps(w, y);
340 auto z = _mm256_mul_ps(y2, var6);
341 z = _mm256_add_ps(z, var7);
342 z = _mm256_mul_ps(z, y2);
343 z = _mm256_add_ps(z, var8);
344 z = _mm256_mul_ps(z, y2);
345 z = _mm256_add_ps(z, var9);
346 z = _mm256_div_ps(w, z);
347 z = _mm256_max_ps(z, varNegOne);
348 y = _mm256_min_ps(z, varOne);
349 }
350 y = _mm256_add_ps(y, varOne);
351 y = _mm256_mul_ps(y, x);
352 y = _mm256_mul_ps(y, var10);
353 _mm256_storeu_ps(dst + i * 8, y);
354 }
355 }
356
_AVX_MNNAxByClampBroadcastUnit(float * C,const float * A,const float * B,size_t width,size_t cStride,size_t aStride,size_t height,const float * parameters)357 void _AVX_MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters) {
358 auto minF = _mm256_broadcast_ss(parameters + 2);
359 auto maxF = _mm256_broadcast_ss(parameters + 3);
360 for (int y = 0; y < height; ++y) {
361 auto a = A + aStride * y;
362 auto b = B + 8 * y;
363 auto bv = _mm256_loadu_ps(b);
364 auto c = C + cStride * y;
365 for (int x = 0; x < width; ++x) {
366 auto av = _mm256_loadu_ps(a);
367 auto cv = _mm256_add_ps(av, bv);
368 cv = _mm256_min_ps(cv, maxF);
369 cv = _mm256_max_ps(cv, minF);
370 _mm256_storeu_ps(c, cv);
371 a += 8;
372 c += 8;
373 }
374 }
375 }
376
_AVX_MNNExpC8(float * dest,const float * source,const float * parameters,size_t countC8)377 void _AVX_MNNExpC8(float* dest, const float* source, const float* parameters, size_t countC8) {
378 auto count = countC8;
379 auto p0 = _mm256_set1_ps(parameters[0]);
380 auto p1 = _mm256_set1_ps(parameters[1]);
381 auto p2 = _mm256_set1_ps(parameters[2]);
382 auto p3 = _mm256_set1_ps(parameters[3]);
383 auto p4 = _mm256_set1_ps(parameters[4]);
384 auto p5 = _mm256_set1_ps(parameters[5]);
385 auto p6 = _mm256_set1_ps(parameters[6]);
386 auto p7 = _mm256_set1_ps(parameters[7]);
387 auto xMax = _mm256_set1_ps(87);
388 auto xMin = _mm256_set1_ps(-87);
389 auto basic = _mm256_set1_epi32(1 << 23);
390 auto temp127 = _mm256_set1_epi32(127);
391 auto negZero = _mm256_set1_ps(-0.f);
392 for (int i = 0; i < count; ++i) {
393 auto x = _mm256_xor_ps(_mm256_loadu_ps(source + i * 8), negZero);
394 x = _mm256_max_ps(x, xMin);
395 x = _mm256_min_ps(x, xMax);
396 auto div = _mm256_mul_ps(x, p1);
397 auto divInt = _mm256_cvtps_epi32(div);
398 div = _mm256_cvtepi32_ps(divInt);
399 auto div2 = _mm256_add_epi32(divInt, temp127);
400 div2 = _mm256_mullo_epi32(div2, basic);
401 auto expBasic = _mm256_castsi256_ps(div2);
402 auto xReamin = _mm256_sub_ps(x, _mm256_mul_ps(div, p0));
403 auto t = xReamin;
404 auto c0 = _mm256_mul_ps(p7, t);
405 auto c1 = _mm256_add_ps(c0, p6);
406 auto c2 = _mm256_mul_ps(c1, t);
407 auto c3 = _mm256_add_ps(c2, p5);
408 auto c4 = _mm256_mul_ps(c3, t);
409 auto c5 = _mm256_add_ps(c4, p4);
410 auto c6 = _mm256_mul_ps(c5, t);
411 auto c7 = _mm256_add_ps(c6, p3);
412 auto c8 = _mm256_mul_ps(c7, t);
413 auto c9 = _mm256_add_ps(c8, p2);
414 auto expRemain = c9;
415 _mm256_storeu_ps(dest + 8 * i, _mm256_mul_ps(expBasic, expRemain));
416 }
417 }
418
_AVX_MNNConvRunForUnitDepthWise(float * dst,const float * src,const float * weight,size_t fw,size_t fh,size_t weight_y_step,size_t dilateX_step,size_t dilateY_step)419 void _AVX_MNNConvRunForUnitDepthWise(float* dst, const float* src, const float* weight, size_t fw, size_t fh,
420 size_t weight_y_step, size_t dilateX_step, size_t dilateY_step) {
421 int fx, fy;
422 __m256 dstValue = _mm256_setzero_ps();
423 const float* src_z = src;
424 const float* weight_z = weight;
425 for (fy = 0; fy < fh; ++fy) {
426 const float* src_y = src_z + fy * dilateY_step;
427 const float* weight_y = weight_z + fy * weight_y_step;
428 for (fx = 0; fx < fw; ++fx) {
429 const float* weight_x = weight_y + 8 * fx;
430 const float* src_x = src_y + fx * dilateX_step;
431 dstValue = _mm256_add_ps(dstValue, _mm256_mul_ps(_mm256_loadu_ps(src_x), _mm256_loadu_ps(weight_x)));
432 }
433 }
434 _mm256_storeu_ps(dst, dstValue);
435 }
436
_AVX_MNNConvRunForLineDepthwise(float * dst,const float * src,const float * weight,size_t width,size_t src_w_setup,size_t fw,size_t fh,size_t dilateX_step,size_t dilateY_step,size_t height,size_t srcHStep,size_t dstHStep)437 void _AVX_MNNConvRunForLineDepthwise(float* dst, const float* src, const float* weight, size_t width, size_t src_w_setup,
438 size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, size_t height,
439 size_t srcHStep, size_t dstHStep) {
440 int dx, fx, fy;
441 const int unit = 4;
442 int widthUnit = width / unit;
443 int widthRemain = width - widthUnit * unit;
444 const float* weight_z = weight;
445 if (src_w_setup == 8) {
446 for (int y = 0; y < height; ++y) {
447 auto srcY = src + y * srcHStep;
448 auto dstY = dst + y * dstHStep;
449 for (dx = 0; dx < widthUnit; ++dx) {
450 auto dstValue0 = _mm256_setzero_ps();
451 auto dstValue1 = _mm256_setzero_ps();
452 auto dstValue2 = _mm256_setzero_ps();
453 auto dstValue3 = _mm256_setzero_ps();
454 for (fy = 0; fy < fh; ++fy) {
455 const float* src_y = srcY + fy * dilateY_step;
456 const float* weight_y = weight_z + fy * fw * 8;
457 for (fx = 0; fx < fw; ++fx) {
458 const float* src_x = src_y + fx * dilateX_step;
459 const float* weight_x = weight_y + 8 * fx;
460 auto weightValue = _mm256_loadu_ps(weight_x);
461 dstValue0 = _mm256_add_ps(dstValue0, _mm256_mul_ps(_mm256_loadu_ps(src_x + 0 * 8), weightValue));
462 dstValue1 = _mm256_add_ps(dstValue1, _mm256_mul_ps(_mm256_loadu_ps(src_x + 1 * 8), weightValue));
463 dstValue2 = _mm256_add_ps(dstValue2, _mm256_mul_ps(_mm256_loadu_ps(src_x + 2 * 8), weightValue));
464 dstValue3 = _mm256_add_ps(dstValue3, _mm256_mul_ps(_mm256_loadu_ps(src_x + 3 * 8), weightValue));
465 }
466 }
467 _mm256_storeu_ps(dstY + 8 * 0, dstValue0);
468 _mm256_storeu_ps(dstY + 8 * 1, dstValue1);
469 _mm256_storeu_ps(dstY + 8 * 2, dstValue2);
470 _mm256_storeu_ps(dstY + 8 * 3, dstValue3);
471 dstY += 8 * unit;
472 srcY += unit * src_w_setup;
473 }
474 for (dx = 0; dx < widthRemain; ++dx) {
475 float* dst_x = dstY + dx * 8;
476 auto dstValue = _mm256_setzero_ps();
477 const float* src_z = srcY + src_w_setup * dx;
478 const float* weight_z = weight;
479 for (fy = 0; fy < fh; ++fy) {
480 const float* src_y = src_z + fy * dilateY_step;
481 const float* weight_y = weight_z + fy * fw * 8;
482 for (fx = 0; fx < fw; ++fx) {
483 const float* weight_x = weight_y + 8 * fx;
484 const float* src_x = src_y + fx * dilateX_step;
485 dstValue = _mm256_add_ps(dstValue, _mm256_mul_ps(_mm256_loadu_ps(src_x), _mm256_loadu_ps(weight_x)));
486 }
487 }
488 _mm256_storeu_ps(dst_x, dstValue);
489 }
490 }
491 return;
492 }
493 for (int y = 0; y < height; ++y) {
494 auto srcY = src + y * srcHStep;
495 auto dstY = dst + y * dstHStep;
496 for (dx = 0; dx < widthUnit; ++dx) {
497 auto dstValue0 = _mm256_setzero_ps();
498 auto dstValue1 = _mm256_setzero_ps();
499 auto dstValue2 = _mm256_setzero_ps();
500 auto dstValue3 = _mm256_setzero_ps();
501 for (fy = 0; fy < fh; ++fy) {
502 const float* src_y = srcY + fy * dilateY_step;
503 const float* weight_y = weight_z + fy * fw * 8;
504 for (fx = 0; fx < fw; ++fx) {
505 const float* src_x = src_y + fx * dilateX_step;
506 const float* weight_x = weight_y + 8 * fx;
507 auto weightValue = _mm256_loadu_ps(weight_x);
508 dstValue0 = _mm256_add_ps(dstValue0, _mm256_mul_ps(_mm256_loadu_ps(src_x + 0 * src_w_setup), weightValue));
509 dstValue1 = _mm256_add_ps(dstValue1, _mm256_mul_ps(_mm256_loadu_ps(src_x + 1 * src_w_setup), weightValue));
510 dstValue2 = _mm256_add_ps(dstValue2, _mm256_mul_ps(_mm256_loadu_ps(src_x + 2 * src_w_setup), weightValue));
511 dstValue3 = _mm256_add_ps(dstValue3, _mm256_mul_ps(_mm256_loadu_ps(src_x + 3 * src_w_setup), weightValue));
512 }
513 }
514 _mm256_storeu_ps(dstY + 8 * 0, dstValue0);
515 _mm256_storeu_ps(dstY + 8 * 1, dstValue1);
516 _mm256_storeu_ps(dstY + 8 * 2, dstValue2);
517 _mm256_storeu_ps(dstY + 8 * 3, dstValue3);
518 dstY += 8 * unit;
519 srcY += unit * src_w_setup;
520 }
521 for (dx = 0; dx < widthRemain; ++dx) {
522 float* dst_x = dstY + dx * 8;
523 auto dstValue = _mm256_setzero_ps();
524 const float* src_z = srcY + src_w_setup * dx;
525 const float* weight_z = weight;
526 for (fy = 0; fy < fh; ++fy) {
527 const float* src_y = src_z + fy * dilateY_step;
528 const float* weight_y = weight_z + fy * fw * 8;
529 for (fx = 0; fx < fw; ++fx) {
530 const float* weight_x = weight_y + 8 * fx;
531 const float* src_x = src_y + fx * dilateX_step;
532 dstValue = _mm256_add_ps(dstValue, _mm256_mul_ps(_mm256_loadu_ps(src_x), _mm256_loadu_ps(weight_x)));
533 }
534 }
535 _mm256_storeu_ps(dst_x, dstValue);
536 }
537 }
538 }
539
_AVX_MNNMultiAndDestTransformCommon23(float ** cacheLine,const float * weigth,float * dest,int cacheLineSize,int ow,const float * bias,const float * parameter)540 void _AVX_MNNMultiAndDestTransformCommon23(float **cacheLine, const float *weigth, float *dest, int cacheLineSize, int ow, const float* bias, const float* parameter) {
541 int unit = ow / 2;
542 MNN_ASSERT(cacheLineSize >= 1);
543 auto biasF = Vec8::load(bias);
544 auto minF = Vec8(parameter[2]);
545 auto maxF = Vec8(parameter[3]);
546 for (int x = 0; x < unit; ++x) {
547 auto offset = 4 * 8 * x;
548 int i = 0;
549 Vec8 m0 = Vec8::load(weigth + i * 32 + 8 * 0) * Vec8::load(cacheLine[i] + offset + 8 * 0);
550 Vec8 m1 = Vec8::load(weigth + i * 32 + 8 * 1) * Vec8::load(cacheLine[i] + offset + 8 * 1);
551 Vec8 m2 = Vec8::load(weigth + i * 32 + 8 * 2) * Vec8::load(cacheLine[i] + offset + 8 * 2);
552 Vec8 m3 = Vec8::load(weigth + i * 32 + 8 * 3) * Vec8::load(cacheLine[i] + offset + 8 * 3);
553
554 for (i = 1; i < cacheLineSize; ++i) {
555 m0 = m0 + Vec8::load(weigth + i * 32 + 8 * 0) * Vec8::load(cacheLine[i] + offset + 8 * 0);
556 m1 = m1 + Vec8::load(weigth + i * 32 + 8 * 1) * Vec8::load(cacheLine[i] + offset + 8 * 1);
557 m2 = m2 + Vec8::load(weigth + i * 32 + 8 * 2) * Vec8::load(cacheLine[i] + offset + 8 * 2);
558 m3 = m3 + Vec8::load(weigth + i * 32 + 8 * 3) * Vec8::load(cacheLine[i] + offset + 8 * 3);
559 }
560 auto o0 = m0 + m1 + m2 + biasF;
561 auto o1 = m1 - m2 + m3 + biasF;
562 o0 = Vec8::min(maxF, o0);
563 o1 = Vec8::min(maxF, o1);
564 o0 = Vec8::max(minF, o0);
565 o1 = Vec8::max(minF, o1);
566
567 Vec8::save(dest + 16 * x + 0 * 8, o0);
568 Vec8::save(dest + 16 * x + 1 * 8, o1);
569 }
570 if (unit * 2 < ow) {
571 auto offset = 8 * 4 * unit;
572 int i = 0;
573 Vec8 m0 = Vec8::load(weigth + i * 32 + 8 * 0) * Vec8::load(cacheLine[i] + offset + 8 * 0);
574 Vec8 m1 = Vec8::load(weigth + i * 32 + 8 * 1) * Vec8::load(cacheLine[i] + offset + 8 * 1);
575 Vec8 m2 = Vec8::load(weigth + i * 32 + 8 * 2) * Vec8::load(cacheLine[i] + offset + 8 * 2);
576
577 for (i = 1; i < cacheLineSize; ++i) {
578 m0 = m0 + Vec8::load(weigth + i * 32 + 8 * 0) * Vec8::load(cacheLine[i] + offset + 8 * 0);
579 m1 = m1 + Vec8::load(weigth + i * 32 + 8 * 1) * Vec8::load(cacheLine[i] + offset + 8 * 1);
580 m2 = m2 + Vec8::load(weigth + i * 32 + 8 * 2) * Vec8::load(cacheLine[i] + offset + 8 * 2);
581 }
582 auto o0 = m0 + m1 + m2 + biasF;
583 o0 = Vec8::min(maxF, o0);
584 o0 = Vec8::max(minF, o0);
585 Vec8::save(dest + 16 * unit + 0 * 8, o0);
586 }
587 }
_AVX_MNNConvDwF23SourceTransUnit(const float * source,float * dest,size_t unit)588 static void _AVX_MNNConvDwF23SourceTransUnit(const float *source, float *dest, size_t unit) {
589 if (unit <= 0) {
590 return;
591 }
592 Vec8 v0 = Vec8::load(source + 8 * 0);
593 Vec8 v1 = Vec8::load(source + 8 * 1);
594 Vec8 v2;
595 Vec8 v3;
596 source += 16;
597
598 for (int x = 0; x < unit; ++x) {
599 v2 = Vec8::load(source + 0 * 8);
600 v3 = Vec8::load(source + 1 * 8);
601 auto m0 = v0 - v2;
602 auto m1 = v1 + v2;
603 auto m2 = v2 - v1;
604 auto m3 = v3 - v1;
605
606 Vec8::save(dest + 8 * 0, m0);
607 Vec8::save(dest + 8 * 1, m1);
608 Vec8::save(dest + 8 * 2, m2);
609 Vec8::save(dest + 8 * 3, m3);
610
611 source += 16;
612 dest += 32;
613
614 v0 = v2;
615 v1 = v3;
616 }
617 }
618
_AVX_MNNSourceTransformCommonF23(const float * source,float * dest,int unit,int iw,int pad,int su,int eu)619 void _AVX_MNNSourceTransformCommonF23(const float *source, float *dest, int unit, int iw, int pad, int su, int eu) {
620 for (int x = 0; x < su; ++x) {
621 auto dstX = dest + 4 * 8 * x;
622 auto sx = x * 2 - (int)pad;
623 auto ex = sx + 4;
624
625 auto clampSx = std::max(sx, 0);
626 auto clampEx = std::min(ex, (int)iw);
627
628 Vec8 v[4] = {0.0f, 0.0f, 0.0f, 0.0f};
629 for (int i = clampSx; i < clampEx; ++i) {
630 v[i - sx] = Vec8::load(source + 8 * i);
631 }
632 auto m0 = v[0] - v[2];
633 auto m1 = v[1] + v[2];
634 auto m2 = v[2] - v[1];
635 auto m3 = v[3] - v[1];
636
637 Vec8::save(dstX + 8 * 0, m0);
638 Vec8::save(dstX + 8 * 1, m1);
639 Vec8::save(dstX + 8 * 2, m2);
640 Vec8::save(dstX + 8 * 3, m3);
641 }
642 _AVX_MNNConvDwF23SourceTransUnit(source + 8 * (su * 2 - pad), dest + 8 * 4 * su, eu - su);
643
644 for (int x = eu; x < unit; ++x) {
645 auto dstX = dest + 8 * 4 * x;
646 auto sx = x * 2 - (int)pad;
647 auto ex = sx + 4;
648
649 auto clampSx = std::max(sx, 0);
650 auto clampEx = std::min(ex, (int)iw);
651
652 Vec8 v[4] = {0.0f, 0.0f, 0.0f, 0.0f};
653 for (int i = clampSx; i < clampEx; ++i) {
654 v[i - sx] = Vec8::load(source + 8 * i);
655 }
656 auto m0 = v[0] - v[2];
657 auto m1 = v[1] + v[2];
658 auto m2 = v[2] - v[1];
659 auto m3 = v[3] - v[1];
660
661 Vec8::save(dstX + 8 * 0, m0);
662 Vec8::save(dstX + 8 * 1, m1);
663 Vec8::save(dstX + 8 * 2, m2);
664 Vec8::save(dstX + 8 * 3, m3);
665 }
666 }
667
_AVX_MNNConvDwF23MulTransUnit(float ** cacheLine,const float * weigth,float * dest,size_t ow,const float * bias,const float * parameter)668 void _AVX_MNNConvDwF23MulTransUnit(float **cacheLine, const float *weigth, float *dest, size_t ow, const float* bias, const float* parameter) {
669 int unit = ow / 2;
670 auto w00 = Vec8::load(weigth + 0 * 32 + 8 * 0);
671 auto w01 = Vec8::load(weigth + 0 * 32 + 8 * 1);
672 auto w02 = Vec8::load(weigth + 0 * 32 + 8 * 2);
673 auto w03 = Vec8::load(weigth + 0 * 32 + 8 * 3);
674 auto w10 = Vec8::load(weigth + 1 * 32 + 8 * 0);
675 auto w11 = Vec8::load(weigth + 1 * 32 + 8 * 1);
676 auto w12 = Vec8::load(weigth + 1 * 32 + 8 * 2);
677 auto w13 = Vec8::load(weigth + 1 * 32 + 8 * 3);
678 auto w20 = Vec8::load(weigth + 2 * 32 + 8 * 0);
679 auto w21 = Vec8::load(weigth + 2 * 32 + 8 * 1);
680 auto w22 = Vec8::load(weigth + 2 * 32 + 8 * 2);
681 auto w23 = Vec8::load(weigth + 2 * 32 + 8 * 3);
682 auto biasF = Vec8::load(bias);
683 auto minF = Vec8(parameter[2]);
684 auto maxF = Vec8(parameter[3]);
685
686 for (int x = 0; x < unit; ++x) {
687 auto offset = 8 * 4 * x;
688 int i = 0;
689 Vec8 m0 = w00 * Vec8::load(cacheLine[0] + offset + 8 * 0);
690 Vec8 m1 = w01 * Vec8::load(cacheLine[0] + offset + 8 * 1);
691 Vec8 m2 = w02 * Vec8::load(cacheLine[0] + offset + 8 * 2);
692 Vec8 m3 = w03 * Vec8::load(cacheLine[0] + offset + 8 * 3);
693
694 m0 = m0 + w10 * Vec8::load(cacheLine[1] + offset + 8 * 0);
695 m1 = m1 + w11 * Vec8::load(cacheLine[1] + offset + 8 * 1);
696 m2 = m2 + w12 * Vec8::load(cacheLine[1] + offset + 8 * 2);
697 m3 = m3 + w13 * Vec8::load(cacheLine[1] + offset + 8 * 3);
698
699 m0 = m0 + w20 * Vec8::load(cacheLine[2] + offset + 8 * 0);
700 m1 = m1 + w21 * Vec8::load(cacheLine[2] + offset + 8 * 1);
701 m2 = m2 + w22 * Vec8::load(cacheLine[2] + offset + 8 * 2);
702 m3 = m3 + w23 * Vec8::load(cacheLine[2] + offset + 8 * 3);
703
704 auto o0 = m0 + m1 + m2 + biasF;
705 auto o1 = m1 - m2 + m3 + biasF;
706 o0 = Vec8::min(maxF, o0);
707 o1 = Vec8::min(maxF, o1);
708 o0 = Vec8::max(minF, o0);
709 o1 = Vec8::max(minF, o1);
710 Vec8::save(dest + 16 * x + 0 * 8, o0);
711 Vec8::save(dest + 16 * x + 1 * 8, o1);
712 }
713 if (unit * 2 < ow) {
714 auto offset = 8 * 4 * unit;
715 Vec8 m0 = w00 * Vec8::load(cacheLine[0] + offset + 8 * 0);
716 Vec8 m1 = w01 * Vec8::load(cacheLine[0] + offset + 8 * 1);
717 Vec8 m2 = w02 * Vec8::load(cacheLine[0] + offset + 8 * 2);
718
719 m0 = m0 + w10 * Vec8::load(cacheLine[1] + offset + 8 * 0);
720 m1 = m1 + w11 * Vec8::load(cacheLine[1] + offset + 8 * 1);
721 m2 = m2 + w12 * Vec8::load(cacheLine[1] + offset + 8 * 2);
722
723 m0 = m0 + w20 * Vec8::load(cacheLine[2] + offset + 8 * 0);
724 m1 = m1 + w21 * Vec8::load(cacheLine[2] + offset + 8 * 1);
725 m2 = m2 + w22 * Vec8::load(cacheLine[2] + offset + 8 * 2);
726 auto o0 = m0 + m1 + m2 + biasF;
727 o0 = Vec8::min(maxF, o0);
728 o0 = Vec8::max(minF, o0);
729 Vec8::save(dest + 16 * unit + 0 * 8, o0);
730 }
731 }
732
_AVX2_MNNSelectBinaryFunctionForFloat(int opType)733 static MNNBinaryExecute _AVX2_MNNSelectBinaryFunctionForFloat(int opType) {
734 auto vecF = MNN::selectVector<Vec8, 8>(opType);
735 if (nullptr != vecF) {
736 return vecF;
737 }
738 return MNN::MNNGetCoreFunctions()->MNNSelectBinaryFunctionForFloat(opType);
739 }
740
741
_AVX_ExtraInit(void * functions)742 void _AVX_ExtraInit(void* functions) {
743 auto coreFunction = static_cast<MNN::CoreFunctions*>(functions);
744 coreFunction->MNNPoolingAvg = (decltype(coreFunction->MNNPoolingAvg))(MNN::poolingAvg<float, Vec8, 8>);
745 // Set min value as 1 << 24
746 coreFunction->MNNPoolingMax = (decltype(coreFunction->MNNPoolingMax))(MNN::poolingMax<float, Vec8, 8, -16777216>);
747 coreFunction->MNNSelectBinaryFunctionForFloat = _AVX2_MNNSelectBinaryFunctionForFloat;
748 }
749
_AVX_MNNScaleAndAddBias(float * dst,const float * src,const float * bias,const float * alpha,size_t planeNumber,size_t biasNumber)750 void _AVX_MNNScaleAndAddBias(float* dst, const float* src, const float* bias, const float* alpha, size_t planeNumber,
751 size_t biasNumber) {
752 for (int z = 0; z < biasNumber; ++z) {
753 float* dstZ = dst + planeNumber * 8 * z;
754 const float* srcZ = src + planeNumber * 8 * z;
755 auto biasZ = Vec8::load(bias + 8 * z);
756 auto alphaZ = Vec8::load(alpha + 8 * z);
757 for (int p = 0; p < planeNumber; ++p) {
758 float* dstX = dstZ + 8 * p;
759 const float* srcX = srcZ + 8 * p;
760 Vec8::save(dstX, (Vec8::load(srcX) * alphaZ) + biasZ);
761 }
762 }
763 }
764
_AVX_MNNDeconvRunForUnitDepthWise(const float * dst,float * src,const float * weight,size_t fw,size_t fh,size_t weight_y_step,size_t dilateX_step,size_t dilateY_step)765 void _AVX_MNNDeconvRunForUnitDepthWise(const float* dst, float* src, const float* weight, size_t fw, size_t fh,
766 size_t weight_y_step, size_t dilateX_step, size_t dilateY_step) {
767 int fx, fy;
768 float* src_z = src;
769 const float* weight_z = weight;
770 Vec8 dstV = Vec8::load(dst);
771 for (fy = 0; fy < fh; ++fy) {
772 float* src_y = src_z + fy * dilateY_step;
773 const float* weight_y = weight_z + fy * weight_y_step;
774 for (fx = 0; fx < fw; ++fx) {
775 Vec8 weight_x = Vec8::load(weight_y + 8 * fx);
776 Vec8 src_x = Vec8::load(src_y + fx * dilateX_step);
777 Vec8::save(src_y + fx * dilateX_step, src_x + weight_x * dstV);
778 }
779 }
780 }
_AVX_MNNDeconvRunForLineDepthwise(const float * dst,float * src,const float * weight,size_t width,size_t src_w_setup,size_t fw,size_t fh,size_t dilateX_step,size_t dilateY_step)781 void _AVX_MNNDeconvRunForLineDepthwise(const float* dst, float* src, const float* weight, size_t width, size_t src_w_setup,
782 size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step) {
783 int dx;
784 for (dx = 0; dx < width; ++dx) {
785 const float* dst_x = dst + dx * 8;
786 float* src_dx = src + src_w_setup * dx;
787 _AVX_MNNDeconvRunForUnitDepthWise(dst_x, src_dx, weight, fw, fh, fw * 8, dilateX_step, dilateY_step);
788 }
789 }
790