1import unittest
2import wave
3from memory_profiler import memory_usage
4
5import webrtcvad
6
7
8class WebRtcVadTests(unittest.TestCase):
9    @staticmethod
10    def _load_wave(file_name):
11        fp = wave.open(file_name, 'rb')
12        try:
13            assert fp.getnchannels() == 1, (
14                '{0}: sound format is incorrect! Sound must be mono.'.format(
15                    file_name))
16            assert fp.getsampwidth() == 2, (
17                '{0}: sound format is incorrect! '
18                'Sample width of sound must be 2 bytes.').format(file_name)
19            assert fp.getframerate() in (8000, 16000, 32000), (
20                '{0}: sound format is incorrect! '
21                'Sampling frequency must be 8000 Hz, 16000 Hz or 32000 Hz.')
22            sampling_frequency = fp.getframerate()
23            sound_data = fp.readframes(fp.getnframes())
24        finally:
25            fp.close()
26            del fp
27        return sound_data, sampling_frequency
28
29    def test_constructor(self):
30        vad = webrtcvad.Vad()
31
32    def test_set_mode(self):
33        vad = webrtcvad.Vad()
34        vad.set_mode(0)
35        vad.set_mode(1)
36        vad.set_mode(2)
37        vad.set_mode(3)
38        self.assertRaises(
39            ValueError,
40            vad.set_mode, 4)
41
42    def test_valid_rate_and_frame_length(self):
43        self.assertTrue(webrtcvad.valid_rate_and_frame_length(8000, 160))
44        self.assertTrue(webrtcvad.valid_rate_and_frame_length(16000, 160))
45        self.assertFalse(webrtcvad.valid_rate_and_frame_length(32000, 160))
46        self.assertRaises(
47            ValueError,
48            webrtcvad.valid_rate_and_frame_length, 2 ** 35, 10)
49
50    def test_process_zeroes(self):
51        frame_len = 160
52        self.assertTrue(
53            webrtcvad.valid_rate_and_frame_length(8000, frame_len))
54        sample = b'\x00' * frame_len * 2
55        vad = webrtcvad.Vad()
56        self.assertFalse(vad.is_speech(sample, 16000))
57
58    def test_process_file(self):
59        with open('test-audio.raw', 'rb') as f:
60            data = f.read()
61        frame_ms = 30
62        n = int(8000 * 2 * 30 / 1000.0)
63        frame_len = int(n / 2)
64        self.assertTrue(
65            webrtcvad.valid_rate_and_frame_length(8000, frame_len))
66        chunks = list(data[pos:pos + n] for pos in range(0, len(data), n))
67        if len(chunks[-1]) != n:
68            chunks = chunks[:-1]
69        expecteds = [
70            '011110111111111111111111111100',
71            '011110111111111111111111111100',
72            '000000111111111111111111110000',
73            '000000111111111111111100000000'
74        ]
75        for mode in (0, 1, 2, 3):
76            vad = webrtcvad.Vad(mode)
77            result = ''
78            for chunk in chunks:
79                voiced = vad.is_speech(chunk, 8000)
80                result += '1' if voiced else '0'
81            self.assertEqual(expecteds[mode], result)
82
83    def test_leak(self):
84        sound, fs = self._load_wave('leak-test.wav')
85        frame_ms = 0.010
86        frame_len = int(round(fs * frame_ms))
87        n = int(len(sound) / (2 * frame_len))
88        nrepeats = 1000
89        vad = webrtcvad.Vad(3)
90        used_memory_before = memory_usage(-1)[0]
91        for counter in range(nrepeats):
92            find_voice = False
93            for frame_ind in range(n):
94                slice_start = (frame_ind * 2 * frame_len)
95                slice_end = ((frame_ind + 1) * 2 * frame_len)
96                if vad.is_speech(sound[slice_start:slice_end], fs):
97                    find_voice = True
98            self.assertTrue(find_voice)
99        used_memory_after = memory_usage(-1)[0]
100        self.assertGreaterEqual(
101            used_memory_before / 5.0,
102            used_memory_after - used_memory_before)
103
104
105if __name__ == '__main__':
106    unittest.main(verbosity=2)
107