1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 #include "modules/audio_processing/aec3/matched_filter.h"
11 
12 #if defined(WEBRTC_HAS_NEON)
13 #include <arm_neon.h>
14 #endif
15 #include "typedefs.h"  // NOLINT(build/include)
16 #if defined(WEBRTC_ARCH_X86_FAMILY)
17 #include <emmintrin.h>
18 #endif
19 #include <algorithm>
20 #include <numeric>
21 
22 #include "modules/audio_processing/include/audio_processing.h"
23 #include "modules/audio_processing/logging/apm_data_dumper.h"
24 #include "rtc_base/logging.h"
25 
26 namespace webrtc {
27 namespace aec3 {
28 
29 #if defined(WEBRTC_HAS_NEON)
30 
MatchedFilterCore_NEON(size_t x_start_index,float x2_sum_threshold,rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> h,bool * filters_updated,float * error_sum)31 void MatchedFilterCore_NEON(size_t x_start_index,
32                             float x2_sum_threshold,
33                             rtc::ArrayView<const float> x,
34                             rtc::ArrayView<const float> y,
35                             rtc::ArrayView<float> h,
36                             bool* filters_updated,
37                             float* error_sum) {
38   const int h_size = static_cast<int>(h.size());
39   const int x_size = static_cast<int>(x.size());
40   RTC_DCHECK_EQ(0, h_size % 4);
41 
42   // Process for all samples in the sub-block.
43   for (size_t i = 0; i < y.size(); ++i) {
44     // Apply the matched filter as filter * x, and compute x * x.
45 
46     RTC_DCHECK_GT(x_size, x_start_index);
47     const float* x_p = &x[x_start_index];
48     const float* h_p = &h[0];
49 
50     // Initialize values for the accumulation.
51     float32x4_t s_128 = vdupq_n_f32(0);
52     float32x4_t x2_sum_128 = vdupq_n_f32(0);
53     float x2_sum = 0.f;
54     float s = 0;
55 
56     // Compute loop chunk sizes until, and after, the wraparound of the circular
57     // buffer for x.
58     const int chunk1 =
59         std::min(h_size, static_cast<int>(x_size - x_start_index));
60 
61     // Perform the loop in two chunks.
62     const int chunk2 = h_size - chunk1;
63     for (int limit : {chunk1, chunk2}) {
64       // Perform 128 bit vector operations.
65       const int limit_by_4 = limit >> 2;
66       for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
67         // Load the data into 128 bit vectors.
68         const float32x4_t x_k = vld1q_f32(x_p);
69         const float32x4_t h_k = vld1q_f32(h_p);
70         // Compute and accumulate x * x and h * x.
71         x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k);
72         s_128 = vmlaq_f32(s_128, h_k, x_k);
73       }
74 
75       // Perform non-vector operations for any remaining items.
76       for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
77         const float x_k = *x_p;
78         x2_sum += x_k * x_k;
79         s += *h_p * x_k;
80       }
81 
82       x_p = &x[0];
83     }
84 
85     // Combine the accumulated vector and scalar values.
86     float* v = reinterpret_cast<float*>(&x2_sum_128);
87     x2_sum += v[0] + v[1] + v[2] + v[3];
88     v = reinterpret_cast<float*>(&s_128);
89     s += v[0] + v[1] + v[2] + v[3];
90 
91     // Compute the matched filter error.
92     const float e = std::min(32767.f, std::max(-32768.f, y[i] - s));
93     *error_sum += e * e;
94 
95     // Update the matched filter estimate in an NLMS manner.
96     if (x2_sum > x2_sum_threshold) {
97       RTC_DCHECK_LT(0.f, x2_sum);
98       const float alpha = 0.7f * e / x2_sum;
99       const float32x4_t alpha_128 = vmovq_n_f32(alpha);
100 
101       // filter = filter + 0.7 * (y - filter * x) / x * x.
102       float* h_p = &h[0];
103       x_p = &x[x_start_index];
104 
105       // Perform the loop in two chunks.
106       for (int limit : {chunk1, chunk2}) {
107         // Perform 128 bit vector operations.
108         const int limit_by_4 = limit >> 2;
109         for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
110           // Load the data into 128 bit vectors.
111           float32x4_t h_k = vld1q_f32(h_p);
112           const float32x4_t x_k = vld1q_f32(x_p);
113           // Compute h = h + alpha * x.
114           h_k = vmlaq_f32(h_k, alpha_128, x_k);
115 
116           // Store the result.
117           vst1q_f32(h_p, h_k);
118         }
119 
120         // Perform non-vector operations for any remaining items.
121         for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
122           *h_p += alpha * *x_p;
123         }
124 
125         x_p = &x[0];
126       }
127 
128       *filters_updated = true;
129     }
130 
131     x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
132   }
133 }
134 
135 #endif
136 
137 #if defined(WEBRTC_ARCH_X86_FAMILY)
138 
MatchedFilterCore_SSE2(size_t x_start_index,float x2_sum_threshold,rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> h,bool * filters_updated,float * error_sum)139 void MatchedFilterCore_SSE2(size_t x_start_index,
140                             float x2_sum_threshold,
141                             rtc::ArrayView<const float> x,
142                             rtc::ArrayView<const float> y,
143                             rtc::ArrayView<float> h,
144                             bool* filters_updated,
145                             float* error_sum) {
146   const int h_size = static_cast<int>(h.size());
147   const int x_size = static_cast<int>(x.size());
148   RTC_DCHECK_EQ(0, h_size % 4);
149 
150   // Process for all samples in the sub-block.
151   for (size_t i = 0; i < y.size(); ++i) {
152     // Apply the matched filter as filter * x, and compute x * x.
153 
154     RTC_DCHECK_GT(x_size, x_start_index);
155     const float* x_p = &x[x_start_index];
156     const float* h_p = &h[0];
157 
158     // Initialize values for the accumulation.
159     __m128 s_128 = _mm_set1_ps(0);
160     __m128 x2_sum_128 = _mm_set1_ps(0);
161     float x2_sum = 0.f;
162     float s = 0;
163 
164     // Compute loop chunk sizes until, and after, the wraparound of the circular
165     // buffer for x.
166     const int chunk1 =
167         std::min(h_size, static_cast<int>(x_size - x_start_index));
168 
169     // Perform the loop in two chunks.
170     const int chunk2 = h_size - chunk1;
171     for (int limit : {chunk1, chunk2}) {
172       // Perform 128 bit vector operations.
173       const int limit_by_4 = limit >> 2;
174       for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
175         // Load the data into 128 bit vectors.
176         const __m128 x_k = _mm_loadu_ps(x_p);
177         const __m128 h_k = _mm_loadu_ps(h_p);
178         const __m128 xx = _mm_mul_ps(x_k, x_k);
179         // Compute and accumulate x * x and h * x.
180         x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
181         const __m128 hx = _mm_mul_ps(h_k, x_k);
182         s_128 = _mm_add_ps(s_128, hx);
183       }
184 
185       // Perform non-vector operations for any remaining items.
186       for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
187         const float x_k = *x_p;
188         x2_sum += x_k * x_k;
189         s += *h_p * x_k;
190       }
191 
192       x_p = &x[0];
193     }
194 
195     // Combine the accumulated vector and scalar values.
196     float* v = reinterpret_cast<float*>(&x2_sum_128);
197     x2_sum += v[0] + v[1] + v[2] + v[3];
198     v = reinterpret_cast<float*>(&s_128);
199     s += v[0] + v[1] + v[2] + v[3];
200 
201     // Compute the matched filter error.
202     const float e = std::min(32767.f, std::max(-32768.f, y[i] - s));
203     *error_sum += e * e;
204 
205     // Update the matched filter estimate in an NLMS manner.
206     if (x2_sum > x2_sum_threshold) {
207       RTC_DCHECK_LT(0.f, x2_sum);
208       const float alpha = 0.7f * e / x2_sum;
209       const __m128 alpha_128 = _mm_set1_ps(alpha);
210 
211       // filter = filter + 0.7 * (y - filter * x) / x * x.
212       float* h_p = &h[0];
213       x_p = &x[x_start_index];
214 
215       // Perform the loop in two chunks.
216       for (int limit : {chunk1, chunk2}) {
217         // Perform 128 bit vector operations.
218         const int limit_by_4 = limit >> 2;
219         for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
220           // Load the data into 128 bit vectors.
221           __m128 h_k = _mm_loadu_ps(h_p);
222           const __m128 x_k = _mm_loadu_ps(x_p);
223 
224           // Compute h = h + alpha * x.
225           const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k);
226           h_k = _mm_add_ps(h_k, alpha_x);
227 
228           // Store the result.
229           _mm_storeu_ps(h_p, h_k);
230         }
231 
232         // Perform non-vector operations for any remaining items.
233         for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
234           *h_p += alpha * *x_p;
235         }
236 
237         x_p = &x[0];
238       }
239 
240       *filters_updated = true;
241     }
242 
243     x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
244   }
245 }
246 #endif
247 
MatchedFilterCore(size_t x_start_index,float x2_sum_threshold,rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> h,bool * filters_updated,float * error_sum)248 void MatchedFilterCore(size_t x_start_index,
249                        float x2_sum_threshold,
250                        rtc::ArrayView<const float> x,
251                        rtc::ArrayView<const float> y,
252                        rtc::ArrayView<float> h,
253                        bool* filters_updated,
254                        float* error_sum) {
255   // Process for all samples in the sub-block.
256   for (size_t i = 0; i < y.size(); ++i) {
257     // Apply the matched filter as filter * x, and compute x * x.
258     float x2_sum = 0.f;
259     float s = 0;
260     size_t x_index = x_start_index;
261     for (size_t k = 0; k < h.size(); ++k) {
262       x2_sum += x[x_index] * x[x_index];
263       s += h[k] * x[x_index];
264       x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
265     }
266 
267     // Compute the matched filter error.
268     const float e = std::min(32767.f, std::max(-32768.f, y[i] - s));
269     (*error_sum) += e * e;
270 
271     // Update the matched filter estimate in an NLMS manner.
272     if (x2_sum > x2_sum_threshold) {
273       RTC_DCHECK_LT(0.f, x2_sum);
274       const float alpha = 0.7f * e / x2_sum;
275 
276       // filter = filter + 0.7 * (y - filter * x) / x * x.
277       size_t x_index = x_start_index;
278       for (size_t k = 0; k < h.size(); ++k) {
279         h[k] += alpha * x[x_index];
280         x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
281       }
282       *filters_updated = true;
283     }
284 
285     x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1;
286   }
287 }
288 
289 }  // namespace aec3
290 
MatchedFilter(ApmDataDumper * data_dumper,Aec3Optimization optimization,size_t sub_block_size,size_t window_size_sub_blocks,int num_matched_filters,size_t alignment_shift_sub_blocks,float excitation_limit)291 MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
292                              Aec3Optimization optimization,
293                              size_t sub_block_size,
294                              size_t window_size_sub_blocks,
295                              int num_matched_filters,
296                              size_t alignment_shift_sub_blocks,
297                              float excitation_limit)
298     : data_dumper_(data_dumper),
299       optimization_(optimization),
300       sub_block_size_(sub_block_size),
301       filter_intra_lag_shift_(alignment_shift_sub_blocks * sub_block_size_),
302       filters_(
303           num_matched_filters,
304           std::vector<float>(window_size_sub_blocks * sub_block_size_, 0.f)),
305       lag_estimates_(num_matched_filters),
306       filters_offsets_(num_matched_filters, 0),
307       excitation_limit_(excitation_limit) {
308   RTC_DCHECK(data_dumper);
309   RTC_DCHECK_LT(0, window_size_sub_blocks);
310   RTC_DCHECK((kBlockSize % sub_block_size) == 0);
311   RTC_DCHECK((sub_block_size % 4) == 0);
312 }
313 
314 MatchedFilter::~MatchedFilter() = default;
315 
Reset()316 void MatchedFilter::Reset() {
317   for (auto& f : filters_) {
318     std::fill(f.begin(), f.end(), 0.f);
319   }
320 
321   for (auto& l : lag_estimates_) {
322     l = MatchedFilter::LagEstimate();
323   }
324 }
325 
Update(const DownsampledRenderBuffer & render_buffer,rtc::ArrayView<const float> capture)326 void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
327                            rtc::ArrayView<const float> capture) {
328   RTC_DCHECK_EQ(sub_block_size_, capture.size());
329   auto& y = capture;
330 
331   const float x2_sum_threshold =
332       filters_[0].size() * excitation_limit_ * excitation_limit_;
333 
334   // Apply all matched filters.
335   size_t alignment_shift = 0;
336   for (size_t n = 0; n < filters_.size(); ++n) {
337     float error_sum = 0.f;
338     bool filters_updated = false;
339 
340     size_t x_start_index =
341         (render_buffer.position + alignment_shift + sub_block_size_ - 1) %
342         render_buffer.buffer.size();
343 
344     switch (optimization_) {
345 #if defined(WEBRTC_ARCH_X86_FAMILY)
346       case Aec3Optimization::kSse2:
347         aec3::MatchedFilterCore_SSE2(x_start_index, x2_sum_threshold,
348                                      render_buffer.buffer, y, filters_[n],
349                                      &filters_updated, &error_sum);
350         break;
351 #endif
352 #if defined(WEBRTC_HAS_NEON)
353       case Aec3Optimization::kNeon:
354         aec3::MatchedFilterCore_NEON(x_start_index, x2_sum_threshold,
355                                      render_buffer.buffer, y, filters_[n],
356                                      &filters_updated, &error_sum);
357         break;
358 #endif
359       default:
360         aec3::MatchedFilterCore(x_start_index, x2_sum_threshold,
361                                 render_buffer.buffer, y, filters_[n],
362                                 &filters_updated, &error_sum);
363     }
364 
365     // Compute anchor for the matched filter error.
366     const float error_sum_anchor =
367         std::inner_product(y.begin(), y.end(), y.begin(), 0.f);
368 
369     // Estimate the lag in the matched filter as the distance to the portion in
370     // the filter that contributes the most to the matched filter output. This
371     // is detected as the peak of the matched filter.
372     const size_t lag_estimate = std::distance(
373         filters_[n].begin(),
374         std::max_element(
375             filters_[n].begin(), filters_[n].end(),
376             [](float a, float b) -> bool { return a * a < b * b; }));
377 
378     // Update the lag estimates for the matched filter.
379     const float kMatchingFilterThreshold = 0.2f;
380     lag_estimates_[n] = LagEstimate(
381         error_sum_anchor - error_sum,
382         (lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) &&
383          error_sum < kMatchingFilterThreshold * error_sum_anchor),
384         lag_estimate + alignment_shift, filters_updated);
385 
386     RTC_DCHECK_GE(10, filters_.size());
387     switch (n) {
388       case 0:
389         data_dumper_->DumpRaw("aec3_correlator_0_h", filters_[0]);
390         break;
391       case 1:
392         data_dumper_->DumpRaw("aec3_correlator_1_h", filters_[1]);
393         break;
394       case 2:
395         data_dumper_->DumpRaw("aec3_correlator_2_h", filters_[2]);
396         break;
397       case 3:
398         data_dumper_->DumpRaw("aec3_correlator_3_h", filters_[3]);
399         break;
400       case 4:
401         data_dumper_->DumpRaw("aec3_correlator_4_h", filters_[4]);
402         break;
403       case 5:
404         data_dumper_->DumpRaw("aec3_correlator_5_h", filters_[5]);
405         break;
406       case 6:
407         data_dumper_->DumpRaw("aec3_correlator_6_h", filters_[6]);
408         break;
409       case 7:
410         data_dumper_->DumpRaw("aec3_correlator_7_h", filters_[7]);
411         break;
412       case 8:
413         data_dumper_->DumpRaw("aec3_correlator_8_h", filters_[8]);
414         break;
415       case 9:
416         data_dumper_->DumpRaw("aec3_correlator_9_h", filters_[9]);
417         break;
418       default:
419         RTC_NOTREACHED();
420     }
421 
422     alignment_shift += filter_intra_lag_shift_;
423   }
424 }
425 
LogFilterProperties(int sample_rate_hz,size_t shift,size_t downsampling_factor) const426 void MatchedFilter::LogFilterProperties(int sample_rate_hz,
427                                         size_t shift,
428                                         size_t downsampling_factor) const {
429   size_t alignment_shift = 0;
430   const int fs_by_1000 = LowestBandRate(sample_rate_hz) / 1000;
431   for (size_t k = 0; k < filters_.size(); ++k) {
432     int start = static_cast<int>(alignment_shift * downsampling_factor);
433     int end = static_cast<int>((alignment_shift + filters_[k].size()) *
434                                downsampling_factor);
435     RTC_LOG(LS_INFO) << "Filter " << k << ": start: "
436                      << (start - static_cast<int>(shift)) / fs_by_1000
437                      << " ms, end: "
438                      << (end - static_cast<int>(shift)) / fs_by_1000 << " ms.";
439     alignment_shift += filter_intra_lag_shift_;
440   }
441 }
442 
443 }  // namespace webrtc
444