1 // Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2 //
3 // Use of this source code is governed by a BSD-style license
4 // that can be found in the LICENSE file in the root of the source
5 // tree. An additional intellectual property rights grant can be found
6 // in the file PATENTS.  All contributing project authors may
7 // be found in the AUTHORS file in the root of the source tree.
8 
9 #include <array>
10 #include <fstream>
11 #include <memory>
12 
13 #include "absl/flags/flag.h"
14 #include "absl/flags/parse.h"
15 #include "common_audio/wav_file.h"
16 #include "modules/audio_processing/vad/voice_activity_detector.h"
17 #include "rtc_base/logging.h"
18 
19 ABSL_FLAG(std::string, i, "", "Input wav file");
20 ABSL_FLAG(std::string, o_probs, "", "VAD probabilities output file");
21 ABSL_FLAG(std::string, o_rms, "", "VAD output file");
22 
23 namespace webrtc {
24 namespace test {
25 namespace {
26 
27 constexpr uint8_t kAudioFrameLengthMilliseconds = 10;
28 constexpr int kMaxSampleRate = 48000;
29 constexpr size_t kMaxFrameLen =
30     kAudioFrameLengthMilliseconds * kMaxSampleRate / 1000;
31 
main(int argc,char * argv[])32 int main(int argc, char* argv[]) {
33   absl::ParseCommandLine(argc, argv);
34   const std::string input_file = absl::GetFlag(FLAGS_i);
35   const std::string output_probs_file = absl::GetFlag(FLAGS_o_probs);
36   const std::string output_file = absl::GetFlag(FLAGS_o_rms);
37   // Open wav input file and check properties.
38   WavReader wav_reader(input_file);
39   if (wav_reader.num_channels() != 1) {
40     RTC_LOG(LS_ERROR) << "Only mono wav files supported";
41     return 1;
42   }
43   if (wav_reader.sample_rate() > kMaxSampleRate) {
44     RTC_LOG(LS_ERROR) << "Beyond maximum sample rate (" << kMaxSampleRate
45                       << ")";
46     return 1;
47   }
48   const size_t audio_frame_len = rtc::CheckedDivExact(
49       kAudioFrameLengthMilliseconds * wav_reader.sample_rate(), 1000);
50   if (audio_frame_len > kMaxFrameLen) {
51     RTC_LOG(LS_ERROR) << "The frame size and/or the sample rate are too large.";
52     return 1;
53   }
54 
55   // Create output file and write header.
56   std::ofstream out_probs_file(output_probs_file, std::ofstream::binary);
57   std::ofstream out_rms_file(output_file, std::ofstream::binary);
58 
59   // Run VAD and write decisions.
60   VoiceActivityDetector vad;
61   std::array<int16_t, kMaxFrameLen> samples;
62 
63   while (true) {
64     // Process frame.
65     const auto read_samples =
66         wav_reader.ReadSamples(audio_frame_len, samples.data());
67     if (read_samples < audio_frame_len) {
68       break;
69     }
70     vad.ProcessChunk(samples.data(), audio_frame_len, wav_reader.sample_rate());
71     // Write output.
72     auto probs = vad.chunkwise_voice_probabilities();
73     auto rms = vad.chunkwise_rms();
74     RTC_CHECK_EQ(probs.size(), rms.size());
75     RTC_CHECK_EQ(sizeof(double), 8);
76 
77     for (const auto& p : probs) {
78       out_probs_file.write(reinterpret_cast<const char*>(&p), 8);
79     }
80     for (const auto& r : rms) {
81       out_rms_file.write(reinterpret_cast<const char*>(&r), 8);
82     }
83   }
84 
85   out_probs_file.close();
86   out_rms_file.close();
87   return 0;
88 }
89 
90 }  // namespace
91 }  // namespace test
92 }  // namespace webrtc
93 
main(int argc,char * argv[])94 int main(int argc, char* argv[]) {
95   return webrtc::test::main(argc, argv);
96 }
97