1 /* 2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #ifndef MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_ 12 #define MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_ 13 14 #include "typedefs.h" // NOLINT(build/include) 15 #if defined(WEBRTC_HAS_NEON) 16 #include <arm_neon.h> 17 #endif 18 #if defined(WEBRTC_ARCH_X86_FAMILY) 19 #include <emmintrin.h> 20 #endif 21 #include <math.h> 22 #include <algorithm> 23 #include <array> 24 #include <functional> 25 26 #include "api/array_view.h" 27 #include "modules/audio_processing/aec3/aec3_common.h" 28 #include "rtc_base/checks.h" 29 30 namespace webrtc { 31 namespace aec3 { 32 33 // Provides optimizations for mathematical operations based on vectors. 34 class VectorMath { 35 public: VectorMath(Aec3Optimization optimization)36 explicit VectorMath(Aec3Optimization optimization) 37 : optimization_(optimization) {} 38 39 // Elementwise square root. Sqrt(rtc::ArrayView<float> x)40 void Sqrt(rtc::ArrayView<float> x) { 41 switch (optimization_) { 42 #if defined(WEBRTC_ARCH_X86_FAMILY) 43 case Aec3Optimization::kSse2: { 44 const int x_size = static_cast<int>(x.size()); 45 const int vector_limit = x_size >> 2; 46 47 int j = 0; 48 for (; j < vector_limit * 4; j += 4) { 49 __m128 g = _mm_loadu_ps(&x[j]); 50 g = _mm_sqrt_ps(g); 51 _mm_storeu_ps(&x[j], g); 52 } 53 54 for (; j < x_size; ++j) { 55 x[j] = sqrtf(x[j]); 56 } 57 } break; 58 #endif 59 #if defined(WEBRTC_HAS_NEON) 60 case Aec3Optimization::kNeon: { 61 const int x_size = static_cast<int>(x.size()); 62 const int vector_limit = x_size >> 2; 63 64 int j = 0; 65 for (; j < vector_limit * 4; j += 4) { 66 float32x4_t g = vld1q_f32(&x[j]); 67 #if !defined(WEBRTC_ARCH_ARM64) 68 float32x4_t y = vrsqrteq_f32(g); 69 70 // Code to handle sqrt(0). 71 // If the input to sqrtf() is zero, a zero will be returned. 72 // If the input to vrsqrteq_f32() is zero, positive infinity is 73 // returned. 74 const uint32x4_t vec_p_inf = vdupq_n_u32(0x7F800000); 75 // check for divide by zero 76 const uint32x4_t div_by_zero = 77 vceqq_u32(vec_p_inf, vreinterpretq_u32_f32(y)); 78 // zero out the positive infinity results 79 y = vreinterpretq_f32_u32( 80 vandq_u32(vmvnq_u32(div_by_zero), vreinterpretq_u32_f32(y))); 81 // from arm documentation 82 // The Newton-Raphson iteration: 83 // y[n+1] = y[n] * (3 - d * (y[n] * y[n])) / 2) 84 // converges to (1/√d) if y0 is the result of VRSQRTE applied to d. 85 // 86 // Note: The precision did not improve after 2 iterations. 87 for (int i = 0; i < 2; i++) { 88 y = vmulq_f32(vrsqrtsq_f32(vmulq_f32(y, y), g), y); 89 } 90 // sqrt(g) = g * 1/sqrt(g) 91 g = vmulq_f32(g, y); 92 #else 93 g = vsqrtq_f32(g); 94 #endif 95 vst1q_f32(&x[j], g); 96 } 97 98 for (; j < x_size; ++j) { 99 x[j] = sqrtf(x[j]); 100 } 101 } 102 #endif 103 break; 104 default: 105 std::for_each(x.begin(), x.end(), [](float& a) { a = sqrtf(a); }); 106 } 107 } 108 109 // Elementwise vector multiplication z = x * y. Multiply(rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> z)110 void Multiply(rtc::ArrayView<const float> x, 111 rtc::ArrayView<const float> y, 112 rtc::ArrayView<float> z) { 113 RTC_DCHECK_EQ(z.size(), x.size()); 114 RTC_DCHECK_EQ(z.size(), y.size()); 115 switch (optimization_) { 116 #if defined(WEBRTC_ARCH_X86_FAMILY) 117 case Aec3Optimization::kSse2: { 118 const int x_size = static_cast<int>(x.size()); 119 const int vector_limit = x_size >> 2; 120 121 int j = 0; 122 for (; j < vector_limit * 4; j += 4) { 123 const __m128 x_j = _mm_loadu_ps(&x[j]); 124 const __m128 y_j = _mm_loadu_ps(&y[j]); 125 const __m128 z_j = _mm_mul_ps(x_j, y_j); 126 _mm_storeu_ps(&z[j], z_j); 127 } 128 129 for (; j < x_size; ++j) { 130 z[j] = x[j] * y[j]; 131 } 132 } break; 133 #endif 134 #if defined(WEBRTC_HAS_NEON) 135 case Aec3Optimization::kNeon: { 136 const int x_size = static_cast<int>(x.size()); 137 const int vector_limit = x_size >> 2; 138 139 int j = 0; 140 for (; j < vector_limit * 4; j += 4) { 141 const float32x4_t x_j = vld1q_f32(&x[j]); 142 const float32x4_t y_j = vld1q_f32(&y[j]); 143 const float32x4_t z_j = vmulq_f32(x_j, y_j); 144 vst1q_f32(&z[j], z_j); 145 } 146 147 for (; j < x_size; ++j) { 148 z[j] = x[j] * y[j]; 149 } 150 } break; 151 #endif 152 default: 153 std::transform(x.begin(), x.end(), y.begin(), z.begin(), 154 std::multiplies<float>()); 155 } 156 } 157 158 // Elementwise vector accumulation z += x. Accumulate(rtc::ArrayView<const float> x,rtc::ArrayView<float> z)159 void Accumulate(rtc::ArrayView<const float> x, rtc::ArrayView<float> z) { 160 RTC_DCHECK_EQ(z.size(), x.size()); 161 switch (optimization_) { 162 #if defined(WEBRTC_ARCH_X86_FAMILY) 163 case Aec3Optimization::kSse2: { 164 const int x_size = static_cast<int>(x.size()); 165 const int vector_limit = x_size >> 2; 166 167 int j = 0; 168 for (; j < vector_limit * 4; j += 4) { 169 const __m128 x_j = _mm_loadu_ps(&x[j]); 170 __m128 z_j = _mm_loadu_ps(&z[j]); 171 z_j = _mm_add_ps(x_j, z_j); 172 _mm_storeu_ps(&z[j], z_j); 173 } 174 175 for (; j < x_size; ++j) { 176 z[j] += x[j]; 177 } 178 } break; 179 #endif 180 #if defined(WEBRTC_HAS_NEON) 181 case Aec3Optimization::kNeon: { 182 const int x_size = static_cast<int>(x.size()); 183 const int vector_limit = x_size >> 2; 184 185 int j = 0; 186 for (; j < vector_limit * 4; j += 4) { 187 const float32x4_t x_j = vld1q_f32(&x[j]); 188 float32x4_t z_j = vld1q_f32(&z[j]); 189 z_j = vaddq_f32(z_j, x_j); 190 vst1q_f32(&z[j], z_j); 191 } 192 193 for (; j < x_size; ++j) { 194 z[j] += x[j]; 195 } 196 } break; 197 #endif 198 default: 199 std::transform(x.begin(), x.end(), z.begin(), z.begin(), 200 std::plus<float>()); 201 } 202 } 203 204 private: 205 Aec3Optimization optimization_; 206 }; 207 208 } // namespace aec3 209 210 } // namespace webrtc 211 212 #endif // MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_ 213