1/* This file is included multiple times, once for each backend instruction set. */ 2 3#if defined(KERNELS_THIS_IS_SSE2) 4 #define CPU_NAME SSE2 5 #define CPU_ATTR INTGEMM_SSE2 6#elif defined(KERNELS_THIS_IS_AVX2) 7 #define CPU_NAME AVX2 8 #define CPU_ATTR INTGEMM_AVX2 9#elif defined(KERNELS_THIS_IS_AVX512BW) 10 #define CPU_NAME AVX512BW 11 #define CPU_ATTR INTGEMM_AVX512BW 12#else 13 #error "Only SSE2, AVX2 and AVX512BW are supported" 14#endif 15 16#define vi vector_t<CPUType::CPU_NAME, int> 17#define vf vector_t<CPUType::CPU_NAME, float> 18#define vd vector_t<CPUType::CPU_NAME, double> 19 20/* 21 * Kernels implementations.... 22 */ 23namespace intgemm { 24namespace kernels { 25 26/* 27 * Write 28 */ 29CPU_ATTR static inline void write(vi input, int8_t* output, Index offset) { 30 *reinterpret_cast<vi*>(output + offset) = input; 31} 32 33CPU_ATTR static inline void write(vi input, int16_t* output, Index offset) { 34 *reinterpret_cast<vi*>(output + offset) = input; 35} 36 37CPU_ATTR static inline void write(vi input, int* output, Index offset) { 38 *reinterpret_cast<vi*>(output + offset) = input; 39} 40 41CPU_ATTR static inline void write(vf input, float* output, Index offset) { 42 *reinterpret_cast<vf*>(output + offset) = input; 43} 44 45CPU_ATTR static inline void write(vd input, double* output, Index offset) { 46 *reinterpret_cast<vd*>(output + offset) = input; 47} 48 49/* 50 * Quantize 51 */ 52CPU_ATTR static inline vi quantize(vf input, vf quant_mult) { 53 return cvtps_epi32(mul_ps(input, quant_mult)); 54} 55 56/* 57 * Unquantize 58 */ 59CPU_ATTR static inline vf unquantize(vi input, vf unquant_mult) { 60 return mul_ps(cvtepi32_ps(input), unquant_mult); 61} 62 63/* 64 * Add a bias term 65 */ 66CPU_ATTR static inline vi add_bias(vi input, const int8_t* bias_addr, Index bias_offset) { 67 auto bias_term = *reinterpret_cast<const vi*>(bias_addr + bias_offset); 68 return add_epi8(input, bias_term); 69} 70 71CPU_ATTR static inline vi add_bias(vi input, const int16_t* bias_addr, Index bias_offset) { 72 auto bias_term = *reinterpret_cast<const vi*>(bias_addr + bias_offset); 73 return add_epi16(input, bias_term); 74} 75 76CPU_ATTR static inline vi add_bias(vi input, const int* bias_addr, Index bias_offset) { 77 auto bias_term = *reinterpret_cast<const vi*>(bias_addr + bias_offset); 78 return add_epi32(input, bias_term); 79} 80 81CPU_ATTR static inline vf add_bias(vf input, const float* bias_addr, Index bias_offset) { 82 auto bias_term = *reinterpret_cast<const vf*>(bias_addr + bias_offset); 83 return add_ps(input, bias_term); 84} 85 86CPU_ATTR static inline vd add_bias(vd input, const double* bias_addr, Index bias_offset) { 87 auto bias_term = *reinterpret_cast<const vd*>(bias_addr + bias_offset); 88 return add_pd(input, bias_term); 89} 90 91/* 92 * ReLU 93 */ 94template <typename Type> 95CPU_ATTR static inline vector_t<CPUType::CPU_NAME, Type> relu(vector_t<CPUType::CPU_NAME, Type> input); 96 97template <> 98CPU_ATTR inline vi relu<int8_t>(vi input) { 99 static const auto vconst_zero = set1_epi8<vi>(0); 100#if defined(KERNELS_THIS_IS_SSE2) 101 return and_si(input, _mm_cmplt_epi8(vconst_zero, input)); 102#elif defined(KERNELS_THIS_IS_AVX2) 103 return _mm256_max_epi8(input, vconst_zero); 104#else 105 return _mm512_max_epi8(input, vconst_zero); 106#endif 107} 108 109template <> 110CPU_ATTR inline vi relu<int16_t>(vi input) { 111 static const auto vconst_zero = set1_epi16<vi>(0); 112 return max_epi16(input, vconst_zero); 113} 114 115template <> 116CPU_ATTR inline vi relu<int>(vi input) { 117 static const auto vconst_zero = set1_epi32<vi>(0); 118#if defined(KERNELS_THIS_IS_SSE2) 119 return and_si(input, _mm_cmplt_epi32(vconst_zero, input)); 120#elif defined(KERNELS_THIS_IS_AVX2) 121 return _mm256_max_epi32(input, vconst_zero); 122#else 123 return _mm512_max_epi32(input, vconst_zero); 124#endif 125} 126 127template <> 128CPU_ATTR inline vf relu<float>(vf input) { 129 static const auto vconst_zero = setzero_ps<vf>(); 130 return max_ps(input, vconst_zero); 131} 132 133template <> 134CPU_ATTR inline vd relu<double>(vd input) { 135 static const auto vconst_zero = setzero_pd<vd>(); 136 return max_pd(input, vconst_zero); 137} 138 139/* 140 * Multiply (elemwise) 141 */ 142template <typename Type> 143CPU_ATTR static inline vector_t<CPUType::CPU_NAME, Type> multiply(vector_t<CPUType::CPU_NAME, Type> a, vector_t<CPUType::CPU_NAME, Type> b); 144 145template <> 146CPU_ATTR inline vi multiply<int8_t>(vi a, vi b) { 147 auto even = mullo_epi16(a, b); 148 auto odd = mullo_epi16(srli_epi16<8>(a), srli_epi16<8>(b)); 149 return or_si(slli_epi16<8>(odd), srli_epi16<8>(slli_epi16<8>(even))); 150} 151 152template <> 153CPU_ATTR inline vi multiply<int16_t>(vi a, vi b) { 154 return mullo_epi16(a, b); 155} 156 157template <> 158CPU_ATTR inline vi multiply<int>(vi a, vi b) { 159#if defined(KERNELS_THIS_IS_SSE2) 160 auto even = mul_epu32(a, b); 161 auto odd = mul_epu32(_mm_srli_si128(a, 4), _mm_srli_si128(b, 4)); 162 return unpacklo_epi32(_mm_shuffle_epi32(even, 0x8 /* = 0 0 2 0 */), _mm_shuffle_epi32(odd, 0x8 /* = 0 0 2 0 */)); 163#elif defined(KERNELS_THIS_IS_AVX2) 164 return _mm256_mullo_epi32(a, b); 165#else 166 return _mm512_mullo_epi32(a, b); 167#endif 168} 169 170template <> 171CPU_ATTR inline vf multiply<float>(vf a, vf b) { 172 return mul_ps(a, b); 173} 174 175template <> 176CPU_ATTR inline vd multiply<double>(vd a, vd b) { 177 return mul_pd(a, b); 178} 179 180/* 181 * Downcast 182 */ 183CPU_ATTR static inline vi downcast32to8(vi input1, vi input2, vi input3, vi input4) { 184 auto result = packs_epi16(packs_epi32(input1, input2), packs_epi32(input3, input4)); 185 186#if defined(KERNELS_THIS_IS_SSE2) 187 return result; 188#elif defined(KERNELS_THIS_IS_AVX2) 189 return _mm256_shuffle_epi32(_mm256_permute4x64_epi64(result, 0xd8 /* = 0 2 1 3 */), 0xd8 /* = 0 2 1 3 */); 190#else 191 static const auto permutation_indices = _mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); 192 return _mm512_castps_si512(_mm512_permutexvar_ps(permutation_indices, _mm512_castsi512_ps(result))); 193#endif 194} 195 196CPU_ATTR static inline vi downcast32to16(vi input1, vi input2) { 197 auto result = packs_epi32(input1, input2); 198 199#if defined(KERNELS_THIS_IS_SSE2) 200 return result; 201#elif defined(KERNELS_THIS_IS_AVX2) 202 return _mm256_permute4x64_epi64(result, 0xd8 /* = 0 2 1 3 */); 203#else 204 static const auto permutation_indices = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); 205 return _mm512_castpd_si512(_mm512_permutexvar_pd(permutation_indices, _mm512_castsi512_pd(result))); 206#endif 207} 208 209CPU_ATTR static inline vi downcast16to8(vi input1, vi input2) { 210 auto result = packs_epi16(input1, input2); 211 212#if defined(KERNELS_THIS_IS_SSE2) 213 return result; 214#elif defined(KERNELS_THIS_IS_AVX2) 215 return _mm256_permute4x64_epi64(result, 0xd8 /* = 0 2 1 3 */); 216#else 217 static const auto permutation_indices = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); 218 return _mm512_castpd_si512(_mm512_permutexvar_pd(permutation_indices, _mm512_castsi512_pd(result))); 219#endif 220} 221 222/* 223 * Upcast 224 */ 225CPU_ATTR static inline dvector_t<CPUType::CPU_NAME, int16_t> upcast8to16(vi input) { 226 static const auto vzero = set1_epi8<vi>(0); 227 228#if defined(KERNELS_THIS_IS_SSE2) 229 auto higher_byte = _mm_cmpgt_epi8(vzero, input); 230#elif defined(KERNELS_THIS_IS_AVX2) 231 input = _mm256_permute4x64_epi64(input, 0xd8 /* = 0 2 1 3 */); 232 auto higher_byte = _mm256_cmpgt_epi8(vzero, input); 233#else 234 static const auto vmax_negative = set1_epi8<vi>(-1 /* 0xff */); 235 static const auto permutation_indices = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0); 236 237 input = _mm512_castpd_si512(_mm512_permutexvar_pd(permutation_indices, _mm512_castsi512_pd(input))); 238 auto negatives = _mm512_cmp_epi8_mask(input, vzero, 1 /* _MM_CMPINT_LT */); 239 auto higher_byte = _mm512_mask_blend_epi8(negatives, vzero, vmax_negative); 240#endif 241 242 return { 243 unpacklo_epi8(input, higher_byte), 244 unpackhi_epi8(input, higher_byte), 245 }; 246} 247 248CPU_ATTR static inline dvector_t<CPUType::CPU_NAME, int> upcast16to32(vi input) { 249 static const auto vzero = set1_epi16<vi>(0); 250 251#if defined(KERNELS_THIS_IS_SSE2) 252 auto higher_byte = _mm_cmpgt_epi16(vzero, input); 253#elif defined(KERNELS_THIS_IS_AVX2) 254 input = _mm256_permute4x64_epi64(input, 0xd8 /* = 0 2 1 3 */); 255 auto higher_byte = _mm256_cmpgt_epi16(vzero, input); 256#else 257 static const auto vmax_negative = set1_epi16<vi>(-1 /* 0xffff */); 258 static const auto permutation_indices = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0); 259 260 input = _mm512_castpd_si512(_mm512_permutexvar_pd(permutation_indices, _mm512_castsi512_pd(input))); 261 auto negatives = _mm512_cmp_epi16_mask(input, vzero, 1 /* _MM_CMPINT_LT */); 262 auto higher_byte = _mm512_mask_blend_epi16(negatives, vzero, vmax_negative); 263#endif 264 265 return { 266 unpacklo_epi16(input, higher_byte), 267 unpackhi_epi16(input, higher_byte), 268 }; 269} 270 271CPU_ATTR static inline qvector_t<CPUType::CPU_NAME, int> upcast8to32(vi input) { 272 auto result16 = upcast8to16(input); 273 auto result32a = upcast16to32(result16.first); 274 auto result32b = upcast16to32(result16.second); 275 276 return { 277 result32a.first, 278 result32a.second, 279 result32b.first, 280 result32b.second, 281 }; 282} 283 284/* 285 * Rescale int32 286 */ 287CPU_ATTR static inline vi rescale(vi input, vf scale) { 288 return cvtps_epi32(mul_ps(cvtepi32_ps(input), scale)); 289} 290 291/* 292 * Bitwise not 293 */ 294CPU_ATTR static inline vi bitwise_not(vi v) { 295 return xor_si(v, set1_epi32<vi>(0xffffffff)); 296} 297 298/* 299 * Floor 300 */ 301CPU_ATTR static inline vf floor(vf input) { 302#if defined(KERNELS_THIS_IS_SSE2) 303 static const auto vconst_zero = setzero_ps<vf>(); 304 static const auto vconst_one = set1_ps<vf>(1.f); 305 306 auto result = cvtepi32_ps(cvttps_epi32(input)); 307 auto negatives = _mm_cmplt_ps(input, vconst_zero); 308 auto nonintegers = _mm_cmpneq_ps(input, result); 309 310 return sub_ps(result, and_ps(vconst_one, and_ps(negatives, nonintegers))); 311#elif defined(KERNELS_THIS_IS_AVX2) 312 return _mm256_floor_ps(input); 313#else 314 // TODO: It should work but compiler throw the error "incorrect rounding operand" 315 // return _mm512_roundscale_round_ps(input, 0, _MM_FROUND_FLOOR); 316 317 static const auto vconst_zero = setzero_ps<vf>(); 318 static const auto vconst_one = set1_ps<vf>(1.f); 319 320 auto result = cvtepi32_ps(cvttps_epi32(input)); 321 auto negatives = _mm512_cmp_ps_mask(input, vconst_zero, _CMP_LT_OQ); 322 auto nonintegers = _mm512_cmp_ps_mask(input, result, _CMP_NEQ_OQ); 323 324 return _mm512_mask_blend_ps(_mm512_kand(negatives, nonintegers), result, sub_ps(result, vconst_one)); 325#endif 326} 327 328/* 329 * Calculate approximation of e^x using Taylor series and lookup table 330 */ 331#if defined(KERNELS_THIS_IS_SSE2) 332CPU_ATTR static inline vf exp_approx_taylor(vf) { 333 std::abort(); 334} 335#else 336CPU_ATTR static inline vf exp_approx_taylor(vf x) { 337 static constexpr int EXP_MIN = -20; 338 static constexpr int EXP_MAX = 20; 339 static constexpr float EXP_LOOKUP[EXP_MAX - EXP_MIN + 1] = { 340 expif(-20), expif(-19), expif(-18), expif(-17), expif(-16), expif(-15), 341 expif(-14), expif(-13), expif(-12), expif(-11), expif(-10), expif(-9), 342 expif(-8), expif(-7), expif(-6), expif(-5), expif(-4), expif(-3), expif(-2), 343 expif(-1), expif(0), expif(1), expif(2), expif(3), expif(4), expif(5), 344 expif(6), expif(7), expif(8), expif(9), expif(10), expif(11), expif(12), 345 expif(13), expif(14), expif(15), expif(16), expif(17), expif(18), expif(19), 346 expif(20), 347 }; 348 349 static const vf dividers[] = { 350 set1_ps<vf>(1.f / factorial(7)), 351 set1_ps<vf>(1.f / factorial(6)), 352 set1_ps<vf>(1.f / factorial(5)), 353 set1_ps<vf>(1.f / factorial(4)), 354 set1_ps<vf>(1.f / factorial(3)), 355 set1_ps<vf>(1.f / factorial(2)), 356 set1_ps<vf>(1.f / factorial(1)), 357 }; 358 static const auto const_one = set1_ps<vf>(1.f); 359 static const auto const_min_x = set1_ps<vf>(EXP_MIN); 360 static const auto const_max_x = set1_ps<vf>(EXP_MAX); 361 362 x = max_ps(x, const_min_x); 363 x = min_ps(x, const_max_x); 364 365 auto a = floor(x); 366 auto xa = sub_ps(x, a); 367 368 auto result = mul_ps(dividers[0], xa); 369 370 result = add_ps(result, dividers[1]); 371 result = mul_ps(result, xa); 372 result = add_ps(result, dividers[2]); 373 result = mul_ps(result, xa); 374 result = add_ps(result, dividers[3]); 375 result = mul_ps(result, xa); 376 result = add_ps(result, dividers[4]); 377 result = mul_ps(result, xa); 378 result = add_ps(result, dividers[5]); 379 result = mul_ps(result, xa); 380 result = add_ps(result, dividers[6]); 381 result = mul_ps(result, xa); 382 383 result = add_ps(result, const_one); 384 385 auto ea = i32gather_ps<4>(EXP_LOOKUP + EXP_MAX, cvtps_epi32(a)); 386 return mul_ps(ea, result); 387} 388#endif 389 390/* 391 * Sigmoid 392 */ 393CPU_ATTR static inline vf sigmoid(vf 394#ifndef KERNELS_THIS_IS_SSE2 395 input 396#endif 397 ) { 398#if defined(KERNELS_THIS_IS_SSE2) 399 std::abort(); // TODO: missing exp_approx_taylor for SSE2 400#elif defined(KERNELS_THIS_IS_AVX2) 401 static const auto vconst_zero = setzero_ps<vf>(); 402 static const auto vconst_one = set1_ps<vf>(1.f); 403 404 auto x = input; 405 auto minus_x = sub_ps(vconst_zero, x); 406 auto e_x = exp_approx_taylor(x); 407 auto e_minus_x = exp_approx_taylor(minus_x); 408 409 auto sigmoid_case1 = _mm256_rcp_ps(add_ps(vconst_one, e_minus_x)); 410 auto sigmoid_case2 = mul_ps(e_x, _mm256_rcp_ps(add_ps(vconst_one, e_x))); 411 412 auto nonnegative_x_mask = _mm256_cmp_ps(vconst_zero, x, _CMP_LT_OS); 413 return _mm256_blendv_ps(sigmoid_case1, sigmoid_case2, nonnegative_x_mask); 414#else 415 static const auto vconst_zero = setzero_ps<vf>(); 416 static const auto vconst_one = set1_ps<vf>(1.f); 417 418 auto x = input; 419 auto minus_x = sub_ps(vconst_zero, x); 420 auto e_x = exp_approx_taylor(x); 421 auto e_minus_x = exp_approx_taylor(minus_x); 422 423 auto sigmoid_case1 = _mm512_rcp14_ps(add_ps(vconst_one, e_minus_x)); 424 auto sigmoid_case2 = mul_ps(e_x, _mm512_rcp14_ps(add_ps(vconst_one, e_x))); 425 426 auto nonnegative_x_mask = _mm512_cmp_ps_mask(vconst_zero, x, _CMP_LT_OS); 427 return _mm512_mask_blend_ps(nonnegative_x_mask, sigmoid_case1, sigmoid_case2); 428#endif 429} 430 431/* 432 * Tanh 433 */ 434#if defined(KERNELS_THIS_IS_SSE2) 435CPU_ATTR static inline vf tanh(vf) { 436 std::abort(); // TODO: missing exp_approx_taylor for SSE2 437} 438#else 439CPU_ATTR static inline vf tanh(vf input) { 440 const static auto vconst_zero = setzero_ps<vf>(); 441 442 auto e_x = exp_approx_taylor(input); 443 auto e_minus_x = exp_approx_taylor(sub_ps(vconst_zero, input)); 444 445 return div_ps(sub_ps(e_x, e_minus_x), add_ps(e_x, e_minus_x)); 446} 447#endif 448 449} 450} 451 452#undef CPU_NAME 453#undef CPU_ATTR 454#undef vi 455#undef vf 456#undef vd 457