1 /*
2 * Copyright (c) 2012 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/noise_suppression_impl.h"
12
13 #include "modules/audio_processing/audio_buffer.h"
14 #include "rtc_base/checks.h"
15 #include "rtc_base/constructormagic.h"
16 #if defined(WEBRTC_NS_FLOAT)
17 #include "modules/audio_processing/ns/noise_suppression.h"
18
19 #define NS_CREATE WebRtcNs_Create
20 #define NS_FREE WebRtcNs_Free
21 #define NS_INIT WebRtcNs_Init
22 #define NS_SET_POLICY WebRtcNs_set_policy
23 typedef NsHandle NsState;
24 #elif defined(WEBRTC_NS_FIXED)
25 #include "modules/audio_processing/ns/noise_suppression_x.h"
26
27 #define NS_CREATE WebRtcNsx_Create
28 #define NS_FREE WebRtcNsx_Free
29 #define NS_INIT WebRtcNsx_Init
30 #define NS_SET_POLICY WebRtcNsx_set_policy
31 typedef NsxHandle NsState;
32 #endif
33
34 namespace webrtc {
35 class NoiseSuppressionImpl::Suppressor {
36 public:
Suppressor(int sample_rate_hz)37 explicit Suppressor(int sample_rate_hz) {
38 state_ = NS_CREATE();
39 RTC_CHECK(state_);
40 int error = NS_INIT(state_, sample_rate_hz);
41 RTC_DCHECK_EQ(0, error);
42 }
~Suppressor()43 ~Suppressor() { NS_FREE(state_); }
state()44 NsState* state() { return state_; }
45
46 private:
47 NsState* state_ = nullptr;
48 RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(Suppressor);
49 };
50
NoiseSuppressionImpl(rtc::CriticalSection * crit)51 NoiseSuppressionImpl::NoiseSuppressionImpl(rtc::CriticalSection* crit)
52 : crit_(crit) {
53 RTC_DCHECK(crit);
54 }
55
~NoiseSuppressionImpl()56 NoiseSuppressionImpl::~NoiseSuppressionImpl() {}
57
Initialize(size_t channels,int sample_rate_hz)58 void NoiseSuppressionImpl::Initialize(size_t channels, int sample_rate_hz) {
59 rtc::CritScope cs(crit_);
60 channels_ = channels;
61 sample_rate_hz_ = sample_rate_hz;
62 std::vector<std::unique_ptr<Suppressor>> new_suppressors;
63 if (enabled_) {
64 new_suppressors.resize(channels);
65 for (size_t i = 0; i < channels; i++) {
66 new_suppressors[i].reset(new Suppressor(sample_rate_hz));
67 }
68 }
69 suppressors_.swap(new_suppressors);
70 set_level(level_);
71 }
72
AnalyzeCaptureAudio(AudioBuffer * audio)73 void NoiseSuppressionImpl::AnalyzeCaptureAudio(AudioBuffer* audio) {
74 RTC_DCHECK(audio);
75 #if defined(WEBRTC_NS_FLOAT)
76 rtc::CritScope cs(crit_);
77 if (!enabled_) {
78 return;
79 }
80
81 RTC_DCHECK_GE(160, audio->num_frames_per_band());
82 RTC_DCHECK_EQ(suppressors_.size(), audio->num_channels());
83 for (size_t i = 0; i < suppressors_.size(); i++) {
84 WebRtcNs_Analyze(suppressors_[i]->state(),
85 audio->split_bands_const_f(i)[kBand0To8kHz]);
86 }
87 #endif
88 }
89
ProcessCaptureAudio(AudioBuffer * audio)90 void NoiseSuppressionImpl::ProcessCaptureAudio(AudioBuffer* audio) {
91 RTC_DCHECK(audio);
92 rtc::CritScope cs(crit_);
93 if (!enabled_) {
94 return;
95 }
96
97 RTC_DCHECK_GE(160, audio->num_frames_per_band());
98 RTC_DCHECK_EQ(suppressors_.size(), audio->num_channels());
99 for (size_t i = 0; i < suppressors_.size(); i++) {
100 #if defined(WEBRTC_NS_FLOAT)
101 WebRtcNs_Process(suppressors_[i]->state(), audio->split_bands_const_f(i),
102 audio->num_bands(), audio->split_bands_f(i));
103 #elif defined(WEBRTC_NS_FIXED)
104 WebRtcNsx_Process(suppressors_[i]->state(), audio->split_bands_const(i),
105 audio->num_bands(), audio->split_bands(i));
106 #endif
107 }
108 }
109
Enable(bool enable)110 int NoiseSuppressionImpl::Enable(bool enable) {
111 rtc::CritScope cs(crit_);
112 if (enabled_ != enable) {
113 enabled_ = enable;
114 Initialize(channels_, sample_rate_hz_);
115 }
116 return AudioProcessing::kNoError;
117 }
118
is_enabled() const119 bool NoiseSuppressionImpl::is_enabled() const {
120 rtc::CritScope cs(crit_);
121 return enabled_;
122 }
123
set_level(Level level)124 int NoiseSuppressionImpl::set_level(Level level) {
125 int policy = 1;
126 switch (level) {
127 case NoiseSuppression::kLow:
128 policy = 0;
129 break;
130 case NoiseSuppression::kModerate:
131 policy = 1;
132 break;
133 case NoiseSuppression::kHigh:
134 policy = 2;
135 break;
136 case NoiseSuppression::kVeryHigh:
137 policy = 3;
138 break;
139 default:
140 RTC_NOTREACHED();
141 }
142 rtc::CritScope cs(crit_);
143 level_ = level;
144 for (auto& suppressor : suppressors_) {
145 int error = NS_SET_POLICY(suppressor->state(), policy);
146 RTC_DCHECK_EQ(0, error);
147 }
148 return AudioProcessing::kNoError;
149 }
150
level() const151 NoiseSuppression::Level NoiseSuppressionImpl::level() const {
152 rtc::CritScope cs(crit_);
153 return level_;
154 }
155
speech_probability() const156 float NoiseSuppressionImpl::speech_probability() const {
157 rtc::CritScope cs(crit_);
158 #if defined(WEBRTC_NS_FLOAT)
159 float probability_average = 0.0f;
160 for (auto& suppressor : suppressors_) {
161 probability_average +=
162 WebRtcNs_prior_speech_probability(suppressor->state());
163 }
164 if (!suppressors_.empty()) {
165 probability_average /= suppressors_.size();
166 }
167 return probability_average;
168 #elif defined(WEBRTC_NS_FIXED)
169 // TODO(peah): Returning error code as a float! Remove this.
170 // Currently not available for the fixed point implementation.
171 return AudioProcessing::kUnsupportedFunctionError;
172 #endif
173 }
174
NoiseEstimate()175 std::vector<float> NoiseSuppressionImpl::NoiseEstimate() {
176 rtc::CritScope cs(crit_);
177 std::vector<float> noise_estimate;
178 #if defined(WEBRTC_NS_FLOAT)
179 const float kNumChannelsFraction = 1.f / suppressors_.size();
180 noise_estimate.assign(WebRtcNs_num_freq(), 0.f);
181 for (auto& suppressor : suppressors_) {
182 const float* noise = WebRtcNs_noise_estimate(suppressor->state());
183 for (size_t i = 0; i < noise_estimate.size(); ++i) {
184 noise_estimate[i] += kNumChannelsFraction * noise[i];
185 }
186 }
187 #elif defined(WEBRTC_NS_FIXED)
188 noise_estimate.assign(WebRtcNsx_num_freq(), 0.f);
189 for (auto& suppressor : suppressors_) {
190 int q_noise;
191 const uint32_t* noise =
192 WebRtcNsx_noise_estimate(suppressor->state(), &q_noise);
193 const float kNormalizationFactor =
194 1.f / ((1 << q_noise) * suppressors_.size());
195 for (size_t i = 0; i < noise_estimate.size(); ++i) {
196 noise_estimate[i] += kNormalizationFactor * noise[i];
197 }
198 }
199 #endif
200 return noise_estimate;
201 }
202
num_noise_bins()203 size_t NoiseSuppressionImpl::num_noise_bins() {
204 #if defined(WEBRTC_NS_FLOAT)
205 return WebRtcNs_num_freq();
206 #elif defined(WEBRTC_NS_FIXED)
207 return WebRtcNsx_num_freq();
208 #endif
209 }
210
211 } // namespace webrtc
212