1import collections 2import contextlib 3import sys 4import wave 5 6import webrtcvad 7 8 9def read_wave(path): 10 with contextlib.closing(wave.open(path, 'rb')) as wf: 11 num_channels = wf.getnchannels() 12 assert num_channels == 1 13 sample_width = wf.getsampwidth() 14 assert sample_width == 2 15 sample_rate = wf.getframerate() 16 assert sample_rate in (8000, 16000, 32000) 17 pcm_data = wf.readframes(wf.getnframes()) 18 return pcm_data, sample_rate 19 20 21def write_wave(path, audio, sample_rate): 22 with contextlib.closing(wave.open(path, 'wb')) as wf: 23 wf.setnchannels(1) 24 wf.setsampwidth(2) 25 wf.setframerate(sample_rate) 26 wf.writeframes(audio) 27 28 29class Frame(object): 30 def __init__(self, bytes, timestamp, duration): 31 self.bytes = bytes 32 self.timestamp = timestamp 33 self.duration = duration 34 35 36def frame_generator(frame_duration_ms, audio, sample_rate): 37 n = int(sample_rate * (frame_duration_ms / 1000.0) * 2) 38 offset = 0 39 timestamp = 0.0 40 duration = (float(n) / sample_rate) / 2.0 41 while offset + n < len(audio): 42 yield Frame(audio[offset:offset + n], timestamp, duration) 43 timestamp += duration 44 offset += n 45 46 47def vad_collector(sample_rate, frame_duration_ms, 48 padding_duration_ms, vad, frames): 49 num_padding_frames = int(padding_duration_ms / frame_duration_ms) 50 ring_buffer = collections.deque(maxlen=num_padding_frames) 51 triggered = False 52 voiced_frames = [] 53 for frame in frames: 54 sys.stdout.write( 55 '1' if vad.is_speech(frame.bytes, sample_rate) else '0') 56 if not triggered: 57 ring_buffer.append(frame) 58 num_voiced = len([f for f in ring_buffer 59 if vad.is_speech(f.bytes, sample_rate)]) 60 if num_voiced > 0.9 * ring_buffer.maxlen: 61 sys.stdout.write('+(%s)' % (ring_buffer[0].timestamp,)) 62 triggered = True 63 voiced_frames.extend(ring_buffer) 64 ring_buffer.clear() 65 else: 66 voiced_frames.append(frame) 67 ring_buffer.append(frame) 68 num_unvoiced = len([f for f in ring_buffer 69 if not vad.is_speech(f.bytes, sample_rate)]) 70 if num_unvoiced > 0.9 * ring_buffer.maxlen: 71 sys.stdout.write('-(%s)' % (frame.timestamp + frame.duration)) 72 triggered = False 73 yield b''.join([f.bytes for f in voiced_frames]) 74 ring_buffer.clear() 75 voiced_frames = [] 76 if triggered: 77 sys.stdout.write('-(%s)' % (frame.timestamp + frame.duration)) 78 sys.stdout.write('\n') 79 if voiced_frames: 80 yield b''.join([f.bytes for f in voiced_frames]) 81 82 83def main(args): 84 if len(args) != 2: 85 sys.stderr.write( 86 'Usage: example.py <aggressiveness> <path to wav file>\n') 87 sys.exit(1) 88 audio, sample_rate = read_wave(args[1]) 89 vad = webrtcvad.Vad(int(args[0])) 90 frames = frame_generator(30, audio, sample_rate) 91 frames = list(frames) 92 segments = vad_collector(sample_rate, 30, 300, vad, frames) 93 for i, segment in enumerate(segments): 94 path = 'chunk-%002d.wav' % (i,) 95 print(' Writing %s' % (path,)) 96 write_wave(path, segment, sample_rate) 97 98 99if __name__ == '__main__': 100 main(sys.argv[1:]) 101