1 /*
2  *  Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 /*
12  * The core AEC algorithm, neon version of speed-critical functions.
13  *
14  * Based on aec_core_sse2.c.
15  */
16 
17 #include <arm_neon.h>
18 #include <math.h>
19 #include <string.h>  // memset
20 
21 #include "webrtc/common_audio/signal_processing/include/signal_processing_library.h"
22 #include "webrtc/modules/audio_processing/aec/aec_common.h"
23 #include "webrtc/modules/audio_processing/aec/aec_core_internal.h"
24 #include "webrtc/modules/audio_processing/aec/aec_rdft.h"
25 
26 enum { kShiftExponentIntoTopMantissa = 8 };
27 enum { kFloatExponentShift = 23 };
28 
MulRe(float aRe,float aIm,float bRe,float bIm)29 __inline static float MulRe(float aRe, float aIm, float bRe, float bIm) {
30   return aRe * bRe - aIm * bIm;
31 }
32 
MulIm(float aRe,float aIm,float bRe,float bIm)33 __inline static float MulIm(float aRe, float aIm, float bRe, float bIm) {
34   return aRe * bIm + aIm * bRe;
35 }
36 
FilterFarNEON(AecCore * aec,float yf[2][PART_LEN1])37 static void FilterFarNEON(AecCore* aec, float yf[2][PART_LEN1]) {
38   int i;
39   const int num_partitions = aec->num_partitions;
40   for (i = 0; i < num_partitions; i++) {
41     int j;
42     int xPos = (i + aec->xfBufBlockPos) * PART_LEN1;
43     int pos = i * PART_LEN1;
44     // Check for wrap
45     if (i + aec->xfBufBlockPos >= num_partitions) {
46       xPos -= num_partitions * PART_LEN1;
47     }
48 
49     // vectorized code (four at once)
50     for (j = 0; j + 3 < PART_LEN1; j += 4) {
51       const float32x4_t xfBuf_re = vld1q_f32(&aec->xfBuf[0][xPos + j]);
52       const float32x4_t xfBuf_im = vld1q_f32(&aec->xfBuf[1][xPos + j]);
53       const float32x4_t wfBuf_re = vld1q_f32(&aec->wfBuf[0][pos + j]);
54       const float32x4_t wfBuf_im = vld1q_f32(&aec->wfBuf[1][pos + j]);
55       const float32x4_t yf_re = vld1q_f32(&yf[0][j]);
56       const float32x4_t yf_im = vld1q_f32(&yf[1][j]);
57       const float32x4_t a = vmulq_f32(xfBuf_re, wfBuf_re);
58       const float32x4_t e = vmlsq_f32(a, xfBuf_im, wfBuf_im);
59       const float32x4_t c = vmulq_f32(xfBuf_re, wfBuf_im);
60       const float32x4_t f = vmlaq_f32(c, xfBuf_im, wfBuf_re);
61       const float32x4_t g = vaddq_f32(yf_re, e);
62       const float32x4_t h = vaddq_f32(yf_im, f);
63       vst1q_f32(&yf[0][j], g);
64       vst1q_f32(&yf[1][j], h);
65     }
66     // scalar code for the remaining items.
67     for (; j < PART_LEN1; j++) {
68       yf[0][j] += MulRe(aec->xfBuf[0][xPos + j],
69                         aec->xfBuf[1][xPos + j],
70                         aec->wfBuf[0][pos + j],
71                         aec->wfBuf[1][pos + j]);
72       yf[1][j] += MulIm(aec->xfBuf[0][xPos + j],
73                         aec->xfBuf[1][xPos + j],
74                         aec->wfBuf[0][pos + j],
75                         aec->wfBuf[1][pos + j]);
76     }
77   }
78 }
79 
80 // ARM64's arm_neon.h has already defined vdivq_f32 vsqrtq_f32.
81 #if !defined (WEBRTC_ARCH_ARM64)
vdivq_f32(float32x4_t a,float32x4_t b)82 static float32x4_t vdivq_f32(float32x4_t a, float32x4_t b) {
83   int i;
84   float32x4_t x = vrecpeq_f32(b);
85   // from arm documentation
86   // The Newton-Raphson iteration:
87   //     x[n+1] = x[n] * (2 - d * x[n])
88   // converges to (1/d) if x0 is the result of VRECPE applied to d.
89   //
90   // Note: The precision did not improve after 2 iterations.
91   for (i = 0; i < 2; i++) {
92     x = vmulq_f32(vrecpsq_f32(b, x), x);
93   }
94   // a/b = a*(1/b)
95   return vmulq_f32(a, x);
96 }
97 
vsqrtq_f32(float32x4_t s)98 static float32x4_t vsqrtq_f32(float32x4_t s) {
99   int i;
100   float32x4_t x = vrsqrteq_f32(s);
101 
102   // Code to handle sqrt(0).
103   // If the input to sqrtf() is zero, a zero will be returned.
104   // If the input to vrsqrteq_f32() is zero, positive infinity is returned.
105   const uint32x4_t vec_p_inf = vdupq_n_u32(0x7F800000);
106   // check for divide by zero
107   const uint32x4_t div_by_zero = vceqq_u32(vec_p_inf, vreinterpretq_u32_f32(x));
108   // zero out the positive infinity results
109   x = vreinterpretq_f32_u32(vandq_u32(vmvnq_u32(div_by_zero),
110                                       vreinterpretq_u32_f32(x)));
111   // from arm documentation
112   // The Newton-Raphson iteration:
113   //     x[n+1] = x[n] * (3 - d * (x[n] * x[n])) / 2)
114   // converges to (1/√d) if x0 is the result of VRSQRTE applied to d.
115   //
116   // Note: The precision did not improve after 2 iterations.
117   for (i = 0; i < 2; i++) {
118     x = vmulq_f32(vrsqrtsq_f32(vmulq_f32(x, x), s), x);
119   }
120   // sqrt(s) = s * 1/sqrt(s)
121   return vmulq_f32(s, x);;
122 }
123 #endif  // WEBRTC_ARCH_ARM64
124 
ScaleErrorSignalNEON(AecCore * aec,float ef[2][PART_LEN1])125 static void ScaleErrorSignalNEON(AecCore* aec, float ef[2][PART_LEN1]) {
126   const float mu = aec->extended_filter_enabled ? kExtendedMu : aec->normal_mu;
127   const float error_threshold = aec->extended_filter_enabled ?
128       kExtendedErrorThreshold : aec->normal_error_threshold;
129   const float32x4_t k1e_10f = vdupq_n_f32(1e-10f);
130   const float32x4_t kMu = vmovq_n_f32(mu);
131   const float32x4_t kThresh = vmovq_n_f32(error_threshold);
132   int i;
133   // vectorized code (four at once)
134   for (i = 0; i + 3 < PART_LEN1; i += 4) {
135     const float32x4_t xPow = vld1q_f32(&aec->xPow[i]);
136     const float32x4_t ef_re_base = vld1q_f32(&ef[0][i]);
137     const float32x4_t ef_im_base = vld1q_f32(&ef[1][i]);
138     const float32x4_t xPowPlus = vaddq_f32(xPow, k1e_10f);
139     float32x4_t ef_re = vdivq_f32(ef_re_base, xPowPlus);
140     float32x4_t ef_im = vdivq_f32(ef_im_base, xPowPlus);
141     const float32x4_t ef_re2 = vmulq_f32(ef_re, ef_re);
142     const float32x4_t ef_sum2 = vmlaq_f32(ef_re2, ef_im, ef_im);
143     const float32x4_t absEf = vsqrtq_f32(ef_sum2);
144     const uint32x4_t bigger = vcgtq_f32(absEf, kThresh);
145     const float32x4_t absEfPlus = vaddq_f32(absEf, k1e_10f);
146     const float32x4_t absEfInv = vdivq_f32(kThresh, absEfPlus);
147     uint32x4_t ef_re_if = vreinterpretq_u32_f32(vmulq_f32(ef_re, absEfInv));
148     uint32x4_t ef_im_if = vreinterpretq_u32_f32(vmulq_f32(ef_im, absEfInv));
149     uint32x4_t ef_re_u32 = vandq_u32(vmvnq_u32(bigger),
150                                      vreinterpretq_u32_f32(ef_re));
151     uint32x4_t ef_im_u32 = vandq_u32(vmvnq_u32(bigger),
152                                      vreinterpretq_u32_f32(ef_im));
153     ef_re_if = vandq_u32(bigger, ef_re_if);
154     ef_im_if = vandq_u32(bigger, ef_im_if);
155     ef_re_u32 = vorrq_u32(ef_re_u32, ef_re_if);
156     ef_im_u32 = vorrq_u32(ef_im_u32, ef_im_if);
157     ef_re = vmulq_f32(vreinterpretq_f32_u32(ef_re_u32), kMu);
158     ef_im = vmulq_f32(vreinterpretq_f32_u32(ef_im_u32), kMu);
159     vst1q_f32(&ef[0][i], ef_re);
160     vst1q_f32(&ef[1][i], ef_im);
161   }
162   // scalar code for the remaining items.
163   for (; i < PART_LEN1; i++) {
164     float abs_ef;
165     ef[0][i] /= (aec->xPow[i] + 1e-10f);
166     ef[1][i] /= (aec->xPow[i] + 1e-10f);
167     abs_ef = sqrtf(ef[0][i] * ef[0][i] + ef[1][i] * ef[1][i]);
168 
169     if (abs_ef > error_threshold) {
170       abs_ef = error_threshold / (abs_ef + 1e-10f);
171       ef[0][i] *= abs_ef;
172       ef[1][i] *= abs_ef;
173     }
174 
175     // Stepsize factor
176     ef[0][i] *= mu;
177     ef[1][i] *= mu;
178   }
179 }
180 
FilterAdaptationNEON(AecCore * aec,float * fft,float ef[2][PART_LEN1])181 static void FilterAdaptationNEON(AecCore* aec,
182                                  float* fft,
183                                  float ef[2][PART_LEN1]) {
184   int i;
185   const int num_partitions = aec->num_partitions;
186   for (i = 0; i < num_partitions; i++) {
187     int xPos = (i + aec->xfBufBlockPos) * PART_LEN1;
188     int pos = i * PART_LEN1;
189     int j;
190     // Check for wrap
191     if (i + aec->xfBufBlockPos >= num_partitions) {
192       xPos -= num_partitions * PART_LEN1;
193     }
194 
195     // Process the whole array...
196     for (j = 0; j < PART_LEN; j += 4) {
197       // Load xfBuf and ef.
198       const float32x4_t xfBuf_re = vld1q_f32(&aec->xfBuf[0][xPos + j]);
199       const float32x4_t xfBuf_im = vld1q_f32(&aec->xfBuf[1][xPos + j]);
200       const float32x4_t ef_re = vld1q_f32(&ef[0][j]);
201       const float32x4_t ef_im = vld1q_f32(&ef[1][j]);
202       // Calculate the product of conjugate(xfBuf) by ef.
203       //   re(conjugate(a) * b) = aRe * bRe + aIm * bIm
204       //   im(conjugate(a) * b)=  aRe * bIm - aIm * bRe
205       const float32x4_t a = vmulq_f32(xfBuf_re, ef_re);
206       const float32x4_t e = vmlaq_f32(a, xfBuf_im, ef_im);
207       const float32x4_t c = vmulq_f32(xfBuf_re, ef_im);
208       const float32x4_t f = vmlsq_f32(c, xfBuf_im, ef_re);
209       // Interleave real and imaginary parts.
210       const float32x4x2_t g_n_h = vzipq_f32(e, f);
211       // Store
212       vst1q_f32(&fft[2 * j + 0], g_n_h.val[0]);
213       vst1q_f32(&fft[2 * j + 4], g_n_h.val[1]);
214     }
215     // ... and fixup the first imaginary entry.
216     fft[1] = MulRe(aec->xfBuf[0][xPos + PART_LEN],
217                    -aec->xfBuf[1][xPos + PART_LEN],
218                    ef[0][PART_LEN],
219                    ef[1][PART_LEN]);
220 
221     aec_rdft_inverse_128(fft);
222     memset(fft + PART_LEN, 0, sizeof(float) * PART_LEN);
223 
224     // fft scaling
225     {
226       const float scale = 2.0f / PART_LEN2;
227       const float32x4_t scale_ps = vmovq_n_f32(scale);
228       for (j = 0; j < PART_LEN; j += 4) {
229         const float32x4_t fft_ps = vld1q_f32(&fft[j]);
230         const float32x4_t fft_scale = vmulq_f32(fft_ps, scale_ps);
231         vst1q_f32(&fft[j], fft_scale);
232       }
233     }
234     aec_rdft_forward_128(fft);
235 
236     {
237       const float wt1 = aec->wfBuf[1][pos];
238       aec->wfBuf[0][pos + PART_LEN] += fft[1];
239       for (j = 0; j < PART_LEN; j += 4) {
240         float32x4_t wtBuf_re = vld1q_f32(&aec->wfBuf[0][pos + j]);
241         float32x4_t wtBuf_im = vld1q_f32(&aec->wfBuf[1][pos + j]);
242         const float32x4_t fft0 = vld1q_f32(&fft[2 * j + 0]);
243         const float32x4_t fft4 = vld1q_f32(&fft[2 * j + 4]);
244         const float32x4x2_t fft_re_im = vuzpq_f32(fft0, fft4);
245         wtBuf_re = vaddq_f32(wtBuf_re, fft_re_im.val[0]);
246         wtBuf_im = vaddq_f32(wtBuf_im, fft_re_im.val[1]);
247 
248         vst1q_f32(&aec->wfBuf[0][pos + j], wtBuf_re);
249         vst1q_f32(&aec->wfBuf[1][pos + j], wtBuf_im);
250       }
251       aec->wfBuf[1][pos] = wt1;
252     }
253   }
254 }
255 
vpowq_f32(float32x4_t a,float32x4_t b)256 static float32x4_t vpowq_f32(float32x4_t a, float32x4_t b) {
257   // a^b = exp2(b * log2(a))
258   //   exp2(x) and log2(x) are calculated using polynomial approximations.
259   float32x4_t log2_a, b_log2_a, a_exp_b;
260 
261   // Calculate log2(x), x = a.
262   {
263     // To calculate log2(x), we decompose x like this:
264     //   x = y * 2^n
265     //     n is an integer
266     //     y is in the [1.0, 2.0) range
267     //
268     //   log2(x) = log2(y) + n
269     //     n       can be evaluated by playing with float representation.
270     //     log2(y) in a small range can be approximated, this code uses an order
271     //             five polynomial approximation. The coefficients have been
272     //             estimated with the Remez algorithm and the resulting
273     //             polynomial has a maximum relative error of 0.00086%.
274 
275     // Compute n.
276     //    This is done by masking the exponent, shifting it into the top bit of
277     //    the mantissa, putting eight into the biased exponent (to shift/
278     //    compensate the fact that the exponent has been shifted in the top/
279     //    fractional part and finally getting rid of the implicit leading one
280     //    from the mantissa by substracting it out.
281     const uint32x4_t vec_float_exponent_mask = vdupq_n_u32(0x7F800000);
282     const uint32x4_t vec_eight_biased_exponent = vdupq_n_u32(0x43800000);
283     const uint32x4_t vec_implicit_leading_one = vdupq_n_u32(0x43BF8000);
284     const uint32x4_t two_n = vandq_u32(vreinterpretq_u32_f32(a),
285                                        vec_float_exponent_mask);
286     const uint32x4_t n_1 = vshrq_n_u32(two_n, kShiftExponentIntoTopMantissa);
287     const uint32x4_t n_0 = vorrq_u32(n_1, vec_eight_biased_exponent);
288     const float32x4_t n =
289         vsubq_f32(vreinterpretq_f32_u32(n_0),
290                   vreinterpretq_f32_u32(vec_implicit_leading_one));
291     // Compute y.
292     const uint32x4_t vec_mantissa_mask = vdupq_n_u32(0x007FFFFF);
293     const uint32x4_t vec_zero_biased_exponent_is_one = vdupq_n_u32(0x3F800000);
294     const uint32x4_t mantissa = vandq_u32(vreinterpretq_u32_f32(a),
295                                           vec_mantissa_mask);
296     const float32x4_t y =
297         vreinterpretq_f32_u32(vorrq_u32(mantissa,
298                                         vec_zero_biased_exponent_is_one));
299     // Approximate log2(y) ~= (y - 1) * pol5(y).
300     //    pol5(y) = C5 * y^5 + C4 * y^4 + C3 * y^3 + C2 * y^2 + C1 * y + C0
301     const float32x4_t C5 = vdupq_n_f32(-3.4436006e-2f);
302     const float32x4_t C4 = vdupq_n_f32(3.1821337e-1f);
303     const float32x4_t C3 = vdupq_n_f32(-1.2315303f);
304     const float32x4_t C2 = vdupq_n_f32(2.5988452f);
305     const float32x4_t C1 = vdupq_n_f32(-3.3241990f);
306     const float32x4_t C0 = vdupq_n_f32(3.1157899f);
307     float32x4_t pol5_y = C5;
308     pol5_y = vmlaq_f32(C4, y, pol5_y);
309     pol5_y = vmlaq_f32(C3, y, pol5_y);
310     pol5_y = vmlaq_f32(C2, y, pol5_y);
311     pol5_y = vmlaq_f32(C1, y, pol5_y);
312     pol5_y = vmlaq_f32(C0, y, pol5_y);
313     const float32x4_t y_minus_one =
314         vsubq_f32(y, vreinterpretq_f32_u32(vec_zero_biased_exponent_is_one));
315     const float32x4_t log2_y = vmulq_f32(y_minus_one, pol5_y);
316 
317     // Combine parts.
318     log2_a = vaddq_f32(n, log2_y);
319   }
320 
321   // b * log2(a)
322   b_log2_a = vmulq_f32(b, log2_a);
323 
324   // Calculate exp2(x), x = b * log2(a).
325   {
326     // To calculate 2^x, we decompose x like this:
327     //   x = n + y
328     //     n is an integer, the value of x - 0.5 rounded down, therefore
329     //     y is in the [0.5, 1.5) range
330     //
331     //   2^x = 2^n * 2^y
332     //     2^n can be evaluated by playing with float representation.
333     //     2^y in a small range can be approximated, this code uses an order two
334     //         polynomial approximation. The coefficients have been estimated
335     //         with the Remez algorithm and the resulting polynomial has a
336     //         maximum relative error of 0.17%.
337     // To avoid over/underflow, we reduce the range of input to ]-127, 129].
338     const float32x4_t max_input = vdupq_n_f32(129.f);
339     const float32x4_t min_input = vdupq_n_f32(-126.99999f);
340     const float32x4_t x_min = vminq_f32(b_log2_a, max_input);
341     const float32x4_t x_max = vmaxq_f32(x_min, min_input);
342     // Compute n.
343     const float32x4_t half = vdupq_n_f32(0.5f);
344     const float32x4_t x_minus_half = vsubq_f32(x_max, half);
345     const int32x4_t x_minus_half_floor = vcvtq_s32_f32(x_minus_half);
346 
347     // Compute 2^n.
348     const int32x4_t float_exponent_bias = vdupq_n_s32(127);
349     const int32x4_t two_n_exponent =
350         vaddq_s32(x_minus_half_floor, float_exponent_bias);
351     const float32x4_t two_n =
352         vreinterpretq_f32_s32(vshlq_n_s32(two_n_exponent, kFloatExponentShift));
353     // Compute y.
354     const float32x4_t y = vsubq_f32(x_max, vcvtq_f32_s32(x_minus_half_floor));
355 
356     // Approximate 2^y ~= C2 * y^2 + C1 * y + C0.
357     const float32x4_t C2 = vdupq_n_f32(3.3718944e-1f);
358     const float32x4_t C1 = vdupq_n_f32(6.5763628e-1f);
359     const float32x4_t C0 = vdupq_n_f32(1.0017247f);
360     float32x4_t exp2_y = C2;
361     exp2_y = vmlaq_f32(C1, y, exp2_y);
362     exp2_y = vmlaq_f32(C0, y, exp2_y);
363 
364     // Combine parts.
365     a_exp_b = vmulq_f32(exp2_y, two_n);
366   }
367 
368   return a_exp_b;
369 }
370 
OverdriveAndSuppressNEON(AecCore * aec,float hNl[PART_LEN1],const float hNlFb,float efw[2][PART_LEN1])371 static void OverdriveAndSuppressNEON(AecCore* aec,
372                                      float hNl[PART_LEN1],
373                                      const float hNlFb,
374                                      float efw[2][PART_LEN1]) {
375   int i;
376   const float32x4_t vec_hNlFb = vmovq_n_f32(hNlFb);
377   const float32x4_t vec_one = vdupq_n_f32(1.0f);
378   const float32x4_t vec_minus_one = vdupq_n_f32(-1.0f);
379   const float32x4_t vec_overDriveSm = vmovq_n_f32(aec->overDriveSm);
380 
381   // vectorized code (four at once)
382   for (i = 0; i + 3 < PART_LEN1; i += 4) {
383     // Weight subbands
384     float32x4_t vec_hNl = vld1q_f32(&hNl[i]);
385     const float32x4_t vec_weightCurve = vld1q_f32(&WebRtcAec_weightCurve[i]);
386     const uint32x4_t bigger = vcgtq_f32(vec_hNl, vec_hNlFb);
387     const float32x4_t vec_weightCurve_hNlFb = vmulq_f32(vec_weightCurve,
388                                                         vec_hNlFb);
389     const float32x4_t vec_one_weightCurve = vsubq_f32(vec_one, vec_weightCurve);
390     const float32x4_t vec_one_weightCurve_hNl = vmulq_f32(vec_one_weightCurve,
391                                                           vec_hNl);
392     const uint32x4_t vec_if0 = vandq_u32(vmvnq_u32(bigger),
393                                          vreinterpretq_u32_f32(vec_hNl));
394     const float32x4_t vec_one_weightCurve_add =
395         vaddq_f32(vec_weightCurve_hNlFb, vec_one_weightCurve_hNl);
396     const uint32x4_t vec_if1 =
397         vandq_u32(bigger, vreinterpretq_u32_f32(vec_one_weightCurve_add));
398 
399     vec_hNl = vreinterpretq_f32_u32(vorrq_u32(vec_if0, vec_if1));
400 
401     {
402       const float32x4_t vec_overDriveCurve =
403           vld1q_f32(&WebRtcAec_overDriveCurve[i]);
404       const float32x4_t vec_overDriveSm_overDriveCurve =
405           vmulq_f32(vec_overDriveSm, vec_overDriveCurve);
406       vec_hNl = vpowq_f32(vec_hNl, vec_overDriveSm_overDriveCurve);
407       vst1q_f32(&hNl[i], vec_hNl);
408     }
409 
410     // Suppress error signal
411     {
412       float32x4_t vec_efw_re = vld1q_f32(&efw[0][i]);
413       float32x4_t vec_efw_im = vld1q_f32(&efw[1][i]);
414       vec_efw_re = vmulq_f32(vec_efw_re, vec_hNl);
415       vec_efw_im = vmulq_f32(vec_efw_im, vec_hNl);
416 
417       // Ooura fft returns incorrect sign on imaginary component. It matters
418       // here because we are making an additive change with comfort noise.
419       vec_efw_im = vmulq_f32(vec_efw_im, vec_minus_one);
420       vst1q_f32(&efw[0][i], vec_efw_re);
421       vst1q_f32(&efw[1][i], vec_efw_im);
422     }
423   }
424 
425   // scalar code for the remaining items.
426   for (; i < PART_LEN1; i++) {
427     // Weight subbands
428     if (hNl[i] > hNlFb) {
429       hNl[i] = WebRtcAec_weightCurve[i] * hNlFb +
430                (1 - WebRtcAec_weightCurve[i]) * hNl[i];
431     }
432 
433     hNl[i] = powf(hNl[i], aec->overDriveSm * WebRtcAec_overDriveCurve[i]);
434 
435     // Suppress error signal
436     efw[0][i] *= hNl[i];
437     efw[1][i] *= hNl[i];
438 
439     // Ooura fft returns incorrect sign on imaginary component. It matters
440     // here because we are making an additive change with comfort noise.
441     efw[1][i] *= -1;
442   }
443 }
444 
PartitionDelay(const AecCore * aec)445 static int PartitionDelay(const AecCore* aec) {
446   // Measures the energy in each filter partition and returns the partition with
447   // highest energy.
448   // TODO(bjornv): Spread computational cost by computing one partition per
449   // block?
450   float wfEnMax = 0;
451   int i;
452   int delay = 0;
453 
454   for (i = 0; i < aec->num_partitions; i++) {
455     int j;
456     int pos = i * PART_LEN1;
457     float wfEn = 0;
458     float32x4_t vec_wfEn = vdupq_n_f32(0.0f);
459     // vectorized code (four at once)
460     for (j = 0; j + 3 < PART_LEN1; j += 4) {
461       const float32x4_t vec_wfBuf0 = vld1q_f32(&aec->wfBuf[0][pos + j]);
462       const float32x4_t vec_wfBuf1 = vld1q_f32(&aec->wfBuf[1][pos + j]);
463       vec_wfEn = vmlaq_f32(vec_wfEn, vec_wfBuf0, vec_wfBuf0);
464       vec_wfEn = vmlaq_f32(vec_wfEn, vec_wfBuf1, vec_wfBuf1);
465     }
466     {
467       float32x2_t vec_total;
468       // A B C D
469       vec_total = vpadd_f32(vget_low_f32(vec_wfEn), vget_high_f32(vec_wfEn));
470       // A+B C+D
471       vec_total = vpadd_f32(vec_total, vec_total);
472       // A+B+C+D A+B+C+D
473       wfEn = vget_lane_f32(vec_total, 0);
474     }
475 
476     // scalar code for the remaining items.
477     for (; j < PART_LEN1; j++) {
478       wfEn += aec->wfBuf[0][pos + j] * aec->wfBuf[0][pos + j] +
479               aec->wfBuf[1][pos + j] * aec->wfBuf[1][pos + j];
480     }
481 
482     if (wfEn > wfEnMax) {
483       wfEnMax = wfEn;
484       delay = i;
485     }
486   }
487   return delay;
488 }
489 
490 // Updates the following smoothed  Power Spectral Densities (PSD):
491 //  - sd  : near-end
492 //  - se  : residual echo
493 //  - sx  : far-end
494 //  - sde : cross-PSD of near-end and residual echo
495 //  - sxd : cross-PSD of near-end and far-end
496 //
497 // In addition to updating the PSDs, also the filter diverge state is determined
498 // upon actions are taken.
SmoothedPSD(AecCore * aec,float efw[2][PART_LEN1],float dfw[2][PART_LEN1],float xfw[2][PART_LEN1])499 static void SmoothedPSD(AecCore* aec,
500                         float efw[2][PART_LEN1],
501                         float dfw[2][PART_LEN1],
502                         float xfw[2][PART_LEN1]) {
503   // Power estimate smoothing coefficients.
504   const float* ptrGCoh = aec->extended_filter_enabled
505       ? WebRtcAec_kExtendedSmoothingCoefficients[aec->mult - 1]
506       : WebRtcAec_kNormalSmoothingCoefficients[aec->mult - 1];
507   int i;
508   float sdSum = 0, seSum = 0;
509   const float32x4_t vec_15 =  vdupq_n_f32(WebRtcAec_kMinFarendPSD);
510   float32x4_t vec_sdSum = vdupq_n_f32(0.0f);
511   float32x4_t vec_seSum = vdupq_n_f32(0.0f);
512 
513   for (i = 0; i + 3 < PART_LEN1; i += 4) {
514     const float32x4_t vec_dfw0 = vld1q_f32(&dfw[0][i]);
515     const float32x4_t vec_dfw1 = vld1q_f32(&dfw[1][i]);
516     const float32x4_t vec_efw0 = vld1q_f32(&efw[0][i]);
517     const float32x4_t vec_efw1 = vld1q_f32(&efw[1][i]);
518     const float32x4_t vec_xfw0 = vld1q_f32(&xfw[0][i]);
519     const float32x4_t vec_xfw1 = vld1q_f32(&xfw[1][i]);
520     float32x4_t vec_sd = vmulq_n_f32(vld1q_f32(&aec->sd[i]), ptrGCoh[0]);
521     float32x4_t vec_se = vmulq_n_f32(vld1q_f32(&aec->se[i]), ptrGCoh[0]);
522     float32x4_t vec_sx = vmulq_n_f32(vld1q_f32(&aec->sx[i]), ptrGCoh[0]);
523     float32x4_t vec_dfw_sumsq = vmulq_f32(vec_dfw0, vec_dfw0);
524     float32x4_t vec_efw_sumsq = vmulq_f32(vec_efw0, vec_efw0);
525     float32x4_t vec_xfw_sumsq = vmulq_f32(vec_xfw0, vec_xfw0);
526 
527     vec_dfw_sumsq = vmlaq_f32(vec_dfw_sumsq, vec_dfw1, vec_dfw1);
528     vec_efw_sumsq = vmlaq_f32(vec_efw_sumsq, vec_efw1, vec_efw1);
529     vec_xfw_sumsq = vmlaq_f32(vec_xfw_sumsq, vec_xfw1, vec_xfw1);
530     vec_xfw_sumsq = vmaxq_f32(vec_xfw_sumsq, vec_15);
531     vec_sd = vmlaq_n_f32(vec_sd, vec_dfw_sumsq, ptrGCoh[1]);
532     vec_se = vmlaq_n_f32(vec_se, vec_efw_sumsq, ptrGCoh[1]);
533     vec_sx = vmlaq_n_f32(vec_sx, vec_xfw_sumsq, ptrGCoh[1]);
534 
535     vst1q_f32(&aec->sd[i], vec_sd);
536     vst1q_f32(&aec->se[i], vec_se);
537     vst1q_f32(&aec->sx[i], vec_sx);
538 
539     {
540       float32x4x2_t vec_sde = vld2q_f32(&aec->sde[i][0]);
541       float32x4_t vec_dfwefw0011 = vmulq_f32(vec_dfw0, vec_efw0);
542       float32x4_t vec_dfwefw0110 = vmulq_f32(vec_dfw0, vec_efw1);
543       vec_sde.val[0] = vmulq_n_f32(vec_sde.val[0], ptrGCoh[0]);
544       vec_sde.val[1] = vmulq_n_f32(vec_sde.val[1], ptrGCoh[0]);
545       vec_dfwefw0011 = vmlaq_f32(vec_dfwefw0011, vec_dfw1, vec_efw1);
546       vec_dfwefw0110 = vmlsq_f32(vec_dfwefw0110, vec_dfw1, vec_efw0);
547       vec_sde.val[0] = vmlaq_n_f32(vec_sde.val[0], vec_dfwefw0011, ptrGCoh[1]);
548       vec_sde.val[1] = vmlaq_n_f32(vec_sde.val[1], vec_dfwefw0110, ptrGCoh[1]);
549       vst2q_f32(&aec->sde[i][0], vec_sde);
550     }
551 
552     {
553       float32x4x2_t vec_sxd = vld2q_f32(&aec->sxd[i][0]);
554       float32x4_t vec_dfwxfw0011 = vmulq_f32(vec_dfw0, vec_xfw0);
555       float32x4_t vec_dfwxfw0110 = vmulq_f32(vec_dfw0, vec_xfw1);
556       vec_sxd.val[0] = vmulq_n_f32(vec_sxd.val[0], ptrGCoh[0]);
557       vec_sxd.val[1] = vmulq_n_f32(vec_sxd.val[1], ptrGCoh[0]);
558       vec_dfwxfw0011 = vmlaq_f32(vec_dfwxfw0011, vec_dfw1, vec_xfw1);
559       vec_dfwxfw0110 = vmlsq_f32(vec_dfwxfw0110, vec_dfw1, vec_xfw0);
560       vec_sxd.val[0] = vmlaq_n_f32(vec_sxd.val[0], vec_dfwxfw0011, ptrGCoh[1]);
561       vec_sxd.val[1] = vmlaq_n_f32(vec_sxd.val[1], vec_dfwxfw0110, ptrGCoh[1]);
562       vst2q_f32(&aec->sxd[i][0], vec_sxd);
563     }
564 
565     vec_sdSum = vaddq_f32(vec_sdSum, vec_sd);
566     vec_seSum = vaddq_f32(vec_seSum, vec_se);
567   }
568   {
569     float32x2_t vec_sdSum_total;
570     float32x2_t vec_seSum_total;
571     // A B C D
572     vec_sdSum_total = vpadd_f32(vget_low_f32(vec_sdSum),
573                                 vget_high_f32(vec_sdSum));
574     vec_seSum_total = vpadd_f32(vget_low_f32(vec_seSum),
575                                 vget_high_f32(vec_seSum));
576     // A+B C+D
577     vec_sdSum_total = vpadd_f32(vec_sdSum_total, vec_sdSum_total);
578     vec_seSum_total = vpadd_f32(vec_seSum_total, vec_seSum_total);
579     // A+B+C+D A+B+C+D
580     sdSum = vget_lane_f32(vec_sdSum_total, 0);
581     seSum = vget_lane_f32(vec_seSum_total, 0);
582   }
583 
584   // scalar code for the remaining items.
585   for (; i < PART_LEN1; i++) {
586     aec->sd[i] = ptrGCoh[0] * aec->sd[i] +
587                  ptrGCoh[1] * (dfw[0][i] * dfw[0][i] + dfw[1][i] * dfw[1][i]);
588     aec->se[i] = ptrGCoh[0] * aec->se[i] +
589                  ptrGCoh[1] * (efw[0][i] * efw[0][i] + efw[1][i] * efw[1][i]);
590     // We threshold here to protect against the ill-effects of a zero farend.
591     // The threshold is not arbitrarily chosen, but balances protection and
592     // adverse interaction with the algorithm's tuning.
593     // TODO(bjornv): investigate further why this is so sensitive.
594     aec->sx[i] =
595         ptrGCoh[0] * aec->sx[i] +
596         ptrGCoh[1] * WEBRTC_SPL_MAX(
597             xfw[0][i] * xfw[0][i] + xfw[1][i] * xfw[1][i],
598             WebRtcAec_kMinFarendPSD);
599 
600     aec->sde[i][0] =
601         ptrGCoh[0] * aec->sde[i][0] +
602         ptrGCoh[1] * (dfw[0][i] * efw[0][i] + dfw[1][i] * efw[1][i]);
603     aec->sde[i][1] =
604         ptrGCoh[0] * aec->sde[i][1] +
605         ptrGCoh[1] * (dfw[0][i] * efw[1][i] - dfw[1][i] * efw[0][i]);
606 
607     aec->sxd[i][0] =
608         ptrGCoh[0] * aec->sxd[i][0] +
609         ptrGCoh[1] * (dfw[0][i] * xfw[0][i] + dfw[1][i] * xfw[1][i]);
610     aec->sxd[i][1] =
611         ptrGCoh[0] * aec->sxd[i][1] +
612         ptrGCoh[1] * (dfw[0][i] * xfw[1][i] - dfw[1][i] * xfw[0][i]);
613 
614     sdSum += aec->sd[i];
615     seSum += aec->se[i];
616   }
617 
618   // Divergent filter safeguard.
619   aec->divergeState = (aec->divergeState ? 1.05f : 1.0f) * seSum > sdSum;
620 
621   if (aec->divergeState)
622     memcpy(efw, dfw, sizeof(efw[0][0]) * 2 * PART_LEN1);
623 
624   // Reset if error is significantly larger than nearend (13 dB).
625   if (!aec->extended_filter_enabled && seSum > (19.95f * sdSum))
626     memset(aec->wfBuf, 0, sizeof(aec->wfBuf));
627 }
628 
629 // Window time domain data to be used by the fft.
WindowData(float * x_windowed,const float * x)630 __inline static void WindowData(float* x_windowed, const float* x) {
631   int i;
632   for (i = 0; i < PART_LEN; i += 4) {
633     const float32x4_t vec_Buf1 = vld1q_f32(&x[i]);
634     const float32x4_t vec_Buf2 = vld1q_f32(&x[PART_LEN + i]);
635     const float32x4_t vec_sqrtHanning = vld1q_f32(&WebRtcAec_sqrtHanning[i]);
636     // A B C D
637     float32x4_t vec_sqrtHanning_rev =
638         vld1q_f32(&WebRtcAec_sqrtHanning[PART_LEN - i - 3]);
639     // B A D C
640     vec_sqrtHanning_rev = vrev64q_f32(vec_sqrtHanning_rev);
641     // D C B A
642     vec_sqrtHanning_rev = vcombine_f32(vget_high_f32(vec_sqrtHanning_rev),
643                                        vget_low_f32(vec_sqrtHanning_rev));
644     vst1q_f32(&x_windowed[i], vmulq_f32(vec_Buf1, vec_sqrtHanning));
645     vst1q_f32(&x_windowed[PART_LEN + i],
646             vmulq_f32(vec_Buf2, vec_sqrtHanning_rev));
647   }
648 }
649 
650 // Puts fft output data into a complex valued array.
StoreAsComplex(const float * data,float data_complex[2][PART_LEN1])651 __inline static void StoreAsComplex(const float* data,
652                                     float data_complex[2][PART_LEN1]) {
653   int i;
654   for (i = 0; i < PART_LEN; i += 4) {
655     const float32x4x2_t vec_data = vld2q_f32(&data[2 * i]);
656     vst1q_f32(&data_complex[0][i], vec_data.val[0]);
657     vst1q_f32(&data_complex[1][i], vec_data.val[1]);
658   }
659   // fix beginning/end values
660   data_complex[1][0] = 0;
661   data_complex[1][PART_LEN] = 0;
662   data_complex[0][0] = data[0];
663   data_complex[0][PART_LEN] = data[1];
664 }
665 
SubbandCoherenceNEON(AecCore * aec,float efw[2][PART_LEN1],float xfw[2][PART_LEN1],float * fft,float * cohde,float * cohxd)666 static void SubbandCoherenceNEON(AecCore* aec,
667                                  float efw[2][PART_LEN1],
668                                  float xfw[2][PART_LEN1],
669                                  float* fft,
670                                  float* cohde,
671                                  float* cohxd) {
672   float dfw[2][PART_LEN1];
673   int i;
674 
675   if (aec->delayEstCtr == 0)
676     aec->delayIdx = PartitionDelay(aec);
677 
678   // Use delayed far.
679   memcpy(xfw,
680          aec->xfwBuf + aec->delayIdx * PART_LEN1,
681          sizeof(xfw[0][0]) * 2 * PART_LEN1);
682 
683   // Windowed near fft
684   WindowData(fft, aec->dBuf);
685   aec_rdft_forward_128(fft);
686   StoreAsComplex(fft, dfw);
687 
688   // Windowed error fft
689   WindowData(fft, aec->eBuf);
690   aec_rdft_forward_128(fft);
691   StoreAsComplex(fft, efw);
692 
693   SmoothedPSD(aec, efw, dfw, xfw);
694 
695   {
696     const float32x4_t vec_1eminus10 =  vdupq_n_f32(1e-10f);
697 
698     // Subband coherence
699     for (i = 0; i + 3 < PART_LEN1; i += 4) {
700       const float32x4_t vec_sd = vld1q_f32(&aec->sd[i]);
701       const float32x4_t vec_se = vld1q_f32(&aec->se[i]);
702       const float32x4_t vec_sx = vld1q_f32(&aec->sx[i]);
703       const float32x4_t vec_sdse = vmlaq_f32(vec_1eminus10, vec_sd, vec_se);
704       const float32x4_t vec_sdsx = vmlaq_f32(vec_1eminus10, vec_sd, vec_sx);
705       float32x4x2_t vec_sde = vld2q_f32(&aec->sde[i][0]);
706       float32x4x2_t vec_sxd = vld2q_f32(&aec->sxd[i][0]);
707       float32x4_t vec_cohde = vmulq_f32(vec_sde.val[0], vec_sde.val[0]);
708       float32x4_t vec_cohxd = vmulq_f32(vec_sxd.val[0], vec_sxd.val[0]);
709       vec_cohde = vmlaq_f32(vec_cohde, vec_sde.val[1], vec_sde.val[1]);
710       vec_cohde = vdivq_f32(vec_cohde, vec_sdse);
711       vec_cohxd = vmlaq_f32(vec_cohxd, vec_sxd.val[1], vec_sxd.val[1]);
712       vec_cohxd = vdivq_f32(vec_cohxd, vec_sdsx);
713 
714       vst1q_f32(&cohde[i], vec_cohde);
715       vst1q_f32(&cohxd[i], vec_cohxd);
716     }
717   }
718   // scalar code for the remaining items.
719   for (; i < PART_LEN1; i++) {
720     cohde[i] =
721         (aec->sde[i][0] * aec->sde[i][0] + aec->sde[i][1] * aec->sde[i][1]) /
722         (aec->sd[i] * aec->se[i] + 1e-10f);
723     cohxd[i] =
724         (aec->sxd[i][0] * aec->sxd[i][0] + aec->sxd[i][1] * aec->sxd[i][1]) /
725         (aec->sx[i] * aec->sd[i] + 1e-10f);
726   }
727 }
728 
WebRtcAec_InitAec_neon(void)729 void WebRtcAec_InitAec_neon(void) {
730   WebRtcAec_FilterFar = FilterFarNEON;
731   WebRtcAec_ScaleErrorSignal = ScaleErrorSignalNEON;
732   WebRtcAec_FilterAdaptation = FilterAdaptationNEON;
733   WebRtcAec_OverdriveAndSuppress = OverdriveAndSuppressNEON;
734   WebRtcAec_SubbandCoherence = SubbandCoherenceNEON;
735 }
736 
737