1 /*
2 * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/aec3/vector_math.h"
12
13 #include <immintrin.h>
14 #include <math.h>
15
16 #include "api/array_view.h"
17 #include "rtc_base/checks.h"
18
19 namespace webrtc {
20 namespace aec3 {
21
22 // Elementwise square root.
SqrtAVX2(rtc::ArrayView<float> x)23 void VectorMath::SqrtAVX2(rtc::ArrayView<float> x) {
24 const int x_size = static_cast<int>(x.size());
25 const int vector_limit = x_size >> 3;
26
27 int j = 0;
28 for (; j < vector_limit * 8; j += 8) {
29 __m256 g = _mm256_loadu_ps(&x[j]);
30 g = _mm256_sqrt_ps(g);
31 _mm256_storeu_ps(&x[j], g);
32 }
33
34 for (; j < x_size; ++j) {
35 x[j] = sqrtf(x[j]);
36 }
37 }
38
39 // Elementwise vector multiplication z = x * y.
MultiplyAVX2(rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> z)40 void VectorMath::MultiplyAVX2(rtc::ArrayView<const float> x,
41 rtc::ArrayView<const float> y,
42 rtc::ArrayView<float> z) {
43 RTC_DCHECK_EQ(z.size(), x.size());
44 RTC_DCHECK_EQ(z.size(), y.size());
45 const int x_size = static_cast<int>(x.size());
46 const int vector_limit = x_size >> 3;
47
48 int j = 0;
49 for (; j < vector_limit * 8; j += 8) {
50 const __m256 x_j = _mm256_loadu_ps(&x[j]);
51 const __m256 y_j = _mm256_loadu_ps(&y[j]);
52 const __m256 z_j = _mm256_mul_ps(x_j, y_j);
53 _mm256_storeu_ps(&z[j], z_j);
54 }
55
56 for (; j < x_size; ++j) {
57 z[j] = x[j] * y[j];
58 }
59 }
60
61 // Elementwise vector accumulation z += x.
AccumulateAVX2(rtc::ArrayView<const float> x,rtc::ArrayView<float> z)62 void VectorMath::AccumulateAVX2(rtc::ArrayView<const float> x,
63 rtc::ArrayView<float> z) {
64 RTC_DCHECK_EQ(z.size(), x.size());
65 const int x_size = static_cast<int>(x.size());
66 const int vector_limit = x_size >> 3;
67
68 int j = 0;
69 for (; j < vector_limit * 8; j += 8) {
70 const __m256 x_j = _mm256_loadu_ps(&x[j]);
71 __m256 z_j = _mm256_loadu_ps(&z[j]);
72 z_j = _mm256_add_ps(x_j, z_j);
73 _mm256_storeu_ps(&z[j], z_j);
74 }
75
76 for (; j < x_size; ++j) {
77 z[j] += x[j];
78 }
79 }
80
81 } // namespace aec3
82 } // namespace webrtc
83