1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9   */
10 
11 #include "modules/audio_processing/aec3/matched_filter_lag_aggregator.h"
12 
13 #include <sstream>
14 #include <string>
15 #include <vector>
16 
17 #include "api/array_view.h"
18 #include "modules/audio_processing/aec3/aec3_common.h"
19 #include "modules/audio_processing/logging/apm_data_dumper.h"
20 #include "test/gtest.h"
21 
22 namespace webrtc {
23 namespace {
24 
25 constexpr size_t kNumLagsBeforeDetection = 25;
26 
27 }  // namespace
28 
29 // Verifies that the most accurate lag estimate is chosen.
TEST(MatchedFilterLagAggregator,MostAccurateLagChosen)30 TEST(MatchedFilterLagAggregator, MostAccurateLagChosen) {
31   constexpr size_t kLag1 = 5;
32   constexpr size_t kLag2 = 10;
33   ApmDataDumper data_dumper(0);
34   std::vector<MatchedFilter::LagEstimate> lag_estimates(2);
35   MatchedFilterLagAggregator aggregator(&data_dumper, std::max(kLag1, kLag2));
36   lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, kLag1, true);
37   lag_estimates[1] = MatchedFilter::LagEstimate(0.5f, true, kLag2, true);
38 
39   for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) {
40     EXPECT_FALSE(aggregator.Aggregate(lag_estimates));
41   }
42 
43   rtc::Optional<size_t> aggregated_lag = aggregator.Aggregate(lag_estimates);
44   EXPECT_TRUE(aggregated_lag);
45   EXPECT_EQ(kLag1, *aggregated_lag);
46 
47   lag_estimates[0] = MatchedFilter::LagEstimate(0.5f, true, kLag1, true);
48   lag_estimates[1] = MatchedFilter::LagEstimate(1.f, true, kLag2, true);
49 
50   for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) {
51     aggregated_lag = aggregator.Aggregate(lag_estimates);
52     EXPECT_TRUE(aggregated_lag);
53     EXPECT_EQ(kLag1, *aggregated_lag);
54   }
55 
56   aggregated_lag = aggregator.Aggregate(lag_estimates);
57   aggregated_lag = aggregator.Aggregate(lag_estimates);
58   EXPECT_TRUE(aggregated_lag);
59   EXPECT_EQ(kLag2, *aggregated_lag);
60 }
61 
62 // Verifies that varying lag estimates causes lag estimates to not be deemed
63 // reliable.
TEST(MatchedFilterLagAggregator,LagEstimateInvarianceRequiredForAggregatedLag)64 TEST(MatchedFilterLagAggregator,
65      LagEstimateInvarianceRequiredForAggregatedLag) {
66   ApmDataDumper data_dumper(0);
67   std::vector<MatchedFilter::LagEstimate> lag_estimates(1);
68   MatchedFilterLagAggregator aggregator(&data_dumper, 100);
69   for (size_t k = 0; k < kNumLagsBeforeDetection * 100; ++k) {
70     lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, k % 100, true);
71     rtc::Optional<size_t> aggregated_lag = aggregator.Aggregate(lag_estimates);
72     EXPECT_FALSE(aggregated_lag);
73   }
74 }
75 
76 // Verifies that lag estimate updates are required to produce an updated lag
77 // aggregate.
TEST(MatchedFilterLagAggregator,DISABLED_LagEstimateUpdatesRequiredForAggregatedLag)78 TEST(MatchedFilterLagAggregator,
79      DISABLED_LagEstimateUpdatesRequiredForAggregatedLag) {
80   constexpr size_t kLag = 5;
81   ApmDataDumper data_dumper(0);
82   std::vector<MatchedFilter::LagEstimate> lag_estimates(1);
83   MatchedFilterLagAggregator aggregator(&data_dumper, kLag);
84   for (size_t k = 0; k < kNumLagsBeforeDetection * 10; ++k) {
85     lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, kLag, false);
86     rtc::Optional<size_t> aggregated_lag = aggregator.Aggregate(lag_estimates);
87     EXPECT_FALSE(aggregated_lag);
88     EXPECT_EQ(kLag, *aggregated_lag);
89   }
90 }
91 
92 // Verifies that an aggregated lag is persistent if the lag estimates do not
93 // change and that an aggregated lag is not produced without gaining lag
94 // estimate confidence.
TEST(MatchedFilterLagAggregator,DISABLED_PersistentAggregatedLag)95 TEST(MatchedFilterLagAggregator, DISABLED_PersistentAggregatedLag) {
96   constexpr size_t kLag1 = 5;
97   constexpr size_t kLag2 = 10;
98   ApmDataDumper data_dumper(0);
99   std::vector<MatchedFilter::LagEstimate> lag_estimates(1);
100   MatchedFilterLagAggregator aggregator(&data_dumper, std::max(kLag1, kLag2));
101   rtc::Optional<size_t> aggregated_lag;
102   for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) {
103     lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, kLag1, true);
104     aggregated_lag = aggregator.Aggregate(lag_estimates);
105   }
106   EXPECT_TRUE(aggregated_lag);
107   EXPECT_EQ(kLag1, *aggregated_lag);
108 
109   for (size_t k = 0; k < kNumLagsBeforeDetection * 40; ++k) {
110     lag_estimates[0] = MatchedFilter::LagEstimate(1.f, false, kLag2, true);
111     aggregated_lag = aggregator.Aggregate(lag_estimates);
112     EXPECT_TRUE(aggregated_lag);
113     EXPECT_EQ(kLag1, *aggregated_lag);
114   }
115 }
116 
117 #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
118 
119 // Verifies the check for non-null data dumper.
TEST(MatchedFilterLagAggregator,NullDataDumper)120 TEST(MatchedFilterLagAggregator, NullDataDumper) {
121   EXPECT_DEATH(MatchedFilterLagAggregator(nullptr, 10), "");
122 }
123 
124 #endif
125 
126 }  // namespace webrtc
127