1import collections
2import contextlib
3import sys
4import wave
5
6import webrtcvad
7
8
9def read_wave(path):
10    with contextlib.closing(wave.open(path, 'rb')) as wf:
11        num_channels = wf.getnchannels()
12        assert num_channels == 1
13        sample_width = wf.getsampwidth()
14        assert sample_width == 2
15        sample_rate = wf.getframerate()
16        assert sample_rate in (8000, 16000, 32000)
17        pcm_data = wf.readframes(wf.getnframes())
18        return pcm_data, sample_rate
19
20
21def write_wave(path, audio, sample_rate):
22    with contextlib.closing(wave.open(path, 'wb')) as wf:
23        wf.setnchannels(1)
24        wf.setsampwidth(2)
25        wf.setframerate(sample_rate)
26        wf.writeframes(audio)
27
28
29class Frame(object):
30    def __init__(self, bytes, timestamp, duration):
31        self.bytes = bytes
32        self.timestamp = timestamp
33        self.duration = duration
34
35
36def frame_generator(frame_duration_ms, audio, sample_rate):
37    n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
38    offset = 0
39    timestamp = 0.0
40    duration = (float(n) / sample_rate) / 2.0
41    while offset + n < len(audio):
42        yield Frame(audio[offset:offset + n], timestamp, duration)
43        timestamp += duration
44        offset += n
45
46
47def vad_collector(sample_rate, frame_duration_ms,
48                  padding_duration_ms, vad, frames):
49    num_padding_frames = int(padding_duration_ms / frame_duration_ms)
50    ring_buffer = collections.deque(maxlen=num_padding_frames)
51    triggered = False
52    voiced_frames = []
53    for frame in frames:
54        sys.stdout.write(
55            '1' if vad.is_speech(frame.bytes, sample_rate) else '0')
56        if not triggered:
57            ring_buffer.append(frame)
58            num_voiced = len([f for f in ring_buffer
59                              if vad.is_speech(f.bytes, sample_rate)])
60            if num_voiced > 0.9 * ring_buffer.maxlen:
61                sys.stdout.write('+(%s)' % (ring_buffer[0].timestamp,))
62                triggered = True
63                voiced_frames.extend(ring_buffer)
64                ring_buffer.clear()
65        else:
66            voiced_frames.append(frame)
67            ring_buffer.append(frame)
68            num_unvoiced = len([f for f in ring_buffer
69                                if not vad.is_speech(f.bytes, sample_rate)])
70            if num_unvoiced > 0.9 * ring_buffer.maxlen:
71                sys.stdout.write('-(%s)' % (frame.timestamp + frame.duration))
72                triggered = False
73                yield b''.join([f.bytes for f in voiced_frames])
74                ring_buffer.clear()
75                voiced_frames = []
76    if triggered:
77        sys.stdout.write('-(%s)' % (frame.timestamp + frame.duration))
78    sys.stdout.write('\n')
79    if voiced_frames:
80        yield b''.join([f.bytes for f in voiced_frames])
81
82
83def main(args):
84    if len(args) != 2:
85        sys.stderr.write(
86            'Usage: example.py <aggressiveness> <path to wav file>\n')
87        sys.exit(1)
88    audio, sample_rate = read_wave(args[1])
89    vad = webrtcvad.Vad(int(args[0]))
90    frames = frame_generator(30, audio, sample_rate)
91    frames = list(frames)
92    segments = vad_collector(sample_rate, 30, 300, vad, frames)
93    for i, segment in enumerate(segments):
94        path = 'chunk-%002d.wav' % (i,)
95        print(' Writing %s' % (path,))
96        write_wave(path, segment, sample_rate)
97
98
99if __name__ == '__main__':
100    main(sys.argv[1:])
101