1 /*
2  *  Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/vector_math.h"
12 
13 #include <immintrin.h>
14 #include <math.h>
15 
16 #include "api/array_view.h"
17 #include "rtc_base/checks.h"
18 
19 namespace webrtc {
20 namespace aec3 {
21 
22 // Elementwise square root.
SqrtAVX2(rtc::ArrayView<float> x)23 void VectorMath::SqrtAVX2(rtc::ArrayView<float> x) {
24   const int x_size = static_cast<int>(x.size());
25   const int vector_limit = x_size >> 3;
26 
27   int j = 0;
28   for (; j < vector_limit * 8; j += 8) {
29     __m256 g = _mm256_loadu_ps(&x[j]);
30     g = _mm256_sqrt_ps(g);
31     _mm256_storeu_ps(&x[j], g);
32   }
33 
34   for (; j < x_size; ++j) {
35     x[j] = sqrtf(x[j]);
36   }
37 }
38 
39 // Elementwise vector multiplication z = x * y.
MultiplyAVX2(rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> z)40 void VectorMath::MultiplyAVX2(rtc::ArrayView<const float> x,
41                               rtc::ArrayView<const float> y,
42                               rtc::ArrayView<float> z) {
43   RTC_DCHECK_EQ(z.size(), x.size());
44   RTC_DCHECK_EQ(z.size(), y.size());
45   const int x_size = static_cast<int>(x.size());
46   const int vector_limit = x_size >> 3;
47 
48   int j = 0;
49   for (; j < vector_limit * 8; j += 8) {
50     const __m256 x_j = _mm256_loadu_ps(&x[j]);
51     const __m256 y_j = _mm256_loadu_ps(&y[j]);
52     const __m256 z_j = _mm256_mul_ps(x_j, y_j);
53     _mm256_storeu_ps(&z[j], z_j);
54   }
55 
56   for (; j < x_size; ++j) {
57     z[j] = x[j] * y[j];
58   }
59 }
60 
61 // Elementwise vector accumulation z += x.
AccumulateAVX2(rtc::ArrayView<const float> x,rtc::ArrayView<float> z)62 void VectorMath::AccumulateAVX2(rtc::ArrayView<const float> x,
63                                 rtc::ArrayView<float> z) {
64   RTC_DCHECK_EQ(z.size(), x.size());
65   const int x_size = static_cast<int>(x.size());
66   const int vector_limit = x_size >> 3;
67 
68   int j = 0;
69   for (; j < vector_limit * 8; j += 8) {
70     const __m256 x_j = _mm256_loadu_ps(&x[j]);
71     __m256 z_j = _mm256_loadu_ps(&z[j]);
72     z_j = _mm256_add_ps(x_j, z_j);
73     _mm256_storeu_ps(&z[j], z_j);
74   }
75 
76   for (; j < x_size; ++j) {
77     z[j] += x[j];
78   }
79 }
80 
81 }  // namespace aec3
82 }  // namespace webrtc
83