1 /*
2  *  Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/transient/transient_suppressor.h"
12 
13 #include <math.h>
14 #include <string.h>
15 #include <cmath>
16 #include <complex>
17 #include <deque>
18 #include <set>
19 
20 #include "common_audio/include/audio_util.h"
21 #include "common_audio/signal_processing/include/signal_processing_library.h"
22 #include "common_audio/third_party/fft4g/fft4g.h"
23 #include "modules/audio_processing/ns/windows_private.h"
24 #include "modules/audio_processing/transient/common.h"
25 #include "modules/audio_processing/transient/transient_detector.h"
26 #include "rtc_base/checks.h"
27 #include "rtc_base/logging.h"
28 
29 namespace webrtc {
30 
31 static const float kMeanIIRCoefficient = 0.5f;
32 static const float kVoiceThreshold = 0.02f;
33 
34 // TODO(aluebs): Check if these values work also for 48kHz.
35 static const size_t kMinVoiceBin = 3;
36 static const size_t kMaxVoiceBin = 60;
37 
38 namespace {
39 
ComplexMagnitude(float a,float b)40 float ComplexMagnitude(float a, float b) {
41   return std::abs(a) + std::abs(b);
42 }
43 
44 }  // namespace
45 
TransientSuppressor()46 TransientSuppressor::TransientSuppressor()
47     : data_length_(0),
48       detection_length_(0),
49       analysis_length_(0),
50       buffer_delay_(0),
51       complex_analysis_length_(0),
52       num_channels_(0),
53       window_(NULL),
54       detector_smoothed_(0.f),
55       keypress_counter_(0),
56       chunks_since_keypress_(0),
57       detection_enabled_(false),
58       suppression_enabled_(false),
59       use_hard_restoration_(false),
60       chunks_since_voice_change_(0),
61       seed_(182),
62       using_reference_(false) {}
63 
~TransientSuppressor()64 TransientSuppressor::~TransientSuppressor() {}
65 
Initialize(int sample_rate_hz,int detection_rate_hz,int num_channels)66 int TransientSuppressor::Initialize(int sample_rate_hz,
67                                     int detection_rate_hz,
68                                     int num_channels) {
69   switch (sample_rate_hz) {
70     case ts::kSampleRate8kHz:
71       analysis_length_ = 128u;
72       window_ = kBlocks80w128;
73       break;
74     case ts::kSampleRate16kHz:
75       analysis_length_ = 256u;
76       window_ = kBlocks160w256;
77       break;
78     case ts::kSampleRate32kHz:
79       analysis_length_ = 512u;
80       window_ = kBlocks320w512;
81       break;
82     case ts::kSampleRate48kHz:
83       analysis_length_ = 1024u;
84       window_ = kBlocks480w1024;
85       break;
86     default:
87       return -1;
88   }
89   if (detection_rate_hz != ts::kSampleRate8kHz &&
90       detection_rate_hz != ts::kSampleRate16kHz &&
91       detection_rate_hz != ts::kSampleRate32kHz &&
92       detection_rate_hz != ts::kSampleRate48kHz) {
93     return -1;
94   }
95   if (num_channels <= 0) {
96     return -1;
97   }
98 
99   detector_.reset(new TransientDetector(detection_rate_hz));
100   data_length_ = sample_rate_hz * ts::kChunkSizeMs / 1000;
101   if (data_length_ > analysis_length_) {
102     RTC_NOTREACHED();
103     return -1;
104   }
105   buffer_delay_ = analysis_length_ - data_length_;
106 
107   complex_analysis_length_ = analysis_length_ / 2 + 1;
108   RTC_DCHECK_GE(complex_analysis_length_, kMaxVoiceBin);
109   num_channels_ = num_channels;
110   in_buffer_.reset(new float[analysis_length_ * num_channels_]);
111   memset(in_buffer_.get(), 0,
112          analysis_length_ * num_channels_ * sizeof(in_buffer_[0]));
113   detection_length_ = detection_rate_hz * ts::kChunkSizeMs / 1000;
114   detection_buffer_.reset(new float[detection_length_]);
115   memset(detection_buffer_.get(), 0,
116          detection_length_ * sizeof(detection_buffer_[0]));
117   out_buffer_.reset(new float[analysis_length_ * num_channels_]);
118   memset(out_buffer_.get(), 0,
119          analysis_length_ * num_channels_ * sizeof(out_buffer_[0]));
120   // ip[0] must be zero to trigger initialization using rdft().
121   size_t ip_length = 2 + sqrtf(analysis_length_);
122   ip_.reset(new size_t[ip_length]());
123   memset(ip_.get(), 0, ip_length * sizeof(ip_[0]));
124   wfft_.reset(new float[complex_analysis_length_ - 1]);
125   memset(wfft_.get(), 0, (complex_analysis_length_ - 1) * sizeof(wfft_[0]));
126   spectral_mean_.reset(new float[complex_analysis_length_ * num_channels_]);
127   memset(spectral_mean_.get(), 0,
128          complex_analysis_length_ * num_channels_ * sizeof(spectral_mean_[0]));
129   fft_buffer_.reset(new float[analysis_length_ + 2]);
130   memset(fft_buffer_.get(), 0, (analysis_length_ + 2) * sizeof(fft_buffer_[0]));
131   magnitudes_.reset(new float[complex_analysis_length_]);
132   memset(magnitudes_.get(), 0,
133          complex_analysis_length_ * sizeof(magnitudes_[0]));
134   mean_factor_.reset(new float[complex_analysis_length_]);
135 
136   static const float kFactorHeight = 10.f;
137   static const float kLowSlope = 1.f;
138   static const float kHighSlope = 0.3f;
139   for (size_t i = 0; i < complex_analysis_length_; ++i) {
140     mean_factor_[i] =
141         kFactorHeight /
142             (1.f + exp(kLowSlope * static_cast<int>(i - kMinVoiceBin))) +
143         kFactorHeight /
144             (1.f + exp(kHighSlope * static_cast<int>(kMaxVoiceBin - i)));
145   }
146   detector_smoothed_ = 0.f;
147   keypress_counter_ = 0;
148   chunks_since_keypress_ = 0;
149   detection_enabled_ = false;
150   suppression_enabled_ = false;
151   use_hard_restoration_ = false;
152   chunks_since_voice_change_ = 0;
153   seed_ = 182;
154   using_reference_ = false;
155   return 0;
156 }
157 
Suppress(float * data,size_t data_length,int num_channels,const float * detection_data,size_t detection_length,const float * reference_data,size_t reference_length,float voice_probability,bool key_pressed)158 int TransientSuppressor::Suppress(float* data,
159                                   size_t data_length,
160                                   int num_channels,
161                                   const float* detection_data,
162                                   size_t detection_length,
163                                   const float* reference_data,
164                                   size_t reference_length,
165                                   float voice_probability,
166                                   bool key_pressed) {
167   if (!data || data_length != data_length_ || num_channels != num_channels_ ||
168       detection_length != detection_length_ || voice_probability < 0 ||
169       voice_probability > 1) {
170     return -1;
171   }
172 
173   UpdateKeypress(key_pressed);
174   UpdateBuffers(data);
175 
176   int result = 0;
177   if (detection_enabled_) {
178     UpdateRestoration(voice_probability);
179 
180     if (!detection_data) {
181       // Use the input data  of the first channel if special detection data is
182       // not supplied.
183       detection_data = &in_buffer_[buffer_delay_];
184     }
185 
186     float detector_result = detector_->Detect(detection_data, detection_length,
187                                               reference_data, reference_length);
188     if (detector_result < 0) {
189       return -1;
190     }
191 
192     using_reference_ = detector_->using_reference();
193 
194     // |detector_smoothed_| follows the |detector_result| when this last one is
195     // increasing, but has an exponential decaying tail to be able to suppress
196     // the ringing of keyclicks.
197     float smooth_factor = using_reference_ ? 0.6 : 0.1;
198     detector_smoothed_ = detector_result >= detector_smoothed_
199                              ? detector_result
200                              : smooth_factor * detector_smoothed_ +
201                                    (1 - smooth_factor) * detector_result;
202 
203     for (int i = 0; i < num_channels_; ++i) {
204       Suppress(&in_buffer_[i * analysis_length_],
205                &spectral_mean_[i * complex_analysis_length_],
206                &out_buffer_[i * analysis_length_]);
207     }
208   }
209 
210   // If the suppression isn't enabled, we use the in buffer to delay the signal
211   // appropriately. This also gives time for the out buffer to be refreshed with
212   // new data between detection and suppression getting enabled.
213   for (int i = 0; i < num_channels_; ++i) {
214     memcpy(&data[i * data_length_],
215            suppression_enabled_ ? &out_buffer_[i * analysis_length_]
216                                 : &in_buffer_[i * analysis_length_],
217            data_length_ * sizeof(*data));
218   }
219   return result;
220 }
221 
222 // This should only be called when detection is enabled. UpdateBuffers() must
223 // have been called. At return, |out_buffer_| will be filled with the
224 // processed output.
Suppress(float * in_ptr,float * spectral_mean,float * out_ptr)225 void TransientSuppressor::Suppress(float* in_ptr,
226                                    float* spectral_mean,
227                                    float* out_ptr) {
228   // Go to frequency domain.
229   for (size_t i = 0; i < analysis_length_; ++i) {
230     // TODO(aluebs): Rename windows
231     fft_buffer_[i] = in_ptr[i] * window_[i];
232   }
233 
234   WebRtc_rdft(analysis_length_, 1, fft_buffer_.get(), ip_.get(), wfft_.get());
235 
236   // Since WebRtc_rdft puts R[n/2] in fft_buffer_[1], we move it to the end
237   // for convenience.
238   fft_buffer_[analysis_length_] = fft_buffer_[1];
239   fft_buffer_[analysis_length_ + 1] = 0.f;
240   fft_buffer_[1] = 0.f;
241 
242   for (size_t i = 0; i < complex_analysis_length_; ++i) {
243     magnitudes_[i] =
244         ComplexMagnitude(fft_buffer_[i * 2], fft_buffer_[i * 2 + 1]);
245   }
246   // Restore audio if necessary.
247   if (suppression_enabled_) {
248     if (use_hard_restoration_) {
249       HardRestoration(spectral_mean);
250     } else {
251       SoftRestoration(spectral_mean);
252     }
253   }
254 
255   // Update the spectral mean.
256   for (size_t i = 0; i < complex_analysis_length_; ++i) {
257     spectral_mean[i] = (1 - kMeanIIRCoefficient) * spectral_mean[i] +
258                        kMeanIIRCoefficient * magnitudes_[i];
259   }
260 
261   // Back to time domain.
262   // Put R[n/2] back in fft_buffer_[1].
263   fft_buffer_[1] = fft_buffer_[analysis_length_];
264 
265   WebRtc_rdft(analysis_length_, -1, fft_buffer_.get(), ip_.get(), wfft_.get());
266   const float fft_scaling = 2.f / analysis_length_;
267 
268   for (size_t i = 0; i < analysis_length_; ++i) {
269     out_ptr[i] += fft_buffer_[i] * window_[i] * fft_scaling;
270   }
271 }
272 
UpdateKeypress(bool key_pressed)273 void TransientSuppressor::UpdateKeypress(bool key_pressed) {
274   const int kKeypressPenalty = 1000 / ts::kChunkSizeMs;
275   const int kIsTypingThreshold = 1000 / ts::kChunkSizeMs;
276   const int kChunksUntilNotTyping = 4000 / ts::kChunkSizeMs;  // 4 seconds.
277 
278   if (key_pressed) {
279     keypress_counter_ += kKeypressPenalty;
280     chunks_since_keypress_ = 0;
281     detection_enabled_ = true;
282   }
283   keypress_counter_ = std::max(0, keypress_counter_ - 1);
284 
285   if (keypress_counter_ > kIsTypingThreshold) {
286     if (!suppression_enabled_) {
287       RTC_LOG(LS_INFO) << "[ts] Transient suppression is now enabled.";
288     }
289     suppression_enabled_ = true;
290     keypress_counter_ = 0;
291   }
292 
293   if (detection_enabled_ && ++chunks_since_keypress_ > kChunksUntilNotTyping) {
294     if (suppression_enabled_) {
295       RTC_LOG(LS_INFO) << "[ts] Transient suppression is now disabled.";
296     }
297     detection_enabled_ = false;
298     suppression_enabled_ = false;
299     keypress_counter_ = 0;
300   }
301 }
302 
UpdateRestoration(float voice_probability)303 void TransientSuppressor::UpdateRestoration(float voice_probability) {
304   const int kHardRestorationOffsetDelay = 3;
305   const int kHardRestorationOnsetDelay = 80;
306 
307   bool not_voiced = voice_probability < kVoiceThreshold;
308 
309   if (not_voiced == use_hard_restoration_) {
310     chunks_since_voice_change_ = 0;
311   } else {
312     ++chunks_since_voice_change_;
313 
314     if ((use_hard_restoration_ &&
315          chunks_since_voice_change_ > kHardRestorationOffsetDelay) ||
316         (!use_hard_restoration_ &&
317          chunks_since_voice_change_ > kHardRestorationOnsetDelay)) {
318       use_hard_restoration_ = not_voiced;
319       chunks_since_voice_change_ = 0;
320     }
321   }
322 }
323 
324 // Shift buffers to make way for new data. Must be called after
325 // |detection_enabled_| is updated by UpdateKeypress().
UpdateBuffers(float * data)326 void TransientSuppressor::UpdateBuffers(float* data) {
327   // TODO(aluebs): Change to ring buffer.
328   memmove(in_buffer_.get(), &in_buffer_[data_length_],
329           (buffer_delay_ + (num_channels_ - 1) * analysis_length_) *
330               sizeof(in_buffer_[0]));
331   // Copy new chunk to buffer.
332   for (int i = 0; i < num_channels_; ++i) {
333     memcpy(&in_buffer_[buffer_delay_ + i * analysis_length_],
334            &data[i * data_length_], data_length_ * sizeof(*data));
335   }
336   if (detection_enabled_) {
337     // Shift previous chunk in out buffer.
338     memmove(out_buffer_.get(), &out_buffer_[data_length_],
339             (buffer_delay_ + (num_channels_ - 1) * analysis_length_) *
340                 sizeof(out_buffer_[0]));
341     // Initialize new chunk in out buffer.
342     for (int i = 0; i < num_channels_; ++i) {
343       memset(&out_buffer_[buffer_delay_ + i * analysis_length_], 0,
344              data_length_ * sizeof(out_buffer_[0]));
345     }
346   }
347 }
348 
349 // Restores the unvoiced signal if a click is present.
350 // Attenuates by a certain factor every peak in the |fft_buffer_| that exceeds
351 // the spectral mean. The attenuation depends on |detector_smoothed_|.
352 // If a restoration takes place, the |magnitudes_| are updated to the new value.
HardRestoration(float * spectral_mean)353 void TransientSuppressor::HardRestoration(float* spectral_mean) {
354   const float detector_result =
355       1.f - pow(1.f - detector_smoothed_, using_reference_ ? 200.f : 50.f);
356   // To restore, we get the peaks in the spectrum. If higher than the previous
357   // spectral mean we adjust them.
358   for (size_t i = 0; i < complex_analysis_length_; ++i) {
359     if (magnitudes_[i] > spectral_mean[i] && magnitudes_[i] > 0) {
360       // RandU() generates values on [0, int16::max()]
361       const float phase = 2 * ts::kPi * WebRtcSpl_RandU(&seed_) /
362                           std::numeric_limits<int16_t>::max();
363       const float scaled_mean = detector_result * spectral_mean[i];
364 
365       fft_buffer_[i * 2] = (1 - detector_result) * fft_buffer_[i * 2] +
366                            scaled_mean * cosf(phase);
367       fft_buffer_[i * 2 + 1] = (1 - detector_result) * fft_buffer_[i * 2 + 1] +
368                                scaled_mean * sinf(phase);
369       magnitudes_[i] = magnitudes_[i] -
370                        detector_result * (magnitudes_[i] - spectral_mean[i]);
371     }
372   }
373 }
374 
375 // Restores the voiced signal if a click is present.
376 // Attenuates by a certain factor every peak in the |fft_buffer_| that exceeds
377 // the spectral mean and that is lower than some function of the current block
378 // frequency mean. The attenuation depends on |detector_smoothed_|.
379 // If a restoration takes place, the |magnitudes_| are updated to the new value.
SoftRestoration(float * spectral_mean)380 void TransientSuppressor::SoftRestoration(float* spectral_mean) {
381   // Get the spectral magnitude mean of the current block.
382   float block_frequency_mean = 0;
383   for (size_t i = kMinVoiceBin; i < kMaxVoiceBin; ++i) {
384     block_frequency_mean += magnitudes_[i];
385   }
386   block_frequency_mean /= (kMaxVoiceBin - kMinVoiceBin);
387 
388   // To restore, we get the peaks in the spectrum. If higher than the
389   // previous spectral mean and lower than a factor of the block mean
390   // we adjust them. The factor is a double sigmoid that has a minimum in the
391   // voice frequency range (300Hz - 3kHz).
392   for (size_t i = 0; i < complex_analysis_length_; ++i) {
393     if (magnitudes_[i] > spectral_mean[i] && magnitudes_[i] > 0 &&
394         (using_reference_ ||
395          magnitudes_[i] < block_frequency_mean * mean_factor_[i])) {
396       const float new_magnitude =
397           magnitudes_[i] -
398           detector_smoothed_ * (magnitudes_[i] - spectral_mean[i]);
399       const float magnitude_ratio = new_magnitude / magnitudes_[i];
400 
401       fft_buffer_[i * 2] *= magnitude_ratio;
402       fft_buffer_[i * 2 + 1] *= magnitude_ratio;
403       magnitudes_[i] = new_magnitude;
404     }
405   }
406 }
407 
408 }  // namespace webrtc
409