1 /*
2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
12
13 #include <array>
14 #include <string>
15 #include <tuple>
16
17 #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
18 #include "rtc_base/strings/string_builder.h"
19 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
20 // #include "test/fpe_observer.h"
21 #include "test/gtest.h"
22
23 namespace webrtc {
24 namespace rnn_vad {
25 namespace {
26
27 constexpr int kTestPitchPeriodsLow = 3 * kMinPitch48kHz / 2;
28 constexpr int kTestPitchPeriodsHigh = (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2;
29
30 constexpr float kTestPitchStrengthLow = 0.35f;
31 constexpr float kTestPitchStrengthHigh = 0.75f;
32
33 template <class T>
PrintTestIndexAndCpuFeatures(const::testing::TestParamInfo<T> & info)34 std::string PrintTestIndexAndCpuFeatures(
35 const ::testing::TestParamInfo<T>& info) {
36 rtc::StringBuilder builder;
37 builder << info.index << "_" << info.param.cpu_features.ToString();
38 return builder.str();
39 }
40
41 // Finds the relevant CPU features combinations to test.
GetCpuFeaturesToTest()42 std::vector<AvailableCpuFeatures> GetCpuFeaturesToTest() {
43 std::vector<AvailableCpuFeatures> v;
44 v.push_back(NoAvailableCpuFeatures());
45 AvailableCpuFeatures available = GetAvailableCpuFeatures();
46 if (available.avx2) {
47 v.push_back({/*sse2=*/false, /*avx2=*/true, /*neon=*/false});
48 }
49 if (available.sse2) {
50 v.push_back({/*sse2=*/true, /*avx2=*/false, /*neon=*/false});
51 }
52 return v;
53 }
54
55 // Checks that the frame-wise sliding square energy function produces output
56 // within tolerance given test input data.
TEST(RnnVadTest,ComputeSlidingFrameSquareEnergies24kHzWithinTolerance)57 TEST(RnnVadTest, ComputeSlidingFrameSquareEnergies24kHzWithinTolerance) {
58 const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
59
60 PitchTestData test_data;
61 std::array<float, kRefineNumLags24kHz> computed_output;
62 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
63 // FloatingPointExceptionObserver fpe_observer;
64 ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(),
65 computed_output, cpu_features);
66 auto square_energies_view = test_data.SquareEnergies24kHzView();
67 ExpectNearAbsolute({square_energies_view.data(), square_energies_view.size()},
68 computed_output, 1e-3f);
69 }
70
71 // Checks that the estimated pitch period is bit-exact given test input data.
TEST(RnnVadTest,ComputePitchPeriod12kHzBitExactness)72 TEST(RnnVadTest, ComputePitchPeriod12kHzBitExactness) {
73 const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
74
75 PitchTestData test_data;
76 std::array<float, kBufSize12kHz> pitch_buf_decimated;
77 Decimate2x(test_data.PitchBuffer24kHzView(), pitch_buf_decimated);
78 CandidatePitchPeriods pitch_candidates;
79 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
80 // FloatingPointExceptionObserver fpe_observer;
81 pitch_candidates = ComputePitchPeriod12kHz(
82 pitch_buf_decimated, test_data.AutoCorrelation12kHzView(), cpu_features);
83 EXPECT_EQ(pitch_candidates.best, 140);
84 EXPECT_EQ(pitch_candidates.second_best, 142);
85 }
86
87 // Checks that the refined pitch period is bit-exact given test input data.
TEST(RnnVadTest,ComputePitchPeriod48kHzBitExactness)88 TEST(RnnVadTest, ComputePitchPeriod48kHzBitExactness) {
89 const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
90
91 PitchTestData test_data;
92 std::vector<float> y_energy(kRefineNumLags24kHz);
93 rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
94 kRefineNumLags24kHz);
95 ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(),
96 y_energy_view, cpu_features);
97 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
98 // FloatingPointExceptionObserver fpe_observer;
99 EXPECT_EQ(
100 ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view,
101 /*pitch_candidates=*/{280, 284}, cpu_features),
102 560);
103 EXPECT_EQ(
104 ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view,
105 /*pitch_candidates=*/{260, 284}, cpu_features),
106 568);
107 }
108
109 struct PitchCandidatesParameters {
110 CandidatePitchPeriods pitch_candidates;
111 AvailableCpuFeatures cpu_features;
112 };
113
114 class PitchCandidatesParametrization
115 : public ::testing::TestWithParam<PitchCandidatesParameters> {};
116
117 // Checks that the result of `ComputePitchPeriod48kHz()` does not depend on the
118 // order of the input pitch candidates.
TEST_P(PitchCandidatesParametrization,ComputePitchPeriod48kHzOrderDoesNotMatter)119 TEST_P(PitchCandidatesParametrization,
120 ComputePitchPeriod48kHzOrderDoesNotMatter) {
121 const PitchCandidatesParameters params = GetParam();
122 const CandidatePitchPeriods swapped_pitch_candidates{
123 params.pitch_candidates.second_best, params.pitch_candidates.best};
124
125 PitchTestData test_data;
126 std::vector<float> y_energy(kRefineNumLags24kHz);
127 rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
128 kRefineNumLags24kHz);
129 ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(),
130 y_energy_view, params.cpu_features);
131 EXPECT_EQ(
132 ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view,
133 params.pitch_candidates, params.cpu_features),
134 ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view,
135 swapped_pitch_candidates, params.cpu_features));
136 }
137
CreatePitchCandidatesParameters()138 std::vector<PitchCandidatesParameters> CreatePitchCandidatesParameters() {
139 std::vector<PitchCandidatesParameters> v;
140 for (AvailableCpuFeatures cpu_features : GetCpuFeaturesToTest()) {
141 v.push_back({{0, 2}, cpu_features});
142 v.push_back({{260, 284}, cpu_features});
143 v.push_back({{280, 284}, cpu_features});
144 v.push_back(
145 {{kInitialNumLags24kHz - 2, kInitialNumLags24kHz - 1}, cpu_features});
146 }
147 return v;
148 }
149
150 INSTANTIATE_TEST_SUITE_P(
151 RnnVadTest,
152 PitchCandidatesParametrization,
153 ::testing::ValuesIn(CreatePitchCandidatesParameters()),
154 PrintTestIndexAndCpuFeatures<PitchCandidatesParameters>);
155
156 struct ExtendedPitchPeriodSearchParameters {
157 int initial_pitch_period;
158 PitchInfo last_pitch;
159 PitchInfo expected_pitch;
160 AvailableCpuFeatures cpu_features;
161 };
162
163 class ExtendedPitchPeriodSearchParametrizaion
164 : public ::testing::TestWithParam<ExtendedPitchPeriodSearchParameters> {};
165
166 // Checks that the computed pitch period is bit-exact and that the computed
167 // pitch strength is within tolerance given test input data.
TEST_P(ExtendedPitchPeriodSearchParametrizaion,PeriodBitExactnessGainWithinTolerance)168 TEST_P(ExtendedPitchPeriodSearchParametrizaion,
169 PeriodBitExactnessGainWithinTolerance) {
170 const ExtendedPitchPeriodSearchParameters params = GetParam();
171
172 PitchTestData test_data;
173 std::vector<float> y_energy(kRefineNumLags24kHz);
174 rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
175 kRefineNumLags24kHz);
176 ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(),
177 y_energy_view, params.cpu_features);
178 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
179 // FloatingPointExceptionObserver fpe_observer;
180 const auto computed_output = ComputeExtendedPitchPeriod48kHz(
181 test_data.PitchBuffer24kHzView(), y_energy_view,
182 params.initial_pitch_period, params.last_pitch, params.cpu_features);
183 EXPECT_EQ(params.expected_pitch.period, computed_output.period);
184 EXPECT_NEAR(params.expected_pitch.strength, computed_output.strength, 1e-6f);
185 }
186
187 std::vector<ExtendedPitchPeriodSearchParameters>
CreateExtendedPitchPeriodSearchParameters()188 CreateExtendedPitchPeriodSearchParameters() {
189 std::vector<ExtendedPitchPeriodSearchParameters> v;
190 for (AvailableCpuFeatures cpu_features : GetCpuFeaturesToTest()) {
191 for (int last_pitch_period :
192 {kTestPitchPeriodsLow, kTestPitchPeriodsHigh}) {
193 for (float last_pitch_strength :
194 {kTestPitchStrengthLow, kTestPitchStrengthHigh}) {
195 v.push_back({kTestPitchPeriodsLow,
196 {last_pitch_period, last_pitch_strength},
197 {91, -0.0188608f},
198 cpu_features});
199 v.push_back({kTestPitchPeriodsHigh,
200 {last_pitch_period, last_pitch_strength},
201 {475, -0.0904344f},
202 cpu_features});
203 }
204 }
205 }
206 return v;
207 }
208
209 INSTANTIATE_TEST_SUITE_P(
210 RnnVadTest,
211 ExtendedPitchPeriodSearchParametrizaion,
212 ::testing::ValuesIn(CreateExtendedPitchPeriodSearchParameters()),
213 PrintTestIndexAndCpuFeatures<ExtendedPitchPeriodSearchParameters>);
214
215 } // namespace
216 } // namespace rnn_vad
217 } // namespace webrtc
218