1 #pragma once
2 
3 #include <cmath>
4 #include "intrinsics.h"
5 
6 #ifdef _OPENMP
7 #include <omp.h>
8 #endif
9 
10 namespace intgemm {
11 
12 /* Horizontal max and sums.  TODO make a template argument? */
13 
MaxFloat32(__m128 a)14 INTGEMM_SSE2 static inline float MaxFloat32(__m128 a) {
15   // Fold to just using the first 64 bits.
16   __m128 second_half = _mm_shuffle_ps(a, a, 3 * 4 + 2);
17   a = _mm_max_ps(a, second_half);
18   // Fold to just using the first 32 bits.
19   second_half = _mm_shuffle_ps(a, a, 1);
20   a = _mm_max_ps(a, second_half);
21   // This casting compiles to nothing.
22   return *reinterpret_cast<float*>(&a);
23 }
AddFloat32(__m128 a)24 INTGEMM_SSE2 static inline float AddFloat32(__m128 a) {
25   // Fold to just using the first 64 bits.
26   __m128 second_half = _mm_shuffle_ps(a, a, 3 * 4 + 2);
27   a = _mm_add_ps(a, second_half);
28   // Fold to just using the first 32 bits.
29   second_half = _mm_shuffle_ps(a, a, 1);
30   a = _mm_add_ps(a, second_half);
31   // This casting compiles to nothing.
32   return *reinterpret_cast<float*>(&a);
33 }
34 
35 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
MaxFloat32(__m256 a)36 INTGEMM_AVX2 static inline float MaxFloat32(__m256 a) {
37   return MaxFloat32(max_ps(_mm256_castps256_ps128(a), _mm256_extractf128_ps(a, 1)));
38 }
AddFloat32(__m256 a)39 INTGEMM_AVX2 static inline float AddFloat32(__m256 a) {
40   return AddFloat32(add_ps(_mm256_castps256_ps128(a), _mm256_extractf128_ps(a, 1)));
41 }
42 #endif
43 
44 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
45 // Find the maximum float.
MaxFloat32(__m512 a)46 INTGEMM_AVX512F static inline float MaxFloat32(__m512 a) {
47   // _mm512_extractf32x8_ps is AVX512DQ but we don't care about masking.
48   // So cast to pd, do AVX512F _mm512_extractf64x4_pd, then cast to ps.
49   __m256 upper = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a), 1));
50   return MaxFloat32(max_ps(_mm512_castps512_ps256(a), upper));
51 }
AddFloat32(__m512 a)52 INTGEMM_AVX512F static inline float AddFloat32(__m512 a) {
53   __m256 upper = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a), 1));
54   return AddFloat32(add_ps(_mm512_castps512_ps256(a), upper));
55 }
56 #endif
57 
58 constexpr int32_t kFloatAbsoluteMask = 0x7fffffff;
59 
60 } // namespace intgemm
61 
62 #define INTGEMM_THIS_IS_SSE2
63 #include "stats.inl"
64 #undef INTGEMM_THIS_IS_SSE2
65 
66 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
67 #define INTGEMM_THIS_IS_AVX2
68 #include "stats.inl"
69 #undef INTGEMM_THIS_IS_AVX2
70 #endif
71 
72 #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
73 #define INTGEMM_THIS_IS_AVX512DQ
74 #include "stats.inl"
75 #undef INTGEMM_THIS_IS_AVX512DQ
76 #endif
77