1 /*
2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/aec3/main_filter_update_gain.h"
12
13 #include <algorithm>
14 #include <functional>
15
16 #include "modules/audio_processing/aec3/aec3_common.h"
17 #include "modules/audio_processing/logging/apm_data_dumper.h"
18 #include "rtc_base/atomicops.h"
19 #include "rtc_base/checks.h"
20
21 namespace webrtc {
22 namespace {
23
24 constexpr float kHErrorInitial = 10000.f;
25 constexpr int kPoorExcitationCounterInitial = 1000;
26
27 } // namespace
28
29 int MainFilterUpdateGain::instance_count_ = 0;
30
MainFilterUpdateGain()31 MainFilterUpdateGain::MainFilterUpdateGain()
32 : data_dumper_(
33 new ApmDataDumper(rtc::AtomicOps::Increment(&instance_count_))),
34 poor_excitation_counter_(kPoorExcitationCounterInitial) {
35 H_error_.fill(kHErrorInitial);
36 }
37
~MainFilterUpdateGain()38 MainFilterUpdateGain::~MainFilterUpdateGain() {}
39
HandleEchoPathChange()40 void MainFilterUpdateGain::HandleEchoPathChange() {
41 H_error_.fill(kHErrorInitial);
42 poor_excitation_counter_ = kPoorExcitationCounterInitial;
43 call_counter_ = 0;
44 }
45
Compute(const RenderBuffer & render_buffer,const RenderSignalAnalyzer & render_signal_analyzer,const SubtractorOutput & subtractor_output,const AdaptiveFirFilter & filter,bool saturated_capture_signal,FftData * gain_fft)46 void MainFilterUpdateGain::Compute(
47 const RenderBuffer& render_buffer,
48 const RenderSignalAnalyzer& render_signal_analyzer,
49 const SubtractorOutput& subtractor_output,
50 const AdaptiveFirFilter& filter,
51 bool saturated_capture_signal,
52 FftData* gain_fft) {
53 RTC_DCHECK(gain_fft);
54 // Introducing shorter notation to improve readability.
55 const FftData& E_main = subtractor_output.E_main;
56 const auto& E2_main = subtractor_output.E2_main;
57 const auto& E2_shadow = subtractor_output.E2_shadow;
58 FftData* G = gain_fft;
59 const size_t size_partitions = filter.SizePartitions();
60 const auto& X2 = render_buffer.SpectralSum(size_partitions);
61 const auto& erl = filter.Erl();
62
63 ++call_counter_;
64
65 if (render_signal_analyzer.PoorSignalExcitation()) {
66 poor_excitation_counter_ = 0;
67 }
68
69 // Do not update the filter if the render is not sufficiently excited.
70 if (++poor_excitation_counter_ < size_partitions ||
71 saturated_capture_signal || call_counter_ <= size_partitions) {
72 G->re.fill(0.f);
73 G->im.fill(0.f);
74 } else {
75 // Corresponds to WGN of power -39 dBFS.
76 constexpr float kNoiseGatePower = 220075344.f;
77 std::array<float, kFftLengthBy2Plus1> mu;
78 // mu = H_error / (0.5* H_error* X2 + n * E2).
79 for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
80 mu[k] = X2[k] > kNoiseGatePower
81 ? H_error_[k] / (0.5f * H_error_[k] * X2[k] +
82 size_partitions * E2_main[k])
83 : 0.f;
84 }
85
86 // Avoid updating the filter close to narrow bands in the render signals.
87 render_signal_analyzer.MaskRegionsAroundNarrowBands(&mu);
88
89 // H_error = H_error - 0.5 * mu * X2 * H_error.
90 for (size_t k = 0; k < H_error_.size(); ++k) {
91 H_error_[k] -= 0.5f * mu[k] * X2[k] * H_error_[k];
92 }
93
94 // G = mu * E.
95 std::transform(mu.begin(), mu.end(), E_main.re.begin(), G->re.begin(),
96 std::multiplies<float>());
97 std::transform(mu.begin(), mu.end(), E_main.im.begin(), G->im.begin(),
98 std::multiplies<float>());
99 }
100
101 // H_error = H_error + factor * erl.
102 std::array<float, kFftLengthBy2Plus1> H_error_increase;
103 constexpr float kErlScaleAccurate = 1.f / 100.0f;
104 constexpr float kErlScaleInaccurate = 1.f / 60.0f;
105 std::transform(E2_shadow.begin(), E2_shadow.end(), E2_main.begin(),
106 H_error_increase.begin(), [&](float a, float b) {
107 return a >= b ? kErlScaleAccurate : kErlScaleInaccurate;
108 });
109 std::transform(erl.begin(), erl.end(), H_error_increase.begin(),
110 H_error_increase.begin(), std::multiplies<float>());
111 std::transform(H_error_.begin(), H_error_.end(), H_error_increase.begin(),
112 H_error_.begin(),
113 [&](float a, float b) { return std::max(a + b, 0.1f); });
114
115 data_dumper_->DumpRaw("aec3_main_gain_H_error", H_error_);
116 }
117
118 } // namespace webrtc
119