1 /*
2  *  Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 /*
12  * The core AEC algorithm, neon version of speed-critical functions.
13  *
14  * Based on aec_core_sse2.c.
15  */
16 
17 #include <arm_neon.h>
18 #include <math.h>
19 #include <string.h>  // memset
20 
21 extern "C" {
22 #include "common_audio/signal_processing/include/signal_processing_library.h"
23 }
24 #include "modules/audio_processing/aec/aec_common.h"
25 #include "modules/audio_processing/aec/aec_core_optimized_methods.h"
26 #include "modules/audio_processing/utility/ooura_fft.h"
27 
28 namespace webrtc {
29 
30 enum { kShiftExponentIntoTopMantissa = 8 };
31 enum { kFloatExponentShift = 23 };
32 
MulRe(float aRe,float aIm,float bRe,float bIm)33 __inline static float MulRe(float aRe, float aIm, float bRe, float bIm) {
34   return aRe * bRe - aIm * bIm;
35 }
36 
MulIm(float aRe,float aIm,float bRe,float bIm)37 __inline static float MulIm(float aRe, float aIm, float bRe, float bIm) {
38   return aRe * bIm + aIm * bRe;
39 }
40 
FilterFarNEON(int num_partitions,int x_fft_buf_block_pos,float x_fft_buf[2][kExtendedNumPartitions * PART_LEN1],float h_fft_buf[2][kExtendedNumPartitions * PART_LEN1],float y_fft[2][PART_LEN1])41 static void FilterFarNEON(int num_partitions,
42                           int x_fft_buf_block_pos,
43                           float x_fft_buf[2]
44                                          [kExtendedNumPartitions * PART_LEN1],
45                           float h_fft_buf[2]
46                                          [kExtendedNumPartitions * PART_LEN1],
47                           float y_fft[2][PART_LEN1]) {
48   int i;
49   for (i = 0; i < num_partitions; i++) {
50     int j;
51     int xPos = (i + x_fft_buf_block_pos) * PART_LEN1;
52     int pos = i * PART_LEN1;
53     // Check for wrap
54     if (i + x_fft_buf_block_pos >= num_partitions) {
55       xPos -= num_partitions * PART_LEN1;
56     }
57 
58     // vectorized code (four at once)
59     for (j = 0; j + 3 < PART_LEN1; j += 4) {
60       const float32x4_t x_fft_buf_re = vld1q_f32(&x_fft_buf[0][xPos + j]);
61       const float32x4_t x_fft_buf_im = vld1q_f32(&x_fft_buf[1][xPos + j]);
62       const float32x4_t h_fft_buf_re = vld1q_f32(&h_fft_buf[0][pos + j]);
63       const float32x4_t h_fft_buf_im = vld1q_f32(&h_fft_buf[1][pos + j]);
64       const float32x4_t y_fft_re = vld1q_f32(&y_fft[0][j]);
65       const float32x4_t y_fft_im = vld1q_f32(&y_fft[1][j]);
66       const float32x4_t a = vmulq_f32(x_fft_buf_re, h_fft_buf_re);
67       const float32x4_t e = vmlsq_f32(a, x_fft_buf_im, h_fft_buf_im);
68       const float32x4_t c = vmulq_f32(x_fft_buf_re, h_fft_buf_im);
69       const float32x4_t f = vmlaq_f32(c, x_fft_buf_im, h_fft_buf_re);
70       const float32x4_t g = vaddq_f32(y_fft_re, e);
71       const float32x4_t h = vaddq_f32(y_fft_im, f);
72       vst1q_f32(&y_fft[0][j], g);
73       vst1q_f32(&y_fft[1][j], h);
74     }
75     // scalar code for the remaining items.
76     for (; j < PART_LEN1; j++) {
77       y_fft[0][j] += MulRe(x_fft_buf[0][xPos + j], x_fft_buf[1][xPos + j],
78                            h_fft_buf[0][pos + j], h_fft_buf[1][pos + j]);
79       y_fft[1][j] += MulIm(x_fft_buf[0][xPos + j], x_fft_buf[1][xPos + j],
80                            h_fft_buf[0][pos + j], h_fft_buf[1][pos + j]);
81     }
82   }
83 }
84 
85 // ARM64's arm_neon.h has already defined vdivq_f32 vsqrtq_f32.
86 #if !defined(WEBRTC_ARCH_ARM64)
vdivq_f32(float32x4_t a,float32x4_t b)87 static float32x4_t vdivq_f32(float32x4_t a, float32x4_t b) {
88   int i;
89   float32x4_t x = vrecpeq_f32(b);
90   // from arm documentation
91   // The Newton-Raphson iteration:
92   //     x[n+1] = x[n] * (2 - d * x[n])
93   // converges to (1/d) if x0 is the result of VRECPE applied to d.
94   //
95   // Note: The precision did not improve after 2 iterations.
96   for (i = 0; i < 2; i++) {
97     x = vmulq_f32(vrecpsq_f32(b, x), x);
98   }
99   // a/b = a*(1/b)
100   return vmulq_f32(a, x);
101 }
102 
vsqrtq_f32(float32x4_t s)103 static float32x4_t vsqrtq_f32(float32x4_t s) {
104   int i;
105   float32x4_t x = vrsqrteq_f32(s);
106 
107   // Code to handle sqrt(0).
108   // If the input to sqrtf() is zero, a zero will be returned.
109   // If the input to vrsqrteq_f32() is zero, positive infinity is returned.
110   const uint32x4_t vec_p_inf = vdupq_n_u32(0x7F800000);
111   // check for divide by zero
112   const uint32x4_t div_by_zero = vceqq_u32(vec_p_inf, vreinterpretq_u32_f32(x));
113   // zero out the positive infinity results
114   x = vreinterpretq_f32_u32(
115       vandq_u32(vmvnq_u32(div_by_zero), vreinterpretq_u32_f32(x)));
116   // from arm documentation
117   // The Newton-Raphson iteration:
118   //     x[n+1] = x[n] * (3 - d * (x[n] * x[n])) / 2)
119   // converges to (1/√d) if x0 is the result of VRSQRTE applied to d.
120   //
121   // Note: The precision did not improve after 2 iterations.
122   for (i = 0; i < 2; i++) {
123     x = vmulq_f32(vrsqrtsq_f32(vmulq_f32(x, x), s), x);
124   }
125   // sqrt(s) = s * 1/sqrt(s)
126   return vmulq_f32(s, x);
127 }
128 #endif  // WEBRTC_ARCH_ARM64
129 
ScaleErrorSignalNEON(float mu,float error_threshold,float x_pow[PART_LEN1],float ef[2][PART_LEN1])130 static void ScaleErrorSignalNEON(float mu,
131                                  float error_threshold,
132                                  float x_pow[PART_LEN1],
133                                  float ef[2][PART_LEN1]) {
134   const float32x4_t k1e_10f = vdupq_n_f32(1e-10f);
135   const float32x4_t kMu = vmovq_n_f32(mu);
136   const float32x4_t kThresh = vmovq_n_f32(error_threshold);
137   int i;
138   // vectorized code (four at once)
139   for (i = 0; i + 3 < PART_LEN1; i += 4) {
140     const float32x4_t x_pow_local = vld1q_f32(&x_pow[i]);
141     const float32x4_t ef_re_base = vld1q_f32(&ef[0][i]);
142     const float32x4_t ef_im_base = vld1q_f32(&ef[1][i]);
143     const float32x4_t xPowPlus = vaddq_f32(x_pow_local, k1e_10f);
144     float32x4_t ef_re = vdivq_f32(ef_re_base, xPowPlus);
145     float32x4_t ef_im = vdivq_f32(ef_im_base, xPowPlus);
146     const float32x4_t ef_re2 = vmulq_f32(ef_re, ef_re);
147     const float32x4_t ef_sum2 = vmlaq_f32(ef_re2, ef_im, ef_im);
148     const float32x4_t absEf = vsqrtq_f32(ef_sum2);
149     const uint32x4_t bigger = vcgtq_f32(absEf, kThresh);
150     const float32x4_t absEfPlus = vaddq_f32(absEf, k1e_10f);
151     const float32x4_t absEfInv = vdivq_f32(kThresh, absEfPlus);
152     uint32x4_t ef_re_if = vreinterpretq_u32_f32(vmulq_f32(ef_re, absEfInv));
153     uint32x4_t ef_im_if = vreinterpretq_u32_f32(vmulq_f32(ef_im, absEfInv));
154     uint32x4_t ef_re_u32 =
155         vandq_u32(vmvnq_u32(bigger), vreinterpretq_u32_f32(ef_re));
156     uint32x4_t ef_im_u32 =
157         vandq_u32(vmvnq_u32(bigger), vreinterpretq_u32_f32(ef_im));
158     ef_re_if = vandq_u32(bigger, ef_re_if);
159     ef_im_if = vandq_u32(bigger, ef_im_if);
160     ef_re_u32 = vorrq_u32(ef_re_u32, ef_re_if);
161     ef_im_u32 = vorrq_u32(ef_im_u32, ef_im_if);
162     ef_re = vmulq_f32(vreinterpretq_f32_u32(ef_re_u32), kMu);
163     ef_im = vmulq_f32(vreinterpretq_f32_u32(ef_im_u32), kMu);
164     vst1q_f32(&ef[0][i], ef_re);
165     vst1q_f32(&ef[1][i], ef_im);
166   }
167   // scalar code for the remaining items.
168   for (; i < PART_LEN1; i++) {
169     float abs_ef;
170     ef[0][i] /= (x_pow[i] + 1e-10f);
171     ef[1][i] /= (x_pow[i] + 1e-10f);
172     abs_ef = sqrtf(ef[0][i] * ef[0][i] + ef[1][i] * ef[1][i]);
173 
174     if (abs_ef > error_threshold) {
175       abs_ef = error_threshold / (abs_ef + 1e-10f);
176       ef[0][i] *= abs_ef;
177       ef[1][i] *= abs_ef;
178     }
179 
180     // Stepsize factor
181     ef[0][i] *= mu;
182     ef[1][i] *= mu;
183   }
184 }
185 
FilterAdaptationNEON(const OouraFft & ooura_fft,int num_partitions,int x_fft_buf_block_pos,float x_fft_buf[2][kExtendedNumPartitions * PART_LEN1],float e_fft[2][PART_LEN1],float h_fft_buf[2][kExtendedNumPartitions * PART_LEN1])186 static void FilterAdaptationNEON(
187     const OouraFft& ooura_fft,
188     int num_partitions,
189     int x_fft_buf_block_pos,
190     float x_fft_buf[2][kExtendedNumPartitions * PART_LEN1],
191     float e_fft[2][PART_LEN1],
192     float h_fft_buf[2][kExtendedNumPartitions * PART_LEN1]) {
193   float fft[PART_LEN2];
194   int i;
195   for (i = 0; i < num_partitions; i++) {
196     int xPos = (i + x_fft_buf_block_pos) * PART_LEN1;
197     int pos = i * PART_LEN1;
198     int j;
199     // Check for wrap
200     if (i + x_fft_buf_block_pos >= num_partitions) {
201       xPos -= num_partitions * PART_LEN1;
202     }
203 
204     // Process the whole array...
205     for (j = 0; j < PART_LEN; j += 4) {
206       // Load x_fft_buf and e_fft.
207       const float32x4_t x_fft_buf_re = vld1q_f32(&x_fft_buf[0][xPos + j]);
208       const float32x4_t x_fft_buf_im = vld1q_f32(&x_fft_buf[1][xPos + j]);
209       const float32x4_t e_fft_re = vld1q_f32(&e_fft[0][j]);
210       const float32x4_t e_fft_im = vld1q_f32(&e_fft[1][j]);
211       // Calculate the product of conjugate(x_fft_buf) by e_fft.
212       //   re(conjugate(a) * b) = aRe * bRe + aIm * bIm
213       //   im(conjugate(a) * b)=  aRe * bIm - aIm * bRe
214       const float32x4_t a = vmulq_f32(x_fft_buf_re, e_fft_re);
215       const float32x4_t e = vmlaq_f32(a, x_fft_buf_im, e_fft_im);
216       const float32x4_t c = vmulq_f32(x_fft_buf_re, e_fft_im);
217       const float32x4_t f = vmlsq_f32(c, x_fft_buf_im, e_fft_re);
218       // Interleave real and imaginary parts.
219       const float32x4x2_t g_n_h = vzipq_f32(e, f);
220       // Store
221       vst1q_f32(&fft[2 * j + 0], g_n_h.val[0]);
222       vst1q_f32(&fft[2 * j + 4], g_n_h.val[1]);
223     }
224     // ... and fixup the first imaginary entry.
225     fft[1] =
226         MulRe(x_fft_buf[0][xPos + PART_LEN], -x_fft_buf[1][xPos + PART_LEN],
227               e_fft[0][PART_LEN], e_fft[1][PART_LEN]);
228 
229     ooura_fft.InverseFft(fft);
230     memset(fft + PART_LEN, 0, sizeof(float) * PART_LEN);
231 
232     // fft scaling
233     {
234       const float scale = 2.0f / PART_LEN2;
235       const float32x4_t scale_ps = vmovq_n_f32(scale);
236       for (j = 0; j < PART_LEN; j += 4) {
237         const float32x4_t fft_ps = vld1q_f32(&fft[j]);
238         const float32x4_t fft_scale = vmulq_f32(fft_ps, scale_ps);
239         vst1q_f32(&fft[j], fft_scale);
240       }
241     }
242     ooura_fft.Fft(fft);
243 
244     {
245       const float wt1 = h_fft_buf[1][pos];
246       h_fft_buf[0][pos + PART_LEN] += fft[1];
247       for (j = 0; j < PART_LEN; j += 4) {
248         float32x4_t wtBuf_re = vld1q_f32(&h_fft_buf[0][pos + j]);
249         float32x4_t wtBuf_im = vld1q_f32(&h_fft_buf[1][pos + j]);
250         const float32x4_t fft0 = vld1q_f32(&fft[2 * j + 0]);
251         const float32x4_t fft4 = vld1q_f32(&fft[2 * j + 4]);
252         const float32x4x2_t fft_re_im = vuzpq_f32(fft0, fft4);
253         wtBuf_re = vaddq_f32(wtBuf_re, fft_re_im.val[0]);
254         wtBuf_im = vaddq_f32(wtBuf_im, fft_re_im.val[1]);
255 
256         vst1q_f32(&h_fft_buf[0][pos + j], wtBuf_re);
257         vst1q_f32(&h_fft_buf[1][pos + j], wtBuf_im);
258       }
259       h_fft_buf[1][pos] = wt1;
260     }
261   }
262 }
263 
vpowq_f32(float32x4_t a,float32x4_t b)264 static float32x4_t vpowq_f32(float32x4_t a, float32x4_t b) {
265   // a^b = exp2(b * log2(a))
266   //   exp2(x) and log2(x) are calculated using polynomial approximations.
267   float32x4_t log2_a, b_log2_a, a_exp_b;
268 
269   // Calculate log2(x), x = a.
270   {
271     // To calculate log2(x), we decompose x like this:
272     //   x = y * 2^n
273     //     n is an integer
274     //     y is in the [1.0, 2.0) range
275     //
276     //   log2(x) = log2(y) + n
277     //     n       can be evaluated by playing with float representation.
278     //     log2(y) in a small range can be approximated, this code uses an order
279     //             five polynomial approximation. The coefficients have been
280     //             estimated with the Remez algorithm and the resulting
281     //             polynomial has a maximum relative error of 0.00086%.
282 
283     // Compute n.
284     //    This is done by masking the exponent, shifting it into the top bit of
285     //    the mantissa, putting eight into the biased exponent (to shift/
286     //    compensate the fact that the exponent has been shifted in the top/
287     //    fractional part and finally getting rid of the implicit leading one
288     //    from the mantissa by substracting it out.
289     const uint32x4_t vec_float_exponent_mask = vdupq_n_u32(0x7F800000);
290     const uint32x4_t vec_eight_biased_exponent = vdupq_n_u32(0x43800000);
291     const uint32x4_t vec_implicit_leading_one = vdupq_n_u32(0x43BF8000);
292     const uint32x4_t two_n =
293         vandq_u32(vreinterpretq_u32_f32(a), vec_float_exponent_mask);
294     const uint32x4_t n_1 = vshrq_n_u32(two_n, kShiftExponentIntoTopMantissa);
295     const uint32x4_t n_0 = vorrq_u32(n_1, vec_eight_biased_exponent);
296     const float32x4_t n =
297         vsubq_f32(vreinterpretq_f32_u32(n_0),
298                   vreinterpretq_f32_u32(vec_implicit_leading_one));
299     // Compute y.
300     const uint32x4_t vec_mantissa_mask = vdupq_n_u32(0x007FFFFF);
301     const uint32x4_t vec_zero_biased_exponent_is_one = vdupq_n_u32(0x3F800000);
302     const uint32x4_t mantissa =
303         vandq_u32(vreinterpretq_u32_f32(a), vec_mantissa_mask);
304     const float32x4_t y = vreinterpretq_f32_u32(
305         vorrq_u32(mantissa, vec_zero_biased_exponent_is_one));
306     // Approximate log2(y) ~= (y - 1) * pol5(y).
307     //    pol5(y) = C5 * y^5 + C4 * y^4 + C3 * y^3 + C2 * y^2 + C1 * y + C0
308     const float32x4_t C5 = vdupq_n_f32(-3.4436006e-2f);
309     const float32x4_t C4 = vdupq_n_f32(3.1821337e-1f);
310     const float32x4_t C3 = vdupq_n_f32(-1.2315303f);
311     const float32x4_t C2 = vdupq_n_f32(2.5988452f);
312     const float32x4_t C1 = vdupq_n_f32(-3.3241990f);
313     const float32x4_t C0 = vdupq_n_f32(3.1157899f);
314     float32x4_t pol5_y = C5;
315     pol5_y = vmlaq_f32(C4, y, pol5_y);
316     pol5_y = vmlaq_f32(C3, y, pol5_y);
317     pol5_y = vmlaq_f32(C2, y, pol5_y);
318     pol5_y = vmlaq_f32(C1, y, pol5_y);
319     pol5_y = vmlaq_f32(C0, y, pol5_y);
320     const float32x4_t y_minus_one =
321         vsubq_f32(y, vreinterpretq_f32_u32(vec_zero_biased_exponent_is_one));
322     const float32x4_t log2_y = vmulq_f32(y_minus_one, pol5_y);
323 
324     // Combine parts.
325     log2_a = vaddq_f32(n, log2_y);
326   }
327 
328   // b * log2(a)
329   b_log2_a = vmulq_f32(b, log2_a);
330 
331   // Calculate exp2(x), x = b * log2(a).
332   {
333     // To calculate 2^x, we decompose x like this:
334     //   x = n + y
335     //     n is an integer, the value of x - 0.5 rounded down, therefore
336     //     y is in the [0.5, 1.5) range
337     //
338     //   2^x = 2^n * 2^y
339     //     2^n can be evaluated by playing with float representation.
340     //     2^y in a small range can be approximated, this code uses an order two
341     //         polynomial approximation. The coefficients have been estimated
342     //         with the Remez algorithm and the resulting polynomial has a
343     //         maximum relative error of 0.17%.
344     // To avoid over/underflow, we reduce the range of input to ]-127, 129].
345     const float32x4_t max_input = vdupq_n_f32(129.f);
346     const float32x4_t min_input = vdupq_n_f32(-126.99999f);
347     const float32x4_t x_min = vminq_f32(b_log2_a, max_input);
348     const float32x4_t x_max = vmaxq_f32(x_min, min_input);
349     // Compute n.
350     const float32x4_t half = vdupq_n_f32(0.5f);
351     const float32x4_t x_minus_half = vsubq_f32(x_max, half);
352     const int32x4_t x_minus_half_floor = vcvtq_s32_f32(x_minus_half);
353 
354     // Compute 2^n.
355     const int32x4_t float_exponent_bias = vdupq_n_s32(127);
356     const int32x4_t two_n_exponent =
357         vaddq_s32(x_minus_half_floor, float_exponent_bias);
358     const float32x4_t two_n =
359         vreinterpretq_f32_s32(vshlq_n_s32(two_n_exponent, kFloatExponentShift));
360     // Compute y.
361     const float32x4_t y = vsubq_f32(x_max, vcvtq_f32_s32(x_minus_half_floor));
362 
363     // Approximate 2^y ~= C2 * y^2 + C1 * y + C0.
364     const float32x4_t C2 = vdupq_n_f32(3.3718944e-1f);
365     const float32x4_t C1 = vdupq_n_f32(6.5763628e-1f);
366     const float32x4_t C0 = vdupq_n_f32(1.0017247f);
367     float32x4_t exp2_y = C2;
368     exp2_y = vmlaq_f32(C1, y, exp2_y);
369     exp2_y = vmlaq_f32(C0, y, exp2_y);
370 
371     // Combine parts.
372     a_exp_b = vmulq_f32(exp2_y, two_n);
373   }
374 
375   return a_exp_b;
376 }
377 
OverdriveNEON(float overdrive_scaling,float hNlFb,float hNl[PART_LEN1])378 static void OverdriveNEON(float overdrive_scaling,
379                           float hNlFb,
380                           float hNl[PART_LEN1]) {
381   int i;
382   const float32x4_t vec_hNlFb = vmovq_n_f32(hNlFb);
383   const float32x4_t vec_one = vdupq_n_f32(1.0f);
384   const float32x4_t vec_overdrive_scaling = vmovq_n_f32(overdrive_scaling);
385 
386   // vectorized code (four at once)
387   for (i = 0; i + 3 < PART_LEN1; i += 4) {
388     // Weight subbands
389     float32x4_t vec_hNl = vld1q_f32(&hNl[i]);
390     const float32x4_t vec_weightCurve = vld1q_f32(&WebRtcAec_weightCurve[i]);
391     const uint32x4_t bigger = vcgtq_f32(vec_hNl, vec_hNlFb);
392     const float32x4_t vec_weightCurve_hNlFb =
393         vmulq_f32(vec_weightCurve, vec_hNlFb);
394     const float32x4_t vec_one_weightCurve = vsubq_f32(vec_one, vec_weightCurve);
395     const float32x4_t vec_one_weightCurve_hNl =
396         vmulq_f32(vec_one_weightCurve, vec_hNl);
397     const uint32x4_t vec_if0 =
398         vandq_u32(vmvnq_u32(bigger), vreinterpretq_u32_f32(vec_hNl));
399     const float32x4_t vec_one_weightCurve_add =
400         vaddq_f32(vec_weightCurve_hNlFb, vec_one_weightCurve_hNl);
401     const uint32x4_t vec_if1 =
402         vandq_u32(bigger, vreinterpretq_u32_f32(vec_one_weightCurve_add));
403 
404     vec_hNl = vreinterpretq_f32_u32(vorrq_u32(vec_if0, vec_if1));
405 
406     const float32x4_t vec_overDriveCurve =
407         vld1q_f32(&WebRtcAec_overDriveCurve[i]);
408     const float32x4_t vec_overDriveSm_overDriveCurve =
409         vmulq_f32(vec_overdrive_scaling, vec_overDriveCurve);
410     vec_hNl = vpowq_f32(vec_hNl, vec_overDriveSm_overDriveCurve);
411     vst1q_f32(&hNl[i], vec_hNl);
412   }
413 
414   // scalar code for the remaining items.
415   for (; i < PART_LEN1; i++) {
416     // Weight subbands
417     if (hNl[i] > hNlFb) {
418       hNl[i] = WebRtcAec_weightCurve[i] * hNlFb +
419                (1 - WebRtcAec_weightCurve[i]) * hNl[i];
420     }
421 
422     hNl[i] = powf(hNl[i], overdrive_scaling * WebRtcAec_overDriveCurve[i]);
423   }
424 }
425 
SuppressNEON(const float hNl[PART_LEN1],float efw[2][PART_LEN1])426 static void SuppressNEON(const float hNl[PART_LEN1], float efw[2][PART_LEN1]) {
427   int i;
428   const float32x4_t vec_minus_one = vdupq_n_f32(-1.0f);
429   // vectorized code (four at once)
430   for (i = 0; i + 3 < PART_LEN1; i += 4) {
431     float32x4_t vec_hNl = vld1q_f32(&hNl[i]);
432     float32x4_t vec_efw_re = vld1q_f32(&efw[0][i]);
433     float32x4_t vec_efw_im = vld1q_f32(&efw[1][i]);
434     vec_efw_re = vmulq_f32(vec_efw_re, vec_hNl);
435     vec_efw_im = vmulq_f32(vec_efw_im, vec_hNl);
436 
437     // Ooura fft returns incorrect sign on imaginary component. It matters
438     // here because we are making an additive change with comfort noise.
439     vec_efw_im = vmulq_f32(vec_efw_im, vec_minus_one);
440     vst1q_f32(&efw[0][i], vec_efw_re);
441     vst1q_f32(&efw[1][i], vec_efw_im);
442   }
443 
444   // scalar code for the remaining items.
445   for (; i < PART_LEN1; i++) {
446     efw[0][i] *= hNl[i];
447     efw[1][i] *= hNl[i];
448 
449     // Ooura fft returns incorrect sign on imaginary component. It matters
450     // here because we are making an additive change with comfort noise.
451     efw[1][i] *= -1;
452   }
453 }
454 
PartitionDelayNEON(int num_partitions,float h_fft_buf[2][kExtendedNumPartitions * PART_LEN1])455 static int PartitionDelayNEON(
456     int num_partitions,
457     float h_fft_buf[2][kExtendedNumPartitions * PART_LEN1]) {
458   // Measures the energy in each filter partition and returns the partition with
459   // highest energy.
460   // TODO(bjornv): Spread computational cost by computing one partition per
461   // block?
462   float wfEnMax = 0;
463   int i;
464   int delay = 0;
465 
466   for (i = 0; i < num_partitions; i++) {
467     int j;
468     int pos = i * PART_LEN1;
469     float wfEn = 0;
470     float32x4_t vec_wfEn = vdupq_n_f32(0.0f);
471     // vectorized code (four at once)
472     for (j = 0; j + 3 < PART_LEN1; j += 4) {
473       const float32x4_t vec_wfBuf0 = vld1q_f32(&h_fft_buf[0][pos + j]);
474       const float32x4_t vec_wfBuf1 = vld1q_f32(&h_fft_buf[1][pos + j]);
475       vec_wfEn = vmlaq_f32(vec_wfEn, vec_wfBuf0, vec_wfBuf0);
476       vec_wfEn = vmlaq_f32(vec_wfEn, vec_wfBuf1, vec_wfBuf1);
477     }
478     {
479       float32x2_t vec_total;
480       // A B C D
481       vec_total = vpadd_f32(vget_low_f32(vec_wfEn), vget_high_f32(vec_wfEn));
482       // A+B C+D
483       vec_total = vpadd_f32(vec_total, vec_total);
484       // A+B+C+D A+B+C+D
485       wfEn = vget_lane_f32(vec_total, 0);
486     }
487 
488     // scalar code for the remaining items.
489     for (; j < PART_LEN1; j++) {
490       wfEn += h_fft_buf[0][pos + j] * h_fft_buf[0][pos + j] +
491               h_fft_buf[1][pos + j] * h_fft_buf[1][pos + j];
492     }
493 
494     if (wfEn > wfEnMax) {
495       wfEnMax = wfEn;
496       delay = i;
497     }
498   }
499   return delay;
500 }
501 
502 // Updates the following smoothed  Power Spectral Densities (PSD):
503 //  - sd  : near-end
504 //  - se  : residual echo
505 //  - sx  : far-end
506 //  - sde : cross-PSD of near-end and residual echo
507 //  - sxd : cross-PSD of near-end and far-end
508 //
509 // In addition to updating the PSDs, also the filter diverge state is determined
510 // upon actions are taken.
UpdateCoherenceSpectraNEON(int mult,bool extended_filter_enabled,float efw[2][PART_LEN1],float dfw[2][PART_LEN1],float xfw[2][PART_LEN1],CoherenceState * coherence_state,short * filter_divergence_state,int * extreme_filter_divergence)511 static void UpdateCoherenceSpectraNEON(int mult,
512                                        bool extended_filter_enabled,
513                                        float efw[2][PART_LEN1],
514                                        float dfw[2][PART_LEN1],
515                                        float xfw[2][PART_LEN1],
516                                        CoherenceState* coherence_state,
517                                        short* filter_divergence_state,
518                                        int* extreme_filter_divergence) {
519   // Power estimate smoothing coefficients.
520   const float* ptrGCoh =
521       extended_filter_enabled
522           ? WebRtcAec_kExtendedSmoothingCoefficients[mult - 1]
523           : WebRtcAec_kNormalSmoothingCoefficients[mult - 1];
524   int i;
525   float sdSum = 0, seSum = 0;
526   const float32x4_t vec_15 = vdupq_n_f32(WebRtcAec_kMinFarendPSD);
527   float32x4_t vec_sdSum = vdupq_n_f32(0.0f);
528   float32x4_t vec_seSum = vdupq_n_f32(0.0f);
529 
530   for (i = 0; i + 3 < PART_LEN1; i += 4) {
531     const float32x4_t vec_dfw0 = vld1q_f32(&dfw[0][i]);
532     const float32x4_t vec_dfw1 = vld1q_f32(&dfw[1][i]);
533     const float32x4_t vec_efw0 = vld1q_f32(&efw[0][i]);
534     const float32x4_t vec_efw1 = vld1q_f32(&efw[1][i]);
535     const float32x4_t vec_xfw0 = vld1q_f32(&xfw[0][i]);
536     const float32x4_t vec_xfw1 = vld1q_f32(&xfw[1][i]);
537     float32x4_t vec_sd =
538         vmulq_n_f32(vld1q_f32(&coherence_state->sd[i]), ptrGCoh[0]);
539     float32x4_t vec_se =
540         vmulq_n_f32(vld1q_f32(&coherence_state->se[i]), ptrGCoh[0]);
541     float32x4_t vec_sx =
542         vmulq_n_f32(vld1q_f32(&coherence_state->sx[i]), ptrGCoh[0]);
543     float32x4_t vec_dfw_sumsq = vmulq_f32(vec_dfw0, vec_dfw0);
544     float32x4_t vec_efw_sumsq = vmulq_f32(vec_efw0, vec_efw0);
545     float32x4_t vec_xfw_sumsq = vmulq_f32(vec_xfw0, vec_xfw0);
546 
547     vec_dfw_sumsq = vmlaq_f32(vec_dfw_sumsq, vec_dfw1, vec_dfw1);
548     vec_efw_sumsq = vmlaq_f32(vec_efw_sumsq, vec_efw1, vec_efw1);
549     vec_xfw_sumsq = vmlaq_f32(vec_xfw_sumsq, vec_xfw1, vec_xfw1);
550     vec_xfw_sumsq = vmaxq_f32(vec_xfw_sumsq, vec_15);
551     vec_sd = vmlaq_n_f32(vec_sd, vec_dfw_sumsq, ptrGCoh[1]);
552     vec_se = vmlaq_n_f32(vec_se, vec_efw_sumsq, ptrGCoh[1]);
553     vec_sx = vmlaq_n_f32(vec_sx, vec_xfw_sumsq, ptrGCoh[1]);
554 
555     vst1q_f32(&coherence_state->sd[i], vec_sd);
556     vst1q_f32(&coherence_state->se[i], vec_se);
557     vst1q_f32(&coherence_state->sx[i], vec_sx);
558 
559     {
560       float32x4x2_t vec_sde = vld2q_f32(&coherence_state->sde[i][0]);
561       float32x4_t vec_dfwefw0011 = vmulq_f32(vec_dfw0, vec_efw0);
562       float32x4_t vec_dfwefw0110 = vmulq_f32(vec_dfw0, vec_efw1);
563       vec_sde.val[0] = vmulq_n_f32(vec_sde.val[0], ptrGCoh[0]);
564       vec_sde.val[1] = vmulq_n_f32(vec_sde.val[1], ptrGCoh[0]);
565       vec_dfwefw0011 = vmlaq_f32(vec_dfwefw0011, vec_dfw1, vec_efw1);
566       vec_dfwefw0110 = vmlsq_f32(vec_dfwefw0110, vec_dfw1, vec_efw0);
567       vec_sde.val[0] = vmlaq_n_f32(vec_sde.val[0], vec_dfwefw0011, ptrGCoh[1]);
568       vec_sde.val[1] = vmlaq_n_f32(vec_sde.val[1], vec_dfwefw0110, ptrGCoh[1]);
569       vst2q_f32(&coherence_state->sde[i][0], vec_sde);
570     }
571 
572     {
573       float32x4x2_t vec_sxd = vld2q_f32(&coherence_state->sxd[i][0]);
574       float32x4_t vec_dfwxfw0011 = vmulq_f32(vec_dfw0, vec_xfw0);
575       float32x4_t vec_dfwxfw0110 = vmulq_f32(vec_dfw0, vec_xfw1);
576       vec_sxd.val[0] = vmulq_n_f32(vec_sxd.val[0], ptrGCoh[0]);
577       vec_sxd.val[1] = vmulq_n_f32(vec_sxd.val[1], ptrGCoh[0]);
578       vec_dfwxfw0011 = vmlaq_f32(vec_dfwxfw0011, vec_dfw1, vec_xfw1);
579       vec_dfwxfw0110 = vmlsq_f32(vec_dfwxfw0110, vec_dfw1, vec_xfw0);
580       vec_sxd.val[0] = vmlaq_n_f32(vec_sxd.val[0], vec_dfwxfw0011, ptrGCoh[1]);
581       vec_sxd.val[1] = vmlaq_n_f32(vec_sxd.val[1], vec_dfwxfw0110, ptrGCoh[1]);
582       vst2q_f32(&coherence_state->sxd[i][0], vec_sxd);
583     }
584 
585     vec_sdSum = vaddq_f32(vec_sdSum, vec_sd);
586     vec_seSum = vaddq_f32(vec_seSum, vec_se);
587   }
588   {
589     float32x2_t vec_sdSum_total;
590     float32x2_t vec_seSum_total;
591     // A B C D
592     vec_sdSum_total =
593         vpadd_f32(vget_low_f32(vec_sdSum), vget_high_f32(vec_sdSum));
594     vec_seSum_total =
595         vpadd_f32(vget_low_f32(vec_seSum), vget_high_f32(vec_seSum));
596     // A+B C+D
597     vec_sdSum_total = vpadd_f32(vec_sdSum_total, vec_sdSum_total);
598     vec_seSum_total = vpadd_f32(vec_seSum_total, vec_seSum_total);
599     // A+B+C+D A+B+C+D
600     sdSum = vget_lane_f32(vec_sdSum_total, 0);
601     seSum = vget_lane_f32(vec_seSum_total, 0);
602   }
603 
604   // scalar code for the remaining items.
605   for (; i < PART_LEN1; i++) {
606     coherence_state->sd[i] =
607         ptrGCoh[0] * coherence_state->sd[i] +
608         ptrGCoh[1] * (dfw[0][i] * dfw[0][i] + dfw[1][i] * dfw[1][i]);
609     coherence_state->se[i] =
610         ptrGCoh[0] * coherence_state->se[i] +
611         ptrGCoh[1] * (efw[0][i] * efw[0][i] + efw[1][i] * efw[1][i]);
612     // We threshold here to protect against the ill-effects of a zero farend.
613     // The threshold is not arbitrarily chosen, but balances protection and
614     // adverse interaction with the algorithm's tuning.
615     // TODO(bjornv): investigate further why this is so sensitive.
616     coherence_state->sx[i] =
617         ptrGCoh[0] * coherence_state->sx[i] +
618         ptrGCoh[1] *
619             WEBRTC_SPL_MAX(xfw[0][i] * xfw[0][i] + xfw[1][i] * xfw[1][i],
620                            WebRtcAec_kMinFarendPSD);
621 
622     coherence_state->sde[i][0] =
623         ptrGCoh[0] * coherence_state->sde[i][0] +
624         ptrGCoh[1] * (dfw[0][i] * efw[0][i] + dfw[1][i] * efw[1][i]);
625     coherence_state->sde[i][1] =
626         ptrGCoh[0] * coherence_state->sde[i][1] +
627         ptrGCoh[1] * (dfw[0][i] * efw[1][i] - dfw[1][i] * efw[0][i]);
628 
629     coherence_state->sxd[i][0] =
630         ptrGCoh[0] * coherence_state->sxd[i][0] +
631         ptrGCoh[1] * (dfw[0][i] * xfw[0][i] + dfw[1][i] * xfw[1][i]);
632     coherence_state->sxd[i][1] =
633         ptrGCoh[0] * coherence_state->sxd[i][1] +
634         ptrGCoh[1] * (dfw[0][i] * xfw[1][i] - dfw[1][i] * xfw[0][i]);
635 
636     sdSum += coherence_state->sd[i];
637     seSum += coherence_state->se[i];
638   }
639 
640   // Divergent filter safeguard update.
641   *filter_divergence_state =
642       (*filter_divergence_state ? 1.05f : 1.0f) * seSum > sdSum;
643 
644   // Signal extreme filter divergence if the error is significantly larger
645   // than the nearend (13 dB).
646   *extreme_filter_divergence = (seSum > (19.95f * sdSum));
647 }
648 
649 // Window time domain data to be used by the fft.
WindowDataNEON(float * x_windowed,const float * x)650 static void WindowDataNEON(float* x_windowed, const float* x) {
651   int i;
652   for (i = 0; i < PART_LEN; i += 4) {
653     const float32x4_t vec_Buf1 = vld1q_f32(&x[i]);
654     const float32x4_t vec_Buf2 = vld1q_f32(&x[PART_LEN + i]);
655     const float32x4_t vec_sqrtHanning = vld1q_f32(&WebRtcAec_sqrtHanning[i]);
656     // A B C D
657     float32x4_t vec_sqrtHanning_rev =
658         vld1q_f32(&WebRtcAec_sqrtHanning[PART_LEN - i - 3]);
659     // B A D C
660     vec_sqrtHanning_rev = vrev64q_f32(vec_sqrtHanning_rev);
661     // D C B A
662     vec_sqrtHanning_rev = vcombine_f32(vget_high_f32(vec_sqrtHanning_rev),
663                                        vget_low_f32(vec_sqrtHanning_rev));
664     vst1q_f32(&x_windowed[i], vmulq_f32(vec_Buf1, vec_sqrtHanning));
665     vst1q_f32(&x_windowed[PART_LEN + i],
666               vmulq_f32(vec_Buf2, vec_sqrtHanning_rev));
667   }
668 }
669 
670 // Puts fft output data into a complex valued array.
StoreAsComplexNEON(const float * data,float data_complex[2][PART_LEN1])671 static void StoreAsComplexNEON(const float* data,
672                                float data_complex[2][PART_LEN1]) {
673   int i;
674   for (i = 0; i < PART_LEN; i += 4) {
675     const float32x4x2_t vec_data = vld2q_f32(&data[2 * i]);
676     vst1q_f32(&data_complex[0][i], vec_data.val[0]);
677     vst1q_f32(&data_complex[1][i], vec_data.val[1]);
678   }
679   // fix beginning/end values
680   data_complex[1][0] = 0;
681   data_complex[1][PART_LEN] = 0;
682   data_complex[0][0] = data[0];
683   data_complex[0][PART_LEN] = data[1];
684 }
685 
ComputeCoherenceNEON(const CoherenceState * coherence_state,float * cohde,float * cohxd)686 static void ComputeCoherenceNEON(const CoherenceState* coherence_state,
687                                  float* cohde,
688                                  float* cohxd) {
689   int i;
690 
691   {
692     const float32x4_t vec_1eminus10 = vdupq_n_f32(1e-10f);
693 
694     // Subband coherence
695     for (i = 0; i + 3 < PART_LEN1; i += 4) {
696       const float32x4_t vec_sd = vld1q_f32(&coherence_state->sd[i]);
697       const float32x4_t vec_se = vld1q_f32(&coherence_state->se[i]);
698       const float32x4_t vec_sx = vld1q_f32(&coherence_state->sx[i]);
699       const float32x4_t vec_sdse = vmlaq_f32(vec_1eminus10, vec_sd, vec_se);
700       const float32x4_t vec_sdsx = vmlaq_f32(vec_1eminus10, vec_sd, vec_sx);
701       float32x4x2_t vec_sde = vld2q_f32(&coherence_state->sde[i][0]);
702       float32x4x2_t vec_sxd = vld2q_f32(&coherence_state->sxd[i][0]);
703       float32x4_t vec_cohde = vmulq_f32(vec_sde.val[0], vec_sde.val[0]);
704       float32x4_t vec_cohxd = vmulq_f32(vec_sxd.val[0], vec_sxd.val[0]);
705       vec_cohde = vmlaq_f32(vec_cohde, vec_sde.val[1], vec_sde.val[1]);
706       vec_cohde = vdivq_f32(vec_cohde, vec_sdse);
707       vec_cohxd = vmlaq_f32(vec_cohxd, vec_sxd.val[1], vec_sxd.val[1]);
708       vec_cohxd = vdivq_f32(vec_cohxd, vec_sdsx);
709 
710       vst1q_f32(&cohde[i], vec_cohde);
711       vst1q_f32(&cohxd[i], vec_cohxd);
712     }
713   }
714   // scalar code for the remaining items.
715   for (; i < PART_LEN1; i++) {
716     cohde[i] = (coherence_state->sde[i][0] * coherence_state->sde[i][0] +
717                 coherence_state->sde[i][1] * coherence_state->sde[i][1]) /
718                (coherence_state->sd[i] * coherence_state->se[i] + 1e-10f);
719     cohxd[i] = (coherence_state->sxd[i][0] * coherence_state->sxd[i][0] +
720                 coherence_state->sxd[i][1] * coherence_state->sxd[i][1]) /
721                (coherence_state->sx[i] * coherence_state->sd[i] + 1e-10f);
722   }
723 }
724 
WebRtcAec_InitAec_neon(void)725 void WebRtcAec_InitAec_neon(void) {
726   WebRtcAec_FilterFar = FilterFarNEON;
727   WebRtcAec_ScaleErrorSignal = ScaleErrorSignalNEON;
728   WebRtcAec_FilterAdaptation = FilterAdaptationNEON;
729   WebRtcAec_Overdrive = OverdriveNEON;
730   WebRtcAec_Suppress = SuppressNEON;
731   WebRtcAec_ComputeCoherence = ComputeCoherenceNEON;
732   WebRtcAec_UpdateCoherenceSpectra = UpdateCoherenceSpectraNEON;
733   WebRtcAec_StoreAsComplex = StoreAsComplexNEON;
734   WebRtcAec_PartitionDelay = PartitionDelayNEON;
735   WebRtcAec_WindowData = WindowDataNEON;
736 }
737 }  // namespace webrtc
738