1 // Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2 //
3 // Use of this source code is governed by a BSD-style license
4 // that can be found in the LICENSE file in the root of the source
5 // tree. An additional intellectual property rights grant can be found
6 // in the file PATENTS.  All contributing project authors may
7 // be found in the AUTHORS file in the root of the source tree.
8 
9 #include "common_audio/vad/include/vad.h"
10 
11 #include <array>
12 #include <fstream>
13 #include <memory>
14 
15 #include "absl/flags/flag.h"
16 #include "absl/flags/parse.h"
17 #include "common_audio/wav_file.h"
18 #include "rtc_base/logging.h"
19 
20 ABSL_FLAG(std::string, i, "", "Input wav file");
21 ABSL_FLAG(std::string, o, "", "VAD output file");
22 
23 namespace webrtc {
24 namespace test {
25 namespace {
26 
27 // The allowed values are 10, 20 or 30 ms.
28 constexpr uint8_t kAudioFrameLengthMilliseconds = 30;
29 constexpr int kMaxSampleRate = 48000;
30 constexpr size_t kMaxFrameLen =
31     kAudioFrameLengthMilliseconds * kMaxSampleRate / 1000;
32 
33 constexpr uint8_t kBitmaskBuffSize = 8;
34 
main(int argc,char * argv[])35 int main(int argc, char* argv[]) {
36   absl::ParseCommandLine(argc, argv);
37   const std::string input_file = absl::GetFlag(FLAGS_i);
38   const std::string output_file = absl::GetFlag(FLAGS_o);
39   // Open wav input file and check properties.
40   WavReader wav_reader(input_file);
41   if (wav_reader.num_channels() != 1) {
42     RTC_LOG(LS_ERROR) << "Only mono wav files supported";
43     return 1;
44   }
45   if (wav_reader.sample_rate() > kMaxSampleRate) {
46     RTC_LOG(LS_ERROR) << "Beyond maximum sample rate (" << kMaxSampleRate
47                       << ")";
48     return 1;
49   }
50   const size_t audio_frame_length = rtc::CheckedDivExact(
51       kAudioFrameLengthMilliseconds * wav_reader.sample_rate(), 1000);
52   if (audio_frame_length > kMaxFrameLen) {
53     RTC_LOG(LS_ERROR) << "The frame size and/or the sample rate are too large.";
54     return 1;
55   }
56 
57   // Create output file and write header.
58   std::ofstream out_file(output_file, std::ofstream::binary);
59   const char audio_frame_length_ms = kAudioFrameLengthMilliseconds;
60   out_file.write(&audio_frame_length_ms, 1);  // Header.
61 
62   // Run VAD and write decisions.
63   std::unique_ptr<Vad> vad = CreateVad(Vad::Aggressiveness::kVadNormal);
64   std::array<int16_t, kMaxFrameLen> samples;
65   char buff = 0;     // Buffer to write one bit per frame.
66   uint8_t next = 0;  // Points to the next bit to write in |buff|.
67   while (true) {
68     // Process frame.
69     const auto read_samples =
70         wav_reader.ReadSamples(audio_frame_length, samples.data());
71     if (read_samples < audio_frame_length)
72       break;
73     const auto is_speech = vad->VoiceActivity(
74         samples.data(), audio_frame_length, wav_reader.sample_rate());
75 
76     // Write output.
77     buff = is_speech ? buff | (1 << next) : buff & ~(1 << next);
78     if (++next == kBitmaskBuffSize) {
79       out_file.write(&buff, 1);  // Flush.
80       buff = 0;                  // Reset.
81       next = 0;
82     }
83   }
84 
85   // Finalize.
86   char extra_bits = 0;
87   if (next > 0) {
88     extra_bits = kBitmaskBuffSize - next;
89     out_file.write(&buff, 1);  // Flush.
90   }
91   out_file.write(&extra_bits, 1);
92   out_file.close();
93 
94   return 0;
95 }
96 
97 }  // namespace
98 }  // namespace test
99 }  // namespace webrtc
100 
main(int argc,char * argv[])101 int main(int argc, char* argv[]) {
102   return webrtc::test::main(argc, argv);
103 }
104