1 /*
2  *  Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/transient/transient_suppressor_impl.h"
12 
13 #include <string.h>
14 
15 #include <algorithm>
16 #include <cmath>
17 #include <complex>
18 #include <deque>
19 #include <limits>
20 #include <set>
21 
22 #include "common_audio/include/audio_util.h"
23 #include "common_audio/signal_processing/include/signal_processing_library.h"
24 #include "common_audio/third_party/ooura/fft_size_256/fft4g.h"
25 #include "modules/audio_processing/transient/common.h"
26 #include "modules/audio_processing/transient/transient_detector.h"
27 #include "modules/audio_processing/transient/transient_suppressor.h"
28 #include "modules/audio_processing/transient/windows_private.h"
29 #include "rtc_base/checks.h"
30 #include "rtc_base/logging.h"
31 
32 namespace webrtc {
33 
34 static const float kMeanIIRCoefficient = 0.5f;
35 static const float kVoiceThreshold = 0.02f;
36 
37 // TODO(aluebs): Check if these values work also for 48kHz.
38 static const size_t kMinVoiceBin = 3;
39 static const size_t kMaxVoiceBin = 60;
40 
41 namespace {
42 
ComplexMagnitude(float a,float b)43 float ComplexMagnitude(float a, float b) {
44   return std::abs(a) + std::abs(b);
45 }
46 
47 }  // namespace
48 
TransientSuppressorImpl()49 TransientSuppressorImpl::TransientSuppressorImpl()
50     : data_length_(0),
51       detection_length_(0),
52       analysis_length_(0),
53       buffer_delay_(0),
54       complex_analysis_length_(0),
55       num_channels_(0),
56       window_(NULL),
57       detector_smoothed_(0.f),
58       keypress_counter_(0),
59       chunks_since_keypress_(0),
60       detection_enabled_(false),
61       suppression_enabled_(false),
62       use_hard_restoration_(false),
63       chunks_since_voice_change_(0),
64       seed_(182),
65       using_reference_(false) {}
66 
~TransientSuppressorImpl()67 TransientSuppressorImpl::~TransientSuppressorImpl() {}
68 
Initialize(int sample_rate_hz,int detection_rate_hz,int num_channels)69 int TransientSuppressorImpl::Initialize(int sample_rate_hz,
70                                         int detection_rate_hz,
71                                         int num_channels) {
72   switch (sample_rate_hz) {
73     case ts::kSampleRate8kHz:
74       analysis_length_ = 128u;
75       window_ = kBlocks80w128;
76       break;
77     case ts::kSampleRate16kHz:
78       analysis_length_ = 256u;
79       window_ = kBlocks160w256;
80       break;
81     case ts::kSampleRate32kHz:
82       analysis_length_ = 512u;
83       window_ = kBlocks320w512;
84       break;
85     case ts::kSampleRate48kHz:
86       analysis_length_ = 1024u;
87       window_ = kBlocks480w1024;
88       break;
89     default:
90       return -1;
91   }
92   if (detection_rate_hz != ts::kSampleRate8kHz &&
93       detection_rate_hz != ts::kSampleRate16kHz &&
94       detection_rate_hz != ts::kSampleRate32kHz &&
95       detection_rate_hz != ts::kSampleRate48kHz) {
96     return -1;
97   }
98   if (num_channels <= 0) {
99     return -1;
100   }
101 
102   detector_.reset(new TransientDetector(detection_rate_hz));
103   data_length_ = sample_rate_hz * ts::kChunkSizeMs / 1000;
104   if (data_length_ > analysis_length_) {
105     RTC_NOTREACHED();
106     return -1;
107   }
108   buffer_delay_ = analysis_length_ - data_length_;
109 
110   complex_analysis_length_ = analysis_length_ / 2 + 1;
111   RTC_DCHECK_GE(complex_analysis_length_, kMaxVoiceBin);
112   num_channels_ = num_channels;
113   in_buffer_.reset(new float[analysis_length_ * num_channels_]);
114   memset(in_buffer_.get(), 0,
115          analysis_length_ * num_channels_ * sizeof(in_buffer_[0]));
116   detection_length_ = detection_rate_hz * ts::kChunkSizeMs / 1000;
117   detection_buffer_.reset(new float[detection_length_]);
118   memset(detection_buffer_.get(), 0,
119          detection_length_ * sizeof(detection_buffer_[0]));
120   out_buffer_.reset(new float[analysis_length_ * num_channels_]);
121   memset(out_buffer_.get(), 0,
122          analysis_length_ * num_channels_ * sizeof(out_buffer_[0]));
123   // ip[0] must be zero to trigger initialization using rdft().
124   size_t ip_length = 2 + sqrtf(analysis_length_);
125   ip_.reset(new size_t[ip_length]());
126   memset(ip_.get(), 0, ip_length * sizeof(ip_[0]));
127   wfft_.reset(new float[complex_analysis_length_ - 1]);
128   memset(wfft_.get(), 0, (complex_analysis_length_ - 1) * sizeof(wfft_[0]));
129   spectral_mean_.reset(new float[complex_analysis_length_ * num_channels_]);
130   memset(spectral_mean_.get(), 0,
131          complex_analysis_length_ * num_channels_ * sizeof(spectral_mean_[0]));
132   fft_buffer_.reset(new float[analysis_length_ + 2]);
133   memset(fft_buffer_.get(), 0, (analysis_length_ + 2) * sizeof(fft_buffer_[0]));
134   magnitudes_.reset(new float[complex_analysis_length_]);
135   memset(magnitudes_.get(), 0,
136          complex_analysis_length_ * sizeof(magnitudes_[0]));
137   mean_factor_.reset(new float[complex_analysis_length_]);
138 
139   static const float kFactorHeight = 10.f;
140   static const float kLowSlope = 1.f;
141   static const float kHighSlope = 0.3f;
142   for (size_t i = 0; i < complex_analysis_length_; ++i) {
143     mean_factor_[i] =
144         kFactorHeight /
145             (1.f + std::exp(kLowSlope * static_cast<int>(i - kMinVoiceBin))) +
146         kFactorHeight /
147             (1.f + std::exp(kHighSlope * static_cast<int>(kMaxVoiceBin - i)));
148   }
149   detector_smoothed_ = 0.f;
150   keypress_counter_ = 0;
151   chunks_since_keypress_ = 0;
152   detection_enabled_ = false;
153   suppression_enabled_ = false;
154   use_hard_restoration_ = false;
155   chunks_since_voice_change_ = 0;
156   seed_ = 182;
157   using_reference_ = false;
158   return 0;
159 }
160 
Suppress(float * data,size_t data_length,int num_channels,const float * detection_data,size_t detection_length,const float * reference_data,size_t reference_length,float voice_probability,bool key_pressed)161 int TransientSuppressorImpl::Suppress(float* data,
162                                       size_t data_length,
163                                       int num_channels,
164                                       const float* detection_data,
165                                       size_t detection_length,
166                                       const float* reference_data,
167                                       size_t reference_length,
168                                       float voice_probability,
169                                       bool key_pressed) {
170   if (!data || data_length != data_length_ || num_channels != num_channels_ ||
171       detection_length != detection_length_ || voice_probability < 0 ||
172       voice_probability > 1) {
173     return -1;
174   }
175 
176   UpdateKeypress(key_pressed);
177   UpdateBuffers(data);
178 
179   int result = 0;
180   if (detection_enabled_) {
181     UpdateRestoration(voice_probability);
182 
183     if (!detection_data) {
184       // Use the input data  of the first channel if special detection data is
185       // not supplied.
186       detection_data = &in_buffer_[buffer_delay_];
187     }
188 
189     float detector_result = detector_->Detect(detection_data, detection_length,
190                                               reference_data, reference_length);
191     if (detector_result < 0) {
192       return -1;
193     }
194 
195     using_reference_ = detector_->using_reference();
196 
197     // |detector_smoothed_| follows the |detector_result| when this last one is
198     // increasing, but has an exponential decaying tail to be able to suppress
199     // the ringing of keyclicks.
200     float smooth_factor = using_reference_ ? 0.6 : 0.1;
201     detector_smoothed_ = detector_result >= detector_smoothed_
202                              ? detector_result
203                              : smooth_factor * detector_smoothed_ +
204                                    (1 - smooth_factor) * detector_result;
205 
206     for (int i = 0; i < num_channels_; ++i) {
207       Suppress(&in_buffer_[i * analysis_length_],
208                &spectral_mean_[i * complex_analysis_length_],
209                &out_buffer_[i * analysis_length_]);
210     }
211   }
212 
213   // If the suppression isn't enabled, we use the in buffer to delay the signal
214   // appropriately. This also gives time for the out buffer to be refreshed with
215   // new data between detection and suppression getting enabled.
216   for (int i = 0; i < num_channels_; ++i) {
217     memcpy(&data[i * data_length_],
218            suppression_enabled_ ? &out_buffer_[i * analysis_length_]
219                                 : &in_buffer_[i * analysis_length_],
220            data_length_ * sizeof(*data));
221   }
222   return result;
223 }
224 
225 // This should only be called when detection is enabled. UpdateBuffers() must
226 // have been called. At return, |out_buffer_| will be filled with the
227 // processed output.
Suppress(float * in_ptr,float * spectral_mean,float * out_ptr)228 void TransientSuppressorImpl::Suppress(float* in_ptr,
229                                        float* spectral_mean,
230                                        float* out_ptr) {
231   // Go to frequency domain.
232   for (size_t i = 0; i < analysis_length_; ++i) {
233     // TODO(aluebs): Rename windows
234     fft_buffer_[i] = in_ptr[i] * window_[i];
235   }
236 
237   WebRtc_rdft(analysis_length_, 1, fft_buffer_.get(), ip_.get(), wfft_.get());
238 
239   // Since WebRtc_rdft puts R[n/2] in fft_buffer_[1], we move it to the end
240   // for convenience.
241   fft_buffer_[analysis_length_] = fft_buffer_[1];
242   fft_buffer_[analysis_length_ + 1] = 0.f;
243   fft_buffer_[1] = 0.f;
244 
245   for (size_t i = 0; i < complex_analysis_length_; ++i) {
246     magnitudes_[i] =
247         ComplexMagnitude(fft_buffer_[i * 2], fft_buffer_[i * 2 + 1]);
248   }
249   // Restore audio if necessary.
250   if (suppression_enabled_) {
251     if (use_hard_restoration_) {
252       HardRestoration(spectral_mean);
253     } else {
254       SoftRestoration(spectral_mean);
255     }
256   }
257 
258   // Update the spectral mean.
259   for (size_t i = 0; i < complex_analysis_length_; ++i) {
260     spectral_mean[i] = (1 - kMeanIIRCoefficient) * spectral_mean[i] +
261                        kMeanIIRCoefficient * magnitudes_[i];
262   }
263 
264   // Back to time domain.
265   // Put R[n/2] back in fft_buffer_[1].
266   fft_buffer_[1] = fft_buffer_[analysis_length_];
267 
268   WebRtc_rdft(analysis_length_, -1, fft_buffer_.get(), ip_.get(), wfft_.get());
269   const float fft_scaling = 2.f / analysis_length_;
270 
271   for (size_t i = 0; i < analysis_length_; ++i) {
272     out_ptr[i] += fft_buffer_[i] * window_[i] * fft_scaling;
273   }
274 }
275 
UpdateKeypress(bool key_pressed)276 void TransientSuppressorImpl::UpdateKeypress(bool key_pressed) {
277   const int kKeypressPenalty = 1000 / ts::kChunkSizeMs;
278   const int kIsTypingThreshold = 1000 / ts::kChunkSizeMs;
279   const int kChunksUntilNotTyping = 4000 / ts::kChunkSizeMs;  // 4 seconds.
280 
281   if (key_pressed) {
282     keypress_counter_ += kKeypressPenalty;
283     chunks_since_keypress_ = 0;
284     detection_enabled_ = true;
285   }
286   keypress_counter_ = std::max(0, keypress_counter_ - 1);
287 
288   if (keypress_counter_ > kIsTypingThreshold) {
289     if (!suppression_enabled_) {
290       RTC_LOG(LS_INFO) << "[ts] Transient suppression is now enabled.";
291     }
292     suppression_enabled_ = true;
293     keypress_counter_ = 0;
294   }
295 
296   if (detection_enabled_ && ++chunks_since_keypress_ > kChunksUntilNotTyping) {
297     if (suppression_enabled_) {
298       RTC_LOG(LS_INFO) << "[ts] Transient suppression is now disabled.";
299     }
300     detection_enabled_ = false;
301     suppression_enabled_ = false;
302     keypress_counter_ = 0;
303   }
304 }
305 
UpdateRestoration(float voice_probability)306 void TransientSuppressorImpl::UpdateRestoration(float voice_probability) {
307   const int kHardRestorationOffsetDelay = 3;
308   const int kHardRestorationOnsetDelay = 80;
309 
310   bool not_voiced = voice_probability < kVoiceThreshold;
311 
312   if (not_voiced == use_hard_restoration_) {
313     chunks_since_voice_change_ = 0;
314   } else {
315     ++chunks_since_voice_change_;
316 
317     if ((use_hard_restoration_ &&
318          chunks_since_voice_change_ > kHardRestorationOffsetDelay) ||
319         (!use_hard_restoration_ &&
320          chunks_since_voice_change_ > kHardRestorationOnsetDelay)) {
321       use_hard_restoration_ = not_voiced;
322       chunks_since_voice_change_ = 0;
323     }
324   }
325 }
326 
327 // Shift buffers to make way for new data. Must be called after
328 // |detection_enabled_| is updated by UpdateKeypress().
UpdateBuffers(float * data)329 void TransientSuppressorImpl::UpdateBuffers(float* data) {
330   // TODO(aluebs): Change to ring buffer.
331   memmove(in_buffer_.get(), &in_buffer_[data_length_],
332           (buffer_delay_ + (num_channels_ - 1) * analysis_length_) *
333               sizeof(in_buffer_[0]));
334   // Copy new chunk to buffer.
335   for (int i = 0; i < num_channels_; ++i) {
336     memcpy(&in_buffer_[buffer_delay_ + i * analysis_length_],
337            &data[i * data_length_], data_length_ * sizeof(*data));
338   }
339   if (detection_enabled_) {
340     // Shift previous chunk in out buffer.
341     memmove(out_buffer_.get(), &out_buffer_[data_length_],
342             (buffer_delay_ + (num_channels_ - 1) * analysis_length_) *
343                 sizeof(out_buffer_[0]));
344     // Initialize new chunk in out buffer.
345     for (int i = 0; i < num_channels_; ++i) {
346       memset(&out_buffer_[buffer_delay_ + i * analysis_length_], 0,
347              data_length_ * sizeof(out_buffer_[0]));
348     }
349   }
350 }
351 
352 // Restores the unvoiced signal if a click is present.
353 // Attenuates by a certain factor every peak in the |fft_buffer_| that exceeds
354 // the spectral mean. The attenuation depends on |detector_smoothed_|.
355 // If a restoration takes place, the |magnitudes_| are updated to the new value.
HardRestoration(float * spectral_mean)356 void TransientSuppressorImpl::HardRestoration(float* spectral_mean) {
357   const float detector_result =
358       1.f - std::pow(1.f - detector_smoothed_, using_reference_ ? 200.f : 50.f);
359   // To restore, we get the peaks in the spectrum. If higher than the previous
360   // spectral mean we adjust them.
361   for (size_t i = 0; i < complex_analysis_length_; ++i) {
362     if (magnitudes_[i] > spectral_mean[i] && magnitudes_[i] > 0) {
363       // RandU() generates values on [0, int16::max()]
364       const float phase = 2 * ts::kPi * WebRtcSpl_RandU(&seed_) /
365                           std::numeric_limits<int16_t>::max();
366       const float scaled_mean = detector_result * spectral_mean[i];
367 
368       fft_buffer_[i * 2] = (1 - detector_result) * fft_buffer_[i * 2] +
369                            scaled_mean * cosf(phase);
370       fft_buffer_[i * 2 + 1] = (1 - detector_result) * fft_buffer_[i * 2 + 1] +
371                                scaled_mean * sinf(phase);
372       magnitudes_[i] = magnitudes_[i] -
373                        detector_result * (magnitudes_[i] - spectral_mean[i]);
374     }
375   }
376 }
377 
378 // Restores the voiced signal if a click is present.
379 // Attenuates by a certain factor every peak in the |fft_buffer_| that exceeds
380 // the spectral mean and that is lower than some function of the current block
381 // frequency mean. The attenuation depends on |detector_smoothed_|.
382 // If a restoration takes place, the |magnitudes_| are updated to the new value.
SoftRestoration(float * spectral_mean)383 void TransientSuppressorImpl::SoftRestoration(float* spectral_mean) {
384   // Get the spectral magnitude mean of the current block.
385   float block_frequency_mean = 0;
386   for (size_t i = kMinVoiceBin; i < kMaxVoiceBin; ++i) {
387     block_frequency_mean += magnitudes_[i];
388   }
389   block_frequency_mean /= (kMaxVoiceBin - kMinVoiceBin);
390 
391   // To restore, we get the peaks in the spectrum. If higher than the
392   // previous spectral mean and lower than a factor of the block mean
393   // we adjust them. The factor is a double sigmoid that has a minimum in the
394   // voice frequency range (300Hz - 3kHz).
395   for (size_t i = 0; i < complex_analysis_length_; ++i) {
396     if (magnitudes_[i] > spectral_mean[i] && magnitudes_[i] > 0 &&
397         (using_reference_ ||
398          magnitudes_[i] < block_frequency_mean * mean_factor_[i])) {
399       const float new_magnitude =
400           magnitudes_[i] -
401           detector_smoothed_ * (magnitudes_[i] - spectral_mean[i]);
402       const float magnitude_ratio = new_magnitude / magnitudes_[i];
403 
404       fft_buffer_[i * 2] *= magnitude_ratio;
405       fft_buffer_[i * 2 + 1] *= magnitude_ratio;
406       magnitudes_[i] = new_magnitude;
407     }
408   }
409 }
410 
411 }  // namespace webrtc
412