1 /*
2  *  Copyright (c) 2012 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/noise_suppression_impl.h"
12 
13 #include "modules/audio_processing/audio_buffer.h"
14 #include "rtc_base/checks.h"
15 #include "rtc_base/constructormagic.h"
16 #if defined(WEBRTC_NS_FLOAT)
17 #include "modules/audio_processing/ns/noise_suppression.h"
18 
19 #define NS_CREATE WebRtcNs_Create
20 #define NS_FREE WebRtcNs_Free
21 #define NS_INIT WebRtcNs_Init
22 #define NS_SET_POLICY WebRtcNs_set_policy
23 typedef NsHandle NsState;
24 #elif defined(WEBRTC_NS_FIXED)
25 #include "modules/audio_processing/ns/noise_suppression_x.h"
26 
27 #define NS_CREATE WebRtcNsx_Create
28 #define NS_FREE WebRtcNsx_Free
29 #define NS_INIT WebRtcNsx_Init
30 #define NS_SET_POLICY WebRtcNsx_set_policy
31 typedef NsxHandle NsState;
32 #endif
33 
34 namespace webrtc {
35 class NoiseSuppressionImpl::Suppressor {
36  public:
Suppressor(int sample_rate_hz)37   explicit Suppressor(int sample_rate_hz) {
38     state_ = NS_CREATE();
39     RTC_CHECK(state_);
40     int error = NS_INIT(state_, sample_rate_hz);
41     RTC_DCHECK_EQ(0, error);
42   }
~Suppressor()43   ~Suppressor() { NS_FREE(state_); }
state()44   NsState* state() { return state_; }
45 
46  private:
47   NsState* state_ = nullptr;
48   RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(Suppressor);
49 };
50 
NoiseSuppressionImpl(rtc::CriticalSection * crit)51 NoiseSuppressionImpl::NoiseSuppressionImpl(rtc::CriticalSection* crit)
52     : crit_(crit) {
53   RTC_DCHECK(crit);
54 }
55 
~NoiseSuppressionImpl()56 NoiseSuppressionImpl::~NoiseSuppressionImpl() {}
57 
Initialize(size_t channels,int sample_rate_hz)58 void NoiseSuppressionImpl::Initialize(size_t channels, int sample_rate_hz) {
59   rtc::CritScope cs(crit_);
60   channels_ = channels;
61   sample_rate_hz_ = sample_rate_hz;
62   std::vector<std::unique_ptr<Suppressor>> new_suppressors;
63   if (enabled_) {
64     new_suppressors.resize(channels);
65     for (size_t i = 0; i < channels; i++) {
66       new_suppressors[i].reset(new Suppressor(sample_rate_hz));
67     }
68   }
69   suppressors_.swap(new_suppressors);
70   set_level(level_);
71 }
72 
AnalyzeCaptureAudio(AudioBuffer * audio)73 void NoiseSuppressionImpl::AnalyzeCaptureAudio(AudioBuffer* audio) {
74   RTC_DCHECK(audio);
75 #if defined(WEBRTC_NS_FLOAT)
76   rtc::CritScope cs(crit_);
77   if (!enabled_) {
78     return;
79   }
80 
81   RTC_DCHECK_GE(160, audio->num_frames_per_band());
82   RTC_DCHECK_EQ(suppressors_.size(), audio->num_channels());
83   for (size_t i = 0; i < suppressors_.size(); i++) {
84     WebRtcNs_Analyze(suppressors_[i]->state(),
85                      audio->split_bands_const_f(i)[kBand0To8kHz]);
86   }
87 #endif
88 }
89 
ProcessCaptureAudio(AudioBuffer * audio)90 void NoiseSuppressionImpl::ProcessCaptureAudio(AudioBuffer* audio) {
91   RTC_DCHECK(audio);
92   rtc::CritScope cs(crit_);
93   if (!enabled_) {
94     return;
95   }
96 
97   RTC_DCHECK_GE(160, audio->num_frames_per_band());
98   RTC_DCHECK_EQ(suppressors_.size(), audio->num_channels());
99   for (size_t i = 0; i < suppressors_.size(); i++) {
100 #if defined(WEBRTC_NS_FLOAT)
101     WebRtcNs_Process(suppressors_[i]->state(), audio->split_bands_const_f(i),
102                      audio->num_bands(), audio->split_bands_f(i));
103 #elif defined(WEBRTC_NS_FIXED)
104     WebRtcNsx_Process(suppressors_[i]->state(), audio->split_bands_const(i),
105                       audio->num_bands(), audio->split_bands(i));
106 #endif
107   }
108 }
109 
Enable(bool enable)110 int NoiseSuppressionImpl::Enable(bool enable) {
111   rtc::CritScope cs(crit_);
112   if (enabled_ != enable) {
113     enabled_ = enable;
114     Initialize(channels_, sample_rate_hz_);
115   }
116   return AudioProcessing::kNoError;
117 }
118 
is_enabled() const119 bool NoiseSuppressionImpl::is_enabled() const {
120   rtc::CritScope cs(crit_);
121   return enabled_;
122 }
123 
set_level(Level level)124 int NoiseSuppressionImpl::set_level(Level level) {
125   int policy = 1;
126   switch (level) {
127     case NoiseSuppression::kLow:
128       policy = 0;
129       break;
130     case NoiseSuppression::kModerate:
131       policy = 1;
132       break;
133     case NoiseSuppression::kHigh:
134       policy = 2;
135       break;
136     case NoiseSuppression::kVeryHigh:
137       policy = 3;
138       break;
139     default:
140       RTC_NOTREACHED();
141   }
142   rtc::CritScope cs(crit_);
143   level_ = level;
144   for (auto& suppressor : suppressors_) {
145     int error = NS_SET_POLICY(suppressor->state(), policy);
146     RTC_DCHECK_EQ(0, error);
147   }
148   return AudioProcessing::kNoError;
149 }
150 
level() const151 NoiseSuppression::Level NoiseSuppressionImpl::level() const {
152   rtc::CritScope cs(crit_);
153   return level_;
154 }
155 
speech_probability() const156 float NoiseSuppressionImpl::speech_probability() const {
157   rtc::CritScope cs(crit_);
158 #if defined(WEBRTC_NS_FLOAT)
159   float probability_average = 0.0f;
160   for (auto& suppressor : suppressors_) {
161     probability_average +=
162         WebRtcNs_prior_speech_probability(suppressor->state());
163   }
164   if (!suppressors_.empty()) {
165     probability_average /= suppressors_.size();
166   }
167   return probability_average;
168 #elif defined(WEBRTC_NS_FIXED)
169   // TODO(peah): Returning error code as a float! Remove this.
170   // Currently not available for the fixed point implementation.
171   return AudioProcessing::kUnsupportedFunctionError;
172 #endif
173 }
174 
NoiseEstimate()175 std::vector<float> NoiseSuppressionImpl::NoiseEstimate() {
176   rtc::CritScope cs(crit_);
177   std::vector<float> noise_estimate;
178 #if defined(WEBRTC_NS_FLOAT)
179   const float kNumChannelsFraction = 1.f / suppressors_.size();
180   noise_estimate.assign(WebRtcNs_num_freq(), 0.f);
181   for (auto& suppressor : suppressors_) {
182     const float* noise = WebRtcNs_noise_estimate(suppressor->state());
183     for (size_t i = 0; i < noise_estimate.size(); ++i) {
184       noise_estimate[i] += kNumChannelsFraction * noise[i];
185     }
186   }
187 #elif defined(WEBRTC_NS_FIXED)
188   noise_estimate.assign(WebRtcNsx_num_freq(), 0.f);
189   for (auto& suppressor : suppressors_) {
190     int q_noise;
191     const uint32_t* noise =
192         WebRtcNsx_noise_estimate(suppressor->state(), &q_noise);
193     const float kNormalizationFactor =
194         1.f / ((1 << q_noise) * suppressors_.size());
195     for (size_t i = 0; i < noise_estimate.size(); ++i) {
196       noise_estimate[i] += kNormalizationFactor * noise[i];
197     }
198   }
199 #endif
200   return noise_estimate;
201 }
202 
num_noise_bins()203 size_t NoiseSuppressionImpl::num_noise_bins() {
204 #if defined(WEBRTC_NS_FLOAT)
205   return WebRtcNs_num_freq();
206 #elif defined(WEBRTC_NS_FIXED)
207   return WebRtcNsx_num_freq();
208 #endif
209 }
210 
211 }  // namespace webrtc
212