1 /*
2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/aec3/matched_filter.h"
12
13 #include "typedefs.h" // NOLINT(build/include)
14 #if defined(WEBRTC_ARCH_X86_FAMILY)
15 #include <emmintrin.h>
16 #endif
17 #include <algorithm>
18 #include <sstream>
19 #include <string>
20
21 #include "modules/audio_processing/aec3/aec3_common.h"
22 #include "modules/audio_processing/aec3/decimator.h"
23 #include "modules/audio_processing/aec3/render_delay_buffer.h"
24 #include "modules/audio_processing/logging/apm_data_dumper.h"
25 #include "modules/audio_processing/test/echo_canceller_test_tools.h"
26 #include "rtc_base/random.h"
27 #include "system_wrappers/include/cpu_features_wrapper.h"
28 #include "test/gtest.h"
29
30 namespace webrtc {
31 namespace aec3 {
32 namespace {
33
ProduceDebugText(size_t delay,size_t down_sampling_factor)34 std::string ProduceDebugText(size_t delay, size_t down_sampling_factor) {
35 std::ostringstream ss;
36 ss << "Delay: " << delay;
37 ss << ", Down sampling factor: " << down_sampling_factor;
38 return ss.str();
39 }
40
41 constexpr size_t kNumMatchedFilters = 10;
42 constexpr size_t kDownSamplingFactors[] = {2, 4, 8};
43 constexpr size_t kWindowSizeSubBlocks = 32;
44 constexpr size_t kAlignmentShiftSubBlocks = kWindowSizeSubBlocks * 3 / 4;
45
46 } // namespace
47
48 #if defined(WEBRTC_HAS_NEON)
49 // Verifies that the optimized methods for NEON are similar to their reference
50 // counterparts.
TEST(MatchedFilter,TestNeonOptimizations)51 TEST(MatchedFilter, TestNeonOptimizations) {
52 Random random_generator(42U);
53 for (auto down_sampling_factor : kDownSamplingFactors) {
54 const size_t sub_block_size = kBlockSize / down_sampling_factor;
55
56 std::vector<float> x(2000);
57 RandomizeSampleVector(&random_generator, x);
58 std::vector<float> y(sub_block_size);
59 std::vector<float> h_NEON(512);
60 std::vector<float> h(512);
61 int x_index = 0;
62 for (int k = 0; k < 1000; ++k) {
63 RandomizeSampleVector(&random_generator, y);
64
65 bool filters_updated = false;
66 float error_sum = 0.f;
67 bool filters_updated_NEON = false;
68 float error_sum_NEON = 0.f;
69
70 MatchedFilterCore_NEON(x_index, h.size() * 150.f * 150.f, x, y, h_NEON,
71 &filters_updated_NEON, &error_sum_NEON);
72
73 MatchedFilterCore(x_index, h.size() * 150.f * 150.f, x, y, h,
74 &filters_updated, &error_sum);
75
76 EXPECT_EQ(filters_updated, filters_updated_NEON);
77 EXPECT_NEAR(error_sum, error_sum_NEON, error_sum / 100000.f);
78
79 for (size_t j = 0; j < h.size(); ++j) {
80 EXPECT_NEAR(h[j], h_NEON[j], 0.00001f);
81 }
82
83 x_index = (x_index + sub_block_size) % x.size();
84 }
85 }
86 }
87 #endif
88
89 #if defined(WEBRTC_ARCH_X86_FAMILY)
90 // Verifies that the optimized methods for SSE2 are bitexact to their reference
91 // counterparts.
TEST(MatchedFilter,TestSse2Optimizations)92 TEST(MatchedFilter, TestSse2Optimizations) {
93 bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0);
94 if (use_sse2) {
95 Random random_generator(42U);
96 for (auto down_sampling_factor : kDownSamplingFactors) {
97 const size_t sub_block_size = kBlockSize / down_sampling_factor;
98 std::vector<float> x(2000);
99 RandomizeSampleVector(&random_generator, x);
100 std::vector<float> y(sub_block_size);
101 std::vector<float> h_SSE2(512);
102 std::vector<float> h(512);
103 int x_index = 0;
104 for (int k = 0; k < 1000; ++k) {
105 RandomizeSampleVector(&random_generator, y);
106
107 bool filters_updated = false;
108 float error_sum = 0.f;
109 bool filters_updated_SSE2 = false;
110 float error_sum_SSE2 = 0.f;
111
112 MatchedFilterCore_SSE2(x_index, h.size() * 150.f * 150.f, x, y, h_SSE2,
113 &filters_updated_SSE2, &error_sum_SSE2);
114
115 MatchedFilterCore(x_index, h.size() * 150.f * 150.f, x, y, h,
116 &filters_updated, &error_sum);
117
118 EXPECT_EQ(filters_updated, filters_updated_SSE2);
119 EXPECT_NEAR(error_sum, error_sum_SSE2, error_sum / 100000.f);
120
121 for (size_t j = 0; j < h.size(); ++j) {
122 EXPECT_NEAR(h[j], h_SSE2[j], 0.00001f);
123 }
124
125 x_index = (x_index + sub_block_size) % x.size();
126 }
127 }
128 }
129 }
130
131 #endif
132
133 // Verifies that the matched filter produces proper lag estimates for
134 // artificially
135 // delayed signals.
TEST(MatchedFilter,LagEstimation)136 TEST(MatchedFilter, LagEstimation) {
137 Random random_generator(42U);
138 for (auto down_sampling_factor : kDownSamplingFactors) {
139 const size_t sub_block_size = kBlockSize / down_sampling_factor;
140
141 std::vector<std::vector<float>> render(3,
142 std::vector<float>(kBlockSize, 0.f));
143 std::array<float, kBlockSize> capture;
144 capture.fill(0.f);
145 ApmDataDumper data_dumper(0);
146 for (size_t delay_samples : {5, 64, 150, 200, 800, 1000}) {
147 SCOPED_TRACE(ProduceDebugText(delay_samples, down_sampling_factor));
148 Decimator capture_decimator(down_sampling_factor);
149 DelayBuffer<float> signal_delay_buffer(down_sampling_factor *
150 delay_samples);
151 MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size,
152 kWindowSizeSubBlocks, kNumMatchedFilters,
153 kAlignmentShiftSubBlocks, 150);
154 std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
155 RenderDelayBuffer::Create(
156 3, down_sampling_factor,
157 GetDownSampledBufferSize(down_sampling_factor,
158 kNumMatchedFilters),
159 GetRenderDelayBufferSize(down_sampling_factor,
160 kNumMatchedFilters)));
161
162 // Analyze the correlation between render and capture.
163 for (size_t k = 0; k < (300 + delay_samples / sub_block_size); ++k) {
164 RandomizeSampleVector(&random_generator, render[0]);
165 signal_delay_buffer.Delay(render[0], capture);
166 render_delay_buffer->Insert(render);
167 render_delay_buffer->UpdateBuffers();
168 std::array<float, kBlockSize> downsampled_capture_data;
169 rtc::ArrayView<float> downsampled_capture(
170 downsampled_capture_data.data(), sub_block_size);
171 capture_decimator.Decimate(capture, downsampled_capture);
172 filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(),
173 downsampled_capture);
174 }
175
176 // Obtain the lag estimates.
177 auto lag_estimates = filter.GetLagEstimates();
178
179 // Find which lag estimate should be the most accurate.
180 rtc::Optional<size_t> expected_most_accurate_lag_estimate;
181 size_t alignment_shift_sub_blocks = 0;
182 for (size_t k = 0; k < kNumMatchedFilters; ++k) {
183 if ((alignment_shift_sub_blocks + 3 * kWindowSizeSubBlocks / 4) *
184 sub_block_size >
185 delay_samples) {
186 expected_most_accurate_lag_estimate = k > 0 ? k - 1 : 0;
187 break;
188 }
189 alignment_shift_sub_blocks += kAlignmentShiftSubBlocks;
190 }
191 ASSERT_TRUE(expected_most_accurate_lag_estimate);
192
193 // Verify that the expected most accurate lag estimate is the most
194 // accurate estimate.
195 for (size_t k = 0; k < kNumMatchedFilters; ++k) {
196 if (k != *expected_most_accurate_lag_estimate &&
197 k != (*expected_most_accurate_lag_estimate + 1)) {
198 EXPECT_TRUE(
199 lag_estimates[*expected_most_accurate_lag_estimate].accuracy >
200 lag_estimates[k].accuracy ||
201 !lag_estimates[k].reliable ||
202 !lag_estimates[*expected_most_accurate_lag_estimate].reliable);
203 }
204 }
205
206 // Verify that all lag estimates are updated as expected for signals
207 // containing strong noise.
208 for (auto& le : lag_estimates) {
209 EXPECT_TRUE(le.updated);
210 }
211
212 // Verify that the expected most accurate lag estimate is reliable.
213 EXPECT_TRUE(
214 lag_estimates[*expected_most_accurate_lag_estimate].reliable ||
215 lag_estimates[std::min(*expected_most_accurate_lag_estimate + 1,
216 lag_estimates.size() - 1)]
217 .reliable);
218
219 // Verify that the expected most accurate lag estimate is correct.
220 if (lag_estimates[*expected_most_accurate_lag_estimate].reliable) {
221 EXPECT_TRUE(delay_samples ==
222 lag_estimates[*expected_most_accurate_lag_estimate].lag);
223 } else {
224 EXPECT_TRUE(
225 delay_samples ==
226 lag_estimates[std::min(*expected_most_accurate_lag_estimate + 1,
227 lag_estimates.size() - 1)]
228 .lag);
229 }
230 }
231 }
232 }
233
234 // Verifies that the matched filter does not produce reliable and accurate
235 // estimates for uncorrelated render and capture signals.
TEST(MatchedFilter,LagNotReliableForUncorrelatedRenderAndCapture)236 TEST(MatchedFilter, LagNotReliableForUncorrelatedRenderAndCapture) {
237 Random random_generator(42U);
238 for (auto down_sampling_factor : kDownSamplingFactors) {
239 const size_t sub_block_size = kBlockSize / down_sampling_factor;
240
241 std::vector<std::vector<float>> render(3,
242 std::vector<float>(kBlockSize, 0.f));
243 std::array<float, kBlockSize> capture_data;
244 rtc::ArrayView<float> capture(capture_data.data(), sub_block_size);
245 std::fill(capture.begin(), capture.end(), 0.f);
246 ApmDataDumper data_dumper(0);
247 std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
248 RenderDelayBuffer::Create(
249 3, down_sampling_factor,
250 GetDownSampledBufferSize(down_sampling_factor, kNumMatchedFilters),
251 GetRenderDelayBufferSize(down_sampling_factor,
252 kNumMatchedFilters)));
253 MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size,
254 kWindowSizeSubBlocks, kNumMatchedFilters,
255 kAlignmentShiftSubBlocks, 150);
256
257 // Analyze the correlation between render and capture.
258 for (size_t k = 0; k < 100; ++k) {
259 RandomizeSampleVector(&random_generator, render[0]);
260 RandomizeSampleVector(&random_generator, capture);
261 render_delay_buffer->Insert(render);
262 filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(), capture);
263 }
264
265 // Obtain the lag estimates.
266 auto lag_estimates = filter.GetLagEstimates();
267 EXPECT_EQ(kNumMatchedFilters, lag_estimates.size());
268
269 // Verify that no lag estimates are reliable.
270 for (auto& le : lag_estimates) {
271 EXPECT_FALSE(le.reliable);
272 }
273 }
274 }
275
276 // Verifies that the matched filter does not produce updated lag estimates for
277 // render signals of low level.
TEST(MatchedFilter,LagNotUpdatedForLowLevelRender)278 TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) {
279 Random random_generator(42U);
280 for (auto down_sampling_factor : kDownSamplingFactors) {
281 const size_t sub_block_size = kBlockSize / down_sampling_factor;
282
283 std::vector<std::vector<float>> render(3,
284 std::vector<float>(kBlockSize, 0.f));
285 std::array<float, kBlockSize> capture;
286 capture.fill(0.f);
287 ApmDataDumper data_dumper(0);
288 MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size,
289 kWindowSizeSubBlocks, kNumMatchedFilters,
290 kAlignmentShiftSubBlocks, 150);
291 std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
292 RenderDelayBuffer::Create(
293 3, down_sampling_factor,
294 GetDownSampledBufferSize(down_sampling_factor, kNumMatchedFilters),
295 GetRenderDelayBufferSize(down_sampling_factor,
296 kNumMatchedFilters)));
297 Decimator capture_decimator(down_sampling_factor);
298
299 // Analyze the correlation between render and capture.
300 for (size_t k = 0; k < 100; ++k) {
301 RandomizeSampleVector(&random_generator, render[0]);
302 for (auto& render_k : render[0]) {
303 render_k *= 149.f / 32767.f;
304 }
305 std::copy(render[0].begin(), render[0].end(), capture.begin());
306 std::array<float, kBlockSize> downsampled_capture_data;
307 rtc::ArrayView<float> downsampled_capture(downsampled_capture_data.data(),
308 sub_block_size);
309 capture_decimator.Decimate(capture, downsampled_capture);
310 filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(),
311 downsampled_capture);
312 }
313
314 // Obtain the lag estimates.
315 auto lag_estimates = filter.GetLagEstimates();
316 EXPECT_EQ(kNumMatchedFilters, lag_estimates.size());
317
318 // Verify that no lag estimates are updated and that no lag estimates are
319 // reliable.
320 for (auto& le : lag_estimates) {
321 EXPECT_FALSE(le.updated);
322 EXPECT_FALSE(le.reliable);
323 }
324 }
325 }
326
327 // Verifies that the correct number of lag estimates are produced for a certain
328 // number of alignment shifts.
TEST(MatchedFilter,NumberOfLagEstimates)329 TEST(MatchedFilter, NumberOfLagEstimates) {
330 ApmDataDumper data_dumper(0);
331 for (auto down_sampling_factor : kDownSamplingFactors) {
332 const size_t sub_block_size = kBlockSize / down_sampling_factor;
333 for (size_t num_matched_filters = 0; num_matched_filters < 10;
334 ++num_matched_filters) {
335 MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size,
336 32, num_matched_filters, 1, 150);
337 EXPECT_EQ(num_matched_filters, filter.GetLagEstimates().size());
338 }
339 }
340 }
341
342 #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
343
344 // Verifies the check for non-zero windows size.
TEST(MatchedFilter,ZeroWindowSize)345 TEST(MatchedFilter, ZeroWindowSize) {
346 ApmDataDumper data_dumper(0);
347 EXPECT_DEATH(
348 MatchedFilter(&data_dumper, DetectOptimization(), 16, 0, 1, 1, 150), "");
349 }
350
351 // Verifies the check for non-null data dumper.
TEST(MatchedFilter,NullDataDumper)352 TEST(MatchedFilter, NullDataDumper) {
353 EXPECT_DEATH(MatchedFilter(nullptr, DetectOptimization(), 16, 1, 1, 1, 150),
354 "");
355 }
356
357 // Verifies the check for that the sub block size is a multiple of 4.
358 // TODO(peah): Activate the unittest once the required code has been landed.
TEST(MatchedFilter,DISABLED_BlockSizeMultipleOf4)359 TEST(MatchedFilter, DISABLED_BlockSizeMultipleOf4) {
360 ApmDataDumper data_dumper(0);
361 EXPECT_DEATH(
362 MatchedFilter(&data_dumper, DetectOptimization(), 15, 1, 1, 1, 150), "");
363 }
364
365 // Verifies the check for that there is an integer number of sub blocks that add
366 // up to a block size.
367 // TODO(peah): Activate the unittest once the required code has been landed.
TEST(MatchedFilter,DISABLED_SubBlockSizeAddsUpToBlockSize)368 TEST(MatchedFilter, DISABLED_SubBlockSizeAddsUpToBlockSize) {
369 ApmDataDumper data_dumper(0);
370 EXPECT_DEATH(
371 MatchedFilter(&data_dumper, DetectOptimization(), 12, 1, 1, 1, 150), "");
372 }
373
374 #endif
375
376 } // namespace aec3
377 } // namespace webrtc
378