1/* This file is included multiple times, once for each backend instruction set. */
2
3#if defined(KERNELS_THIS_IS_SSE2)
4  #define CPU_NAME SSE2
5  #define CPU_ATTR INTGEMM_SSE2
6#elif defined(KERNELS_THIS_IS_AVX2)
7  #define CPU_NAME AVX2
8  #define CPU_ATTR INTGEMM_AVX2
9#elif defined(KERNELS_THIS_IS_AVX512BW)
10  #define CPU_NAME AVX512BW
11  #define CPU_ATTR INTGEMM_AVX512BW
12#else
13  #error "Only SSE2, AVX2 and AVX512BW are supported"
14#endif
15
16#define vi vector_t<CPUType::CPU_NAME, int>
17#define vf vector_t<CPUType::CPU_NAME, float>
18#define vd vector_t<CPUType::CPU_NAME, double>
19
20/*
21 * Kernels implementations....
22 */
23namespace intgemm {
24namespace kernels {
25
26/*
27 * Write
28 */
29CPU_ATTR static inline void write(vi input, int8_t* output, Index offset) {
30  *reinterpret_cast<vi*>(output + offset) = input;
31}
32
33CPU_ATTR static inline void write(vi input, int16_t* output, Index offset) {
34  *reinterpret_cast<vi*>(output + offset) = input;
35}
36
37CPU_ATTR static inline void write(vi input, int* output, Index offset) {
38  *reinterpret_cast<vi*>(output + offset) = input;
39}
40
41CPU_ATTR static inline void write(vf input, float* output, Index offset) {
42  *reinterpret_cast<vf*>(output + offset) = input;
43}
44
45CPU_ATTR static inline void write(vd input, double* output, Index offset) {
46  *reinterpret_cast<vd*>(output + offset) = input;
47}
48
49/*
50 * Quantize
51 */
52CPU_ATTR static inline vi quantize(vf input, vf quant_mult) {
53  return cvtps_epi32(mul_ps(input, quant_mult));
54}
55
56/*
57 * Unquantize
58 */
59CPU_ATTR static inline vf unquantize(vi input, vf unquant_mult) {
60  return mul_ps(cvtepi32_ps(input), unquant_mult);
61}
62
63/*
64 * Add a bias term
65 */
66CPU_ATTR static inline vi add_bias(vi input, const int8_t* bias_addr, Index bias_offset) {
67  auto bias_term = *reinterpret_cast<const vi*>(bias_addr + bias_offset);
68  return add_epi8(input, bias_term);
69}
70
71CPU_ATTR static inline vi add_bias(vi input, const int16_t* bias_addr, Index bias_offset) {
72  auto bias_term = *reinterpret_cast<const vi*>(bias_addr + bias_offset);
73  return add_epi16(input, bias_term);
74}
75
76CPU_ATTR static inline vi add_bias(vi input, const int* bias_addr, Index bias_offset) {
77  auto bias_term = *reinterpret_cast<const vi*>(bias_addr + bias_offset);
78  return add_epi32(input, bias_term);
79}
80
81CPU_ATTR static inline vf add_bias(vf input, const float* bias_addr, Index bias_offset) {
82  auto bias_term = *reinterpret_cast<const vf*>(bias_addr + bias_offset);
83  return add_ps(input, bias_term);
84}
85
86CPU_ATTR static inline vd add_bias(vd input, const double* bias_addr, Index bias_offset) {
87  auto bias_term = *reinterpret_cast<const vd*>(bias_addr + bias_offset);
88  return add_pd(input, bias_term);
89}
90
91/*
92 * ReLU
93 */
94template <typename Type>
95CPU_ATTR static inline vector_t<CPUType::CPU_NAME, Type> relu(vector_t<CPUType::CPU_NAME, Type> input);
96
97template <>
98CPU_ATTR inline vi relu<int8_t>(vi input) {
99  static const auto vconst_zero = set1_epi8<vi>(0);
100#if defined(KERNELS_THIS_IS_SSE2)
101  return and_si(input, _mm_cmplt_epi8(vconst_zero, input));
102#elif defined(KERNELS_THIS_IS_AVX2)
103  return _mm256_max_epi8(input, vconst_zero);
104#else
105  return _mm512_max_epi8(input, vconst_zero);
106#endif
107}
108
109template <>
110CPU_ATTR inline vi relu<int16_t>(vi input) {
111  static const auto vconst_zero = set1_epi16<vi>(0);
112  return max_epi16(input, vconst_zero);
113}
114
115template <>
116CPU_ATTR inline vi relu<int>(vi input) {
117  static const auto vconst_zero = set1_epi32<vi>(0);
118#if defined(KERNELS_THIS_IS_SSE2)
119  return and_si(input, _mm_cmplt_epi32(vconst_zero, input));
120#elif defined(KERNELS_THIS_IS_AVX2)
121  return _mm256_max_epi32(input, vconst_zero);
122#else
123  return _mm512_max_epi32(input, vconst_zero);
124#endif
125}
126
127template <>
128CPU_ATTR inline vf relu<float>(vf input) {
129  static const auto vconst_zero = setzero_ps<vf>();
130  return max_ps(input, vconst_zero);
131}
132
133template <>
134CPU_ATTR inline vd relu<double>(vd input) {
135  static const auto vconst_zero = setzero_pd<vd>();
136  return max_pd(input, vconst_zero);
137}
138
139/*
140 * Multiply (elemwise)
141 */
142template <typename Type>
143CPU_ATTR static inline vector_t<CPUType::CPU_NAME, Type> multiply(vector_t<CPUType::CPU_NAME, Type> a, vector_t<CPUType::CPU_NAME, Type> b);
144
145template <>
146CPU_ATTR inline vi multiply<int8_t>(vi a, vi b) {
147  auto even = mullo_epi16(a, b);
148  auto odd = mullo_epi16(srli_epi16<8>(a), srli_epi16<8>(b));
149  return or_si(slli_epi16<8>(odd), srli_epi16<8>(slli_epi16<8>(even)));
150}
151
152template <>
153CPU_ATTR inline vi multiply<int16_t>(vi a, vi b) {
154  return mullo_epi16(a, b);
155}
156
157template <>
158CPU_ATTR inline vi multiply<int>(vi a, vi b) {
159#if defined(KERNELS_THIS_IS_SSE2)
160  auto even = mul_epu32(a, b);
161  auto odd = mul_epu32(_mm_srli_si128(a, 4), _mm_srli_si128(b, 4));
162  return unpacklo_epi32(_mm_shuffle_epi32(even, 0x8 /* = 0 0 2 0 */), _mm_shuffle_epi32(odd, 0x8 /* = 0 0 2 0 */));
163#elif defined(KERNELS_THIS_IS_AVX2)
164  return _mm256_mullo_epi32(a, b);
165#else
166  return _mm512_mullo_epi32(a, b);
167#endif
168}
169
170template <>
171CPU_ATTR inline vf multiply<float>(vf a, vf b) {
172  return mul_ps(a, b);
173}
174
175template <>
176CPU_ATTR inline vd multiply<double>(vd a, vd b) {
177  return mul_pd(a, b);
178}
179
180/*
181 * Downcast
182 */
183CPU_ATTR static inline vi downcast32to8(vi input1, vi input2, vi input3, vi input4) {
184  auto result = packs_epi16(packs_epi32(input1, input2), packs_epi32(input3, input4));
185
186#if defined(KERNELS_THIS_IS_SSE2)
187  return result;
188#elif defined(KERNELS_THIS_IS_AVX2)
189  return _mm256_shuffle_epi32(_mm256_permute4x64_epi64(result, 0xd8 /* = 0 2 1 3 */), 0xd8 /* = 0 2 1 3 */);
190#else
191  static const auto permutation_indices = _mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
192  return _mm512_castps_si512(_mm512_permutexvar_ps(permutation_indices, _mm512_castsi512_ps(result)));
193#endif
194}
195
196CPU_ATTR static inline vi downcast32to16(vi input1, vi input2) {
197  auto result = packs_epi32(input1, input2);
198
199#if defined(KERNELS_THIS_IS_SSE2)
200  return result;
201#elif defined(KERNELS_THIS_IS_AVX2)
202  return _mm256_permute4x64_epi64(result, 0xd8 /* = 0 2 1 3 */);
203#else
204  static const auto permutation_indices = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
205  return _mm512_castpd_si512(_mm512_permutexvar_pd(permutation_indices, _mm512_castsi512_pd(result)));
206#endif
207}
208
209CPU_ATTR static inline vi downcast16to8(vi input1, vi input2) {
210  auto result = packs_epi16(input1, input2);
211
212#if defined(KERNELS_THIS_IS_SSE2)
213  return result;
214#elif defined(KERNELS_THIS_IS_AVX2)
215  return _mm256_permute4x64_epi64(result, 0xd8 /* = 0 2 1 3 */);
216#else
217  static const auto permutation_indices = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
218  return _mm512_castpd_si512(_mm512_permutexvar_pd(permutation_indices, _mm512_castsi512_pd(result)));
219#endif
220}
221
222/*
223 * Upcast
224 */
225CPU_ATTR static inline dvector_t<CPUType::CPU_NAME, int16_t> upcast8to16(vi input) {
226  static const auto vzero = set1_epi8<vi>(0);
227
228#if defined(KERNELS_THIS_IS_SSE2)
229  auto higher_byte = _mm_cmpgt_epi8(vzero, input);
230#elif defined(KERNELS_THIS_IS_AVX2)
231  input = _mm256_permute4x64_epi64(input, 0xd8 /* = 0 2 1 3 */);
232  auto higher_byte = _mm256_cmpgt_epi8(vzero, input);
233#else
234  static const auto vmax_negative = set1_epi8<vi>(-1 /* 0xff */);
235  static const auto permutation_indices = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0);
236
237  input = _mm512_castpd_si512(_mm512_permutexvar_pd(permutation_indices, _mm512_castsi512_pd(input)));
238  auto negatives = _mm512_cmp_epi8_mask(input, vzero, 1 /* _MM_CMPINT_LT */);
239  auto higher_byte = _mm512_mask_blend_epi8(negatives, vzero, vmax_negative);
240#endif
241
242  return {
243    unpacklo_epi8(input, higher_byte),
244    unpackhi_epi8(input, higher_byte),
245  };
246}
247
248CPU_ATTR static inline dvector_t<CPUType::CPU_NAME, int> upcast16to32(vi input) {
249  static const auto vzero = set1_epi16<vi>(0);
250
251#if defined(KERNELS_THIS_IS_SSE2)
252  auto higher_byte = _mm_cmpgt_epi16(vzero, input);
253#elif defined(KERNELS_THIS_IS_AVX2)
254  input = _mm256_permute4x64_epi64(input, 0xd8 /* = 0 2 1 3 */);
255  auto higher_byte = _mm256_cmpgt_epi16(vzero, input);
256#else
257  static const auto vmax_negative = set1_epi16<vi>(-1 /* 0xffff */);
258  static const auto permutation_indices = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0);
259
260  input = _mm512_castpd_si512(_mm512_permutexvar_pd(permutation_indices, _mm512_castsi512_pd(input)));
261  auto negatives = _mm512_cmp_epi16_mask(input, vzero, 1 /* _MM_CMPINT_LT */);
262  auto higher_byte = _mm512_mask_blend_epi16(negatives, vzero, vmax_negative);
263#endif
264
265  return {
266    unpacklo_epi16(input, higher_byte),
267    unpackhi_epi16(input, higher_byte),
268  };
269}
270
271CPU_ATTR static inline qvector_t<CPUType::CPU_NAME, int> upcast8to32(vi input) {
272  auto result16 = upcast8to16(input);
273  auto result32a = upcast16to32(result16.first);
274  auto result32b = upcast16to32(result16.second);
275
276  return {
277    result32a.first,
278    result32a.second,
279    result32b.first,
280    result32b.second,
281  };
282}
283
284/*
285 * Rescale int32
286 */
287CPU_ATTR static inline vi rescale(vi input, vf scale) {
288  return cvtps_epi32(mul_ps(cvtepi32_ps(input), scale));
289}
290
291/*
292 * Bitwise not
293 */
294CPU_ATTR static inline vi bitwise_not(vi v) {
295  return xor_si(v, set1_epi32<vi>(0xffffffff));
296}
297
298/*
299 * Floor
300 */
301CPU_ATTR static inline vf floor(vf input) {
302#if defined(KERNELS_THIS_IS_SSE2)
303  static const auto vconst_zero = setzero_ps<vf>();
304  static const auto vconst_one = set1_ps<vf>(1.f);
305
306  auto result = cvtepi32_ps(cvttps_epi32(input));
307  auto negatives = _mm_cmplt_ps(input, vconst_zero);
308  auto nonintegers = _mm_cmpneq_ps(input, result);
309
310  return sub_ps(result, and_ps(vconst_one, and_ps(negatives, nonintegers)));
311#elif defined(KERNELS_THIS_IS_AVX2)
312  return _mm256_floor_ps(input);
313#else
314  // TODO: It should work but compiler throw the error "incorrect rounding operand"
315  // return _mm512_roundscale_round_ps(input, 0, _MM_FROUND_FLOOR);
316
317  static const auto vconst_zero = setzero_ps<vf>();
318  static const auto vconst_one = set1_ps<vf>(1.f);
319
320  auto result = cvtepi32_ps(cvttps_epi32(input));
321  auto negatives = _mm512_cmp_ps_mask(input, vconst_zero, _CMP_LT_OQ);
322  auto nonintegers = _mm512_cmp_ps_mask(input, result, _CMP_NEQ_OQ);
323
324  return _mm512_mask_blend_ps(_mm512_kand(negatives, nonintegers), result, sub_ps(result, vconst_one));
325#endif
326}
327
328/*
329 * Calculate approximation of e^x using Taylor series and lookup table
330 */
331#if defined(KERNELS_THIS_IS_SSE2)
332CPU_ATTR static inline vf exp_approx_taylor(vf) {
333  std::abort();
334}
335#else
336CPU_ATTR static inline vf exp_approx_taylor(vf x) {
337  static constexpr int EXP_MIN = -20;
338  static constexpr int EXP_MAX = 20;
339  static constexpr float EXP_LOOKUP[EXP_MAX - EXP_MIN + 1] = {
340    expif(-20), expif(-19), expif(-18), expif(-17), expif(-16), expif(-15),
341    expif(-14), expif(-13), expif(-12), expif(-11), expif(-10), expif(-9),
342    expif(-8), expif(-7), expif(-6), expif(-5), expif(-4), expif(-3), expif(-2),
343    expif(-1), expif(0), expif(1), expif(2), expif(3), expif(4), expif(5),
344    expif(6), expif(7), expif(8), expif(9), expif(10), expif(11), expif(12),
345    expif(13), expif(14), expif(15), expif(16), expif(17), expif(18), expif(19),
346    expif(20),
347  };
348
349  static const vf dividers[] = {
350    set1_ps<vf>(1.f / factorial(7)),
351    set1_ps<vf>(1.f / factorial(6)),
352    set1_ps<vf>(1.f / factorial(5)),
353    set1_ps<vf>(1.f / factorial(4)),
354    set1_ps<vf>(1.f / factorial(3)),
355    set1_ps<vf>(1.f / factorial(2)),
356    set1_ps<vf>(1.f / factorial(1)),
357  };
358  static const auto const_one = set1_ps<vf>(1.f);
359  static const auto const_min_x = set1_ps<vf>(EXP_MIN);
360  static const auto const_max_x = set1_ps<vf>(EXP_MAX);
361
362  x = max_ps(x, const_min_x);
363  x = min_ps(x, const_max_x);
364
365  auto a = floor(x);
366  auto xa = sub_ps(x, a);
367
368  auto result = mul_ps(dividers[0], xa);
369
370  result = add_ps(result, dividers[1]);
371  result = mul_ps(result, xa);
372  result = add_ps(result, dividers[2]);
373  result = mul_ps(result, xa);
374  result = add_ps(result, dividers[3]);
375  result = mul_ps(result, xa);
376  result = add_ps(result, dividers[4]);
377  result = mul_ps(result, xa);
378  result = add_ps(result, dividers[5]);
379  result = mul_ps(result, xa);
380  result = add_ps(result, dividers[6]);
381  result = mul_ps(result, xa);
382
383  result = add_ps(result, const_one);
384
385  auto ea = i32gather_ps<4>(EXP_LOOKUP + EXP_MAX, cvtps_epi32(a));
386  return mul_ps(ea, result);
387}
388#endif
389
390/*
391 * Sigmoid
392 */
393CPU_ATTR static inline vf sigmoid(vf
394#ifndef KERNELS_THIS_IS_SSE2
395    input
396#endif
397    ) {
398#if defined(KERNELS_THIS_IS_SSE2)
399  std::abort(); // TODO: missing exp_approx_taylor for SSE2
400#elif defined(KERNELS_THIS_IS_AVX2)
401  static const auto vconst_zero = setzero_ps<vf>();
402  static const auto vconst_one = set1_ps<vf>(1.f);
403
404  auto x = input;
405  auto minus_x = sub_ps(vconst_zero, x);
406  auto e_x = exp_approx_taylor(x);
407  auto e_minus_x = exp_approx_taylor(minus_x);
408
409  auto sigmoid_case1 = _mm256_rcp_ps(add_ps(vconst_one, e_minus_x));
410  auto sigmoid_case2 = mul_ps(e_x, _mm256_rcp_ps(add_ps(vconst_one, e_x)));
411
412  auto nonnegative_x_mask = _mm256_cmp_ps(vconst_zero, x, _CMP_LT_OS);
413  return _mm256_blendv_ps(sigmoid_case1, sigmoid_case2, nonnegative_x_mask);
414#else
415  static const auto vconst_zero = setzero_ps<vf>();
416  static const auto vconst_one = set1_ps<vf>(1.f);
417
418  auto x = input;
419  auto minus_x = sub_ps(vconst_zero, x);
420  auto e_x = exp_approx_taylor(x);
421  auto e_minus_x = exp_approx_taylor(minus_x);
422
423  auto sigmoid_case1 = _mm512_rcp14_ps(add_ps(vconst_one, e_minus_x));
424  auto sigmoid_case2 = mul_ps(e_x, _mm512_rcp14_ps(add_ps(vconst_one, e_x)));
425
426  auto nonnegative_x_mask = _mm512_cmp_ps_mask(vconst_zero, x, _CMP_LT_OS);
427  return _mm512_mask_blend_ps(nonnegative_x_mask, sigmoid_case1, sigmoid_case2);
428#endif
429}
430
431/*
432 * Tanh
433 */
434#if defined(KERNELS_THIS_IS_SSE2)
435CPU_ATTR static inline vf tanh(vf) {
436  std::abort(); // TODO: missing exp_approx_taylor for SSE2
437}
438#else
439CPU_ATTR static inline vf tanh(vf input) {
440  const static auto vconst_zero = setzero_ps<vf>();
441
442  auto e_x = exp_approx_taylor(input);
443  auto e_minus_x = exp_approx_taylor(sub_ps(vconst_zero, input));
444
445  return div_ps(sub_ps(e_x, e_minus_x), add_ps(e_x, e_minus_x));
446}
447#endif
448
449}
450}
451
452#undef CPU_NAME
453#undef CPU_ATTR
454#undef vi
455#undef vf
456#undef vd
457