1 /*
2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10 #include "modules/audio_processing/aec3/matched_filter.h"
11
12 #if defined(WEBRTC_HAS_NEON)
13 #include <arm_neon.h>
14 #endif
15 #include "typedefs.h" // NOLINT(build/include)
16 #if defined(WEBRTC_ARCH_X86_FAMILY)
17 #include <emmintrin.h>
18 #endif
19 #include <algorithm>
20 #include <numeric>
21
22 #include "modules/audio_processing/include/audio_processing.h"
23 #include "modules/audio_processing/logging/apm_data_dumper.h"
24 #include "rtc_base/logging.h"
25
26 namespace webrtc {
27 namespace aec3 {
28
29 #if defined(WEBRTC_HAS_NEON)
30
MatchedFilterCore_NEON(size_t x_start_index,float x2_sum_threshold,rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> h,bool * filters_updated,float * error_sum)31 void MatchedFilterCore_NEON(size_t x_start_index,
32 float x2_sum_threshold,
33 rtc::ArrayView<const float> x,
34 rtc::ArrayView<const float> y,
35 rtc::ArrayView<float> h,
36 bool* filters_updated,
37 float* error_sum) {
38 const int h_size = static_cast<int>(h.size());
39 const int x_size = static_cast<int>(x.size());
40 RTC_DCHECK_EQ(0, h_size % 4);
41
42 // Process for all samples in the sub-block.
43 for (size_t i = 0; i < y.size(); ++i) {
44 // Apply the matched filter as filter * x, and compute x * x.
45
46 RTC_DCHECK_GT(x_size, x_start_index);
47 const float* x_p = &x[x_start_index];
48 const float* h_p = &h[0];
49
50 // Initialize values for the accumulation.
51 float32x4_t s_128 = vdupq_n_f32(0);
52 float32x4_t x2_sum_128 = vdupq_n_f32(0);
53 float x2_sum = 0.f;
54 float s = 0;
55
56 // Compute loop chunk sizes until, and after, the wraparound of the circular
57 // buffer for x.
58 const int chunk1 =
59 std::min(h_size, static_cast<int>(x_size - x_start_index));
60
61 // Perform the loop in two chunks.
62 const int chunk2 = h_size - chunk1;
63 for (int limit : {chunk1, chunk2}) {
64 // Perform 128 bit vector operations.
65 const int limit_by_4 = limit >> 2;
66 for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
67 // Load the data into 128 bit vectors.
68 const float32x4_t x_k = vld1q_f32(x_p);
69 const float32x4_t h_k = vld1q_f32(h_p);
70 // Compute and accumulate x * x and h * x.
71 x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k);
72 s_128 = vmlaq_f32(s_128, h_k, x_k);
73 }
74
75 // Perform non-vector operations for any remaining items.
76 for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
77 const float x_k = *x_p;
78 x2_sum += x_k * x_k;
79 s += *h_p * x_k;
80 }
81
82 x_p = &x[0];
83 }
84
85 // Combine the accumulated vector and scalar values.
86 float* v = reinterpret_cast<float*>(&x2_sum_128);
87 x2_sum += v[0] + v[1] + v[2] + v[3];
88 v = reinterpret_cast<float*>(&s_128);
89 s += v[0] + v[1] + v[2] + v[3];
90
91 // Compute the matched filter error.
92 const float e = std::min(32767.f, std::max(-32768.f, y[i] - s));
93 *error_sum += e * e;
94
95 // Update the matched filter estimate in an NLMS manner.
96 if (x2_sum > x2_sum_threshold) {
97 RTC_DCHECK_LT(0.f, x2_sum);
98 const float alpha = 0.7f * e / x2_sum;
99 const float32x4_t alpha_128 = vmovq_n_f32(alpha);
100
101 // filter = filter + 0.7 * (y - filter * x) / x * x.
102 float* h_p = &h[0];
103 x_p = &x[x_start_index];
104
105 // Perform the loop in two chunks.
106 for (int limit : {chunk1, chunk2}) {
107 // Perform 128 bit vector operations.
108 const int limit_by_4 = limit >> 2;
109 for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
110 // Load the data into 128 bit vectors.
111 float32x4_t h_k = vld1q_f32(h_p);
112 const float32x4_t x_k = vld1q_f32(x_p);
113 // Compute h = h + alpha * x.
114 h_k = vmlaq_f32(h_k, alpha_128, x_k);
115
116 // Store the result.
117 vst1q_f32(h_p, h_k);
118 }
119
120 // Perform non-vector operations for any remaining items.
121 for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
122 *h_p += alpha * *x_p;
123 }
124
125 x_p = &x[0];
126 }
127
128 *filters_updated = true;
129 }
130
131 x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
132 }
133 }
134
135 #endif
136
137 #if defined(WEBRTC_ARCH_X86_FAMILY)
138
MatchedFilterCore_SSE2(size_t x_start_index,float x2_sum_threshold,rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> h,bool * filters_updated,float * error_sum)139 void MatchedFilterCore_SSE2(size_t x_start_index,
140 float x2_sum_threshold,
141 rtc::ArrayView<const float> x,
142 rtc::ArrayView<const float> y,
143 rtc::ArrayView<float> h,
144 bool* filters_updated,
145 float* error_sum) {
146 const int h_size = static_cast<int>(h.size());
147 const int x_size = static_cast<int>(x.size());
148 RTC_DCHECK_EQ(0, h_size % 4);
149
150 // Process for all samples in the sub-block.
151 for (size_t i = 0; i < y.size(); ++i) {
152 // Apply the matched filter as filter * x, and compute x * x.
153
154 RTC_DCHECK_GT(x_size, x_start_index);
155 const float* x_p = &x[x_start_index];
156 const float* h_p = &h[0];
157
158 // Initialize values for the accumulation.
159 __m128 s_128 = _mm_set1_ps(0);
160 __m128 x2_sum_128 = _mm_set1_ps(0);
161 float x2_sum = 0.f;
162 float s = 0;
163
164 // Compute loop chunk sizes until, and after, the wraparound of the circular
165 // buffer for x.
166 const int chunk1 =
167 std::min(h_size, static_cast<int>(x_size - x_start_index));
168
169 // Perform the loop in two chunks.
170 const int chunk2 = h_size - chunk1;
171 for (int limit : {chunk1, chunk2}) {
172 // Perform 128 bit vector operations.
173 const int limit_by_4 = limit >> 2;
174 for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
175 // Load the data into 128 bit vectors.
176 const __m128 x_k = _mm_loadu_ps(x_p);
177 const __m128 h_k = _mm_loadu_ps(h_p);
178 const __m128 xx = _mm_mul_ps(x_k, x_k);
179 // Compute and accumulate x * x and h * x.
180 x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
181 const __m128 hx = _mm_mul_ps(h_k, x_k);
182 s_128 = _mm_add_ps(s_128, hx);
183 }
184
185 // Perform non-vector operations for any remaining items.
186 for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
187 const float x_k = *x_p;
188 x2_sum += x_k * x_k;
189 s += *h_p * x_k;
190 }
191
192 x_p = &x[0];
193 }
194
195 // Combine the accumulated vector and scalar values.
196 float* v = reinterpret_cast<float*>(&x2_sum_128);
197 x2_sum += v[0] + v[1] + v[2] + v[3];
198 v = reinterpret_cast<float*>(&s_128);
199 s += v[0] + v[1] + v[2] + v[3];
200
201 // Compute the matched filter error.
202 const float e = std::min(32767.f, std::max(-32768.f, y[i] - s));
203 *error_sum += e * e;
204
205 // Update the matched filter estimate in an NLMS manner.
206 if (x2_sum > x2_sum_threshold) {
207 RTC_DCHECK_LT(0.f, x2_sum);
208 const float alpha = 0.7f * e / x2_sum;
209 const __m128 alpha_128 = _mm_set1_ps(alpha);
210
211 // filter = filter + 0.7 * (y - filter * x) / x * x.
212 float* h_p = &h[0];
213 x_p = &x[x_start_index];
214
215 // Perform the loop in two chunks.
216 for (int limit : {chunk1, chunk2}) {
217 // Perform 128 bit vector operations.
218 const int limit_by_4 = limit >> 2;
219 for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
220 // Load the data into 128 bit vectors.
221 __m128 h_k = _mm_loadu_ps(h_p);
222 const __m128 x_k = _mm_loadu_ps(x_p);
223
224 // Compute h = h + alpha * x.
225 const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k);
226 h_k = _mm_add_ps(h_k, alpha_x);
227
228 // Store the result.
229 _mm_storeu_ps(h_p, h_k);
230 }
231
232 // Perform non-vector operations for any remaining items.
233 for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
234 *h_p += alpha * *x_p;
235 }
236
237 x_p = &x[0];
238 }
239
240 *filters_updated = true;
241 }
242
243 x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
244 }
245 }
246 #endif
247
MatchedFilterCore(size_t x_start_index,float x2_sum_threshold,rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> h,bool * filters_updated,float * error_sum)248 void MatchedFilterCore(size_t x_start_index,
249 float x2_sum_threshold,
250 rtc::ArrayView<const float> x,
251 rtc::ArrayView<const float> y,
252 rtc::ArrayView<float> h,
253 bool* filters_updated,
254 float* error_sum) {
255 // Process for all samples in the sub-block.
256 for (size_t i = 0; i < y.size(); ++i) {
257 // Apply the matched filter as filter * x, and compute x * x.
258 float x2_sum = 0.f;
259 float s = 0;
260 size_t x_index = x_start_index;
261 for (size_t k = 0; k < h.size(); ++k) {
262 x2_sum += x[x_index] * x[x_index];
263 s += h[k] * x[x_index];
264 x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
265 }
266
267 // Compute the matched filter error.
268 const float e = std::min(32767.f, std::max(-32768.f, y[i] - s));
269 (*error_sum) += e * e;
270
271 // Update the matched filter estimate in an NLMS manner.
272 if (x2_sum > x2_sum_threshold) {
273 RTC_DCHECK_LT(0.f, x2_sum);
274 const float alpha = 0.7f * e / x2_sum;
275
276 // filter = filter + 0.7 * (y - filter * x) / x * x.
277 size_t x_index = x_start_index;
278 for (size_t k = 0; k < h.size(); ++k) {
279 h[k] += alpha * x[x_index];
280 x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
281 }
282 *filters_updated = true;
283 }
284
285 x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1;
286 }
287 }
288
289 } // namespace aec3
290
MatchedFilter(ApmDataDumper * data_dumper,Aec3Optimization optimization,size_t sub_block_size,size_t window_size_sub_blocks,int num_matched_filters,size_t alignment_shift_sub_blocks,float excitation_limit)291 MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
292 Aec3Optimization optimization,
293 size_t sub_block_size,
294 size_t window_size_sub_blocks,
295 int num_matched_filters,
296 size_t alignment_shift_sub_blocks,
297 float excitation_limit)
298 : data_dumper_(data_dumper),
299 optimization_(optimization),
300 sub_block_size_(sub_block_size),
301 filter_intra_lag_shift_(alignment_shift_sub_blocks * sub_block_size_),
302 filters_(
303 num_matched_filters,
304 std::vector<float>(window_size_sub_blocks * sub_block_size_, 0.f)),
305 lag_estimates_(num_matched_filters),
306 filters_offsets_(num_matched_filters, 0),
307 excitation_limit_(excitation_limit) {
308 RTC_DCHECK(data_dumper);
309 RTC_DCHECK_LT(0, window_size_sub_blocks);
310 RTC_DCHECK((kBlockSize % sub_block_size) == 0);
311 RTC_DCHECK((sub_block_size % 4) == 0);
312 }
313
314 MatchedFilter::~MatchedFilter() = default;
315
Reset()316 void MatchedFilter::Reset() {
317 for (auto& f : filters_) {
318 std::fill(f.begin(), f.end(), 0.f);
319 }
320
321 for (auto& l : lag_estimates_) {
322 l = MatchedFilter::LagEstimate();
323 }
324 }
325
Update(const DownsampledRenderBuffer & render_buffer,rtc::ArrayView<const float> capture)326 void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
327 rtc::ArrayView<const float> capture) {
328 RTC_DCHECK_EQ(sub_block_size_, capture.size());
329 auto& y = capture;
330
331 const float x2_sum_threshold =
332 filters_[0].size() * excitation_limit_ * excitation_limit_;
333
334 // Apply all matched filters.
335 size_t alignment_shift = 0;
336 for (size_t n = 0; n < filters_.size(); ++n) {
337 float error_sum = 0.f;
338 bool filters_updated = false;
339
340 size_t x_start_index =
341 (render_buffer.position + alignment_shift + sub_block_size_ - 1) %
342 render_buffer.buffer.size();
343
344 switch (optimization_) {
345 #if defined(WEBRTC_ARCH_X86_FAMILY)
346 case Aec3Optimization::kSse2:
347 aec3::MatchedFilterCore_SSE2(x_start_index, x2_sum_threshold,
348 render_buffer.buffer, y, filters_[n],
349 &filters_updated, &error_sum);
350 break;
351 #endif
352 #if defined(WEBRTC_HAS_NEON)
353 case Aec3Optimization::kNeon:
354 aec3::MatchedFilterCore_NEON(x_start_index, x2_sum_threshold,
355 render_buffer.buffer, y, filters_[n],
356 &filters_updated, &error_sum);
357 break;
358 #endif
359 default:
360 aec3::MatchedFilterCore(x_start_index, x2_sum_threshold,
361 render_buffer.buffer, y, filters_[n],
362 &filters_updated, &error_sum);
363 }
364
365 // Compute anchor for the matched filter error.
366 const float error_sum_anchor =
367 std::inner_product(y.begin(), y.end(), y.begin(), 0.f);
368
369 // Estimate the lag in the matched filter as the distance to the portion in
370 // the filter that contributes the most to the matched filter output. This
371 // is detected as the peak of the matched filter.
372 const size_t lag_estimate = std::distance(
373 filters_[n].begin(),
374 std::max_element(
375 filters_[n].begin(), filters_[n].end(),
376 [](float a, float b) -> bool { return a * a < b * b; }));
377
378 // Update the lag estimates for the matched filter.
379 const float kMatchingFilterThreshold = 0.2f;
380 lag_estimates_[n] = LagEstimate(
381 error_sum_anchor - error_sum,
382 (lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) &&
383 error_sum < kMatchingFilterThreshold * error_sum_anchor),
384 lag_estimate + alignment_shift, filters_updated);
385
386 RTC_DCHECK_GE(10, filters_.size());
387 switch (n) {
388 case 0:
389 data_dumper_->DumpRaw("aec3_correlator_0_h", filters_[0]);
390 break;
391 case 1:
392 data_dumper_->DumpRaw("aec3_correlator_1_h", filters_[1]);
393 break;
394 case 2:
395 data_dumper_->DumpRaw("aec3_correlator_2_h", filters_[2]);
396 break;
397 case 3:
398 data_dumper_->DumpRaw("aec3_correlator_3_h", filters_[3]);
399 break;
400 case 4:
401 data_dumper_->DumpRaw("aec3_correlator_4_h", filters_[4]);
402 break;
403 case 5:
404 data_dumper_->DumpRaw("aec3_correlator_5_h", filters_[5]);
405 break;
406 case 6:
407 data_dumper_->DumpRaw("aec3_correlator_6_h", filters_[6]);
408 break;
409 case 7:
410 data_dumper_->DumpRaw("aec3_correlator_7_h", filters_[7]);
411 break;
412 case 8:
413 data_dumper_->DumpRaw("aec3_correlator_8_h", filters_[8]);
414 break;
415 case 9:
416 data_dumper_->DumpRaw("aec3_correlator_9_h", filters_[9]);
417 break;
418 default:
419 RTC_NOTREACHED();
420 }
421
422 alignment_shift += filter_intra_lag_shift_;
423 }
424 }
425
LogFilterProperties(int sample_rate_hz,size_t shift,size_t downsampling_factor) const426 void MatchedFilter::LogFilterProperties(int sample_rate_hz,
427 size_t shift,
428 size_t downsampling_factor) const {
429 size_t alignment_shift = 0;
430 const int fs_by_1000 = LowestBandRate(sample_rate_hz) / 1000;
431 for (size_t k = 0; k < filters_.size(); ++k) {
432 int start = static_cast<int>(alignment_shift * downsampling_factor);
433 int end = static_cast<int>((alignment_shift + filters_[k].size()) *
434 downsampling_factor);
435 RTC_LOG(LS_INFO) << "Filter " << k << ": start: "
436 << (start - static_cast<int>(shift)) / fs_by_1000
437 << " ms, end: "
438 << (end - static_cast<int>(shift)) / fs_by_1000 << " ms.";
439 alignment_shift += filter_intra_lag_shift_;
440 }
441 }
442
443 } // namespace webrtc
444