1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #ifndef MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_
12 #define MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_
13 
14 #include "typedefs.h"  // NOLINT(build/include)
15 #if defined(WEBRTC_HAS_NEON)
16 #include <arm_neon.h>
17 #endif
18 #if defined(WEBRTC_ARCH_X86_FAMILY)
19 #include <emmintrin.h>
20 #endif
21 #include <math.h>
22 #include <algorithm>
23 #include <array>
24 #include <functional>
25 
26 #include "api/array_view.h"
27 #include "modules/audio_processing/aec3/aec3_common.h"
28 #include "rtc_base/checks.h"
29 
30 namespace webrtc {
31 namespace aec3 {
32 
33 // Provides optimizations for mathematical operations based on vectors.
34 class VectorMath {
35  public:
VectorMath(Aec3Optimization optimization)36   explicit VectorMath(Aec3Optimization optimization)
37       : optimization_(optimization) {}
38 
39   // Elementwise square root.
Sqrt(rtc::ArrayView<float> x)40   void Sqrt(rtc::ArrayView<float> x) {
41     switch (optimization_) {
42 #if defined(WEBRTC_ARCH_X86_FAMILY)
43       case Aec3Optimization::kSse2: {
44         const int x_size = static_cast<int>(x.size());
45         const int vector_limit = x_size >> 2;
46 
47         int j = 0;
48         for (; j < vector_limit * 4; j += 4) {
49           __m128 g = _mm_loadu_ps(&x[j]);
50           g = _mm_sqrt_ps(g);
51           _mm_storeu_ps(&x[j], g);
52         }
53 
54         for (; j < x_size; ++j) {
55           x[j] = sqrtf(x[j]);
56         }
57       } break;
58 #endif
59 #if defined(WEBRTC_HAS_NEON)
60       case Aec3Optimization::kNeon: {
61         const int x_size = static_cast<int>(x.size());
62         const int vector_limit = x_size >> 2;
63 
64         int j = 0;
65         for (; j < vector_limit * 4; j += 4) {
66           float32x4_t g = vld1q_f32(&x[j]);
67 #if !defined(WEBRTC_ARCH_ARM64)
68           float32x4_t y = vrsqrteq_f32(g);
69 
70           // Code to handle sqrt(0).
71           // If the input to sqrtf() is zero, a zero will be returned.
72           // If the input to vrsqrteq_f32() is zero, positive infinity is
73           // returned.
74           const uint32x4_t vec_p_inf = vdupq_n_u32(0x7F800000);
75           // check for divide by zero
76           const uint32x4_t div_by_zero =
77               vceqq_u32(vec_p_inf, vreinterpretq_u32_f32(y));
78           // zero out the positive infinity results
79           y = vreinterpretq_f32_u32(
80               vandq_u32(vmvnq_u32(div_by_zero), vreinterpretq_u32_f32(y)));
81           // from arm documentation
82           // The Newton-Raphson iteration:
83           //     y[n+1] = y[n] * (3 - d * (y[n] * y[n])) / 2)
84           // converges to (1/√d) if y0 is the result of VRSQRTE applied to d.
85           //
86           // Note: The precision did not improve after 2 iterations.
87           for (int i = 0; i < 2; i++) {
88             y = vmulq_f32(vrsqrtsq_f32(vmulq_f32(y, y), g), y);
89           }
90           // sqrt(g) = g * 1/sqrt(g)
91           g = vmulq_f32(g, y);
92 #else
93           g = vsqrtq_f32(g);
94 #endif
95           vst1q_f32(&x[j], g);
96         }
97 
98         for (; j < x_size; ++j) {
99           x[j] = sqrtf(x[j]);
100         }
101       }
102 #endif
103       break;
104       default:
105         std::for_each(x.begin(), x.end(), [](float& a) { a = sqrtf(a); });
106     }
107   }
108 
109   // Elementwise vector multiplication z = x * y.
Multiply(rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> z)110   void Multiply(rtc::ArrayView<const float> x,
111                 rtc::ArrayView<const float> y,
112                 rtc::ArrayView<float> z) {
113     RTC_DCHECK_EQ(z.size(), x.size());
114     RTC_DCHECK_EQ(z.size(), y.size());
115     switch (optimization_) {
116 #if defined(WEBRTC_ARCH_X86_FAMILY)
117       case Aec3Optimization::kSse2: {
118         const int x_size = static_cast<int>(x.size());
119         const int vector_limit = x_size >> 2;
120 
121         int j = 0;
122         for (; j < vector_limit * 4; j += 4) {
123           const __m128 x_j = _mm_loadu_ps(&x[j]);
124           const __m128 y_j = _mm_loadu_ps(&y[j]);
125           const __m128 z_j = _mm_mul_ps(x_j, y_j);
126           _mm_storeu_ps(&z[j], z_j);
127         }
128 
129         for (; j < x_size; ++j) {
130           z[j] = x[j] * y[j];
131         }
132       } break;
133 #endif
134 #if defined(WEBRTC_HAS_NEON)
135       case Aec3Optimization::kNeon: {
136         const int x_size = static_cast<int>(x.size());
137         const int vector_limit = x_size >> 2;
138 
139         int j = 0;
140         for (; j < vector_limit * 4; j += 4) {
141           const float32x4_t x_j = vld1q_f32(&x[j]);
142           const float32x4_t y_j = vld1q_f32(&y[j]);
143           const float32x4_t z_j = vmulq_f32(x_j, y_j);
144           vst1q_f32(&z[j], z_j);
145         }
146 
147         for (; j < x_size; ++j) {
148           z[j] = x[j] * y[j];
149         }
150       } break;
151 #endif
152       default:
153         std::transform(x.begin(), x.end(), y.begin(), z.begin(),
154                        std::multiplies<float>());
155     }
156   }
157 
158   // Elementwise vector accumulation z += x.
Accumulate(rtc::ArrayView<const float> x,rtc::ArrayView<float> z)159   void Accumulate(rtc::ArrayView<const float> x, rtc::ArrayView<float> z) {
160     RTC_DCHECK_EQ(z.size(), x.size());
161     switch (optimization_) {
162 #if defined(WEBRTC_ARCH_X86_FAMILY)
163       case Aec3Optimization::kSse2: {
164         const int x_size = static_cast<int>(x.size());
165         const int vector_limit = x_size >> 2;
166 
167         int j = 0;
168         for (; j < vector_limit * 4; j += 4) {
169           const __m128 x_j = _mm_loadu_ps(&x[j]);
170           __m128 z_j = _mm_loadu_ps(&z[j]);
171           z_j = _mm_add_ps(x_j, z_j);
172           _mm_storeu_ps(&z[j], z_j);
173         }
174 
175         for (; j < x_size; ++j) {
176           z[j] += x[j];
177         }
178       } break;
179 #endif
180 #if defined(WEBRTC_HAS_NEON)
181       case Aec3Optimization::kNeon: {
182         const int x_size = static_cast<int>(x.size());
183         const int vector_limit = x_size >> 2;
184 
185         int j = 0;
186         for (; j < vector_limit * 4; j += 4) {
187           const float32x4_t x_j = vld1q_f32(&x[j]);
188           float32x4_t z_j = vld1q_f32(&z[j]);
189           z_j = vaddq_f32(z_j, x_j);
190           vst1q_f32(&z[j], z_j);
191         }
192 
193         for (; j < x_size; ++j) {
194           z[j] += x[j];
195         }
196       } break;
197 #endif
198       default:
199         std::transform(x.begin(), x.end(), z.begin(), z.begin(),
200                        std::plus<float>());
201     }
202   }
203 
204  private:
205   Aec3Optimization optimization_;
206 };
207 
208 }  // namespace aec3
209 
210 }  // namespace webrtc
211 
212 #endif  // MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_
213