1 /*
2  *  Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "webrtc/modules/audio_processing/transient/transient_suppressor.h"
12 
13 #include <math.h>
14 #include <string.h>
15 #include <cmath>
16 #include <complex>
17 #include <deque>
18 #include <set>
19 
20 #include "webrtc/base/scoped_ptr.h"
21 #include "webrtc/common_audio/fft4g.h"
22 #include "webrtc/common_audio/include/audio_util.h"
23 #include "webrtc/common_audio/signal_processing/include/signal_processing_library.h"
24 #include "webrtc/modules/audio_processing/transient/common.h"
25 #include "webrtc/modules/audio_processing/transient/transient_detector.h"
26 #include "webrtc/modules/audio_processing/ns/windows_private.h"
27 #include "webrtc/system_wrappers/interface/logging.h"
28 #include "webrtc/typedefs.h"
29 
30 namespace webrtc {
31 
32 static const float kMeanIIRCoefficient = 0.5f;
33 static const float kVoiceThreshold = 0.02f;
34 
35 // TODO(aluebs): Check if these values work also for 48kHz.
36 static const size_t kMinVoiceBin = 3;
37 static const size_t kMaxVoiceBin = 60;
38 
39 namespace {
40 
ComplexMagnitude(float a,float b)41 float ComplexMagnitude(float a, float b) {
42   return std::abs(a) + std::abs(b);
43 }
44 
45 }  // namespace
46 
TransientSuppressor()47 TransientSuppressor::TransientSuppressor()
48     : data_length_(0),
49       detection_length_(0),
50       analysis_length_(0),
51       buffer_delay_(0),
52       complex_analysis_length_(0),
53       num_channels_(0),
54       window_(NULL),
55       detector_smoothed_(0.f),
56       keypress_counter_(0),
57       chunks_since_keypress_(0),
58       detection_enabled_(false),
59       suppression_enabled_(false),
60       use_hard_restoration_(false),
61       chunks_since_voice_change_(0),
62       seed_(182),
63       using_reference_(false) {
64 }
65 
~TransientSuppressor()66 TransientSuppressor::~TransientSuppressor() {}
67 
Initialize(int sample_rate_hz,int detection_rate_hz,int num_channels)68 int TransientSuppressor::Initialize(int sample_rate_hz,
69                                     int detection_rate_hz,
70                                     int num_channels) {
71   switch (sample_rate_hz) {
72     case ts::kSampleRate8kHz:
73       analysis_length_ = 128u;
74       window_ = kBlocks80w128;
75       break;
76     case ts::kSampleRate16kHz:
77       analysis_length_ = 256u;
78       window_ = kBlocks160w256;
79       break;
80     case ts::kSampleRate32kHz:
81       analysis_length_ = 512u;
82       window_ = kBlocks320w512;
83       break;
84     case ts::kSampleRate48kHz:
85       analysis_length_ = 1024u;
86       window_ = kBlocks480w1024;
87       break;
88     default:
89       return -1;
90   }
91   if (detection_rate_hz != ts::kSampleRate8kHz &&
92       detection_rate_hz != ts::kSampleRate16kHz &&
93       detection_rate_hz != ts::kSampleRate32kHz &&
94       detection_rate_hz != ts::kSampleRate48kHz) {
95     return -1;
96   }
97   if (num_channels <= 0) {
98     return -1;
99   }
100 
101   detector_.reset(new TransientDetector(detection_rate_hz));
102   data_length_ = sample_rate_hz * ts::kChunkSizeMs / 1000;
103   if (data_length_ > analysis_length_) {
104     assert(false);
105     return -1;
106   }
107   buffer_delay_ = analysis_length_ - data_length_;
108 
109   complex_analysis_length_ = analysis_length_ / 2 + 1;
110   assert(complex_analysis_length_ >= kMaxVoiceBin);
111   num_channels_ = num_channels;
112   in_buffer_.reset(new float[analysis_length_ * num_channels_]);
113   memset(in_buffer_.get(),
114          0,
115          analysis_length_ * num_channels_ * sizeof(in_buffer_[0]));
116   detection_length_ = detection_rate_hz * ts::kChunkSizeMs / 1000;
117   detection_buffer_.reset(new float[detection_length_]);
118   memset(detection_buffer_.get(),
119          0,
120          detection_length_ * sizeof(detection_buffer_[0]));
121   out_buffer_.reset(new float[analysis_length_ * num_channels_]);
122   memset(out_buffer_.get(),
123          0,
124          analysis_length_ * num_channels_ * sizeof(out_buffer_[0]));
125   // ip[0] must be zero to trigger initialization using rdft().
126   size_t ip_length = 2 + sqrtf(analysis_length_);
127   ip_.reset(new int[ip_length]());
128   memset(ip_.get(), 0, ip_length * sizeof(ip_[0]));
129   wfft_.reset(new float[complex_analysis_length_ - 1]);
130   memset(wfft_.get(), 0, (complex_analysis_length_ - 1) * sizeof(wfft_[0]));
131   spectral_mean_.reset(new float[complex_analysis_length_ * num_channels_]);
132   memset(spectral_mean_.get(),
133          0,
134          complex_analysis_length_ * num_channels_ * sizeof(spectral_mean_[0]));
135   fft_buffer_.reset(new float[analysis_length_ + 2]);
136   memset(fft_buffer_.get(), 0, (analysis_length_ + 2) * sizeof(fft_buffer_[0]));
137   magnitudes_.reset(new float[complex_analysis_length_]);
138   memset(magnitudes_.get(),
139          0,
140          complex_analysis_length_ * sizeof(magnitudes_[0]));
141   mean_factor_.reset(new float[complex_analysis_length_]);
142 
143   static const float kFactorHeight = 10.f;
144   static const float kLowSlope = 1.f;
145   static const float kHighSlope = 0.3f;
146   for (size_t i = 0; i < complex_analysis_length_; ++i) {
147     mean_factor_[i] =
148         kFactorHeight /
149             (1.f + exp(kLowSlope * static_cast<int>(i - kMinVoiceBin))) +
150         kFactorHeight /
151             (1.f + exp(kHighSlope * static_cast<int>(kMaxVoiceBin - i)));
152   }
153   detector_smoothed_ = 0.f;
154   keypress_counter_ = 0;
155   chunks_since_keypress_ = 0;
156   detection_enabled_ = false;
157   suppression_enabled_ = false;
158   use_hard_restoration_ = false;
159   chunks_since_voice_change_ = 0;
160   seed_ = 182;
161   using_reference_ = false;
162   return 0;
163 }
164 
Suppress(float * data,size_t data_length,int num_channels,const float * detection_data,size_t detection_length,const float * reference_data,size_t reference_length,float voice_probability,bool key_pressed)165 int TransientSuppressor::Suppress(float* data,
166                                   size_t data_length,
167                                   int num_channels,
168                                   const float* detection_data,
169                                   size_t detection_length,
170                                   const float* reference_data,
171                                   size_t reference_length,
172                                   float voice_probability,
173                                   bool key_pressed) {
174   if (!data || data_length != data_length_ || num_channels != num_channels_ ||
175       detection_length != detection_length_ || voice_probability < 0 ||
176       voice_probability > 1) {
177     return -1;
178   }
179 
180   UpdateKeypress(key_pressed);
181   UpdateBuffers(data);
182 
183   int result = 0;
184   if (detection_enabled_) {
185     UpdateRestoration(voice_probability);
186 
187     if (!detection_data) {
188       // Use the input data  of the first channel if special detection data is
189       // not supplied.
190       detection_data = &in_buffer_[buffer_delay_];
191     }
192 
193     float detector_result = detector_->Detect(
194         detection_data, detection_length, reference_data, reference_length);
195     if (detector_result < 0) {
196       return -1;
197     }
198 
199     using_reference_ = detector_->using_reference();
200 
201     // |detector_smoothed_| follows the |detector_result| when this last one is
202     // increasing, but has an exponential decaying tail to be able to suppress
203     // the ringing of keyclicks.
204     float smooth_factor = using_reference_ ? 0.6 : 0.1;
205     detector_smoothed_ = detector_result >= detector_smoothed_
206                              ? detector_result
207                              : smooth_factor * detector_smoothed_ +
208                                    (1 - smooth_factor) * detector_result;
209 
210     for (int i = 0; i < num_channels_; ++i) {
211       Suppress(&in_buffer_[i * analysis_length_],
212                &spectral_mean_[i * complex_analysis_length_],
213                &out_buffer_[i * analysis_length_]);
214     }
215   }
216 
217   // If the suppression isn't enabled, we use the in buffer to delay the signal
218   // appropriately. This also gives time for the out buffer to be refreshed with
219   // new data between detection and suppression getting enabled.
220   for (int i = 0; i < num_channels_; ++i) {
221     memcpy(&data[i * data_length_],
222            suppression_enabled_ ? &out_buffer_[i * analysis_length_]
223                                 : &in_buffer_[i * analysis_length_],
224            data_length_ * sizeof(*data));
225   }
226   return result;
227 }
228 
229 // This should only be called when detection is enabled. UpdateBuffers() must
230 // have been called. At return, |out_buffer_| will be filled with the
231 // processed output.
Suppress(float * in_ptr,float * spectral_mean,float * out_ptr)232 void TransientSuppressor::Suppress(float* in_ptr,
233                                    float* spectral_mean,
234                                    float* out_ptr) {
235   // Go to frequency domain.
236   for (size_t i = 0; i < analysis_length_; ++i) {
237     // TODO(aluebs): Rename windows
238     fft_buffer_[i] = in_ptr[i] * window_[i];
239   }
240 
241   WebRtc_rdft(analysis_length_, 1, fft_buffer_.get(), ip_.get(), wfft_.get());
242 
243   // Since WebRtc_rdft puts R[n/2] in fft_buffer_[1], we move it to the end
244   // for convenience.
245   fft_buffer_[analysis_length_] = fft_buffer_[1];
246   fft_buffer_[analysis_length_ + 1] = 0.f;
247   fft_buffer_[1] = 0.f;
248 
249   for (size_t i = 0; i < complex_analysis_length_; ++i) {
250     magnitudes_[i] = ComplexMagnitude(fft_buffer_[i * 2],
251                                       fft_buffer_[i * 2 + 1]);
252   }
253   // Restore audio if necessary.
254   if (suppression_enabled_) {
255     if (use_hard_restoration_) {
256       HardRestoration(spectral_mean);
257     } else {
258       SoftRestoration(spectral_mean);
259     }
260   }
261 
262   // Update the spectral mean.
263   for (size_t i = 0; i < complex_analysis_length_; ++i) {
264     spectral_mean[i] = (1 - kMeanIIRCoefficient) * spectral_mean[i] +
265                        kMeanIIRCoefficient * magnitudes_[i];
266   }
267 
268   // Back to time domain.
269   // Put R[n/2] back in fft_buffer_[1].
270   fft_buffer_[1] = fft_buffer_[analysis_length_];
271 
272   WebRtc_rdft(analysis_length_,
273               -1,
274               fft_buffer_.get(),
275               ip_.get(),
276               wfft_.get());
277   const float fft_scaling = 2.f / analysis_length_;
278 
279   for (size_t i = 0; i < analysis_length_; ++i) {
280     out_ptr[i] += fft_buffer_[i] * window_[i] * fft_scaling;
281   }
282 }
283 
UpdateKeypress(bool key_pressed)284 void TransientSuppressor::UpdateKeypress(bool key_pressed) {
285   const int kKeypressPenalty = 1000 / ts::kChunkSizeMs;
286   const int kIsTypingThreshold = 1000 / ts::kChunkSizeMs;
287   const int kChunksUntilNotTyping = 4000 / ts::kChunkSizeMs;  // 4 seconds.
288 
289   if (key_pressed) {
290     keypress_counter_ += kKeypressPenalty;
291     chunks_since_keypress_ = 0;
292     detection_enabled_ = true;
293   }
294   keypress_counter_ = std::max(0, keypress_counter_ - 1);
295 
296   if (keypress_counter_ > kIsTypingThreshold) {
297     if (!suppression_enabled_) {
298       LOG(LS_INFO) << "[ts] Transient suppression is now enabled.";
299     }
300     suppression_enabled_ = true;
301     keypress_counter_ = 0;
302   }
303 
304   if (detection_enabled_ &&
305       ++chunks_since_keypress_ > kChunksUntilNotTyping) {
306     if (suppression_enabled_) {
307       LOG(LS_INFO) << "[ts] Transient suppression is now disabled.";
308     }
309     detection_enabled_ = false;
310     suppression_enabled_ = false;
311     keypress_counter_ = 0;
312   }
313 }
314 
UpdateRestoration(float voice_probability)315 void TransientSuppressor::UpdateRestoration(float voice_probability) {
316   const int kHardRestorationOffsetDelay = 3;
317   const int kHardRestorationOnsetDelay = 80;
318 
319   bool not_voiced = voice_probability < kVoiceThreshold;
320 
321   if (not_voiced == use_hard_restoration_) {
322     chunks_since_voice_change_ = 0;
323   } else {
324     ++chunks_since_voice_change_;
325 
326     if ((use_hard_restoration_ &&
327          chunks_since_voice_change_ > kHardRestorationOffsetDelay) ||
328         (!use_hard_restoration_ &&
329          chunks_since_voice_change_ > kHardRestorationOnsetDelay)) {
330       use_hard_restoration_ = not_voiced;
331       chunks_since_voice_change_ = 0;
332     }
333   }
334 }
335 
336 // Shift buffers to make way for new data. Must be called after
337 // |detection_enabled_| is updated by UpdateKeypress().
UpdateBuffers(float * data)338 void TransientSuppressor::UpdateBuffers(float* data) {
339   // TODO(aluebs): Change to ring buffer.
340   memmove(in_buffer_.get(),
341           &in_buffer_[data_length_],
342           (buffer_delay_ + (num_channels_ - 1) * analysis_length_) *
343               sizeof(in_buffer_[0]));
344   // Copy new chunk to buffer.
345   for (int i = 0; i < num_channels_; ++i) {
346     memcpy(&in_buffer_[buffer_delay_ + i * analysis_length_],
347            &data[i * data_length_],
348            data_length_ * sizeof(*data));
349   }
350   if (detection_enabled_) {
351     // Shift previous chunk in out buffer.
352     memmove(out_buffer_.get(),
353             &out_buffer_[data_length_],
354             (buffer_delay_ + (num_channels_ - 1) * analysis_length_) *
355                 sizeof(out_buffer_[0]));
356     // Initialize new chunk in out buffer.
357     for (int i = 0; i < num_channels_; ++i) {
358       memset(&out_buffer_[buffer_delay_ + i * analysis_length_],
359              0,
360              data_length_ * sizeof(out_buffer_[0]));
361     }
362   }
363 }
364 
365 // Restores the unvoiced signal if a click is present.
366 // Attenuates by a certain factor every peak in the |fft_buffer_| that exceeds
367 // the spectral mean. The attenuation depends on |detector_smoothed_|.
368 // If a restoration takes place, the |magnitudes_| are updated to the new value.
HardRestoration(float * spectral_mean)369 void TransientSuppressor::HardRestoration(float* spectral_mean) {
370   const float detector_result =
371       1.f - pow(1.f - detector_smoothed_, using_reference_ ? 200.f : 50.f);
372   // To restore, we get the peaks in the spectrum. If higher than the previous
373   // spectral mean we adjust them.
374   for (size_t i = 0; i < complex_analysis_length_; ++i) {
375     if (magnitudes_[i] > spectral_mean[i] && magnitudes_[i] > 0) {
376       // RandU() generates values on [0, int16::max()]
377       const float phase = 2 * ts::kPi * WebRtcSpl_RandU(&seed_) /
378           std::numeric_limits<int16_t>::max();
379       const float scaled_mean = detector_result * spectral_mean[i];
380 
381       fft_buffer_[i * 2] = (1 - detector_result) * fft_buffer_[i * 2] +
382                            scaled_mean * cosf(phase);
383       fft_buffer_[i * 2 + 1] = (1 - detector_result) * fft_buffer_[i * 2 + 1] +
384                                scaled_mean * sinf(phase);
385       magnitudes_[i] = magnitudes_[i] -
386                        detector_result * (magnitudes_[i] - spectral_mean[i]);
387     }
388   }
389 }
390 
391 // Restores the voiced signal if a click is present.
392 // Attenuates by a certain factor every peak in the |fft_buffer_| that exceeds
393 // the spectral mean and that is lower than some function of the current block
394 // frequency mean. The attenuation depends on |detector_smoothed_|.
395 // If a restoration takes place, the |magnitudes_| are updated to the new value.
SoftRestoration(float * spectral_mean)396 void TransientSuppressor::SoftRestoration(float* spectral_mean) {
397   // Get the spectral magnitude mean of the current block.
398   float block_frequency_mean = 0;
399   for (size_t i = kMinVoiceBin; i < kMaxVoiceBin; ++i) {
400     block_frequency_mean += magnitudes_[i];
401   }
402   block_frequency_mean /= (kMaxVoiceBin - kMinVoiceBin);
403 
404   // To restore, we get the peaks in the spectrum. If higher than the
405   // previous spectral mean and lower than a factor of the block mean
406   // we adjust them. The factor is a double sigmoid that has a minimum in the
407   // voice frequency range (300Hz - 3kHz).
408   for (size_t i = 0; i < complex_analysis_length_; ++i) {
409     if (magnitudes_[i] > spectral_mean[i] && magnitudes_[i] > 0 &&
410         (using_reference_ ||
411          magnitudes_[i] < block_frequency_mean * mean_factor_[i])) {
412       const float new_magnitude =
413           magnitudes_[i] -
414           detector_smoothed_ * (magnitudes_[i] - spectral_mean[i]);
415       const float magnitude_ratio = new_magnitude / magnitudes_[i];
416 
417       fft_buffer_[i * 2] *= magnitude_ratio;
418       fft_buffer_[i * 2 + 1] *= magnitude_ratio;
419       magnitudes_[i] = new_magnitude;
420     }
421   }
422 }
423 
424 }  // namespace webrtc
425