1 /*
2  *  Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/ns/prior_signal_model_estimator.h"
12 
13 #include <math.h>
14 #include <algorithm>
15 
16 #include "modules/audio_processing/ns/fast_math.h"
17 #include "rtc_base/checks.h"
18 
19 namespace webrtc {
20 
21 namespace {
22 
23 // Identifies the first of the two largest peaks in the histogram.
FindFirstOfTwoLargestPeaks(float bin_size,rtc::ArrayView<const int,kHistogramSize> spectral_flatness,float * peak_position,int * peak_weight)24 void FindFirstOfTwoLargestPeaks(
25     float bin_size,
26     rtc::ArrayView<const int, kHistogramSize> spectral_flatness,
27     float* peak_position,
28     int* peak_weight) {
29   RTC_DCHECK(peak_position);
30   RTC_DCHECK(peak_weight);
31 
32   int peak_value = 0;
33   int secondary_peak_value = 0;
34   *peak_position = 0.f;
35   float secondary_peak_position = 0.f;
36   *peak_weight = 0;
37   int secondary_peak_weight = 0;
38 
39   // Identify the two largest peaks.
40   for (int i = 0; i < kHistogramSize; ++i) {
41     const float bin_mid = (i + 0.5f) * bin_size;
42     if (spectral_flatness[i] > peak_value) {
43       // Found new "first" peak candidate.
44       secondary_peak_value = peak_value;
45       secondary_peak_weight = *peak_weight;
46       secondary_peak_position = *peak_position;
47 
48       peak_value = spectral_flatness[i];
49       *peak_weight = spectral_flatness[i];
50       *peak_position = bin_mid;
51     } else if (spectral_flatness[i] > secondary_peak_value) {
52       // Found new "second" peak candidate.
53       secondary_peak_value = spectral_flatness[i];
54       secondary_peak_weight = spectral_flatness[i];
55       secondary_peak_position = bin_mid;
56     }
57   }
58 
59   // Merge the peaks if they are close.
60   if ((fabs(secondary_peak_position - *peak_position) < 2 * bin_size) &&
61       (secondary_peak_weight > 0.5f * (*peak_weight))) {
62     *peak_weight += secondary_peak_weight;
63     *peak_position = 0.5f * (*peak_position + secondary_peak_position);
64   }
65 }
66 
UpdateLrt(rtc::ArrayView<const int,kHistogramSize> lrt_histogram,float * prior_model_lrt,bool * low_lrt_fluctuations)67 void UpdateLrt(rtc::ArrayView<const int, kHistogramSize> lrt_histogram,
68                float* prior_model_lrt,
69                bool* low_lrt_fluctuations) {
70   RTC_DCHECK(prior_model_lrt);
71   RTC_DCHECK(low_lrt_fluctuations);
72 
73   float average = 0.f;
74   float average_compl = 0.f;
75   float average_squared = 0.f;
76   int count = 0;
77 
78   for (int i = 0; i < 10; ++i) {
79     float bin_mid = (i + 0.5f) * kBinSizeLrt;
80     average += lrt_histogram[i] * bin_mid;
81     count += lrt_histogram[i];
82   }
83   if (count > 0) {
84     average = average / count;
85   }
86 
87   for (int i = 0; i < kHistogramSize; ++i) {
88     float bin_mid = (i + 0.5f) * kBinSizeLrt;
89     average_squared += lrt_histogram[i] * bin_mid * bin_mid;
90     average_compl += lrt_histogram[i] * bin_mid;
91   }
92   constexpr float kOneFeatureUpdateWindowSize = 1.f / kFeatureUpdateWindowSize;
93   average_squared = average_squared * kOneFeatureUpdateWindowSize;
94   average_compl = average_compl * kOneFeatureUpdateWindowSize;
95 
96   // Fluctuation limit of LRT feature.
97   *low_lrt_fluctuations = average_squared - average * average_compl < 0.05f;
98 
99   // Get threshold for LRT feature.
100   constexpr float kMaxLrt = 1.f;
101   constexpr float kMinLrt = .2f;
102   if (*low_lrt_fluctuations) {
103     // Very low fluctuation, so likely noise.
104     *prior_model_lrt = kMaxLrt;
105   } else {
106     *prior_model_lrt = std::min(kMaxLrt, std::max(kMinLrt, 1.2f * average));
107   }
108 }
109 
110 }  // namespace
111 
PriorSignalModelEstimator(float lrt_initial_value)112 PriorSignalModelEstimator::PriorSignalModelEstimator(float lrt_initial_value)
113     : prior_model_(lrt_initial_value) {}
114 
115 // Extract thresholds for feature parameters and computes the threshold/weights.
Update(const Histograms & histograms)116 void PriorSignalModelEstimator::Update(const Histograms& histograms) {
117   bool low_lrt_fluctuations;
118   UpdateLrt(histograms.get_lrt(), &prior_model_.lrt, &low_lrt_fluctuations);
119 
120   // For spectral flatness and spectral difference: compute the main peaks of
121   // the histograms.
122   float spectral_flatness_peak_position;
123   int spectral_flatness_peak_weight;
124   FindFirstOfTwoLargestPeaks(
125       kBinSizeSpecFlat, histograms.get_spectral_flatness(),
126       &spectral_flatness_peak_position, &spectral_flatness_peak_weight);
127 
128   float spectral_diff_peak_position = 0.f;
129   int spectral_diff_peak_weight = 0;
130   FindFirstOfTwoLargestPeaks(kBinSizeSpecDiff, histograms.get_spectral_diff(),
131                              &spectral_diff_peak_position,
132                              &spectral_diff_peak_weight);
133 
134   // Reject if weight of peaks is not large enough, or peak value too small.
135   // Peak limit for spectral flatness (varies between 0 and 1).
136   const int use_spec_flat = spectral_flatness_peak_weight < 0.3f * 500 ||
137                                     spectral_flatness_peak_position < 0.6f
138                                 ? 0
139                                 : 1;
140 
141   // Reject if weight of peaks is not large enough or if fluctuation of the LRT
142   // feature are very low, indicating a noise state.
143   const int use_spec_diff =
144       spectral_diff_peak_weight < 0.3f * 500 || low_lrt_fluctuations ? 0 : 1;
145 
146   // Update the model.
147   prior_model_.template_diff_threshold = 1.2f * spectral_diff_peak_position;
148   prior_model_.template_diff_threshold =
149       std::min(1.f, std::max(0.16f, prior_model_.template_diff_threshold));
150 
151   float one_by_feature_sum = 1.f / (1.f + use_spec_flat + use_spec_diff);
152   prior_model_.lrt_weighting = one_by_feature_sum;
153 
154   if (use_spec_flat == 1) {
155     prior_model_.flatness_threshold = 0.9f * spectral_flatness_peak_position;
156     prior_model_.flatness_threshold =
157         std::min(.95f, std::max(0.1f, prior_model_.flatness_threshold));
158     prior_model_.flatness_weighting = one_by_feature_sum;
159   } else {
160     prior_model_.flatness_weighting = 0.f;
161   }
162 
163   if (use_spec_diff == 1) {
164     prior_model_.difference_weighting = one_by_feature_sum;
165   } else {
166     prior_model_.difference_weighting = 0.f;
167   }
168 }
169 
170 }  // namespace webrtc
171