1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/suppression_filter.h"
12 
13 #include <math.h>
14 #include <algorithm>
15 #include <cstring>
16 #include <functional>
17 #include <numeric>
18 
19 #include "modules/audio_processing/utility/ooura_fft.h"
20 #include "rtc_base/numerics/safe_minmax.h"
21 
22 namespace webrtc {
23 namespace {
24 
25 // Hanning window from Matlab command win = sqrt(hanning(128)).
26 const float kSqrtHanning[kFftLength] = {
27     0.00000000000000f, 0.02454122852291f, 0.04906767432742f, 0.07356456359967f,
28     0.09801714032956f, 0.12241067519922f, 0.14673047445536f, 0.17096188876030f,
29     0.19509032201613f, 0.21910124015687f, 0.24298017990326f, 0.26671275747490f,
30     0.29028467725446f, 0.31368174039889f, 0.33688985339222f, 0.35989503653499f,
31     0.38268343236509f, 0.40524131400499f, 0.42755509343028f, 0.44961132965461f,
32     0.47139673682600f, 0.49289819222978f, 0.51410274419322f, 0.53499761988710f,
33     0.55557023301960f, 0.57580819141785f, 0.59569930449243f, 0.61523159058063f,
34     0.63439328416365f, 0.65317284295378f, 0.67155895484702f, 0.68954054473707f,
35     0.70710678118655f, 0.72424708295147f, 0.74095112535496f, 0.75720884650648f,
36     0.77301045336274f, 0.78834642762661f, 0.80320753148064f, 0.81758481315158f,
37     0.83146961230255f, 0.84485356524971f, 0.85772861000027f, 0.87008699110871f,
38     0.88192126434835f, 0.89322430119552f, 0.90398929312344f, 0.91420975570353f,
39     0.92387953251129f, 0.93299279883474f, 0.94154406518302f, 0.94952818059304f,
40     0.95694033573221f, 0.96377606579544f, 0.97003125319454f, 0.97570213003853f,
41     0.98078528040323f, 0.98527764238894f, 0.98917650996478f, 0.99247953459871f,
42     0.99518472667220f, 0.99729045667869f, 0.99879545620517f, 0.99969881869620f,
43     1.00000000000000f, 0.99969881869620f, 0.99879545620517f, 0.99729045667869f,
44     0.99518472667220f, 0.99247953459871f, 0.98917650996478f, 0.98527764238894f,
45     0.98078528040323f, 0.97570213003853f, 0.97003125319454f, 0.96377606579544f,
46     0.95694033573221f, 0.94952818059304f, 0.94154406518302f, 0.93299279883474f,
47     0.92387953251129f, 0.91420975570353f, 0.90398929312344f, 0.89322430119552f,
48     0.88192126434835f, 0.87008699110871f, 0.85772861000027f, 0.84485356524971f,
49     0.83146961230255f, 0.81758481315158f, 0.80320753148064f, 0.78834642762661f,
50     0.77301045336274f, 0.75720884650648f, 0.74095112535496f, 0.72424708295147f,
51     0.70710678118655f, 0.68954054473707f, 0.67155895484702f, 0.65317284295378f,
52     0.63439328416365f, 0.61523159058063f, 0.59569930449243f, 0.57580819141785f,
53     0.55557023301960f, 0.53499761988710f, 0.51410274419322f, 0.49289819222978f,
54     0.47139673682600f, 0.44961132965461f, 0.42755509343028f, 0.40524131400499f,
55     0.38268343236509f, 0.35989503653499f, 0.33688985339222f, 0.31368174039889f,
56     0.29028467725446f, 0.26671275747490f, 0.24298017990326f, 0.21910124015687f,
57     0.19509032201613f, 0.17096188876030f, 0.14673047445536f, 0.12241067519922f,
58     0.09801714032956f, 0.07356456359967f, 0.04906767432742f, 0.02454122852291f};
59 
60 }  // namespace
61 
SuppressionFilter(int sample_rate_hz)62 SuppressionFilter::SuppressionFilter(int sample_rate_hz)
63     : sample_rate_hz_(sample_rate_hz),
64       fft_(),
65       e_output_old_(NumBandsForRate(sample_rate_hz_)) {
66   RTC_DCHECK(ValidFullBandRate(sample_rate_hz_));
67   e_input_old_.fill(0.f);
68   std::for_each(e_output_old_.begin(), e_output_old_.end(),
69                 [](std::array<float, kFftLengthBy2>& a) { a.fill(0.f); });
70 }
71 
72 SuppressionFilter::~SuppressionFilter() = default;
73 
ApplyGain(const FftData & comfort_noise,const FftData & comfort_noise_high_band,const std::array<float,kFftLengthBy2Plus1> & suppression_gain,float high_bands_gain,std::vector<std::vector<float>> * e)74 void SuppressionFilter::ApplyGain(
75     const FftData& comfort_noise,
76     const FftData& comfort_noise_high_band,
77     const std::array<float, kFftLengthBy2Plus1>& suppression_gain,
78     float high_bands_gain,
79     std::vector<std::vector<float>>* e) {
80   RTC_DCHECK(e);
81   RTC_DCHECK_EQ(e->size(), NumBandsForRate(sample_rate_hz_));
82   FftData E;
83   std::array<float, kFftLength> e_extended;
84   constexpr float kIfftNormalization = 2.f / kFftLength;
85 
86   // Analysis filterbank.
87   std::transform(e_input_old_.begin(), e_input_old_.end(),
88                  std::begin(kSqrtHanning), e_extended.begin(),
89                  std::multiplies<float>());
90   std::transform((*e)[0].begin(), (*e)[0].end(),
91                  std::begin(kSqrtHanning) + kFftLengthBy2,
92                  e_extended.begin() + kFftLengthBy2, std::multiplies<float>());
93   std::copy((*e)[0].begin(), (*e)[0].end(), e_input_old_.begin());
94   fft_.Fft(&e_extended, &E);
95 
96   // Apply gain.
97   std::transform(suppression_gain.begin(), suppression_gain.end(), E.re.begin(),
98                  E.re.begin(), std::multiplies<float>());
99   std::transform(suppression_gain.begin(), suppression_gain.end(), E.im.begin(),
100                  E.im.begin(), std::multiplies<float>());
101 
102   // Compute and add the comfort noise.
103   std::array<float, kFftLengthBy2Plus1> scaled_comfort_noise;
104   std::transform(suppression_gain.begin(), suppression_gain.end(),
105                  comfort_noise.re.begin(), scaled_comfort_noise.begin(),
106                  [](float a, float b) { return std::max(1.f - a, 0.f) * b; });
107   std::transform(scaled_comfort_noise.begin(), scaled_comfort_noise.end(),
108                  E.re.begin(), E.re.begin(), std::plus<float>());
109   std::transform(suppression_gain.begin(), suppression_gain.end(),
110                  comfort_noise.im.begin(), scaled_comfort_noise.begin(),
111                  [](float a, float b) { return std::max(1.f - a, 0.f) * b; });
112   std::transform(scaled_comfort_noise.begin(), scaled_comfort_noise.end(),
113                  E.im.begin(), E.im.begin(), std::plus<float>());
114 
115   // Synthesis filterbank.
116   fft_.Ifft(E, &e_extended);
117   std::transform(e_output_old_[0].begin(), e_output_old_[0].end(),
118                  std::begin(kSqrtHanning) + kFftLengthBy2, (*e)[0].begin(),
119                  [&](float a, float b) { return kIfftNormalization * a * b; });
120   std::transform(e_extended.begin(), e_extended.begin() + kFftLengthBy2,
121                  std::begin(kSqrtHanning), e_extended.begin(),
122                  [&](float a, float b) { return kIfftNormalization * a * b; });
123   std::transform((*e)[0].begin(), (*e)[0].end(), e_extended.begin(),
124                  (*e)[0].begin(), std::plus<float>());
125   std::for_each((*e)[0].begin(), (*e)[0].end(), [](float& x_k) {
126     x_k = rtc::SafeClamp(x_k, -32768.f, 32767.f);
127   });
128   std::copy(e_extended.begin() + kFftLengthBy2, e_extended.begin() + kFftLength,
129             std::begin(e_output_old_[0]));
130 
131   if (e->size() > 1) {
132     // Form time-domain high-band noise.
133     std::array<float, kFftLength> time_domain_high_band_noise;
134     std::transform(comfort_noise_high_band.re.begin(),
135                    comfort_noise_high_band.re.end(), E.re.begin(),
136                    [&](float a) { return kIfftNormalization * a; });
137     std::transform(comfort_noise_high_band.im.begin(),
138                    comfort_noise_high_band.im.end(), E.im.begin(),
139                    [&](float a) { return kIfftNormalization * a; });
140     fft_.Ifft(E, &time_domain_high_band_noise);
141 
142     // Scale and apply the noise to the signals.
143     const float high_bands_noise_scaling =
144         0.4f * std::max(1.f - high_bands_gain, 0.f);
145 
146     std::transform(
147         (*e)[1].begin(), (*e)[1].end(), time_domain_high_band_noise.begin(),
148         (*e)[1].begin(), [&](float a, float b) {
149           return std::max(
150               std::min(b * high_bands_noise_scaling + high_bands_gain * a,
151                        32767.0f),
152               -32768.0f);
153         });
154 
155     if (e->size() > 2) {
156       RTC_DCHECK_EQ(3, e->size());
157       std::for_each((*e)[2].begin(), (*e)[2].end(), [&](float& a) {
158         a = rtc::SafeClamp(a * high_bands_gain, -32768.f, 32767.f);
159       });
160     }
161 
162     std::array<float, kFftLengthBy2> tmp;
163     for (size_t k = 1; k < e->size(); ++k) {
164       std::copy((*e)[k].begin(), (*e)[k].end(), tmp.begin());
165       std::copy(e_output_old_[k].begin(), e_output_old_[k].end(),
166                 (*e)[k].begin());
167       std::copy(tmp.begin(), tmp.end(), e_output_old_[k].begin());
168     }
169   }
170 }
171 
172 }  // namespace webrtc
173