1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
12 
13 #include <array>
14 #include <string>
15 #include <tuple>
16 
17 #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
18 #include "rtc_base/strings/string_builder.h"
19 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
20 // #include "test/fpe_observer.h"
21 #include "test/gtest.h"
22 
23 namespace webrtc {
24 namespace rnn_vad {
25 namespace {
26 
27 constexpr int kTestPitchPeriodsLow = 3 * kMinPitch48kHz / 2;
28 constexpr int kTestPitchPeriodsHigh = (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2;
29 
30 constexpr float kTestPitchStrengthLow = 0.35f;
31 constexpr float kTestPitchStrengthHigh = 0.75f;
32 
33 template <class T>
PrintTestIndexAndCpuFeatures(const::testing::TestParamInfo<T> & info)34 std::string PrintTestIndexAndCpuFeatures(
35     const ::testing::TestParamInfo<T>& info) {
36   rtc::StringBuilder builder;
37   builder << info.index << "_" << info.param.cpu_features.ToString();
38   return builder.str();
39 }
40 
41 // Finds the relevant CPU features combinations to test.
GetCpuFeaturesToTest()42 std::vector<AvailableCpuFeatures> GetCpuFeaturesToTest() {
43   std::vector<AvailableCpuFeatures> v;
44   v.push_back(NoAvailableCpuFeatures());
45   AvailableCpuFeatures available = GetAvailableCpuFeatures();
46   if (available.avx2) {
47     v.push_back({/*sse2=*/false, /*avx2=*/true, /*neon=*/false});
48   }
49   if (available.sse2) {
50     v.push_back({/*sse2=*/true, /*avx2=*/false, /*neon=*/false});
51   }
52   return v;
53 }
54 
55 // Checks that the frame-wise sliding square energy function produces output
56 // within tolerance given test input data.
TEST(RnnVadTest,ComputeSlidingFrameSquareEnergies24kHzWithinTolerance)57 TEST(RnnVadTest, ComputeSlidingFrameSquareEnergies24kHzWithinTolerance) {
58   const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
59 
60   PitchTestData test_data;
61   std::array<float, kRefineNumLags24kHz> computed_output;
62   // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
63   // FloatingPointExceptionObserver fpe_observer;
64   ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(),
65                                          computed_output, cpu_features);
66   auto square_energies_view = test_data.SquareEnergies24kHzView();
67   ExpectNearAbsolute({square_energies_view.data(), square_energies_view.size()},
68                      computed_output, 1e-3f);
69 }
70 
71 // Checks that the estimated pitch period is bit-exact given test input data.
TEST(RnnVadTest,ComputePitchPeriod12kHzBitExactness)72 TEST(RnnVadTest, ComputePitchPeriod12kHzBitExactness) {
73   const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
74 
75   PitchTestData test_data;
76   std::array<float, kBufSize12kHz> pitch_buf_decimated;
77   Decimate2x(test_data.PitchBuffer24kHzView(), pitch_buf_decimated);
78   CandidatePitchPeriods pitch_candidates;
79   // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
80   // FloatingPointExceptionObserver fpe_observer;
81   pitch_candidates = ComputePitchPeriod12kHz(
82       pitch_buf_decimated, test_data.AutoCorrelation12kHzView(), cpu_features);
83   EXPECT_EQ(pitch_candidates.best, 140);
84   EXPECT_EQ(pitch_candidates.second_best, 142);
85 }
86 
87 // Checks that the refined pitch period is bit-exact given test input data.
TEST(RnnVadTest,ComputePitchPeriod48kHzBitExactness)88 TEST(RnnVadTest, ComputePitchPeriod48kHzBitExactness) {
89   const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
90 
91   PitchTestData test_data;
92   std::vector<float> y_energy(kRefineNumLags24kHz);
93   rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
94                                                            kRefineNumLags24kHz);
95   ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(),
96                                          y_energy_view, cpu_features);
97   // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
98   // FloatingPointExceptionObserver fpe_observer;
99   EXPECT_EQ(
100       ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view,
101                               /*pitch_candidates=*/{280, 284}, cpu_features),
102       560);
103   EXPECT_EQ(
104       ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view,
105                               /*pitch_candidates=*/{260, 284}, cpu_features),
106       568);
107 }
108 
109 struct PitchCandidatesParameters {
110   CandidatePitchPeriods pitch_candidates;
111   AvailableCpuFeatures cpu_features;
112 };
113 
114 class PitchCandidatesParametrization
115     : public ::testing::TestWithParam<PitchCandidatesParameters> {};
116 
117 // Checks that the result of `ComputePitchPeriod48kHz()` does not depend on the
118 // order of the input pitch candidates.
TEST_P(PitchCandidatesParametrization,ComputePitchPeriod48kHzOrderDoesNotMatter)119 TEST_P(PitchCandidatesParametrization,
120        ComputePitchPeriod48kHzOrderDoesNotMatter) {
121   const PitchCandidatesParameters params = GetParam();
122   const CandidatePitchPeriods swapped_pitch_candidates{
123       params.pitch_candidates.second_best, params.pitch_candidates.best};
124 
125   PitchTestData test_data;
126   std::vector<float> y_energy(kRefineNumLags24kHz);
127   rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
128                                                            kRefineNumLags24kHz);
129   ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(),
130                                          y_energy_view, params.cpu_features);
131   EXPECT_EQ(
132       ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view,
133                               params.pitch_candidates, params.cpu_features),
134       ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view,
135                               swapped_pitch_candidates, params.cpu_features));
136 }
137 
CreatePitchCandidatesParameters()138 std::vector<PitchCandidatesParameters> CreatePitchCandidatesParameters() {
139   std::vector<PitchCandidatesParameters> v;
140   for (AvailableCpuFeatures cpu_features : GetCpuFeaturesToTest()) {
141     v.push_back({{0, 2}, cpu_features});
142     v.push_back({{260, 284}, cpu_features});
143     v.push_back({{280, 284}, cpu_features});
144     v.push_back(
145         {{kInitialNumLags24kHz - 2, kInitialNumLags24kHz - 1}, cpu_features});
146   }
147   return v;
148 }
149 
150 INSTANTIATE_TEST_SUITE_P(
151     RnnVadTest,
152     PitchCandidatesParametrization,
153     ::testing::ValuesIn(CreatePitchCandidatesParameters()),
154     PrintTestIndexAndCpuFeatures<PitchCandidatesParameters>);
155 
156 struct ExtendedPitchPeriodSearchParameters {
157   int initial_pitch_period;
158   PitchInfo last_pitch;
159   PitchInfo expected_pitch;
160   AvailableCpuFeatures cpu_features;
161 };
162 
163 class ExtendedPitchPeriodSearchParametrizaion
164     : public ::testing::TestWithParam<ExtendedPitchPeriodSearchParameters> {};
165 
166 // Checks that the computed pitch period is bit-exact and that the computed
167 // pitch strength is within tolerance given test input data.
TEST_P(ExtendedPitchPeriodSearchParametrizaion,PeriodBitExactnessGainWithinTolerance)168 TEST_P(ExtendedPitchPeriodSearchParametrizaion,
169        PeriodBitExactnessGainWithinTolerance) {
170   const ExtendedPitchPeriodSearchParameters params = GetParam();
171 
172   PitchTestData test_data;
173   std::vector<float> y_energy(kRefineNumLags24kHz);
174   rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
175                                                            kRefineNumLags24kHz);
176   ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(),
177                                          y_energy_view, params.cpu_features);
178   // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
179   // FloatingPointExceptionObserver fpe_observer;
180   const auto computed_output = ComputeExtendedPitchPeriod48kHz(
181       test_data.PitchBuffer24kHzView(), y_energy_view,
182       params.initial_pitch_period, params.last_pitch, params.cpu_features);
183   EXPECT_EQ(params.expected_pitch.period, computed_output.period);
184   EXPECT_NEAR(params.expected_pitch.strength, computed_output.strength, 1e-6f);
185 }
186 
187 std::vector<ExtendedPitchPeriodSearchParameters>
CreateExtendedPitchPeriodSearchParameters()188 CreateExtendedPitchPeriodSearchParameters() {
189   std::vector<ExtendedPitchPeriodSearchParameters> v;
190   for (AvailableCpuFeatures cpu_features : GetCpuFeaturesToTest()) {
191     for (int last_pitch_period :
192          {kTestPitchPeriodsLow, kTestPitchPeriodsHigh}) {
193       for (float last_pitch_strength :
194            {kTestPitchStrengthLow, kTestPitchStrengthHigh}) {
195         v.push_back({kTestPitchPeriodsLow,
196                      {last_pitch_period, last_pitch_strength},
197                      {91, -0.0188608f},
198                      cpu_features});
199         v.push_back({kTestPitchPeriodsHigh,
200                      {last_pitch_period, last_pitch_strength},
201                      {475, -0.0904344f},
202                      cpu_features});
203       }
204     }
205   }
206   return v;
207 }
208 
209 INSTANTIATE_TEST_SUITE_P(
210     RnnVadTest,
211     ExtendedPitchPeriodSearchParametrizaion,
212     ::testing::ValuesIn(CreateExtendedPitchPeriodSearchParameters()),
213     PrintTestIndexAndCpuFeatures<ExtendedPitchPeriodSearchParameters>);
214 
215 }  // namespace
216 }  // namespace rnn_vad
217 }  // namespace webrtc
218