1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/suppression_gain.h"
12 
13 #include "modules/audio_processing/aec3/aec_state.h"
14 #include "modules/audio_processing/aec3/render_buffer.h"
15 #include "modules/audio_processing/aec3/subtractor.h"
16 #include "modules/audio_processing/logging/apm_data_dumper.h"
17 #include "rtc_base/checks.h"
18 #include "system_wrappers/include/cpu_features_wrapper.h"
19 #include "test/gtest.h"
20 #include "typedefs.h"  // NOLINT(build/include)
21 
22 namespace webrtc {
23 namespace aec3 {
24 
25 #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
26 
27 // Verifies that the check for non-null output gains works.
TEST(SuppressionGain,NullOutputGains)28 TEST(SuppressionGain, NullOutputGains) {
29   std::array<float, kFftLengthBy2Plus1> E2;
30   std::array<float, kFftLengthBy2Plus1> R2;
31   std::array<float, kFftLengthBy2Plus1> N2;
32   E2.fill(0.f);
33   R2.fill(0.f);
34   N2.fill(0.f);
35   float high_bands_gain;
36   AecState aec_state(EchoCanceller3Config{});
37   EXPECT_DEATH(SuppressionGain(EchoCanceller3Config{}, DetectOptimization())
38                    .GetGain(E2, R2, N2, RenderSignalAnalyzer(), aec_state,
39                             std::vector<std::vector<float>>(
40                                 3, std::vector<float>(kBlockSize, 0.f)),
41                             &high_bands_gain, nullptr),
42                "");
43 }
44 
45 #endif
46 
47 // Does a sanity check that the gains are correctly computed.
TEST(SuppressionGain,BasicGainComputation)48 TEST(SuppressionGain, BasicGainComputation) {
49   SuppressionGain suppression_gain(EchoCanceller3Config(),
50                                    DetectOptimization());
51   RenderSignalAnalyzer analyzer;
52   float high_bands_gain;
53   std::array<float, kFftLengthBy2Plus1> E2;
54   std::array<float, kFftLengthBy2Plus1> Y2;
55   std::array<float, kFftLengthBy2Plus1> R2;
56   std::array<float, kFftLengthBy2Plus1> N2;
57   std::array<float, kFftLengthBy2Plus1> g;
58   std::array<float, kBlockSize> s;
59   std::vector<std::vector<float>> x(1, std::vector<float>(kBlockSize, 0.f));
60   AecState aec_state(EchoCanceller3Config{});
61   ApmDataDumper data_dumper(42);
62   Subtractor subtractor(&data_dumper, DetectOptimization());
63   RenderBuffer render_buffer(
64       DetectOptimization(), 1,
65       std::max(kUnknownDelayRenderWindowSize, kAdaptiveFilterLength),
66       std::vector<size_t>(1, kAdaptiveFilterLength));
67 
68   // Verify the functionality for forcing a zero gain.
69   E2.fill(1000000000.f);
70   R2.fill(10000000000000.f);
71   N2.fill(0.f);
72   s.fill(10.f);
73   aec_state.Update(
74       subtractor.FilterFrequencyResponse(), subtractor.FilterImpulseResponse(),
75       subtractor.ConvergedFilter(), 10, render_buffer, E2, Y2, x[0], s, false);
76   suppression_gain.GetGain(E2, R2, N2, analyzer, aec_state, x, &high_bands_gain,
77                            &g);
78   std::for_each(g.begin(), g.end(), [](float a) { EXPECT_FLOAT_EQ(0.f, a); });
79   EXPECT_FLOAT_EQ(0.f, high_bands_gain);
80 
81   // Ensure that a strong noise is detected to mask any echoes.
82   E2.fill(10.f);
83   Y2.fill(10.f);
84   R2.fill(0.1f);
85   N2.fill(100.f);
86   // Ensure that the gain is no longer forced to zero.
87   for (int k = 0; k <= kNumBlocksPerSecond / 5 + 1; ++k) {
88     aec_state.Update(subtractor.FilterFrequencyResponse(),
89                      subtractor.FilterImpulseResponse(),
90                      subtractor.ConvergedFilter(), 10, render_buffer, E2, Y2,
91                      x[0], s, false);
92   }
93 
94   for (int k = 0; k < 100; ++k) {
95     aec_state.Update(subtractor.FilterFrequencyResponse(),
96                      subtractor.FilterImpulseResponse(),
97                      subtractor.ConvergedFilter(), 10, render_buffer, E2, Y2,
98                      x[0], s, false);
99     suppression_gain.GetGain(E2, R2, N2, analyzer, aec_state, x,
100                              &high_bands_gain, &g);
101   }
102   std::for_each(g.begin(), g.end(),
103                 [](float a) { EXPECT_NEAR(1.f, a, 0.001); });
104 
105   // Ensure that a strong nearend is detected to mask any echoes.
106   E2.fill(100.f);
107   Y2.fill(100.f);
108   R2.fill(0.1f);
109   N2.fill(0.f);
110   for (int k = 0; k < 100; ++k) {
111     aec_state.Update(subtractor.FilterFrequencyResponse(),
112                      subtractor.FilterImpulseResponse(),
113                      subtractor.ConvergedFilter(), 10, render_buffer, E2, Y2,
114                      x[0], s, false);
115     suppression_gain.GetGain(E2, R2, N2, analyzer, aec_state, x,
116                              &high_bands_gain, &g);
117   }
118   std::for_each(g.begin(), g.end(),
119                 [](float a) { EXPECT_NEAR(1.f, a, 0.001); });
120 
121   // Ensure that a strong echo is suppressed.
122   E2.fill(1000000000.f);
123   R2.fill(10000000000000.f);
124   N2.fill(0.f);
125   for (int k = 0; k < 10; ++k) {
126     suppression_gain.GetGain(E2, R2, N2, analyzer, aec_state, x,
127                              &high_bands_gain, &g);
128   }
129   std::for_each(g.begin(), g.end(),
130                 [](float a) { EXPECT_NEAR(0.f, a, 0.001); });
131 
132 }
133 
134 }  // namespace aec3
135 }  // namespace webrtc
136