1 //
2 //  GemmCommon.cpp
3 //  MNN
4 //
5 //  Created by MNN on b'2020/09/22'.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifdef MNN_AVX512_VNNI
10 #include "FunctionSummary.hpp"
11 #include "core/Macro.h"
12 
13 namespace {
mm_loadu_si128(const void * addr)14 static inline __m128i mm_loadu_si128(const void* addr) {
15     return _mm_loadu_si128((__m128i const*)addr);
16 }
17 }  // namespace
18 
_AVX512_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t * dst,const int8_t * src,const int8_t * weight,size_t src_depth_quad,size_t dst_step,size_t dst_depth_quad,const QuanPostTreatParameters * post,size_t realDst)19 void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) {
20     const auto dst_step_tmp = dst_step / sizeof(int8_t);
21     __m128 zero128 = _mm_set1_ps(0.0f);
22     __m128 minValue = _mm_set1_ps(post->minValue);
23     __m128 maxValue = _mm_set1_ps(post->maxValue);
24     __m128 plus = _mm_set1_ps(0.5f);
25     __m128 minus = _mm_set1_ps(-0.5f);
26     if (realDst == 4) {
27         for (int dz = 0; dz < dst_depth_quad; ++dz) {
28             const auto weight_dz = weight + dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT);
29             const auto bias_dz = post->bias + dz * GEMM_INT8_UNIT;
30             const float* scale_dz = nullptr;
31             if (post->scale != nullptr) {
32                 scale_dz  = post->scale + dz * GEMM_INT8_UNIT;
33             }
34             auto dst_z           = dst + dz * dst_step_tmp;
35             const auto src_x   = src;
36             auto dst_x         = dst_z;
37             __m512i D0 = _mm512_set1_epi32(0);
38             __m512i D1 = _mm512_set1_epi32(0);
39             __m512i D2 = _mm512_set1_epi32(0);
40             __m512i D3 = _mm512_set1_epi32(0);
41 
42             for (int sz = 0; sz < src_depth_quad; ++sz) {
43                 const auto weight_sz = weight_dz + (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT) * sz;
44                 const auto src_z     = src_x + sz * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT;
45                 auto w0 = _mm512_loadu_si512(weight_sz + GEMM_INT8_SRC_UNIT * 0);
46 
47                 auto s0 = _mm512_broadcast_i32x4(mm_loadu_si128(src_z + GEMM_INT8_SRC_UNIT * 0));
48                 auto s1 = _mm512_broadcast_i32x4(mm_loadu_si128(src_z + GEMM_INT8_SRC_UNIT * 1));
49                 auto s2 = _mm512_broadcast_i32x4(mm_loadu_si128(src_z + GEMM_INT8_SRC_UNIT * 2));
50                 auto s3 = _mm512_broadcast_i32x4(mm_loadu_si128(src_z + GEMM_INT8_SRC_UNIT * 3));
51 
52                 D0 = _mm512_dpbusds_epi32(D0, s0, w0);
53                 D1 = _mm512_dpbusds_epi32(D1, s1, w0);
54                 D2 = _mm512_dpbusds_epi32(D2, s2, w0);
55                 D3 = _mm512_dpbusds_epi32(D3, s3, w0);
56             }
57             auto d00 = _mm512_extracti32x4_epi32(D0, 0);
58             auto d01 = _mm512_extracti32x4_epi32(D0, 1);
59             auto d02 = _mm512_extracti32x4_epi32(D0, 2);
60             auto d03 = _mm512_extracti32x4_epi32(D0, 3);
61 
62             auto d10 = _mm512_extracti32x4_epi32(D1, 0);
63             auto d11 = _mm512_extracti32x4_epi32(D1, 1);
64             auto d12 = _mm512_extracti32x4_epi32(D1, 2);
65             auto d13 = _mm512_extracti32x4_epi32(D1, 3);
66 
67             auto d20 = _mm512_extracti32x4_epi32(D2, 0);
68             auto d21 = _mm512_extracti32x4_epi32(D2, 1);
69             auto d22 = _mm512_extracti32x4_epi32(D2, 2);
70             auto d23 = _mm512_extracti32x4_epi32(D2, 3);
71 
72             auto d30 = _mm512_extracti32x4_epi32(D3, 0);
73             auto d31 = _mm512_extracti32x4_epi32(D3, 1);
74             auto d32 = _mm512_extracti32x4_epi32(D3, 2);
75             auto d33 = _mm512_extracti32x4_epi32(D3, 3);
76 
77             d00 = _mm_hadd_epi32(d00, d01);
78             d02 = _mm_hadd_epi32(d02, d03);
79             d10 = _mm_hadd_epi32(d10, d11);
80             d12 = _mm_hadd_epi32(d12, d13);
81             d20 = _mm_hadd_epi32(d20, d21);
82             d22 = _mm_hadd_epi32(d22, d23);
83             d30 = _mm_hadd_epi32(d30, d31);
84             d32 = _mm_hadd_epi32(d32, d33);
85 
86             auto d0 = _mm_hadd_epi32(d00, d02);
87             auto d1 = _mm_hadd_epi32(d10, d12);
88             auto d2 = _mm_hadd_epi32(d20, d22);
89             auto d3 = _mm_hadd_epi32(d30, d32);
90 
91             if (post->scale != nullptr) {
92                 auto biasValue = _mm_loadu_si128((__m128i*)(bias_dz));
93                 d0 = _mm_add_epi32(d0, biasValue);
94                 d1 = _mm_add_epi32(d1, biasValue);
95                 d2 = _mm_add_epi32(d2, biasValue);
96                 d3 = _mm_add_epi32(d3, biasValue);
97                 auto scaleValue = _mm_loadu_ps(scale_dz);
98                 __m128 f0 = _mm_cvtepi32_ps(d0);
99                 __m128 f1 = _mm_cvtepi32_ps(d1);
100                 __m128 f2 = _mm_cvtepi32_ps(d2);
101                 __m128 f3 = _mm_cvtepi32_ps(d3);
102                 f0 = _mm_mul_ps(f0, scaleValue);
103                 f1 = _mm_mul_ps(f1, scaleValue);
104                 f2 = _mm_mul_ps(f2, scaleValue);
105                 f3 = _mm_mul_ps(f3, scaleValue);
106                 f0 = _mm_min_ps(f0, maxValue);
107                 f1 = _mm_min_ps(f1, maxValue);
108                 f2 = _mm_min_ps(f2, maxValue);
109                 f3 = _mm_min_ps(f3, maxValue);
110                 f0 = _mm_max_ps(f0, minValue);
111                 f1 = _mm_max_ps(f1, minValue);
112                 f2 = _mm_max_ps(f2, minValue);
113                 f3 = _mm_max_ps(f3, minValue);
114                 auto m0 = _mm_cmplt_ps(f0, zero128);
115                 auto m1 = _mm_cmplt_ps(f1, zero128);
116                 auto m2 = _mm_cmplt_ps(f2, zero128);
117                 auto m3 = _mm_cmplt_ps(f3, zero128);
118                 m0 = _mm_blendv_ps(plus, minus, m0);
119                 m1 = _mm_blendv_ps(plus, minus, m1);
120                 m2 = _mm_blendv_ps(plus, minus, m2);
121                 m3 = _mm_blendv_ps(plus, minus, m3);
122                 f0 = _mm_add_ps(f0, m0);
123                 f1 = _mm_add_ps(f1, m1);
124                 f2 = _mm_add_ps(f2, m2);
125                 f3 = _mm_add_ps(f3, m3);
126                 // 3: _MM_FROUND_TO_ZERO
127                 d0 = _mm_cvtps_epi32(_mm_round_ps(f0, 3));
128                 d1 = _mm_cvtps_epi32(_mm_round_ps(f1, 3));
129                 d2 = _mm_cvtps_epi32(_mm_round_ps(f2, 3));
130                 d3 = _mm_cvtps_epi32(_mm_round_ps(f3, 3));
131 
132                 // Int32 -> Int8
133                 d0 = _mm_packs_epi32(d0, d1);
134                 d2 = _mm_packs_epi32(d2, d3);
135                 d0 = _mm_packs_epi16(d0, d2);
136                 _mm_storeu_ps((float*)dst_x, _mm_castsi128_ps(d0));
137             } else {
138                 auto biasValue = _mm_loadu_si128((__m128i*)(bias_dz));
139                 __m128 f0 = _mm_cvtepi32_ps(_mm_add_epi32(d0, biasValue));
140                 __m128 f1 = _mm_cvtepi32_ps(_mm_add_epi32(d1, biasValue));
141                 __m128 f2 = _mm_cvtepi32_ps(_mm_add_epi32(d2, biasValue));
142                 __m128 f3 = _mm_cvtepi32_ps(_mm_add_epi32(d3, biasValue));
143                 _mm_storeu_ps(((float*)dst_x), f0);
144                 _mm_storeu_ps(((float*)dst_x) + 4, f1);
145                 _mm_storeu_ps(((float*)dst_x) + 8, f2);
146                 _mm_storeu_ps(((float*)dst_x) + 12, f3);
147             }
148         }
149     }
150     if (realDst == 3) {
151         for (int dz = 0; dz < dst_depth_quad; ++dz) {
152             const auto weight_dz = weight + dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT);
153             const auto bias_dz = post->bias + dz * GEMM_INT8_UNIT;
154             const float* scale_dz = nullptr;
155             if (post->scale != nullptr) {
156                 scale_dz  = post->scale + dz * GEMM_INT8_UNIT;
157             }
158             auto dst_z           = dst + dz * dst_step_tmp;
159             const auto src_x   = src;
160             auto dst_x         = dst_z;
161             __m512i D0 = _mm512_set1_epi32(0);
162             __m512i D1 = _mm512_set1_epi32(0);
163             __m512i D2 = _mm512_set1_epi32(0);
164 
165             for (int sz = 0; sz < src_depth_quad; ++sz) {
166                 const auto weight_sz = weight_dz + (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT) * sz;
167                 const auto src_z     = src_x + sz * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT;
168                 auto w0 = _mm512_loadu_si512(weight_sz + GEMM_INT8_SRC_UNIT * 0);
169 
170                 auto s0 = _mm512_broadcast_i32x4(mm_loadu_si128(src_z + GEMM_INT8_SRC_UNIT * 0));
171                 auto s1 = _mm512_broadcast_i32x4(mm_loadu_si128(src_z + GEMM_INT8_SRC_UNIT * 1));
172                 auto s2 = _mm512_broadcast_i32x4(mm_loadu_si128(src_z + GEMM_INT8_SRC_UNIT * 2));
173 
174                 D0 = _mm512_dpbusds_epi32(D0, s0, w0);
175                 D1 = _mm512_dpbusds_epi32(D1, s1, w0);
176                 D2 = _mm512_dpbusds_epi32(D2, s2, w0);
177             }
178             auto d00 = _mm512_extracti32x4_epi32(D0, 0);
179             auto d01 = _mm512_extracti32x4_epi32(D0, 1);
180             auto d02 = _mm512_extracti32x4_epi32(D0, 2);
181             auto d03 = _mm512_extracti32x4_epi32(D0, 3);
182 
183             auto d10 = _mm512_extracti32x4_epi32(D1, 0);
184             auto d11 = _mm512_extracti32x4_epi32(D1, 1);
185             auto d12 = _mm512_extracti32x4_epi32(D1, 2);
186             auto d13 = _mm512_extracti32x4_epi32(D1, 3);
187 
188             auto d20 = _mm512_extracti32x4_epi32(D2, 0);
189             auto d21 = _mm512_extracti32x4_epi32(D2, 1);
190             auto d22 = _mm512_extracti32x4_epi32(D2, 2);
191             auto d23 = _mm512_extracti32x4_epi32(D2, 3);
192 
193             d00 = _mm_hadd_epi32(d00, d01);
194             d02 = _mm_hadd_epi32(d02, d03);
195             d10 = _mm_hadd_epi32(d10, d11);
196             d12 = _mm_hadd_epi32(d12, d13);
197             d20 = _mm_hadd_epi32(d20, d21);
198             d22 = _mm_hadd_epi32(d22, d23);
199 
200             auto d0 = _mm_hadd_epi32(d00, d02);
201             auto d1 = _mm_hadd_epi32(d10, d12);
202             auto d2 = _mm_hadd_epi32(d20, d22);
203 
204             if (post->scale != nullptr) {
205                 auto biasValue = _mm_loadu_si128((__m128i*)(bias_dz));
206                 d0 = _mm_add_epi32(d0, biasValue);
207                 d1 = _mm_add_epi32(d1, biasValue);
208                 d2 = _mm_add_epi32(d2, biasValue);
209                 auto scaleValue = _mm_loadu_ps(scale_dz);
210                 __m128 f0 = _mm_cvtepi32_ps(d0);
211                 __m128 f1 = _mm_cvtepi32_ps(d1);
212                 __m128 f2 = _mm_cvtepi32_ps(d2);
213                 f0 = _mm_mul_ps(f0, scaleValue);
214                 f1 = _mm_mul_ps(f1, scaleValue);
215                 f2 = _mm_mul_ps(f2, scaleValue);
216                 f0 = _mm_min_ps(f0, maxValue);
217                 f1 = _mm_min_ps(f1, maxValue);
218                 f2 = _mm_min_ps(f2, maxValue);
219                 f0 = _mm_max_ps(f0, minValue);
220                 f1 = _mm_max_ps(f1, minValue);
221                 f2 = _mm_max_ps(f2, minValue);
222                 auto m0 = _mm_cmplt_ps(f0, zero128);
223                 auto m1 = _mm_cmplt_ps(f1, zero128);
224                 auto m2 = _mm_cmplt_ps(f2, zero128);
225                 m0 = _mm_blendv_ps(plus, minus, m0);
226                 m1 = _mm_blendv_ps(plus, minus, m1);
227                 m2 = _mm_blendv_ps(plus, minus, m2);
228                 f0 = _mm_add_ps(f0, m0);
229                 f1 = _mm_add_ps(f1, m1);
230                 f2 = _mm_add_ps(f2, m2);
231                 // 3: _MM_FROUND_TO_ZERO
232                 d0 = _mm_cvtps_epi32(_mm_round_ps(f0, 3));
233                 d1 = _mm_cvtps_epi32(_mm_round_ps(f1, 3));
234                 d2 = _mm_cvtps_epi32(_mm_round_ps(f2, 3));
235 
236                 // Int32 -> Int8
237                 d0 = _mm_packs_epi32(d0, d1);
238                 d2 = _mm_packs_epi32(d2, d2);
239                 d0 = _mm_packs_epi16(d0, d2);
240                 int32_t tempV[4];
241                 _mm_storeu_si128((__m128i*)tempV, d0);
242                 for (int j=0; j<realDst; ++j) {
243                     ((int32_t*)dst_x)[j] = tempV[j];
244                 }
245             } else {
246                 auto biasValue = _mm_loadu_si128((__m128i*)(bias_dz));
247                 __m128 f0 = _mm_cvtepi32_ps(_mm_add_epi32(d0, biasValue));
248                 __m128 f1 = _mm_cvtepi32_ps(_mm_add_epi32(d1, biasValue));
249                 __m128 f2 = _mm_cvtepi32_ps(_mm_add_epi32(d2, biasValue));
250                 _mm_storeu_ps(((float*)dst_x), f0);
251                 _mm_storeu_ps(((float*)dst_x) + 4, f1);
252                 _mm_storeu_ps(((float*)dst_x) + 8, f2);
253             }
254 
255         }
256     }
257     if (realDst == 2) {
258         for (int dz = 0; dz < dst_depth_quad; ++dz) {
259             const auto weight_dz = weight + dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT);
260             const auto bias_dz = post->bias + dz * GEMM_INT8_UNIT;
261             const float* scale_dz = nullptr;
262             if (post->scale != nullptr) {
263                 scale_dz  = post->scale + dz * GEMM_INT8_UNIT;
264             }
265             auto dst_z           = dst + dz * dst_step_tmp;
266             const auto src_x   = src;
267             auto dst_x         = dst_z;
268             __m512i D0 = _mm512_set1_epi32(0);
269             __m512i D1 = _mm512_set1_epi32(0);
270 
271             for (int sz = 0; sz < src_depth_quad; ++sz) {
272                 const auto weight_sz = weight_dz + (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT) * sz;
273                 const auto src_z     = src_x + sz * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT;
274                 auto w0 = _mm512_loadu_si512(weight_sz + GEMM_INT8_SRC_UNIT * 0);
275 
276                 auto s0 = _mm512_broadcast_i32x4(mm_loadu_si128(src_z + GEMM_INT8_SRC_UNIT * 0));
277                 auto s1 = _mm512_broadcast_i32x4(mm_loadu_si128(src_z + GEMM_INT8_SRC_UNIT * 1));
278 
279                 D0 = _mm512_dpbusds_epi32(D0, s0, w0);
280                 D1 = _mm512_dpbusds_epi32(D1, s1, w0);
281             }
282             auto d00 = _mm512_extracti32x4_epi32(D0, 0);
283             auto d01 = _mm512_extracti32x4_epi32(D0, 1);
284             auto d02 = _mm512_extracti32x4_epi32(D0, 2);
285             auto d03 = _mm512_extracti32x4_epi32(D0, 3);
286 
287             auto d10 = _mm512_extracti32x4_epi32(D1, 0);
288             auto d11 = _mm512_extracti32x4_epi32(D1, 1);
289             auto d12 = _mm512_extracti32x4_epi32(D1, 2);
290             auto d13 = _mm512_extracti32x4_epi32(D1, 3);
291 
292             d00 = _mm_hadd_epi32(d00, d01);
293             d02 = _mm_hadd_epi32(d02, d03);
294             d10 = _mm_hadd_epi32(d10, d11);
295             d12 = _mm_hadd_epi32(d12, d13);
296             auto d0 = _mm_hadd_epi32(d00, d02);
297             auto d1 = _mm_hadd_epi32(d10, d12);
298 
299             if (post->scale != nullptr) {
300                 auto biasValue = _mm_loadu_si128((__m128i*)(bias_dz));
301                 d0 = _mm_add_epi32(d0, biasValue);
302                 d1 = _mm_add_epi32(d1, biasValue);
303                 auto scaleValue = _mm_loadu_ps(scale_dz);
304                 __m128 f0 = _mm_cvtepi32_ps(d0);
305                 __m128 f1 = _mm_cvtepi32_ps(d1);
306                 f0 = _mm_mul_ps(f0, scaleValue);
307                 f1 = _mm_mul_ps(f1, scaleValue);
308                 f0 = _mm_min_ps(f0, maxValue);
309                 f1 = _mm_min_ps(f1, maxValue);
310                 f0 = _mm_max_ps(f0, minValue);
311                 f1 = _mm_max_ps(f1, minValue);
312                 auto m0 = _mm_cmplt_ps(f0, zero128);
313                 auto m1 = _mm_cmplt_ps(f1, zero128);
314                 m0 = _mm_blendv_ps(plus, minus, m0);
315                 m1 = _mm_blendv_ps(plus, minus, m1);
316                 f0 = _mm_add_ps(f0, m0);
317                 f1 = _mm_add_ps(f1, m1);
318                 // 3: _MM_FROUND_TO_ZERO
319                 d0 = _mm_cvtps_epi32(_mm_round_ps(f0, 3));
320                 d1 = _mm_cvtps_epi32(_mm_round_ps(f1, 3));
321 
322                 // Int32 -> Int8
323                 d0 = _mm_packs_epi32(d0, d1);
324                 d0 = _mm_packs_epi16(d0, d0);
325                 int32_t tempV[4];
326                 _mm_storeu_si128((__m128i*)tempV, d0);
327                 for (int j=0; j<realDst; ++j) {
328                     ((int32_t*)dst_x)[j] = tempV[j];
329                 }
330             } else {
331                 auto biasValue = _mm_loadu_si128((__m128i*)(bias_dz));
332                 __m128 f0 = _mm_cvtepi32_ps(_mm_add_epi32(d0, biasValue));
333                 __m128 f1 = _mm_cvtepi32_ps(_mm_add_epi32(d1, biasValue));
334                 _mm_storeu_ps(((float*)dst_x), f0);
335                 _mm_storeu_ps(((float*)dst_x) + 4, f1);
336             }
337         }
338     }
339     if (realDst == 1) {
340         for (int dz = 0; dz < dst_depth_quad; ++dz) {
341             const auto weight_dz = weight + dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT);
342             const auto bias_dz = post->bias + dz * GEMM_INT8_UNIT;
343             const float* scale_dz = nullptr;
344             if (post->scale != nullptr) {
345                 scale_dz  = post->scale + dz * GEMM_INT8_UNIT;
346             }
347             auto dst_z           = dst + dz * dst_step_tmp;
348             const auto src_x   = src;
349             auto dst_x         = dst_z;
350             __m512i D0 = _mm512_set1_epi32(0);
351 
352             for (int sz = 0; sz < src_depth_quad; ++sz) {
353                 const auto weight_sz = weight_dz + (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT) * sz;
354                 const auto src_z     = src_x + sz * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT;
355                 auto w0 = _mm512_loadu_si512(weight_sz + GEMM_INT8_SRC_UNIT * 0);
356 
357                 auto s0 = _mm512_broadcast_i32x4(mm_loadu_si128(src_z + GEMM_INT8_SRC_UNIT * 0));
358 
359                 D0 = _mm512_dpbusds_epi32(D0, s0, w0);
360             }
361             auto d00 = _mm512_extracti32x4_epi32(D0, 0);
362             auto d01 = _mm512_extracti32x4_epi32(D0, 1);
363             auto d02 = _mm512_extracti32x4_epi32(D0, 2);
364             auto d03 = _mm512_extracti32x4_epi32(D0, 3);
365 
366             d00 = _mm_hadd_epi32(d00, d01);
367             d02 = _mm_hadd_epi32(d02, d03);
368 
369             auto d0 = _mm_hadd_epi32(d00, d02);
370 
371             if (post->scale != nullptr) {
372                 auto biasValue = _mm_loadu_si128((__m128i*)(bias_dz));
373                 d0 = _mm_add_epi32(d0, biasValue);
374                 auto scaleValue = _mm_loadu_ps(scale_dz);
375                 __m128 f0 = _mm_cvtepi32_ps(d0);
376                 f0 = _mm_mul_ps(f0, scaleValue);
377                 f0 = _mm_min_ps(f0, maxValue);
378                 f0 = _mm_max_ps(f0, minValue);
379                 auto m0 = _mm_cmplt_ps(f0, zero128);
380                 m0 = _mm_blendv_ps(plus, minus, m0);
381                 f0 = _mm_add_ps(f0, m0);
382                 // 3: _MM_FROUND_TO_ZERO
383                 d0 = _mm_cvtps_epi32(_mm_round_ps(f0, 3));
384 
385                 // Int32 -> Int8
386                 d0 = _mm_packs_epi32(d0, d0);
387                 d0 = _mm_packs_epi16(d0, d0);
388                 int32_t tempV[4];
389                 _mm_storeu_si128((__m128i*)tempV, d0);
390                 for (int j=0; j<realDst; ++j) {
391                     ((int32_t*)dst_x)[j] = tempV[j];
392                 }
393             } else {
394                 auto biasValue = _mm_loadu_si128((__m128i*)(bias_dz));
395                 __m128 f0 = _mm_cvtepi32_ps(_mm_add_epi32(d0, biasValue));
396                 _mm_storeu_ps(((float*)dst_x), f0);
397             }
398         }
399     }
400 }
401 #endif
402