1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/main_filter_update_gain.h"
12 
13 #include <algorithm>
14 #include <functional>
15 
16 #include "modules/audio_processing/aec3/aec3_common.h"
17 #include "modules/audio_processing/logging/apm_data_dumper.h"
18 #include "rtc_base/atomicops.h"
19 #include "rtc_base/checks.h"
20 
21 namespace webrtc {
22 namespace {
23 
24 constexpr float kHErrorInitial = 10000.f;
25 constexpr int kPoorExcitationCounterInitial = 1000;
26 
27 }  // namespace
28 
29 int MainFilterUpdateGain::instance_count_ = 0;
30 
MainFilterUpdateGain()31 MainFilterUpdateGain::MainFilterUpdateGain()
32     : data_dumper_(
33           new ApmDataDumper(rtc::AtomicOps::Increment(&instance_count_))),
34       poor_excitation_counter_(kPoorExcitationCounterInitial) {
35   H_error_.fill(kHErrorInitial);
36 }
37 
~MainFilterUpdateGain()38 MainFilterUpdateGain::~MainFilterUpdateGain() {}
39 
HandleEchoPathChange()40 void MainFilterUpdateGain::HandleEchoPathChange() {
41   H_error_.fill(kHErrorInitial);
42   poor_excitation_counter_ = kPoorExcitationCounterInitial;
43   call_counter_ = 0;
44 }
45 
Compute(const RenderBuffer & render_buffer,const RenderSignalAnalyzer & render_signal_analyzer,const SubtractorOutput & subtractor_output,const AdaptiveFirFilter & filter,bool saturated_capture_signal,FftData * gain_fft)46 void MainFilterUpdateGain::Compute(
47     const RenderBuffer& render_buffer,
48     const RenderSignalAnalyzer& render_signal_analyzer,
49     const SubtractorOutput& subtractor_output,
50     const AdaptiveFirFilter& filter,
51     bool saturated_capture_signal,
52     FftData* gain_fft) {
53   RTC_DCHECK(gain_fft);
54   // Introducing shorter notation to improve readability.
55   const FftData& E_main = subtractor_output.E_main;
56   const auto& E2_main = subtractor_output.E2_main;
57   const auto& E2_shadow = subtractor_output.E2_shadow;
58   FftData* G = gain_fft;
59   const size_t size_partitions = filter.SizePartitions();
60   const auto& X2 = render_buffer.SpectralSum(size_partitions);
61   const auto& erl = filter.Erl();
62 
63   ++call_counter_;
64 
65   if (render_signal_analyzer.PoorSignalExcitation()) {
66     poor_excitation_counter_ = 0;
67   }
68 
69   // Do not update the filter if the render is not sufficiently excited.
70   if (++poor_excitation_counter_ < size_partitions ||
71       saturated_capture_signal || call_counter_ <= size_partitions) {
72     G->re.fill(0.f);
73     G->im.fill(0.f);
74   } else {
75     // Corresponds to WGN of power -39 dBFS.
76     constexpr float kNoiseGatePower = 220075344.f;
77     std::array<float, kFftLengthBy2Plus1> mu;
78     // mu = H_error / (0.5* H_error* X2 + n * E2).
79     for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
80       mu[k] = X2[k] > kNoiseGatePower
81                   ? H_error_[k] / (0.5f * H_error_[k] * X2[k] +
82                                    size_partitions * E2_main[k])
83                   : 0.f;
84     }
85 
86     // Avoid updating the filter close to narrow bands in the render signals.
87     render_signal_analyzer.MaskRegionsAroundNarrowBands(&mu);
88 
89     // H_error = H_error - 0.5 * mu * X2 * H_error.
90     for (size_t k = 0; k < H_error_.size(); ++k) {
91       H_error_[k] -= 0.5f * mu[k] * X2[k] * H_error_[k];
92     }
93 
94     // G = mu * E.
95     std::transform(mu.begin(), mu.end(), E_main.re.begin(), G->re.begin(),
96                    std::multiplies<float>());
97     std::transform(mu.begin(), mu.end(), E_main.im.begin(), G->im.begin(),
98                    std::multiplies<float>());
99   }
100 
101   // H_error = H_error + factor * erl.
102   std::array<float, kFftLengthBy2Plus1> H_error_increase;
103   constexpr float kErlScaleAccurate = 1.f / 100.0f;
104   constexpr float kErlScaleInaccurate = 1.f / 60.0f;
105   std::transform(E2_shadow.begin(), E2_shadow.end(), E2_main.begin(),
106                  H_error_increase.begin(), [&](float a, float b) {
107                    return a >= b ? kErlScaleAccurate : kErlScaleInaccurate;
108                  });
109   std::transform(erl.begin(), erl.end(), H_error_increase.begin(),
110                  H_error_increase.begin(), std::multiplies<float>());
111   std::transform(H_error_.begin(), H_error_.end(), H_error_increase.begin(),
112                  H_error_.begin(),
113                  [&](float a, float b) { return std::max(a + b, 0.1f); });
114 
115   data_dumper_->DumpRaw("aec3_main_gain_H_error", H_error_);
116 }
117 
118 }  // namespace webrtc
119