1 #pragma once
2
3 #include <cmath>
4 #include "intrinsics.h"
5
6 #ifdef _OPENMP
7 #include <omp.h>
8 #endif
9
10 namespace intgemm {
11
12 /* Horizontal max and sums. TODO make a template argument? */
13
MaxFloat32(__m128 a)14 INTGEMM_SSE2 static inline float MaxFloat32(__m128 a) {
15 // Fold to just using the first 64 bits.
16 __m128 second_half = _mm_shuffle_ps(a, a, 3 * 4 + 2);
17 a = _mm_max_ps(a, second_half);
18 // Fold to just using the first 32 bits.
19 second_half = _mm_shuffle_ps(a, a, 1);
20 a = _mm_max_ps(a, second_half);
21 // This casting compiles to nothing.
22 return *reinterpret_cast<float*>(&a);
23 }
AddFloat32(__m128 a)24 INTGEMM_SSE2 static inline float AddFloat32(__m128 a) {
25 // Fold to just using the first 64 bits.
26 __m128 second_half = _mm_shuffle_ps(a, a, 3 * 4 + 2);
27 a = _mm_add_ps(a, second_half);
28 // Fold to just using the first 32 bits.
29 second_half = _mm_shuffle_ps(a, a, 1);
30 a = _mm_add_ps(a, second_half);
31 // This casting compiles to nothing.
32 return *reinterpret_cast<float*>(&a);
33 }
34
35 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
MaxFloat32(__m256 a)36 INTGEMM_AVX2 static inline float MaxFloat32(__m256 a) {
37 return MaxFloat32(max_ps(_mm256_castps256_ps128(a), _mm256_extractf128_ps(a, 1)));
38 }
AddFloat32(__m256 a)39 INTGEMM_AVX2 static inline float AddFloat32(__m256 a) {
40 return AddFloat32(add_ps(_mm256_castps256_ps128(a), _mm256_extractf128_ps(a, 1)));
41 }
42 #endif
43
44 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
45 // Find the maximum float.
MaxFloat32(__m512 a)46 INTGEMM_AVX512F static inline float MaxFloat32(__m512 a) {
47 // _mm512_extractf32x8_ps is AVX512DQ but we don't care about masking.
48 // So cast to pd, do AVX512F _mm512_extractf64x4_pd, then cast to ps.
49 __m256 upper = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a), 1));
50 return MaxFloat32(max_ps(_mm512_castps512_ps256(a), upper));
51 }
AddFloat32(__m512 a)52 INTGEMM_AVX512F static inline float AddFloat32(__m512 a) {
53 __m256 upper = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a), 1));
54 return AddFloat32(add_ps(_mm512_castps512_ps256(a), upper));
55 }
56 #endif
57
58 constexpr int32_t kFloatAbsoluteMask = 0x7fffffff;
59
60 } // namespace intgemm
61
62 #define INTGEMM_THIS_IS_SSE2
63 #include "stats.inl"
64 #undef INTGEMM_THIS_IS_SSE2
65
66 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
67 #define INTGEMM_THIS_IS_AVX2
68 #include "stats.inl"
69 #undef INTGEMM_THIS_IS_AVX2
70 #endif
71
72 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
73 #define INTGEMM_THIS_IS_AVX512DQ
74 #include "stats.inl"
75 #undef INTGEMM_THIS_IS_AVX512DQ
76 #endif
77