1 /*
2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/aec3/refined_filter_update_gain.h"
12
13 #include <algorithm>
14 #include <functional>
15
16 #include "modules/audio_processing/aec3/adaptive_fir_filter.h"
17 #include "modules/audio_processing/aec3/aec3_common.h"
18 #include "modules/audio_processing/aec3/echo_path_variability.h"
19 #include "modules/audio_processing/aec3/fft_data.h"
20 #include "modules/audio_processing/aec3/render_signal_analyzer.h"
21 #include "modules/audio_processing/aec3/subtractor_output.h"
22 #include "modules/audio_processing/logging/apm_data_dumper.h"
23 #include "rtc_base/atomic_ops.h"
24 #include "rtc_base/checks.h"
25
26 namespace webrtc {
27 namespace {
28
29 constexpr float kHErrorInitial = 10000.f;
30 constexpr int kPoorExcitationCounterInitial = 1000;
31
32 } // namespace
33
34 int RefinedFilterUpdateGain::instance_count_ = 0;
35
RefinedFilterUpdateGain(const EchoCanceller3Config::Filter::RefinedConfiguration & config,size_t config_change_duration_blocks)36 RefinedFilterUpdateGain::RefinedFilterUpdateGain(
37 const EchoCanceller3Config::Filter::RefinedConfiguration& config,
38 size_t config_change_duration_blocks)
39 : data_dumper_(
40 new ApmDataDumper(rtc::AtomicOps::Increment(&instance_count_))),
41 config_change_duration_blocks_(
42 static_cast<int>(config_change_duration_blocks)),
43 poor_excitation_counter_(kPoorExcitationCounterInitial) {
44 SetConfig(config, true);
45 H_error_.fill(kHErrorInitial);
46 RTC_DCHECK_LT(0, config_change_duration_blocks_);
47 one_by_config_change_duration_blocks_ = 1.f / config_change_duration_blocks_;
48 }
49
~RefinedFilterUpdateGain()50 RefinedFilterUpdateGain::~RefinedFilterUpdateGain() {}
51
HandleEchoPathChange(const EchoPathVariability & echo_path_variability)52 void RefinedFilterUpdateGain::HandleEchoPathChange(
53 const EchoPathVariability& echo_path_variability) {
54 if (echo_path_variability.gain_change) {
55 // TODO(bugs.webrtc.org/9526) Handle gain changes.
56 }
57
58 if (echo_path_variability.delay_change !=
59 EchoPathVariability::DelayAdjustment::kNone) {
60 H_error_.fill(kHErrorInitial);
61 }
62
63 if (!echo_path_variability.gain_change) {
64 poor_excitation_counter_ = kPoorExcitationCounterInitial;
65 call_counter_ = 0;
66 }
67 }
68
Compute(const std::array<float,kFftLengthBy2Plus1> & render_power,const RenderSignalAnalyzer & render_signal_analyzer,const SubtractorOutput & subtractor_output,rtc::ArrayView<const float> erl,size_t size_partitions,bool saturated_capture_signal,bool disallow_leakage_diverged,FftData * gain_fft)69 void RefinedFilterUpdateGain::Compute(
70 const std::array<float, kFftLengthBy2Plus1>& render_power,
71 const RenderSignalAnalyzer& render_signal_analyzer,
72 const SubtractorOutput& subtractor_output,
73 rtc::ArrayView<const float> erl,
74 size_t size_partitions,
75 bool saturated_capture_signal,
76 bool disallow_leakage_diverged,
77 FftData* gain_fft) {
78 RTC_DCHECK(gain_fft);
79 // Introducing shorter notation to improve readability.
80 const FftData& E_refined = subtractor_output.E_refined;
81 const auto& E2_refined = subtractor_output.E2_refined;
82 const auto& E2_coarse = subtractor_output.E2_coarse;
83 FftData* G = gain_fft;
84 const auto& X2 = render_power;
85
86 ++call_counter_;
87
88 UpdateCurrentConfig();
89
90 if (render_signal_analyzer.PoorSignalExcitation()) {
91 poor_excitation_counter_ = 0;
92 }
93
94 // Do not update the filter if the render is not sufficiently excited.
95 if (++poor_excitation_counter_ < size_partitions ||
96 saturated_capture_signal || call_counter_ <= size_partitions) {
97 G->re.fill(0.f);
98 G->im.fill(0.f);
99 } else {
100 // Corresponds to WGN of power -39 dBFS.
101 std::array<float, kFftLengthBy2Plus1> mu;
102 // mu = H_error / (0.5* H_error* X2 + n * E2).
103 for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
104 if (X2[k] >= current_config_.noise_gate) {
105 mu[k] = H_error_[k] /
106 (0.5f * H_error_[k] * X2[k] + size_partitions * E2_refined[k]);
107 } else {
108 mu[k] = 0.f;
109 }
110 }
111
112 // Avoid updating the filter close to narrow bands in the render signals.
113 render_signal_analyzer.MaskRegionsAroundNarrowBands(&mu);
114
115 // H_error = H_error - 0.5 * mu * X2 * H_error.
116 for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
117 H_error_[k] -= 0.5f * mu[k] * X2[k] * H_error_[k];
118 }
119
120 // G = mu * E.
121 for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
122 G->re[k] = mu[k] * E_refined.re[k];
123 G->im[k] = mu[k] * E_refined.im[k];
124 }
125 }
126
127 // H_error = H_error + factor * erl.
128 for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
129 if (E2_refined[k] <= E2_coarse[k] || disallow_leakage_diverged) {
130 H_error_[k] += current_config_.leakage_converged * erl[k];
131 } else {
132 H_error_[k] += current_config_.leakage_diverged * erl[k];
133 }
134
135 H_error_[k] = std::max(H_error_[k], current_config_.error_floor);
136 H_error_[k] = std::min(H_error_[k], current_config_.error_ceil);
137 }
138
139 data_dumper_->DumpRaw("aec3_refined_gain_H_error", H_error_);
140 }
141
UpdateCurrentConfig()142 void RefinedFilterUpdateGain::UpdateCurrentConfig() {
143 RTC_DCHECK_GE(config_change_duration_blocks_, config_change_counter_);
144 if (config_change_counter_ > 0) {
145 if (--config_change_counter_ > 0) {
146 auto average = [](float from, float to, float from_weight) {
147 return from * from_weight + to * (1.f - from_weight);
148 };
149
150 float change_factor =
151 config_change_counter_ * one_by_config_change_duration_blocks_;
152
153 current_config_.leakage_converged =
154 average(old_target_config_.leakage_converged,
155 target_config_.leakage_converged, change_factor);
156 current_config_.leakage_diverged =
157 average(old_target_config_.leakage_diverged,
158 target_config_.leakage_diverged, change_factor);
159 current_config_.error_floor =
160 average(old_target_config_.error_floor, target_config_.error_floor,
161 change_factor);
162 current_config_.error_ceil =
163 average(old_target_config_.error_ceil, target_config_.error_ceil,
164 change_factor);
165 current_config_.noise_gate =
166 average(old_target_config_.noise_gate, target_config_.noise_gate,
167 change_factor);
168 } else {
169 current_config_ = old_target_config_ = target_config_;
170 }
171 }
172 RTC_DCHECK_LE(0, config_change_counter_);
173 }
174
175 } // namespace webrtc
176