1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/adaptive_fir_filter.h"
12 
13 #if defined(WEBRTC_HAS_NEON)
14 #include <arm_neon.h>
15 #endif
16 #include "typedefs.h"  // NOLINT(build/include)
17 #if defined(WEBRTC_ARCH_X86_FAMILY)
18 #include <emmintrin.h>
19 #endif
20 #include <algorithm>
21 #include <functional>
22 
23 #include "modules/audio_processing/aec3/fft_data.h"
24 #include "rtc_base/checks.h"
25 
26 namespace webrtc {
27 
28 namespace aec3 {
29 
30 // Computes and stores the frequency response of the filter.
UpdateFrequencyResponse(rtc::ArrayView<const FftData> H,std::vector<std::array<float,kFftLengthBy2Plus1>> * H2)31 void UpdateFrequencyResponse(
32     rtc::ArrayView<const FftData> H,
33     std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) {
34   RTC_DCHECK_EQ(H.size(), H2->size());
35   for (size_t k = 0; k < H.size(); ++k) {
36     std::transform(H[k].re.begin(), H[k].re.end(), H[k].im.begin(),
37                    (*H2)[k].begin(),
38                    [](float a, float b) { return a * a + b * b; });
39   }
40 }
41 
42 #if defined(WEBRTC_HAS_NEON)
43 // Computes and stores the frequency response of the filter.
UpdateFrequencyResponse_NEON(rtc::ArrayView<const FftData> H,std::vector<std::array<float,kFftLengthBy2Plus1>> * H2)44 void UpdateFrequencyResponse_NEON(
45     rtc::ArrayView<const FftData> H,
46     std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) {
47   RTC_DCHECK_EQ(H.size(), H2->size());
48   for (size_t k = 0; k < H.size(); ++k) {
49     for (size_t j = 0; j < kFftLengthBy2; j += 4) {
50       const float32x4_t re = vld1q_f32(&H[k].re[j]);
51       const float32x4_t im = vld1q_f32(&H[k].im[j]);
52       float32x4_t H2_k_j = vmulq_f32(re, re);
53       H2_k_j = vmlaq_f32(H2_k_j, im, im);
54       vst1q_f32(&(*H2)[k][j], H2_k_j);
55     }
56     (*H2)[k][kFftLengthBy2] = H[k].re[kFftLengthBy2] * H[k].re[kFftLengthBy2] +
57                               H[k].im[kFftLengthBy2] * H[k].im[kFftLengthBy2];
58   }
59 }
60 #endif
61 
62 #if defined(WEBRTC_ARCH_X86_FAMILY)
63 // Computes and stores the frequency response of the filter.
UpdateFrequencyResponse_SSE2(rtc::ArrayView<const FftData> H,std::vector<std::array<float,kFftLengthBy2Plus1>> * H2)64 void UpdateFrequencyResponse_SSE2(
65     rtc::ArrayView<const FftData> H,
66     std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) {
67   RTC_DCHECK_EQ(H.size(), H2->size());
68   for (size_t k = 0; k < H.size(); ++k) {
69     for (size_t j = 0; j < kFftLengthBy2; j += 4) {
70       const __m128 re = _mm_loadu_ps(&H[k].re[j]);
71       const __m128 re2 = _mm_mul_ps(re, re);
72       const __m128 im = _mm_loadu_ps(&H[k].im[j]);
73       const __m128 im2 = _mm_mul_ps(im, im);
74       const __m128 H2_k_j = _mm_add_ps(re2, im2);
75       _mm_storeu_ps(&(*H2)[k][j], H2_k_j);
76     }
77     (*H2)[k][kFftLengthBy2] = H[k].re[kFftLengthBy2] * H[k].re[kFftLengthBy2] +
78                               H[k].im[kFftLengthBy2] * H[k].im[kFftLengthBy2];
79   }
80 }
81 #endif
82 
83 // Computes and stores the echo return loss estimate of the filter, which is the
84 // sum of the partition frequency responses.
UpdateErlEstimator(const std::vector<std::array<float,kFftLengthBy2Plus1>> & H2,std::array<float,kFftLengthBy2Plus1> * erl)85 void UpdateErlEstimator(
86     const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
87     std::array<float, kFftLengthBy2Plus1>* erl) {
88   erl->fill(0.f);
89   for (auto& H2_j : H2) {
90     std::transform(H2_j.begin(), H2_j.end(), erl->begin(), erl->begin(),
91                    std::plus<float>());
92   }
93 }
94 
95 #if defined(WEBRTC_HAS_NEON)
96 // Computes and stores the echo return loss estimate of the filter, which is the
97 // sum of the partition frequency responses.
UpdateErlEstimator_NEON(const std::vector<std::array<float,kFftLengthBy2Plus1>> & H2,std::array<float,kFftLengthBy2Plus1> * erl)98 void UpdateErlEstimator_NEON(
99     const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
100     std::array<float, kFftLengthBy2Plus1>* erl) {
101   erl->fill(0.f);
102   for (auto& H2_j : H2) {
103     for (size_t k = 0; k < kFftLengthBy2; k += 4) {
104       const float32x4_t H2_j_k = vld1q_f32(&H2_j[k]);
105       float32x4_t erl_k = vld1q_f32(&(*erl)[k]);
106       erl_k = vaddq_f32(erl_k, H2_j_k);
107       vst1q_f32(&(*erl)[k], erl_k);
108     }
109     (*erl)[kFftLengthBy2] += H2_j[kFftLengthBy2];
110   }
111 }
112 #endif
113 
114 #if defined(WEBRTC_ARCH_X86_FAMILY)
115 // Computes and stores the echo return loss estimate of the filter, which is the
116 // sum of the partition frequency responses.
UpdateErlEstimator_SSE2(const std::vector<std::array<float,kFftLengthBy2Plus1>> & H2,std::array<float,kFftLengthBy2Plus1> * erl)117 void UpdateErlEstimator_SSE2(
118     const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
119     std::array<float, kFftLengthBy2Plus1>* erl) {
120   erl->fill(0.f);
121   for (auto& H2_j : H2) {
122     for (size_t k = 0; k < kFftLengthBy2; k += 4) {
123       const __m128 H2_j_k = _mm_loadu_ps(&H2_j[k]);
124       __m128 erl_k = _mm_loadu_ps(&(*erl)[k]);
125       erl_k = _mm_add_ps(erl_k, H2_j_k);
126       _mm_storeu_ps(&(*erl)[k], erl_k);
127     }
128     (*erl)[kFftLengthBy2] += H2_j[kFftLengthBy2];
129   }
130 }
131 #endif
132 
133 // Adapts the filter partitions as H(t+1)=H(t)+G(t)*conj(X(t)).
AdaptPartitions(const RenderBuffer & render_buffer,const FftData & G,rtc::ArrayView<FftData> H)134 void AdaptPartitions(const RenderBuffer& render_buffer,
135                      const FftData& G,
136                      rtc::ArrayView<FftData> H) {
137   rtc::ArrayView<const FftData> render_buffer_data = render_buffer.Buffer();
138   size_t index = render_buffer.Position();
139   for (auto& H_j : H) {
140     const FftData& X = render_buffer_data[index];
141     for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
142       H_j.re[k] += X.re[k] * G.re[k] + X.im[k] * G.im[k];
143       H_j.im[k] += X.re[k] * G.im[k] - X.im[k] * G.re[k];
144     }
145 
146     index = index < (render_buffer_data.size() - 1) ? index + 1 : 0;
147   }
148 }
149 
150 #if defined(WEBRTC_HAS_NEON)
151 // Adapts the filter partitions. (NEON variant)
AdaptPartitions_NEON(const RenderBuffer & render_buffer,const FftData & G,rtc::ArrayView<FftData> H)152 void AdaptPartitions_NEON(const RenderBuffer& render_buffer,
153                           const FftData& G,
154                           rtc::ArrayView<FftData> H) {
155   rtc::ArrayView<const FftData> render_buffer_data = render_buffer.Buffer();
156   const int lim1 =
157       std::min(render_buffer_data.size() - render_buffer.Position(), H.size());
158   const int lim2 = H.size();
159   constexpr int kNumFourBinBands = kFftLengthBy2 / 4;
160   FftData* H_j = &H[0];
161   const FftData* X = &render_buffer_data[render_buffer.Position()];
162   int limit = lim1;
163   int j = 0;
164   do {
165     for (; j < limit; ++j, ++H_j, ++X) {
166       for (int k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) {
167         const float32x4_t G_re = vld1q_f32(&G.re[k]);
168         const float32x4_t G_im = vld1q_f32(&G.im[k]);
169         const float32x4_t X_re = vld1q_f32(&X->re[k]);
170         const float32x4_t X_im = vld1q_f32(&X->im[k]);
171         const float32x4_t H_re = vld1q_f32(&H_j->re[k]);
172         const float32x4_t H_im = vld1q_f32(&H_j->im[k]);
173         const float32x4_t a = vmulq_f32(X_re, G_re);
174         const float32x4_t e = vmlaq_f32(a, X_im, G_im);
175         const float32x4_t c = vmulq_f32(X_re, G_im);
176         const float32x4_t f = vmlsq_f32(c, X_im, G_re);
177         const float32x4_t g = vaddq_f32(H_re, e);
178         const float32x4_t h = vaddq_f32(H_im, f);
179 
180         vst1q_f32(&H_j->re[k], g);
181         vst1q_f32(&H_j->im[k], h);
182       }
183     }
184 
185     X = &render_buffer_data[0];
186     limit = lim2;
187   } while (j < lim2);
188 
189   H_j = &H[0];
190   X = &render_buffer_data[render_buffer.Position()];
191   limit = lim1;
192   j = 0;
193   do {
194     for (; j < limit; ++j, ++H_j, ++X) {
195       H_j->re[kFftLengthBy2] += X->re[kFftLengthBy2] * G.re[kFftLengthBy2] +
196                                 X->im[kFftLengthBy2] * G.im[kFftLengthBy2];
197       H_j->im[kFftLengthBy2] += X->re[kFftLengthBy2] * G.im[kFftLengthBy2] -
198                                 X->im[kFftLengthBy2] * G.re[kFftLengthBy2];
199     }
200 
201     X = &render_buffer_data[0];
202     limit = lim2;
203   } while (j < lim2);
204 }
205 #endif
206 
207 #if defined(WEBRTC_ARCH_X86_FAMILY)
208 // Adapts the filter partitions. (SSE2 variant)
AdaptPartitions_SSE2(const RenderBuffer & render_buffer,const FftData & G,rtc::ArrayView<FftData> H)209 void AdaptPartitions_SSE2(const RenderBuffer& render_buffer,
210                           const FftData& G,
211                           rtc::ArrayView<FftData> H) {
212   rtc::ArrayView<const FftData> render_buffer_data = render_buffer.Buffer();
213   const int lim1 =
214       std::min(render_buffer_data.size() - render_buffer.Position(), H.size());
215   const int lim2 = H.size();
216   constexpr int kNumFourBinBands = kFftLengthBy2 / 4;
217   FftData* H_j;
218   const FftData* X;
219   int limit;
220   int j;
221   for (int k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) {
222     const __m128 G_re = _mm_loadu_ps(&G.re[k]);
223     const __m128 G_im = _mm_loadu_ps(&G.im[k]);
224 
225     H_j = &H[0];
226     X = &render_buffer_data[render_buffer.Position()];
227     limit = lim1;
228     j = 0;
229     do {
230       for (; j < limit; ++j, ++H_j, ++X) {
231         const __m128 X_re = _mm_loadu_ps(&X->re[k]);
232         const __m128 X_im = _mm_loadu_ps(&X->im[k]);
233         const __m128 H_re = _mm_loadu_ps(&H_j->re[k]);
234         const __m128 H_im = _mm_loadu_ps(&H_j->im[k]);
235         const __m128 a = _mm_mul_ps(X_re, G_re);
236         const __m128 b = _mm_mul_ps(X_im, G_im);
237         const __m128 c = _mm_mul_ps(X_re, G_im);
238         const __m128 d = _mm_mul_ps(X_im, G_re);
239         const __m128 e = _mm_add_ps(a, b);
240         const __m128 f = _mm_sub_ps(c, d);
241         const __m128 g = _mm_add_ps(H_re, e);
242         const __m128 h = _mm_add_ps(H_im, f);
243         _mm_storeu_ps(&H_j->re[k], g);
244         _mm_storeu_ps(&H_j->im[k], h);
245       }
246 
247       X = &render_buffer_data[0];
248       limit = lim2;
249     } while (j < lim2);
250   }
251 
252   H_j = &H[0];
253   X = &render_buffer_data[render_buffer.Position()];
254   limit = lim1;
255   j = 0;
256   do {
257     for (; j < limit; ++j, ++H_j, ++X) {
258       H_j->re[kFftLengthBy2] += X->re[kFftLengthBy2] * G.re[kFftLengthBy2] +
259                                 X->im[kFftLengthBy2] * G.im[kFftLengthBy2];
260       H_j->im[kFftLengthBy2] += X->re[kFftLengthBy2] * G.im[kFftLengthBy2] -
261                                 X->im[kFftLengthBy2] * G.re[kFftLengthBy2];
262     }
263 
264     X = &render_buffer_data[0];
265     limit = lim2;
266   } while (j < lim2);
267 }
268 #endif
269 
270 // Produces the filter output.
ApplyFilter(const RenderBuffer & render_buffer,rtc::ArrayView<const FftData> H,FftData * S)271 void ApplyFilter(const RenderBuffer& render_buffer,
272                  rtc::ArrayView<const FftData> H,
273                  FftData* S) {
274   S->re.fill(0.f);
275   S->im.fill(0.f);
276 
277   rtc::ArrayView<const FftData> render_buffer_data = render_buffer.Buffer();
278   size_t index = render_buffer.Position();
279   for (auto& H_j : H) {
280     const FftData& X = render_buffer_data[index];
281     for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
282       S->re[k] += X.re[k] * H_j.re[k] - X.im[k] * H_j.im[k];
283       S->im[k] += X.re[k] * H_j.im[k] + X.im[k] * H_j.re[k];
284     }
285     index = index < (render_buffer_data.size() - 1) ? index + 1 : 0;
286   }
287 }
288 
289 #if defined(WEBRTC_HAS_NEON)
290 // Produces the filter output (NEON variant).
ApplyFilter_NEON(const RenderBuffer & render_buffer,rtc::ArrayView<const FftData> H,FftData * S)291 void ApplyFilter_NEON(const RenderBuffer& render_buffer,
292                       rtc::ArrayView<const FftData> H,
293                       FftData* S) {
294   RTC_DCHECK_GE(H.size(), H.size() - 1);
295   S->re.fill(0.f);
296   S->im.fill(0.f);
297 
298   rtc::ArrayView<const FftData> render_buffer_data = render_buffer.Buffer();
299   const int lim1 =
300       std::min(render_buffer_data.size() - render_buffer.Position(), H.size());
301   const int lim2 = H.size();
302   constexpr int kNumFourBinBands = kFftLengthBy2 / 4;
303   const FftData* H_j = &H[0];
304   const FftData* X = &render_buffer_data[render_buffer.Position()];
305 
306   int j = 0;
307   int limit = lim1;
308   do {
309     for (; j < limit; ++j, ++H_j, ++X) {
310       for (int k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) {
311         const float32x4_t X_re = vld1q_f32(&X->re[k]);
312         const float32x4_t X_im = vld1q_f32(&X->im[k]);
313         const float32x4_t H_re = vld1q_f32(&H_j->re[k]);
314         const float32x4_t H_im = vld1q_f32(&H_j->im[k]);
315         const float32x4_t S_re = vld1q_f32(&S->re[k]);
316         const float32x4_t S_im = vld1q_f32(&S->im[k]);
317         const float32x4_t a = vmulq_f32(X_re, H_re);
318         const float32x4_t e = vmlsq_f32(a, X_im, H_im);
319         const float32x4_t c = vmulq_f32(X_re, H_im);
320         const float32x4_t f = vmlaq_f32(c, X_im, H_re);
321         const float32x4_t g = vaddq_f32(S_re, e);
322         const float32x4_t h = vaddq_f32(S_im, f);
323         vst1q_f32(&S->re[k], g);
324         vst1q_f32(&S->im[k], h);
325       }
326     }
327     limit = lim2;
328     X = &render_buffer_data[0];
329   } while (j < lim2);
330 
331   H_j = &H[0];
332   X = &render_buffer_data[render_buffer.Position()];
333   j = 0;
334   limit = lim1;
335   do {
336     for (; j < limit; ++j, ++H_j, ++X) {
337       S->re[kFftLengthBy2] += X->re[kFftLengthBy2] * H_j->re[kFftLengthBy2] -
338                               X->im[kFftLengthBy2] * H_j->im[kFftLengthBy2];
339       S->im[kFftLengthBy2] += X->re[kFftLengthBy2] * H_j->im[kFftLengthBy2] +
340                               X->im[kFftLengthBy2] * H_j->re[kFftLengthBy2];
341     }
342     limit = lim2;
343     X = &render_buffer_data[0];
344   } while (j < lim2);
345 }
346 #endif
347 
348 #if defined(WEBRTC_ARCH_X86_FAMILY)
349 // Produces the filter output (SSE2 variant).
ApplyFilter_SSE2(const RenderBuffer & render_buffer,rtc::ArrayView<const FftData> H,FftData * S)350 void ApplyFilter_SSE2(const RenderBuffer& render_buffer,
351                       rtc::ArrayView<const FftData> H,
352                       FftData* S) {
353   RTC_DCHECK_GE(H.size(), H.size() - 1);
354   S->re.fill(0.f);
355   S->im.fill(0.f);
356 
357   rtc::ArrayView<const FftData> render_buffer_data = render_buffer.Buffer();
358   const int lim1 =
359       std::min(render_buffer_data.size() - render_buffer.Position(), H.size());
360   const int lim2 = H.size();
361   constexpr int kNumFourBinBands = kFftLengthBy2 / 4;
362   const FftData* H_j = &H[0];
363   const FftData* X = &render_buffer_data[render_buffer.Position()];
364 
365   int j = 0;
366   int limit = lim1;
367   do {
368     for (; j < limit; ++j, ++H_j, ++X) {
369       for (int k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) {
370         const __m128 X_re = _mm_loadu_ps(&X->re[k]);
371         const __m128 X_im = _mm_loadu_ps(&X->im[k]);
372         const __m128 H_re = _mm_loadu_ps(&H_j->re[k]);
373         const __m128 H_im = _mm_loadu_ps(&H_j->im[k]);
374         const __m128 S_re = _mm_loadu_ps(&S->re[k]);
375         const __m128 S_im = _mm_loadu_ps(&S->im[k]);
376         const __m128 a = _mm_mul_ps(X_re, H_re);
377         const __m128 b = _mm_mul_ps(X_im, H_im);
378         const __m128 c = _mm_mul_ps(X_re, H_im);
379         const __m128 d = _mm_mul_ps(X_im, H_re);
380         const __m128 e = _mm_sub_ps(a, b);
381         const __m128 f = _mm_add_ps(c, d);
382         const __m128 g = _mm_add_ps(S_re, e);
383         const __m128 h = _mm_add_ps(S_im, f);
384         _mm_storeu_ps(&S->re[k], g);
385         _mm_storeu_ps(&S->im[k], h);
386       }
387     }
388     limit = lim2;
389     X = &render_buffer_data[0];
390   } while (j < lim2);
391 
392   H_j = &H[0];
393   X = &render_buffer_data[render_buffer.Position()];
394   j = 0;
395   limit = lim1;
396   do {
397     for (; j < limit; ++j, ++H_j, ++X) {
398       S->re[kFftLengthBy2] += X->re[kFftLengthBy2] * H_j->re[kFftLengthBy2] -
399                               X->im[kFftLengthBy2] * H_j->im[kFftLengthBy2];
400       S->im[kFftLengthBy2] += X->re[kFftLengthBy2] * H_j->im[kFftLengthBy2] +
401                               X->im[kFftLengthBy2] * H_j->re[kFftLengthBy2];
402     }
403     limit = lim2;
404     X = &render_buffer_data[0];
405   } while (j < lim2);
406 }
407 #endif
408 
409 }  // namespace aec3
410 
AdaptiveFirFilter(size_t size_partitions,Aec3Optimization optimization,ApmDataDumper * data_dumper)411 AdaptiveFirFilter::AdaptiveFirFilter(size_t size_partitions,
412                                      Aec3Optimization optimization,
413                                      ApmDataDumper* data_dumper)
414     : data_dumper_(data_dumper),
415       fft_(),
416       optimization_(optimization),
417       H_(size_partitions),
418       H2_(size_partitions, std::array<float, kFftLengthBy2Plus1>()) {
419   RTC_DCHECK(data_dumper_);
420 
421   h_.fill(0.f);
422   for (auto& H_j : H_) {
423     H_j.Clear();
424   }
425   for (auto& H2_k : H2_) {
426     H2_k.fill(0.f);
427   }
428   erl_.fill(0.f);
429 }
430 
431 AdaptiveFirFilter::~AdaptiveFirFilter() = default;
432 
HandleEchoPathChange()433 void AdaptiveFirFilter::HandleEchoPathChange() {
434   h_.fill(0.f);
435   for (auto& H_j : H_) {
436     H_j.Clear();
437   }
438   for (auto& H2_k : H2_) {
439     H2_k.fill(0.f);
440   }
441   erl_.fill(0.f);
442 }
443 
Filter(const RenderBuffer & render_buffer,FftData * S) const444 void AdaptiveFirFilter::Filter(const RenderBuffer& render_buffer,
445                                FftData* S) const {
446   RTC_DCHECK(S);
447   switch (optimization_) {
448 #if defined(WEBRTC_ARCH_X86_FAMILY)
449     case Aec3Optimization::kSse2:
450       aec3::ApplyFilter_SSE2(render_buffer, H_, S);
451       break;
452 #endif
453 #if defined(WEBRTC_HAS_NEON)
454     case Aec3Optimization::kNeon:
455       aec3::ApplyFilter_NEON(render_buffer, H_, S);
456       break;
457 #endif
458     default:
459       aec3::ApplyFilter(render_buffer, H_, S);
460   }
461 }
462 
Adapt(const RenderBuffer & render_buffer,const FftData & G)463 void AdaptiveFirFilter::Adapt(const RenderBuffer& render_buffer,
464                               const FftData& G) {
465   // Adapt the filter.
466   switch (optimization_) {
467 #if defined(WEBRTC_ARCH_X86_FAMILY)
468     case Aec3Optimization::kSse2:
469       aec3::AdaptPartitions_SSE2(render_buffer, G, H_);
470       break;
471 #endif
472 #if defined(WEBRTC_HAS_NEON)
473     case Aec3Optimization::kNeon:
474       aec3::AdaptPartitions_NEON(render_buffer, G, H_);
475       break;
476 #endif
477     default:
478       aec3::AdaptPartitions(render_buffer, G, H_);
479   }
480 
481   // Constrain the filter partitions in a cyclic manner.
482   Constrain();
483 
484   // Update the frequency response and echo return loss for the filter.
485   switch (optimization_) {
486 #if defined(WEBRTC_ARCH_X86_FAMILY)
487     case Aec3Optimization::kSse2:
488       aec3::UpdateFrequencyResponse_SSE2(H_, &H2_);
489       aec3::UpdateErlEstimator_SSE2(H2_, &erl_);
490       break;
491 #endif
492 #if defined(WEBRTC_HAS_NEON)
493     case Aec3Optimization::kNeon:
494       aec3::UpdateFrequencyResponse_NEON(H_, &H2_);
495       aec3::UpdateErlEstimator_NEON(H2_, &erl_);
496       break;
497 #endif
498     default:
499       aec3::UpdateFrequencyResponse(H_, &H2_);
500       aec3::UpdateErlEstimator(H2_, &erl_);
501   }
502 }
503 
504 // Constrains the a partiton of the frequency domain filter to be limited in
505 // time via setting the relevant time-domain coefficients to zero.
Constrain()506 void AdaptiveFirFilter::Constrain() {
507   std::array<float, kFftLength> h;
508   fft_.Ifft(H_[partition_to_constrain_], &h);
509 
510   static constexpr float kScale = 1.0f / kFftLengthBy2;
511   std::for_each(h.begin(), h.begin() + kFftLengthBy2,
512                 [](float& a) { a *= kScale; });
513   std::fill(h.begin() + kFftLengthBy2, h.end(), 0.f);
514 
515   std::copy(h.begin(), h.begin() + kFftLengthBy2,
516             h_.begin() + partition_to_constrain_ * kFftLengthBy2);
517 
518   fft_.Fft(&h, &H_[partition_to_constrain_]);
519 
520   partition_to_constrain_ = partition_to_constrain_ < (H_.size() - 1)
521                                 ? partition_to_constrain_ + 1
522                                 : 0;
523 }
524 
525 }  // namespace webrtc
526