1/* This file is included multiple times, once per architecture. */
2#if defined(INTGEMM_THIS_IS_AVX512DQ)
3#define INTGEMM_ARCH AVX512BW
4#define INTGEMM_TARGET INTGEMM_AVX512DQ
5#elif defined(INTGEMM_THIS_IS_AVX2)
6#define INTGEMM_ARCH AVX2
7#define INTGEMM_TARGET INTGEMM_AVX2
8#elif defined(INTGEMM_THIS_IS_SSE2)
9#define INTGEMM_ARCH SSE2
10#define INTGEMM_TARGET INTGEMM_SSE2
11#else
12#error Included with unexpected architecture
13#endif
14
15namespace intgemm {
16namespace INTGEMM_ARCH {
17
18/* Compute the maximum absolute value over floats aligned to register size.
19 * Do not call this function directly; it's a subroutine of MaxAbsolute.
20 */
21INTGEMM_TARGET static inline float MaxAbsoluteThread(const FRegister *begin, const FRegister *end) {
22  FRegister highest = setzero_ps<FRegister>();
23  const FRegister abs_mask = cast_ps(set1_epi32<Register>(kFloatAbsoluteMask));
24#pragma omp for
25  for (const FRegister *i = begin; i < end; ++i) {
26    FRegister reg = and_ps(abs_mask, *i);
27    highest = max_ps(highest, reg);
28  }
29  return MaxFloat32(highest);
30}
31
32/* Compute the maximum absolute value of an array of floats.
33 * begin_float must be aligned to a multiple of the register size.
34*/
35INTGEMM_TARGET static inline float MaxAbsolute(const float *begin_float, const float *end_float) {
36  assert(reinterpret_cast<uintptr_t>(begin_float) % sizeof(FRegister) == 0);
37  const float *end_reg = end_float - (reinterpret_cast<uintptr_t>(end_float) % sizeof(FRegister)) / sizeof(float);
38  float ret = 0.0;
39#pragma omp parallel reduction(max:ret) num_threads(std::max<int>(1, std::min<int>(omp_get_max_threads(), (end_float - begin_float) / 16384)))
40  {
41    float shard_max = MaxAbsoluteThread(
42        reinterpret_cast<const FRegister*>(begin_float),
43        reinterpret_cast<const FRegister*>(end_reg));
44    ret = std::max(ret, shard_max);
45  }
46  /* Overhang. The beginning was aligned so if there's any overhang we're
47   * allowed to read the next full register.  Then mask that to 0. */
48#if defined(INTGEMM_THIS_IS_AVX512DQ)
49  if (end_float != end_reg) {
50    const FRegister abs_mask = cast_ps(set1_epi32<Register>(kFloatAbsoluteMask));
51    __mmask16 mask = (1 << (end_float - end_reg)) - 1;
52    FRegister masked = _mm512_maskz_and_ps(mask, abs_mask, *reinterpret_cast<const FRegister*>(end_reg));
53    ret = std::max(ret, MaxFloat32(masked));
54  }
55#else
56  for (const float *i = end_reg; i < end_float; ++i) {
57    ret = std::max(ret, std::fabs(*i));
58  }
59#endif
60  return ret;
61}
62
63/* Computes the euclidean norm and returns the mean and the standard deviation. Optionally it can be the mean and standard deviation in absolute terms. */
64INTGEMM_TARGET static inline MeanStd VectorMeanStd(const float *begin_float, const float *end_float, bool absolute) {
65  assert(end_float > begin_float);
66  assert((end_float - begin_float) % (sizeof(FRegister) / sizeof(float)) == 0);
67  size_t num_items = end_float - begin_float;
68  const FRegister *begin = reinterpret_cast<const FRegister*>(begin_float);
69  const FRegister *end = reinterpret_cast<const FRegister*>(end_float);
70  FRegister squares = set1_ps<FRegister>(0);
71  FRegister sums = set1_ps<FRegister>(0);
72  if (absolute) {
73    const FRegister abs_mask = cast_ps(set1_epi32<Register>(kFloatAbsoluteMask));
74    for (; begin != end; begin++) {
75      FRegister vec = and_ps(abs_mask, *begin);
76      squares = add_ps(squares, mul_ps(vec, vec));
77      sums = add_ps(sums, vec);
78    }
79  } else {
80    for (; begin != end; begin++) {
81      FRegister vec = *begin;
82      squares = add_ps(squares, mul_ps(vec, vec));
83      sums = add_ps(sums, vec);
84    }
85  }
86  float squares_sum = AddFloat32(squares);
87  float normal_sums = AddFloat32(sums);
88  MeanStd ret;
89  ret.mean = normal_sums/num_items;
90  ret.stddev = std::sqrt((squares_sum/num_items) - (ret.mean*ret.mean));
91  return ret;
92}
93
94} // namespace INTGEMM_ARCH
95} // namespace intgemm
96
97#undef INTGEMM_ARCH
98#undef INTGEMM_TARGET
99