1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9   */
10 
11 #include "modules/audio_processing/aec3/matched_filter.h"
12 
13 #include "typedefs.h"  // NOLINT(build/include)
14 #if defined(WEBRTC_ARCH_X86_FAMILY)
15 #include <emmintrin.h>
16 #endif
17 #include <algorithm>
18 #include <sstream>
19 #include <string>
20 
21 #include "modules/audio_processing/aec3/aec3_common.h"
22 #include "modules/audio_processing/aec3/decimator.h"
23 #include "modules/audio_processing/aec3/render_delay_buffer.h"
24 #include "modules/audio_processing/logging/apm_data_dumper.h"
25 #include "modules/audio_processing/test/echo_canceller_test_tools.h"
26 #include "rtc_base/random.h"
27 #include "system_wrappers/include/cpu_features_wrapper.h"
28 #include "test/gtest.h"
29 
30 namespace webrtc {
31 namespace aec3 {
32 namespace {
33 
ProduceDebugText(size_t delay,size_t down_sampling_factor)34 std::string ProduceDebugText(size_t delay, size_t down_sampling_factor) {
35   std::ostringstream ss;
36   ss << "Delay: " << delay;
37   ss << ", Down sampling factor: " << down_sampling_factor;
38   return ss.str();
39 }
40 
41 constexpr size_t kNumMatchedFilters = 10;
42 constexpr size_t kDownSamplingFactors[] = {2, 4, 8};
43 constexpr size_t kWindowSizeSubBlocks = 32;
44 constexpr size_t kAlignmentShiftSubBlocks = kWindowSizeSubBlocks * 3 / 4;
45 
46 }  // namespace
47 
48 #if defined(WEBRTC_HAS_NEON)
49 // Verifies that the optimized methods for NEON are similar to their reference
50 // counterparts.
TEST(MatchedFilter,TestNeonOptimizations)51 TEST(MatchedFilter, TestNeonOptimizations) {
52   Random random_generator(42U);
53   for (auto down_sampling_factor : kDownSamplingFactors) {
54     const size_t sub_block_size = kBlockSize / down_sampling_factor;
55 
56     std::vector<float> x(2000);
57     RandomizeSampleVector(&random_generator, x);
58     std::vector<float> y(sub_block_size);
59     std::vector<float> h_NEON(512);
60     std::vector<float> h(512);
61     int x_index = 0;
62     for (int k = 0; k < 1000; ++k) {
63       RandomizeSampleVector(&random_generator, y);
64 
65       bool filters_updated = false;
66       float error_sum = 0.f;
67       bool filters_updated_NEON = false;
68       float error_sum_NEON = 0.f;
69 
70       MatchedFilterCore_NEON(x_index, h.size() * 150.f * 150.f, x, y, h_NEON,
71                              &filters_updated_NEON, &error_sum_NEON);
72 
73       MatchedFilterCore(x_index, h.size() * 150.f * 150.f, x, y, h,
74                         &filters_updated, &error_sum);
75 
76       EXPECT_EQ(filters_updated, filters_updated_NEON);
77       EXPECT_NEAR(error_sum, error_sum_NEON, error_sum / 100000.f);
78 
79       for (size_t j = 0; j < h.size(); ++j) {
80         EXPECT_NEAR(h[j], h_NEON[j], 0.00001f);
81       }
82 
83       x_index = (x_index + sub_block_size) % x.size();
84     }
85   }
86 }
87 #endif
88 
89 #if defined(WEBRTC_ARCH_X86_FAMILY)
90 // Verifies that the optimized methods for SSE2 are bitexact to their reference
91 // counterparts.
TEST(MatchedFilter,TestSse2Optimizations)92 TEST(MatchedFilter, TestSse2Optimizations) {
93   bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0);
94   if (use_sse2) {
95     Random random_generator(42U);
96     for (auto down_sampling_factor : kDownSamplingFactors) {
97       const size_t sub_block_size = kBlockSize / down_sampling_factor;
98       std::vector<float> x(2000);
99       RandomizeSampleVector(&random_generator, x);
100       std::vector<float> y(sub_block_size);
101       std::vector<float> h_SSE2(512);
102       std::vector<float> h(512);
103       int x_index = 0;
104       for (int k = 0; k < 1000; ++k) {
105         RandomizeSampleVector(&random_generator, y);
106 
107         bool filters_updated = false;
108         float error_sum = 0.f;
109         bool filters_updated_SSE2 = false;
110         float error_sum_SSE2 = 0.f;
111 
112         MatchedFilterCore_SSE2(x_index, h.size() * 150.f * 150.f, x, y, h_SSE2,
113                                &filters_updated_SSE2, &error_sum_SSE2);
114 
115         MatchedFilterCore(x_index, h.size() * 150.f * 150.f, x, y, h,
116                           &filters_updated, &error_sum);
117 
118         EXPECT_EQ(filters_updated, filters_updated_SSE2);
119         EXPECT_NEAR(error_sum, error_sum_SSE2, error_sum / 100000.f);
120 
121         for (size_t j = 0; j < h.size(); ++j) {
122           EXPECT_NEAR(h[j], h_SSE2[j], 0.00001f);
123         }
124 
125         x_index = (x_index + sub_block_size) % x.size();
126       }
127     }
128   }
129 }
130 
131 #endif
132 
133 // Verifies that the matched filter produces proper lag estimates for
134 // artificially
135 // delayed signals.
TEST(MatchedFilter,LagEstimation)136 TEST(MatchedFilter, LagEstimation) {
137   Random random_generator(42U);
138   for (auto down_sampling_factor : kDownSamplingFactors) {
139     const size_t sub_block_size = kBlockSize / down_sampling_factor;
140 
141     std::vector<std::vector<float>> render(3,
142                                            std::vector<float>(kBlockSize, 0.f));
143     std::array<float, kBlockSize> capture;
144     capture.fill(0.f);
145     ApmDataDumper data_dumper(0);
146     for (size_t delay_samples : {5, 64, 150, 200, 800, 1000}) {
147       SCOPED_TRACE(ProduceDebugText(delay_samples, down_sampling_factor));
148       Decimator capture_decimator(down_sampling_factor);
149       DelayBuffer<float> signal_delay_buffer(down_sampling_factor *
150                                              delay_samples);
151       MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size,
152                            kWindowSizeSubBlocks, kNumMatchedFilters,
153                            kAlignmentShiftSubBlocks, 150);
154       std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
155           RenderDelayBuffer::Create(
156               3, down_sampling_factor,
157               GetDownSampledBufferSize(down_sampling_factor,
158                                        kNumMatchedFilters),
159               GetRenderDelayBufferSize(down_sampling_factor,
160                                        kNumMatchedFilters)));
161 
162       // Analyze the correlation between render and capture.
163       for (size_t k = 0; k < (300 + delay_samples / sub_block_size); ++k) {
164         RandomizeSampleVector(&random_generator, render[0]);
165         signal_delay_buffer.Delay(render[0], capture);
166         render_delay_buffer->Insert(render);
167         render_delay_buffer->UpdateBuffers();
168         std::array<float, kBlockSize> downsampled_capture_data;
169         rtc::ArrayView<float> downsampled_capture(
170             downsampled_capture_data.data(), sub_block_size);
171         capture_decimator.Decimate(capture, downsampled_capture);
172         filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(),
173                       downsampled_capture);
174       }
175 
176       // Obtain the lag estimates.
177       auto lag_estimates = filter.GetLagEstimates();
178 
179       // Find which lag estimate should be the most accurate.
180       rtc::Optional<size_t> expected_most_accurate_lag_estimate;
181       size_t alignment_shift_sub_blocks = 0;
182       for (size_t k = 0; k < kNumMatchedFilters; ++k) {
183         if ((alignment_shift_sub_blocks + 3 * kWindowSizeSubBlocks / 4) *
184                 sub_block_size >
185             delay_samples) {
186           expected_most_accurate_lag_estimate = k > 0 ? k - 1 : 0;
187           break;
188         }
189         alignment_shift_sub_blocks += kAlignmentShiftSubBlocks;
190       }
191       ASSERT_TRUE(expected_most_accurate_lag_estimate);
192 
193       // Verify that the expected most accurate lag estimate is the most
194       // accurate estimate.
195       for (size_t k = 0; k < kNumMatchedFilters; ++k) {
196         if (k != *expected_most_accurate_lag_estimate &&
197             k != (*expected_most_accurate_lag_estimate + 1)) {
198           EXPECT_TRUE(
199               lag_estimates[*expected_most_accurate_lag_estimate].accuracy >
200                   lag_estimates[k].accuracy ||
201               !lag_estimates[k].reliable ||
202               !lag_estimates[*expected_most_accurate_lag_estimate].reliable);
203         }
204       }
205 
206       // Verify that all lag estimates are updated as expected for signals
207       // containing strong noise.
208       for (auto& le : lag_estimates) {
209         EXPECT_TRUE(le.updated);
210       }
211 
212       // Verify that the expected most accurate lag estimate is reliable.
213       EXPECT_TRUE(
214           lag_estimates[*expected_most_accurate_lag_estimate].reliable ||
215           lag_estimates[std::min(*expected_most_accurate_lag_estimate + 1,
216                                  lag_estimates.size() - 1)]
217               .reliable);
218 
219       // Verify that the expected most accurate lag estimate is correct.
220       if (lag_estimates[*expected_most_accurate_lag_estimate].reliable) {
221         EXPECT_TRUE(delay_samples ==
222                     lag_estimates[*expected_most_accurate_lag_estimate].lag);
223       } else {
224         EXPECT_TRUE(
225             delay_samples ==
226             lag_estimates[std::min(*expected_most_accurate_lag_estimate + 1,
227                                    lag_estimates.size() - 1)]
228                 .lag);
229       }
230     }
231   }
232 }
233 
234 // Verifies that the matched filter does not produce reliable and accurate
235 // estimates for uncorrelated render and capture signals.
TEST(MatchedFilter,LagNotReliableForUncorrelatedRenderAndCapture)236 TEST(MatchedFilter, LagNotReliableForUncorrelatedRenderAndCapture) {
237   Random random_generator(42U);
238   for (auto down_sampling_factor : kDownSamplingFactors) {
239     const size_t sub_block_size = kBlockSize / down_sampling_factor;
240 
241     std::vector<std::vector<float>> render(3,
242                                            std::vector<float>(kBlockSize, 0.f));
243     std::array<float, kBlockSize> capture_data;
244     rtc::ArrayView<float> capture(capture_data.data(), sub_block_size);
245     std::fill(capture.begin(), capture.end(), 0.f);
246     ApmDataDumper data_dumper(0);
247     std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
248         RenderDelayBuffer::Create(
249             3, down_sampling_factor,
250             GetDownSampledBufferSize(down_sampling_factor, kNumMatchedFilters),
251             GetRenderDelayBufferSize(down_sampling_factor,
252                                      kNumMatchedFilters)));
253     MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size,
254                          kWindowSizeSubBlocks, kNumMatchedFilters,
255                          kAlignmentShiftSubBlocks, 150);
256 
257     // Analyze the correlation between render and capture.
258     for (size_t k = 0; k < 100; ++k) {
259       RandomizeSampleVector(&random_generator, render[0]);
260       RandomizeSampleVector(&random_generator, capture);
261       render_delay_buffer->Insert(render);
262       filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(), capture);
263     }
264 
265     // Obtain the lag estimates.
266     auto lag_estimates = filter.GetLagEstimates();
267     EXPECT_EQ(kNumMatchedFilters, lag_estimates.size());
268 
269     // Verify that no lag estimates are reliable.
270     for (auto& le : lag_estimates) {
271       EXPECT_FALSE(le.reliable);
272     }
273   }
274 }
275 
276 // Verifies that the matched filter does not produce updated lag estimates for
277 // render signals of low level.
TEST(MatchedFilter,LagNotUpdatedForLowLevelRender)278 TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) {
279   Random random_generator(42U);
280   for (auto down_sampling_factor : kDownSamplingFactors) {
281     const size_t sub_block_size = kBlockSize / down_sampling_factor;
282 
283     std::vector<std::vector<float>> render(3,
284                                            std::vector<float>(kBlockSize, 0.f));
285     std::array<float, kBlockSize> capture;
286     capture.fill(0.f);
287     ApmDataDumper data_dumper(0);
288     MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size,
289                          kWindowSizeSubBlocks, kNumMatchedFilters,
290                          kAlignmentShiftSubBlocks, 150);
291     std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
292         RenderDelayBuffer::Create(
293             3, down_sampling_factor,
294             GetDownSampledBufferSize(down_sampling_factor, kNumMatchedFilters),
295             GetRenderDelayBufferSize(down_sampling_factor,
296                                      kNumMatchedFilters)));
297     Decimator capture_decimator(down_sampling_factor);
298 
299     // Analyze the correlation between render and capture.
300     for (size_t k = 0; k < 100; ++k) {
301       RandomizeSampleVector(&random_generator, render[0]);
302       for (auto& render_k : render[0]) {
303         render_k *= 149.f / 32767.f;
304       }
305       std::copy(render[0].begin(), render[0].end(), capture.begin());
306       std::array<float, kBlockSize> downsampled_capture_data;
307       rtc::ArrayView<float> downsampled_capture(downsampled_capture_data.data(),
308                                                 sub_block_size);
309       capture_decimator.Decimate(capture, downsampled_capture);
310       filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(),
311                     downsampled_capture);
312     }
313 
314     // Obtain the lag estimates.
315     auto lag_estimates = filter.GetLagEstimates();
316     EXPECT_EQ(kNumMatchedFilters, lag_estimates.size());
317 
318     // Verify that no lag estimates are updated and that no lag estimates are
319     // reliable.
320     for (auto& le : lag_estimates) {
321       EXPECT_FALSE(le.updated);
322       EXPECT_FALSE(le.reliable);
323     }
324   }
325 }
326 
327 // Verifies that the correct number of lag estimates are produced for a certain
328 // number of alignment shifts.
TEST(MatchedFilter,NumberOfLagEstimates)329 TEST(MatchedFilter, NumberOfLagEstimates) {
330   ApmDataDumper data_dumper(0);
331   for (auto down_sampling_factor : kDownSamplingFactors) {
332     const size_t sub_block_size = kBlockSize / down_sampling_factor;
333     for (size_t num_matched_filters = 0; num_matched_filters < 10;
334          ++num_matched_filters) {
335       MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size,
336                            32, num_matched_filters, 1, 150);
337       EXPECT_EQ(num_matched_filters, filter.GetLagEstimates().size());
338     }
339   }
340 }
341 
342 #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
343 
344 // Verifies the check for non-zero windows size.
TEST(MatchedFilter,ZeroWindowSize)345 TEST(MatchedFilter, ZeroWindowSize) {
346   ApmDataDumper data_dumper(0);
347   EXPECT_DEATH(
348       MatchedFilter(&data_dumper, DetectOptimization(), 16, 0, 1, 1, 150), "");
349 }
350 
351 // Verifies the check for non-null data dumper.
TEST(MatchedFilter,NullDataDumper)352 TEST(MatchedFilter, NullDataDumper) {
353   EXPECT_DEATH(MatchedFilter(nullptr, DetectOptimization(), 16, 1, 1, 1, 150),
354                "");
355 }
356 
357 // Verifies the check for that the sub block size is a multiple of 4.
358 // TODO(peah): Activate the unittest once the required code has been landed.
TEST(MatchedFilter,DISABLED_BlockSizeMultipleOf4)359 TEST(MatchedFilter, DISABLED_BlockSizeMultipleOf4) {
360   ApmDataDumper data_dumper(0);
361   EXPECT_DEATH(
362       MatchedFilter(&data_dumper, DetectOptimization(), 15, 1, 1, 1, 150), "");
363 }
364 
365 // Verifies the check for that there is an integer number of sub blocks that add
366 // up to a block size.
367 // TODO(peah): Activate the unittest once the required code has been landed.
TEST(MatchedFilter,DISABLED_SubBlockSizeAddsUpToBlockSize)368 TEST(MatchedFilter, DISABLED_SubBlockSizeAddsUpToBlockSize) {
369   ApmDataDumper data_dumper(0);
370   EXPECT_DEATH(
371       MatchedFilter(&data_dumper, DetectOptimization(), 12, 1, 1, 1, 150), "");
372 }
373 
374 #endif
375 
376 }  // namespace aec3
377 }  // namespace webrtc
378