1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/signal_dependent_erle_estimator.h"
12 
13 #include <algorithm>
14 #include <functional>
15 #include <numeric>
16 
17 #include "modules/audio_processing/aec3/spectrum_buffer.h"
18 #include "rtc_base/numerics/safe_minmax.h"
19 
20 namespace webrtc {
21 
22 namespace {
23 
24 constexpr std::array<size_t, SignalDependentErleEstimator::kSubbands + 1>
25     kBandBoundaries = {1, 8, 16, 24, 32, 48, kFftLengthBy2Plus1};
26 
FormSubbandMap()27 std::array<size_t, kFftLengthBy2Plus1> FormSubbandMap() {
28   std::array<size_t, kFftLengthBy2Plus1> map_band_to_subband;
29   size_t subband = 1;
30   for (size_t k = 0; k < map_band_to_subband.size(); ++k) {
31     RTC_DCHECK_LT(subband, kBandBoundaries.size());
32     if (k >= kBandBoundaries[subband]) {
33       subband++;
34       RTC_DCHECK_LT(k, kBandBoundaries[subband]);
35     }
36     map_band_to_subband[k] = subband - 1;
37   }
38   return map_band_to_subband;
39 }
40 
41 // Defines the size in blocks of the sections that are used for dividing the
42 // linear filter. The sections are split in a non-linear manner so that lower
43 // sections that typically represent the direct path have a larger resolution
44 // than the higher sections which typically represent more reverberant acoustic
45 // paths.
DefineFilterSectionSizes(size_t delay_headroom_blocks,size_t num_blocks,size_t num_sections)46 std::vector<size_t> DefineFilterSectionSizes(size_t delay_headroom_blocks,
47                                              size_t num_blocks,
48                                              size_t num_sections) {
49   size_t filter_length_blocks = num_blocks - delay_headroom_blocks;
50   std::vector<size_t> section_sizes(num_sections);
51   size_t remaining_blocks = filter_length_blocks;
52   size_t remaining_sections = num_sections;
53   size_t estimator_size = 2;
54   size_t idx = 0;
55   while (remaining_sections > 1 &&
56          remaining_blocks > estimator_size * remaining_sections) {
57     RTC_DCHECK_LT(idx, section_sizes.size());
58     section_sizes[idx] = estimator_size;
59     remaining_blocks -= estimator_size;
60     remaining_sections--;
61     estimator_size *= 2;
62     idx++;
63   }
64 
65   size_t last_groups_size = remaining_blocks / remaining_sections;
66   for (; idx < num_sections; idx++) {
67     section_sizes[idx] = last_groups_size;
68   }
69   section_sizes[num_sections - 1] +=
70       remaining_blocks - last_groups_size * remaining_sections;
71   return section_sizes;
72 }
73 
74 // Forms the limits in blocks for each filter section. Those sections
75 // are used for analyzing the echo estimates and investigating which
76 // linear filter sections contribute most to the echo estimate energy.
SetSectionsBoundaries(size_t delay_headroom_blocks,size_t num_blocks,size_t num_sections)77 std::vector<size_t> SetSectionsBoundaries(size_t delay_headroom_blocks,
78                                           size_t num_blocks,
79                                           size_t num_sections) {
80   std::vector<size_t> estimator_boundaries_blocks(num_sections + 1);
81   if (estimator_boundaries_blocks.size() == 2) {
82     estimator_boundaries_blocks[0] = 0;
83     estimator_boundaries_blocks[1] = num_blocks;
84     return estimator_boundaries_blocks;
85   }
86   RTC_DCHECK_GT(estimator_boundaries_blocks.size(), 2);
87   const std::vector<size_t> section_sizes =
88       DefineFilterSectionSizes(delay_headroom_blocks, num_blocks,
89                                estimator_boundaries_blocks.size() - 1);
90 
91   size_t idx = 0;
92   size_t current_size_block = 0;
93   RTC_DCHECK_EQ(section_sizes.size() + 1, estimator_boundaries_blocks.size());
94   estimator_boundaries_blocks[0] = delay_headroom_blocks;
95   for (size_t k = delay_headroom_blocks; k < num_blocks; ++k) {
96     current_size_block++;
97     if (current_size_block >= section_sizes[idx]) {
98       idx = idx + 1;
99       if (idx == section_sizes.size()) {
100         break;
101       }
102       estimator_boundaries_blocks[idx] = k + 1;
103       current_size_block = 0;
104     }
105   }
106   estimator_boundaries_blocks[section_sizes.size()] = num_blocks;
107   return estimator_boundaries_blocks;
108 }
109 
110 std::array<float, SignalDependentErleEstimator::kSubbands>
SetMaxErleSubbands(float max_erle_l,float max_erle_h,size_t limit_subband_l)111 SetMaxErleSubbands(float max_erle_l, float max_erle_h, size_t limit_subband_l) {
112   std::array<float, SignalDependentErleEstimator::kSubbands> max_erle;
113   std::fill(max_erle.begin(), max_erle.begin() + limit_subband_l, max_erle_l);
114   std::fill(max_erle.begin() + limit_subband_l, max_erle.end(), max_erle_h);
115   return max_erle;
116 }
117 
118 }  // namespace
119 
SignalDependentErleEstimator(const EchoCanceller3Config & config,size_t num_capture_channels)120 SignalDependentErleEstimator::SignalDependentErleEstimator(
121     const EchoCanceller3Config& config,
122     size_t num_capture_channels)
123     : min_erle_(config.erle.min),
124       num_sections_(config.erle.num_sections),
125       num_blocks_(config.filter.refined.length_blocks),
126       delay_headroom_blocks_(config.delay.delay_headroom_samples / kBlockSize),
127       band_to_subband_(FormSubbandMap()),
128       max_erle_(SetMaxErleSubbands(config.erle.max_l,
129                                    config.erle.max_h,
130                                    band_to_subband_[kFftLengthBy2 / 2])),
131       section_boundaries_blocks_(SetSectionsBoundaries(delay_headroom_blocks_,
132                                                        num_blocks_,
133                                                        num_sections_)),
134       erle_(num_capture_channels),
135       S2_section_accum_(
136           num_capture_channels,
137           std::vector<std::array<float, kFftLengthBy2Plus1>>(num_sections_)),
138       erle_estimators_(
139           num_capture_channels,
140           std::vector<std::array<float, kSubbands>>(num_sections_)),
141       erle_ref_(num_capture_channels),
142       correction_factors_(
143           num_capture_channels,
144           std::vector<std::array<float, kSubbands>>(num_sections_)),
145       num_updates_(num_capture_channels),
146       n_active_sections_(num_capture_channels) {
147   RTC_DCHECK_LE(num_sections_, num_blocks_);
148   RTC_DCHECK_GE(num_sections_, 1);
149   Reset();
150 }
151 
152 SignalDependentErleEstimator::~SignalDependentErleEstimator() = default;
153 
Reset()154 void SignalDependentErleEstimator::Reset() {
155   for (size_t ch = 0; ch < erle_.size(); ++ch) {
156     erle_[ch].fill(min_erle_);
157     for (auto& erle_estimator : erle_estimators_[ch]) {
158       erle_estimator.fill(min_erle_);
159     }
160     erle_ref_[ch].fill(min_erle_);
161     for (auto& factor : correction_factors_[ch]) {
162       factor.fill(1.0f);
163     }
164     num_updates_[ch].fill(0);
165     n_active_sections_[ch].fill(0);
166   }
167 }
168 
169 // Updates the Erle estimate by analyzing the current input signals. It takes
170 // the render buffer and the filter frequency response in order to do an
171 // estimation of the number of sections of the linear filter that are needed
172 // for getting the majority of the energy in the echo estimate. Based on that
173 // number of sections, it updates the erle estimation by introducing a
174 // correction factor to the erle that is given as an input to this method.
Update(const RenderBuffer & render_buffer,rtc::ArrayView<const std::vector<std::array<float,kFftLengthBy2Plus1>>> filter_frequency_responses,rtc::ArrayView<const float,kFftLengthBy2Plus1> X2,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> Y2,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> E2,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> average_erle,const std::vector<bool> & converged_filters)175 void SignalDependentErleEstimator::Update(
176     const RenderBuffer& render_buffer,
177     rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>>
178         filter_frequency_responses,
179     rtc::ArrayView<const float, kFftLengthBy2Plus1> X2,
180     rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
181     rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
182     rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> average_erle,
183     const std::vector<bool>& converged_filters) {
184   RTC_DCHECK_GT(num_sections_, 1);
185 
186   // Gets the number of filter sections that are needed for achieving 90 %
187   // of the power spectrum energy of the echo estimate.
188   ComputeNumberOfActiveFilterSections(render_buffer,
189                                       filter_frequency_responses);
190 
191   // Updates the correction factors that is used for correcting the erle and
192   // adapt it to the particular characteristics of the input signal.
193   UpdateCorrectionFactors(X2, Y2, E2, converged_filters);
194 
195   // Applies the correction factor to the input erle for getting a more refined
196   // erle estimation for the current input signal.
197   for (size_t ch = 0; ch < erle_.size(); ++ch) {
198     for (size_t k = 0; k < kFftLengthBy2; ++k) {
199       RTC_DCHECK_GT(correction_factors_[ch].size(), n_active_sections_[ch][k]);
200       float correction_factor =
201           correction_factors_[ch][n_active_sections_[ch][k]]
202                              [band_to_subband_[k]];
203       erle_[ch][k] = rtc::SafeClamp(average_erle[ch][k] * correction_factor,
204                                     min_erle_, max_erle_[band_to_subband_[k]]);
205     }
206   }
207 }
208 
Dump(const std::unique_ptr<ApmDataDumper> & data_dumper) const209 void SignalDependentErleEstimator::Dump(
210     const std::unique_ptr<ApmDataDumper>& data_dumper) const {
211   for (auto& erle : erle_estimators_[0]) {
212     data_dumper->DumpRaw("aec3_all_erle", erle);
213   }
214   data_dumper->DumpRaw("aec3_ref_erle", erle_ref_[0]);
215   for (auto& factor : correction_factors_[0]) {
216     data_dumper->DumpRaw("aec3_erle_correction_factor", factor);
217   }
218 }
219 
220 // Estimates for each band the smallest number of sections in the filter that
221 // together constitute 90% of the estimated echo energy.
ComputeNumberOfActiveFilterSections(const RenderBuffer & render_buffer,rtc::ArrayView<const std::vector<std::array<float,kFftLengthBy2Plus1>>> filter_frequency_responses)222 void SignalDependentErleEstimator::ComputeNumberOfActiveFilterSections(
223     const RenderBuffer& render_buffer,
224     rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>>
225         filter_frequency_responses) {
226   RTC_DCHECK_GT(num_sections_, 1);
227   // Computes an approximation of the power spectrum if the filter would have
228   // been limited to a certain number of filter sections.
229   ComputeEchoEstimatePerFilterSection(render_buffer,
230                                       filter_frequency_responses);
231   // For each band, computes the number of filter sections that are needed for
232   // achieving the 90 % energy in the echo estimate.
233   ComputeActiveFilterSections();
234 }
235 
UpdateCorrectionFactors(rtc::ArrayView<const float,kFftLengthBy2Plus1> X2,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> Y2,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> E2,const std::vector<bool> & converged_filters)236 void SignalDependentErleEstimator::UpdateCorrectionFactors(
237     rtc::ArrayView<const float, kFftLengthBy2Plus1> X2,
238     rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
239     rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
240     const std::vector<bool>& converged_filters) {
241   for (size_t ch = 0; ch < converged_filters.size(); ++ch) {
242     if (converged_filters[ch]) {
243       constexpr float kX2BandEnergyThreshold = 44015068.0f;
244       constexpr float kSmthConstantDecreases = 0.1f;
245       constexpr float kSmthConstantIncreases = kSmthConstantDecreases / 2.f;
246       auto subband_powers = [](rtc::ArrayView<const float> power_spectrum,
247                                rtc::ArrayView<float> power_spectrum_subbands) {
248         for (size_t subband = 0; subband < kSubbands; ++subband) {
249           RTC_DCHECK_LE(kBandBoundaries[subband + 1], power_spectrum.size());
250           power_spectrum_subbands[subband] = std::accumulate(
251               power_spectrum.begin() + kBandBoundaries[subband],
252               power_spectrum.begin() + kBandBoundaries[subband + 1], 0.f);
253         }
254       };
255 
256       std::array<float, kSubbands> X2_subbands, E2_subbands, Y2_subbands;
257       subband_powers(X2, X2_subbands);
258       subband_powers(E2[ch], E2_subbands);
259       subband_powers(Y2[ch], Y2_subbands);
260       std::array<size_t, kSubbands> idx_subbands;
261       for (size_t subband = 0; subband < kSubbands; ++subband) {
262         // When aggregating the number of active sections in the filter for
263         // different bands we choose to take the minimum of all of them. As an
264         // example, if for one of the bands it is the direct path its refined
265         // contributor to the final echo estimate, we consider the direct path
266         // is as well the refined contributor for the subband that contains that
267         // particular band. That aggregate number of sections will be later used
268         // as the identifier of the erle estimator that needs to be updated.
269         RTC_DCHECK_LE(kBandBoundaries[subband + 1],
270                       n_active_sections_[ch].size());
271         idx_subbands[subband] = *std::min_element(
272             n_active_sections_[ch].begin() + kBandBoundaries[subband],
273             n_active_sections_[ch].begin() + kBandBoundaries[subband + 1]);
274       }
275 
276       std::array<float, kSubbands> new_erle;
277       std::array<bool, kSubbands> is_erle_updated;
278       is_erle_updated.fill(false);
279       new_erle.fill(0.f);
280       for (size_t subband = 0; subband < kSubbands; ++subband) {
281         if (X2_subbands[subband] > kX2BandEnergyThreshold &&
282             E2_subbands[subband] > 0) {
283           new_erle[subband] = Y2_subbands[subband] / E2_subbands[subband];
284           RTC_DCHECK_GT(new_erle[subband], 0);
285           is_erle_updated[subband] = true;
286           ++num_updates_[ch][subband];
287         }
288       }
289 
290       for (size_t subband = 0; subband < kSubbands; ++subband) {
291         const size_t idx = idx_subbands[subband];
292         RTC_DCHECK_LT(idx, erle_estimators_[ch].size());
293         float alpha = new_erle[subband] > erle_estimators_[ch][idx][subband]
294                           ? kSmthConstantIncreases
295                           : kSmthConstantDecreases;
296         alpha = static_cast<float>(is_erle_updated[subband]) * alpha;
297         erle_estimators_[ch][idx][subband] +=
298             alpha * (new_erle[subband] - erle_estimators_[ch][idx][subband]);
299         erle_estimators_[ch][idx][subband] = rtc::SafeClamp(
300             erle_estimators_[ch][idx][subband], min_erle_, max_erle_[subband]);
301       }
302 
303       for (size_t subband = 0; subband < kSubbands; ++subband) {
304         float alpha = new_erle[subband] > erle_ref_[ch][subband]
305                           ? kSmthConstantIncreases
306                           : kSmthConstantDecreases;
307         alpha = static_cast<float>(is_erle_updated[subband]) * alpha;
308         erle_ref_[ch][subband] +=
309             alpha * (new_erle[subband] - erle_ref_[ch][subband]);
310         erle_ref_[ch][subband] = rtc::SafeClamp(erle_ref_[ch][subband],
311                                                 min_erle_, max_erle_[subband]);
312       }
313 
314       for (size_t subband = 0; subband < kSubbands; ++subband) {
315         constexpr int kNumUpdateThr = 50;
316         if (is_erle_updated[subband] &&
317             num_updates_[ch][subband] > kNumUpdateThr) {
318           const size_t idx = idx_subbands[subband];
319           RTC_DCHECK_GT(erle_ref_[ch][subband], 0.f);
320           // Computes the ratio between the erle that is updated using all the
321           // points and the erle that is updated only on signals that share the
322           // same number of active filter sections.
323           float new_correction_factor =
324               erle_estimators_[ch][idx][subband] / erle_ref_[ch][subband];
325 
326           correction_factors_[ch][idx][subband] +=
327               0.1f *
328               (new_correction_factor - correction_factors_[ch][idx][subband]);
329         }
330       }
331     }
332   }
333 }
334 
ComputeEchoEstimatePerFilterSection(const RenderBuffer & render_buffer,rtc::ArrayView<const std::vector<std::array<float,kFftLengthBy2Plus1>>> filter_frequency_responses)335 void SignalDependentErleEstimator::ComputeEchoEstimatePerFilterSection(
336     const RenderBuffer& render_buffer,
337     rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>>
338         filter_frequency_responses) {
339   const SpectrumBuffer& spectrum_render_buffer =
340       render_buffer.GetSpectrumBuffer();
341   const size_t num_render_channels = spectrum_render_buffer.buffer[0].size();
342   const size_t num_capture_channels = S2_section_accum_.size();
343   const float one_by_num_render_channels = 1.f / num_render_channels;
344 
345   RTC_DCHECK_EQ(S2_section_accum_.size(), filter_frequency_responses.size());
346 
347   for (size_t capture_ch = 0; capture_ch < num_capture_channels; ++capture_ch) {
348     RTC_DCHECK_EQ(S2_section_accum_[capture_ch].size() + 1,
349                   section_boundaries_blocks_.size());
350     size_t idx_render = render_buffer.Position();
351     idx_render = spectrum_render_buffer.OffsetIndex(
352         idx_render, section_boundaries_blocks_[0]);
353 
354     for (size_t section = 0; section < num_sections_; ++section) {
355       std::array<float, kFftLengthBy2Plus1> X2_section;
356       std::array<float, kFftLengthBy2Plus1> H2_section;
357       X2_section.fill(0.f);
358       H2_section.fill(0.f);
359       const size_t block_limit =
360           std::min(section_boundaries_blocks_[section + 1],
361                    filter_frequency_responses[capture_ch].size());
362       for (size_t block = section_boundaries_blocks_[section];
363            block < block_limit; ++block) {
364         for (size_t render_ch = 0;
365              render_ch < spectrum_render_buffer.buffer[idx_render].size();
366              ++render_ch) {
367           for (size_t k = 0; k < X2_section.size(); ++k) {
368             X2_section[k] +=
369                 spectrum_render_buffer.buffer[idx_render][render_ch][k] *
370                 one_by_num_render_channels;
371           }
372         }
373         std::transform(H2_section.begin(), H2_section.end(),
374                        filter_frequency_responses[capture_ch][block].begin(),
375                        H2_section.begin(), std::plus<float>());
376         idx_render = spectrum_render_buffer.IncIndex(idx_render);
377       }
378 
379       std::transform(X2_section.begin(), X2_section.end(), H2_section.begin(),
380                      S2_section_accum_[capture_ch][section].begin(),
381                      std::multiplies<float>());
382     }
383 
384     for (size_t section = 1; section < num_sections_; ++section) {
385       std::transform(S2_section_accum_[capture_ch][section - 1].begin(),
386                      S2_section_accum_[capture_ch][section - 1].end(),
387                      S2_section_accum_[capture_ch][section].begin(),
388                      S2_section_accum_[capture_ch][section].begin(),
389                      std::plus<float>());
390     }
391   }
392 }
393 
ComputeActiveFilterSections()394 void SignalDependentErleEstimator::ComputeActiveFilterSections() {
395   for (size_t ch = 0; ch < n_active_sections_.size(); ++ch) {
396     std::fill(n_active_sections_[ch].begin(), n_active_sections_[ch].end(), 0);
397     for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
398       size_t section = num_sections_;
399       float target = 0.9f * S2_section_accum_[ch][num_sections_ - 1][k];
400       while (section > 0 && S2_section_accum_[ch][section - 1][k] >= target) {
401         n_active_sections_[ch][k] = --section;
402       }
403     }
404   }
405 }
406 }  // namespace webrtc
407