1 //
2 // GemmCommon.cpp
3 // MNN
4 //
5 // Created by MNN on b'2020/09/22'.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #ifdef MNN_AVX512_VNNI
10 #include "FunctionSummary.hpp"
11 #include "core/Macro.h"
12
13 namespace {
mm_loadu_si128(const void * addr)14 static inline __m128i mm_loadu_si128(const void* addr) {
15 return _mm_loadu_si128((__m128i const*)addr);
16 }
17 } // namespace
18
_AVX512_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t * dst,const int8_t * src,const int8_t * weight,size_t src_depth_quad,size_t dst_step,size_t dst_depth_quad,const QuanPostTreatParameters * post,size_t realDst)19 void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) {
20 const auto dst_step_tmp = dst_step / sizeof(int8_t);
21 __m128 zero128 = _mm_set1_ps(0.0f);
22 __m128 minValue = _mm_set1_ps(post->minValue);
23 __m128 maxValue = _mm_set1_ps(post->maxValue);
24 __m128 plus = _mm_set1_ps(0.5f);
25 __m128 minus = _mm_set1_ps(-0.5f);
26 if (realDst == 4) {
27 for (int dz = 0; dz < dst_depth_quad; ++dz) {
28 const auto weight_dz = weight + dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT);
29 const auto bias_dz = post->bias + dz * GEMM_INT8_UNIT;
30 const float* scale_dz = nullptr;
31 if (post->scale != nullptr) {
32 scale_dz = post->scale + dz * GEMM_INT8_UNIT;
33 }
34 auto dst_z = dst + dz * dst_step_tmp;
35 const auto src_x = src;
36 auto dst_x = dst_z;
37 __m512i D0 = _mm512_set1_epi32(0);
38 __m512i D1 = _mm512_set1_epi32(0);
39 __m512i D2 = _mm512_set1_epi32(0);
40 __m512i D3 = _mm512_set1_epi32(0);
41
42 for (int sz = 0; sz < src_depth_quad; ++sz) {
43 const auto weight_sz = weight_dz + (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT) * sz;
44 const auto src_z = src_x + sz * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT;
45 auto w0 = _mm512_loadu_si512(weight_sz + GEMM_INT8_SRC_UNIT * 0);
46
47 auto s0 = _mm512_broadcast_i32x4(mm_loadu_si128(src_z + GEMM_INT8_SRC_UNIT * 0));
48 auto s1 = _mm512_broadcast_i32x4(mm_loadu_si128(src_z + GEMM_INT8_SRC_UNIT * 1));
49 auto s2 = _mm512_broadcast_i32x4(mm_loadu_si128(src_z + GEMM_INT8_SRC_UNIT * 2));
50 auto s3 = _mm512_broadcast_i32x4(mm_loadu_si128(src_z + GEMM_INT8_SRC_UNIT * 3));
51
52 D0 = _mm512_dpbusds_epi32(D0, s0, w0);
53 D1 = _mm512_dpbusds_epi32(D1, s1, w0);
54 D2 = _mm512_dpbusds_epi32(D2, s2, w0);
55 D3 = _mm512_dpbusds_epi32(D3, s3, w0);
56 }
57 auto d00 = _mm512_extracti32x4_epi32(D0, 0);
58 auto d01 = _mm512_extracti32x4_epi32(D0, 1);
59 auto d02 = _mm512_extracti32x4_epi32(D0, 2);
60 auto d03 = _mm512_extracti32x4_epi32(D0, 3);
61
62 auto d10 = _mm512_extracti32x4_epi32(D1, 0);
63 auto d11 = _mm512_extracti32x4_epi32(D1, 1);
64 auto d12 = _mm512_extracti32x4_epi32(D1, 2);
65 auto d13 = _mm512_extracti32x4_epi32(D1, 3);
66
67 auto d20 = _mm512_extracti32x4_epi32(D2, 0);
68 auto d21 = _mm512_extracti32x4_epi32(D2, 1);
69 auto d22 = _mm512_extracti32x4_epi32(D2, 2);
70 auto d23 = _mm512_extracti32x4_epi32(D2, 3);
71
72 auto d30 = _mm512_extracti32x4_epi32(D3, 0);
73 auto d31 = _mm512_extracti32x4_epi32(D3, 1);
74 auto d32 = _mm512_extracti32x4_epi32(D3, 2);
75 auto d33 = _mm512_extracti32x4_epi32(D3, 3);
76
77 d00 = _mm_hadd_epi32(d00, d01);
78 d02 = _mm_hadd_epi32(d02, d03);
79 d10 = _mm_hadd_epi32(d10, d11);
80 d12 = _mm_hadd_epi32(d12, d13);
81 d20 = _mm_hadd_epi32(d20, d21);
82 d22 = _mm_hadd_epi32(d22, d23);
83 d30 = _mm_hadd_epi32(d30, d31);
84 d32 = _mm_hadd_epi32(d32, d33);
85
86 auto d0 = _mm_hadd_epi32(d00, d02);
87 auto d1 = _mm_hadd_epi32(d10, d12);
88 auto d2 = _mm_hadd_epi32(d20, d22);
89 auto d3 = _mm_hadd_epi32(d30, d32);
90
91 if (post->scale != nullptr) {
92 auto biasValue = _mm_loadu_si128((__m128i*)(bias_dz));
93 d0 = _mm_add_epi32(d0, biasValue);
94 d1 = _mm_add_epi32(d1, biasValue);
95 d2 = _mm_add_epi32(d2, biasValue);
96 d3 = _mm_add_epi32(d3, biasValue);
97 auto scaleValue = _mm_loadu_ps(scale_dz);
98 __m128 f0 = _mm_cvtepi32_ps(d0);
99 __m128 f1 = _mm_cvtepi32_ps(d1);
100 __m128 f2 = _mm_cvtepi32_ps(d2);
101 __m128 f3 = _mm_cvtepi32_ps(d3);
102 f0 = _mm_mul_ps(f0, scaleValue);
103 f1 = _mm_mul_ps(f1, scaleValue);
104 f2 = _mm_mul_ps(f2, scaleValue);
105 f3 = _mm_mul_ps(f3, scaleValue);
106 f0 = _mm_min_ps(f0, maxValue);
107 f1 = _mm_min_ps(f1, maxValue);
108 f2 = _mm_min_ps(f2, maxValue);
109 f3 = _mm_min_ps(f3, maxValue);
110 f0 = _mm_max_ps(f0, minValue);
111 f1 = _mm_max_ps(f1, minValue);
112 f2 = _mm_max_ps(f2, minValue);
113 f3 = _mm_max_ps(f3, minValue);
114 auto m0 = _mm_cmplt_ps(f0, zero128);
115 auto m1 = _mm_cmplt_ps(f1, zero128);
116 auto m2 = _mm_cmplt_ps(f2, zero128);
117 auto m3 = _mm_cmplt_ps(f3, zero128);
118 m0 = _mm_blendv_ps(plus, minus, m0);
119 m1 = _mm_blendv_ps(plus, minus, m1);
120 m2 = _mm_blendv_ps(plus, minus, m2);
121 m3 = _mm_blendv_ps(plus, minus, m3);
122 f0 = _mm_add_ps(f0, m0);
123 f1 = _mm_add_ps(f1, m1);
124 f2 = _mm_add_ps(f2, m2);
125 f3 = _mm_add_ps(f3, m3);
126 // 3: _MM_FROUND_TO_ZERO
127 d0 = _mm_cvtps_epi32(_mm_round_ps(f0, 3));
128 d1 = _mm_cvtps_epi32(_mm_round_ps(f1, 3));
129 d2 = _mm_cvtps_epi32(_mm_round_ps(f2, 3));
130 d3 = _mm_cvtps_epi32(_mm_round_ps(f3, 3));
131
132 // Int32 -> Int8
133 d0 = _mm_packs_epi32(d0, d1);
134 d2 = _mm_packs_epi32(d2, d3);
135 d0 = _mm_packs_epi16(d0, d2);
136 _mm_storeu_ps((float*)dst_x, _mm_castsi128_ps(d0));
137 } else {
138 auto biasValue = _mm_loadu_si128((__m128i*)(bias_dz));
139 __m128 f0 = _mm_cvtepi32_ps(_mm_add_epi32(d0, biasValue));
140 __m128 f1 = _mm_cvtepi32_ps(_mm_add_epi32(d1, biasValue));
141 __m128 f2 = _mm_cvtepi32_ps(_mm_add_epi32(d2, biasValue));
142 __m128 f3 = _mm_cvtepi32_ps(_mm_add_epi32(d3, biasValue));
143 _mm_storeu_ps(((float*)dst_x), f0);
144 _mm_storeu_ps(((float*)dst_x) + 4, f1);
145 _mm_storeu_ps(((float*)dst_x) + 8, f2);
146 _mm_storeu_ps(((float*)dst_x) + 12, f3);
147 }
148 }
149 }
150 if (realDst == 3) {
151 for (int dz = 0; dz < dst_depth_quad; ++dz) {
152 const auto weight_dz = weight + dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT);
153 const auto bias_dz = post->bias + dz * GEMM_INT8_UNIT;
154 const float* scale_dz = nullptr;
155 if (post->scale != nullptr) {
156 scale_dz = post->scale + dz * GEMM_INT8_UNIT;
157 }
158 auto dst_z = dst + dz * dst_step_tmp;
159 const auto src_x = src;
160 auto dst_x = dst_z;
161 __m512i D0 = _mm512_set1_epi32(0);
162 __m512i D1 = _mm512_set1_epi32(0);
163 __m512i D2 = _mm512_set1_epi32(0);
164
165 for (int sz = 0; sz < src_depth_quad; ++sz) {
166 const auto weight_sz = weight_dz + (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT) * sz;
167 const auto src_z = src_x + sz * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT;
168 auto w0 = _mm512_loadu_si512(weight_sz + GEMM_INT8_SRC_UNIT * 0);
169
170 auto s0 = _mm512_broadcast_i32x4(mm_loadu_si128(src_z + GEMM_INT8_SRC_UNIT * 0));
171 auto s1 = _mm512_broadcast_i32x4(mm_loadu_si128(src_z + GEMM_INT8_SRC_UNIT * 1));
172 auto s2 = _mm512_broadcast_i32x4(mm_loadu_si128(src_z + GEMM_INT8_SRC_UNIT * 2));
173
174 D0 = _mm512_dpbusds_epi32(D0, s0, w0);
175 D1 = _mm512_dpbusds_epi32(D1, s1, w0);
176 D2 = _mm512_dpbusds_epi32(D2, s2, w0);
177 }
178 auto d00 = _mm512_extracti32x4_epi32(D0, 0);
179 auto d01 = _mm512_extracti32x4_epi32(D0, 1);
180 auto d02 = _mm512_extracti32x4_epi32(D0, 2);
181 auto d03 = _mm512_extracti32x4_epi32(D0, 3);
182
183 auto d10 = _mm512_extracti32x4_epi32(D1, 0);
184 auto d11 = _mm512_extracti32x4_epi32(D1, 1);
185 auto d12 = _mm512_extracti32x4_epi32(D1, 2);
186 auto d13 = _mm512_extracti32x4_epi32(D1, 3);
187
188 auto d20 = _mm512_extracti32x4_epi32(D2, 0);
189 auto d21 = _mm512_extracti32x4_epi32(D2, 1);
190 auto d22 = _mm512_extracti32x4_epi32(D2, 2);
191 auto d23 = _mm512_extracti32x4_epi32(D2, 3);
192
193 d00 = _mm_hadd_epi32(d00, d01);
194 d02 = _mm_hadd_epi32(d02, d03);
195 d10 = _mm_hadd_epi32(d10, d11);
196 d12 = _mm_hadd_epi32(d12, d13);
197 d20 = _mm_hadd_epi32(d20, d21);
198 d22 = _mm_hadd_epi32(d22, d23);
199
200 auto d0 = _mm_hadd_epi32(d00, d02);
201 auto d1 = _mm_hadd_epi32(d10, d12);
202 auto d2 = _mm_hadd_epi32(d20, d22);
203
204 if (post->scale != nullptr) {
205 auto biasValue = _mm_loadu_si128((__m128i*)(bias_dz));
206 d0 = _mm_add_epi32(d0, biasValue);
207 d1 = _mm_add_epi32(d1, biasValue);
208 d2 = _mm_add_epi32(d2, biasValue);
209 auto scaleValue = _mm_loadu_ps(scale_dz);
210 __m128 f0 = _mm_cvtepi32_ps(d0);
211 __m128 f1 = _mm_cvtepi32_ps(d1);
212 __m128 f2 = _mm_cvtepi32_ps(d2);
213 f0 = _mm_mul_ps(f0, scaleValue);
214 f1 = _mm_mul_ps(f1, scaleValue);
215 f2 = _mm_mul_ps(f2, scaleValue);
216 f0 = _mm_min_ps(f0, maxValue);
217 f1 = _mm_min_ps(f1, maxValue);
218 f2 = _mm_min_ps(f2, maxValue);
219 f0 = _mm_max_ps(f0, minValue);
220 f1 = _mm_max_ps(f1, minValue);
221 f2 = _mm_max_ps(f2, minValue);
222 auto m0 = _mm_cmplt_ps(f0, zero128);
223 auto m1 = _mm_cmplt_ps(f1, zero128);
224 auto m2 = _mm_cmplt_ps(f2, zero128);
225 m0 = _mm_blendv_ps(plus, minus, m0);
226 m1 = _mm_blendv_ps(plus, minus, m1);
227 m2 = _mm_blendv_ps(plus, minus, m2);
228 f0 = _mm_add_ps(f0, m0);
229 f1 = _mm_add_ps(f1, m1);
230 f2 = _mm_add_ps(f2, m2);
231 // 3: _MM_FROUND_TO_ZERO
232 d0 = _mm_cvtps_epi32(_mm_round_ps(f0, 3));
233 d1 = _mm_cvtps_epi32(_mm_round_ps(f1, 3));
234 d2 = _mm_cvtps_epi32(_mm_round_ps(f2, 3));
235
236 // Int32 -> Int8
237 d0 = _mm_packs_epi32(d0, d1);
238 d2 = _mm_packs_epi32(d2, d2);
239 d0 = _mm_packs_epi16(d0, d2);
240 int32_t tempV[4];
241 _mm_storeu_si128((__m128i*)tempV, d0);
242 for (int j=0; j<realDst; ++j) {
243 ((int32_t*)dst_x)[j] = tempV[j];
244 }
245 } else {
246 auto biasValue = _mm_loadu_si128((__m128i*)(bias_dz));
247 __m128 f0 = _mm_cvtepi32_ps(_mm_add_epi32(d0, biasValue));
248 __m128 f1 = _mm_cvtepi32_ps(_mm_add_epi32(d1, biasValue));
249 __m128 f2 = _mm_cvtepi32_ps(_mm_add_epi32(d2, biasValue));
250 _mm_storeu_ps(((float*)dst_x), f0);
251 _mm_storeu_ps(((float*)dst_x) + 4, f1);
252 _mm_storeu_ps(((float*)dst_x) + 8, f2);
253 }
254
255 }
256 }
257 if (realDst == 2) {
258 for (int dz = 0; dz < dst_depth_quad; ++dz) {
259 const auto weight_dz = weight + dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT);
260 const auto bias_dz = post->bias + dz * GEMM_INT8_UNIT;
261 const float* scale_dz = nullptr;
262 if (post->scale != nullptr) {
263 scale_dz = post->scale + dz * GEMM_INT8_UNIT;
264 }
265 auto dst_z = dst + dz * dst_step_tmp;
266 const auto src_x = src;
267 auto dst_x = dst_z;
268 __m512i D0 = _mm512_set1_epi32(0);
269 __m512i D1 = _mm512_set1_epi32(0);
270
271 for (int sz = 0; sz < src_depth_quad; ++sz) {
272 const auto weight_sz = weight_dz + (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT) * sz;
273 const auto src_z = src_x + sz * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT;
274 auto w0 = _mm512_loadu_si512(weight_sz + GEMM_INT8_SRC_UNIT * 0);
275
276 auto s0 = _mm512_broadcast_i32x4(mm_loadu_si128(src_z + GEMM_INT8_SRC_UNIT * 0));
277 auto s1 = _mm512_broadcast_i32x4(mm_loadu_si128(src_z + GEMM_INT8_SRC_UNIT * 1));
278
279 D0 = _mm512_dpbusds_epi32(D0, s0, w0);
280 D1 = _mm512_dpbusds_epi32(D1, s1, w0);
281 }
282 auto d00 = _mm512_extracti32x4_epi32(D0, 0);
283 auto d01 = _mm512_extracti32x4_epi32(D0, 1);
284 auto d02 = _mm512_extracti32x4_epi32(D0, 2);
285 auto d03 = _mm512_extracti32x4_epi32(D0, 3);
286
287 auto d10 = _mm512_extracti32x4_epi32(D1, 0);
288 auto d11 = _mm512_extracti32x4_epi32(D1, 1);
289 auto d12 = _mm512_extracti32x4_epi32(D1, 2);
290 auto d13 = _mm512_extracti32x4_epi32(D1, 3);
291
292 d00 = _mm_hadd_epi32(d00, d01);
293 d02 = _mm_hadd_epi32(d02, d03);
294 d10 = _mm_hadd_epi32(d10, d11);
295 d12 = _mm_hadd_epi32(d12, d13);
296 auto d0 = _mm_hadd_epi32(d00, d02);
297 auto d1 = _mm_hadd_epi32(d10, d12);
298
299 if (post->scale != nullptr) {
300 auto biasValue = _mm_loadu_si128((__m128i*)(bias_dz));
301 d0 = _mm_add_epi32(d0, biasValue);
302 d1 = _mm_add_epi32(d1, biasValue);
303 auto scaleValue = _mm_loadu_ps(scale_dz);
304 __m128 f0 = _mm_cvtepi32_ps(d0);
305 __m128 f1 = _mm_cvtepi32_ps(d1);
306 f0 = _mm_mul_ps(f0, scaleValue);
307 f1 = _mm_mul_ps(f1, scaleValue);
308 f0 = _mm_min_ps(f0, maxValue);
309 f1 = _mm_min_ps(f1, maxValue);
310 f0 = _mm_max_ps(f0, minValue);
311 f1 = _mm_max_ps(f1, minValue);
312 auto m0 = _mm_cmplt_ps(f0, zero128);
313 auto m1 = _mm_cmplt_ps(f1, zero128);
314 m0 = _mm_blendv_ps(plus, minus, m0);
315 m1 = _mm_blendv_ps(plus, minus, m1);
316 f0 = _mm_add_ps(f0, m0);
317 f1 = _mm_add_ps(f1, m1);
318 // 3: _MM_FROUND_TO_ZERO
319 d0 = _mm_cvtps_epi32(_mm_round_ps(f0, 3));
320 d1 = _mm_cvtps_epi32(_mm_round_ps(f1, 3));
321
322 // Int32 -> Int8
323 d0 = _mm_packs_epi32(d0, d1);
324 d0 = _mm_packs_epi16(d0, d0);
325 int32_t tempV[4];
326 _mm_storeu_si128((__m128i*)tempV, d0);
327 for (int j=0; j<realDst; ++j) {
328 ((int32_t*)dst_x)[j] = tempV[j];
329 }
330 } else {
331 auto biasValue = _mm_loadu_si128((__m128i*)(bias_dz));
332 __m128 f0 = _mm_cvtepi32_ps(_mm_add_epi32(d0, biasValue));
333 __m128 f1 = _mm_cvtepi32_ps(_mm_add_epi32(d1, biasValue));
334 _mm_storeu_ps(((float*)dst_x), f0);
335 _mm_storeu_ps(((float*)dst_x) + 4, f1);
336 }
337 }
338 }
339 if (realDst == 1) {
340 for (int dz = 0; dz < dst_depth_quad; ++dz) {
341 const auto weight_dz = weight + dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT);
342 const auto bias_dz = post->bias + dz * GEMM_INT8_UNIT;
343 const float* scale_dz = nullptr;
344 if (post->scale != nullptr) {
345 scale_dz = post->scale + dz * GEMM_INT8_UNIT;
346 }
347 auto dst_z = dst + dz * dst_step_tmp;
348 const auto src_x = src;
349 auto dst_x = dst_z;
350 __m512i D0 = _mm512_set1_epi32(0);
351
352 for (int sz = 0; sz < src_depth_quad; ++sz) {
353 const auto weight_sz = weight_dz + (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT) * sz;
354 const auto src_z = src_x + sz * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT;
355 auto w0 = _mm512_loadu_si512(weight_sz + GEMM_INT8_SRC_UNIT * 0);
356
357 auto s0 = _mm512_broadcast_i32x4(mm_loadu_si128(src_z + GEMM_INT8_SRC_UNIT * 0));
358
359 D0 = _mm512_dpbusds_epi32(D0, s0, w0);
360 }
361 auto d00 = _mm512_extracti32x4_epi32(D0, 0);
362 auto d01 = _mm512_extracti32x4_epi32(D0, 1);
363 auto d02 = _mm512_extracti32x4_epi32(D0, 2);
364 auto d03 = _mm512_extracti32x4_epi32(D0, 3);
365
366 d00 = _mm_hadd_epi32(d00, d01);
367 d02 = _mm_hadd_epi32(d02, d03);
368
369 auto d0 = _mm_hadd_epi32(d00, d02);
370
371 if (post->scale != nullptr) {
372 auto biasValue = _mm_loadu_si128((__m128i*)(bias_dz));
373 d0 = _mm_add_epi32(d0, biasValue);
374 auto scaleValue = _mm_loadu_ps(scale_dz);
375 __m128 f0 = _mm_cvtepi32_ps(d0);
376 f0 = _mm_mul_ps(f0, scaleValue);
377 f0 = _mm_min_ps(f0, maxValue);
378 f0 = _mm_max_ps(f0, minValue);
379 auto m0 = _mm_cmplt_ps(f0, zero128);
380 m0 = _mm_blendv_ps(plus, minus, m0);
381 f0 = _mm_add_ps(f0, m0);
382 // 3: _MM_FROUND_TO_ZERO
383 d0 = _mm_cvtps_epi32(_mm_round_ps(f0, 3));
384
385 // Int32 -> Int8
386 d0 = _mm_packs_epi32(d0, d0);
387 d0 = _mm_packs_epi16(d0, d0);
388 int32_t tempV[4];
389 _mm_storeu_si128((__m128i*)tempV, d0);
390 for (int j=0; j<realDst; ++j) {
391 ((int32_t*)dst_x)[j] = tempV[j];
392 }
393 } else {
394 auto biasValue = _mm_loadu_si128((__m128i*)(bias_dz));
395 __m128 f0 = _mm_cvtepi32_ps(_mm_add_epi32(d0, biasValue));
396 _mm_storeu_ps(((float*)dst_x), f0);
397 }
398 }
399 }
400 }
401 #endif
402