1 /*
2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/aec3/matched_filter_lag_aggregator.h"
12
13 #include <sstream>
14 #include <string>
15 #include <vector>
16
17 #include "api/array_view.h"
18 #include "modules/audio_processing/aec3/aec3_common.h"
19 #include "modules/audio_processing/logging/apm_data_dumper.h"
20 #include "test/gtest.h"
21
22 namespace webrtc {
23 namespace {
24
25 constexpr size_t kNumLagsBeforeDetection = 25;
26
27 } // namespace
28
29 // Verifies that the most accurate lag estimate is chosen.
TEST(MatchedFilterLagAggregator,MostAccurateLagChosen)30 TEST(MatchedFilterLagAggregator, MostAccurateLagChosen) {
31 constexpr size_t kLag1 = 5;
32 constexpr size_t kLag2 = 10;
33 ApmDataDumper data_dumper(0);
34 std::vector<MatchedFilter::LagEstimate> lag_estimates(2);
35 MatchedFilterLagAggregator aggregator(&data_dumper, std::max(kLag1, kLag2));
36 lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, kLag1, true);
37 lag_estimates[1] = MatchedFilter::LagEstimate(0.5f, true, kLag2, true);
38
39 for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) {
40 EXPECT_FALSE(aggregator.Aggregate(lag_estimates));
41 }
42
43 rtc::Optional<size_t> aggregated_lag = aggregator.Aggregate(lag_estimates);
44 EXPECT_TRUE(aggregated_lag);
45 EXPECT_EQ(kLag1, *aggregated_lag);
46
47 lag_estimates[0] = MatchedFilter::LagEstimate(0.5f, true, kLag1, true);
48 lag_estimates[1] = MatchedFilter::LagEstimate(1.f, true, kLag2, true);
49
50 for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) {
51 aggregated_lag = aggregator.Aggregate(lag_estimates);
52 EXPECT_TRUE(aggregated_lag);
53 EXPECT_EQ(kLag1, *aggregated_lag);
54 }
55
56 aggregated_lag = aggregator.Aggregate(lag_estimates);
57 aggregated_lag = aggregator.Aggregate(lag_estimates);
58 EXPECT_TRUE(aggregated_lag);
59 EXPECT_EQ(kLag2, *aggregated_lag);
60 }
61
62 // Verifies that varying lag estimates causes lag estimates to not be deemed
63 // reliable.
TEST(MatchedFilterLagAggregator,LagEstimateInvarianceRequiredForAggregatedLag)64 TEST(MatchedFilterLagAggregator,
65 LagEstimateInvarianceRequiredForAggregatedLag) {
66 ApmDataDumper data_dumper(0);
67 std::vector<MatchedFilter::LagEstimate> lag_estimates(1);
68 MatchedFilterLagAggregator aggregator(&data_dumper, 100);
69 for (size_t k = 0; k < kNumLagsBeforeDetection * 100; ++k) {
70 lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, k % 100, true);
71 rtc::Optional<size_t> aggregated_lag = aggregator.Aggregate(lag_estimates);
72 EXPECT_FALSE(aggregated_lag);
73 }
74 }
75
76 // Verifies that lag estimate updates are required to produce an updated lag
77 // aggregate.
TEST(MatchedFilterLagAggregator,DISABLED_LagEstimateUpdatesRequiredForAggregatedLag)78 TEST(MatchedFilterLagAggregator,
79 DISABLED_LagEstimateUpdatesRequiredForAggregatedLag) {
80 constexpr size_t kLag = 5;
81 ApmDataDumper data_dumper(0);
82 std::vector<MatchedFilter::LagEstimate> lag_estimates(1);
83 MatchedFilterLagAggregator aggregator(&data_dumper, kLag);
84 for (size_t k = 0; k < kNumLagsBeforeDetection * 10; ++k) {
85 lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, kLag, false);
86 rtc::Optional<size_t> aggregated_lag = aggregator.Aggregate(lag_estimates);
87 EXPECT_FALSE(aggregated_lag);
88 EXPECT_EQ(kLag, *aggregated_lag);
89 }
90 }
91
92 // Verifies that an aggregated lag is persistent if the lag estimates do not
93 // change and that an aggregated lag is not produced without gaining lag
94 // estimate confidence.
TEST(MatchedFilterLagAggregator,DISABLED_PersistentAggregatedLag)95 TEST(MatchedFilterLagAggregator, DISABLED_PersistentAggregatedLag) {
96 constexpr size_t kLag1 = 5;
97 constexpr size_t kLag2 = 10;
98 ApmDataDumper data_dumper(0);
99 std::vector<MatchedFilter::LagEstimate> lag_estimates(1);
100 MatchedFilterLagAggregator aggregator(&data_dumper, std::max(kLag1, kLag2));
101 rtc::Optional<size_t> aggregated_lag;
102 for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) {
103 lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, kLag1, true);
104 aggregated_lag = aggregator.Aggregate(lag_estimates);
105 }
106 EXPECT_TRUE(aggregated_lag);
107 EXPECT_EQ(kLag1, *aggregated_lag);
108
109 for (size_t k = 0; k < kNumLagsBeforeDetection * 40; ++k) {
110 lag_estimates[0] = MatchedFilter::LagEstimate(1.f, false, kLag2, true);
111 aggregated_lag = aggregator.Aggregate(lag_estimates);
112 EXPECT_TRUE(aggregated_lag);
113 EXPECT_EQ(kLag1, *aggregated_lag);
114 }
115 }
116
117 #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
118
119 // Verifies the check for non-null data dumper.
TEST(MatchedFilterLagAggregator,NullDataDumper)120 TEST(MatchedFilterLagAggregator, NullDataDumper) {
121 EXPECT_DEATH(MatchedFilterLagAggregator(nullptr, 10), "");
122 }
123
124 #endif
125
126 } // namespace webrtc
127