1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2#
3# Use of this source code is governed by a BSD-style license
4# that can be found in the LICENSE file in the root of the source
5# tree. An additional intellectual property rights grant can be found
6# in the file PATENTS.  All contributing project authors may
7# be found in the AUTHORS file in the root of the source tree.
8"""Unit tests for the annotations module.
9"""
10
11from __future__ import division
12import logging
13import os
14import shutil
15import tempfile
16import unittest
17
18import numpy as np
19
20from . import annotations
21from . import external_vad
22from . import input_signal_creator
23from . import signal_processing
24
25
26class TestAnnotationsExtraction(unittest.TestCase):
27    """Unit tests for the annotations module.
28  """
29
30    _CLEAN_TMP_OUTPUT = True
31    _DEBUG_PLOT_VAD = False
32    _VAD_TYPE_CLASS = annotations.AudioAnnotationsExtractor.VadType
33    _ALL_VAD_TYPES = (_VAD_TYPE_CLASS.ENERGY_THRESHOLD
34                      | _VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO
35                      | _VAD_TYPE_CLASS.WEBRTC_APM)
36
37    def setUp(self):
38        """Create temporary folder."""
39        self._tmp_path = tempfile.mkdtemp()
40        self._wav_file_path = os.path.join(self._tmp_path, 'tone.wav')
41        pure_tone, _ = input_signal_creator.InputSignalCreator.Create(
42            'pure_tone', [440, 1000])
43        signal_processing.SignalProcessingUtils.SaveWav(
44            self._wav_file_path, pure_tone)
45        self._sample_rate = pure_tone.frame_rate
46
47    def tearDown(self):
48        """Recursively delete temporary folder."""
49        if self._CLEAN_TMP_OUTPUT:
50            shutil.rmtree(self._tmp_path)
51        else:
52            logging.warning(self.id() + ' did not clean the temporary path ' +
53                            (self._tmp_path))
54
55    def testFrameSizes(self):
56        e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
57        e.Extract(self._wav_file_path)
58        samples_to_ms = lambda n, sr: 1000 * n // sr
59        self.assertEqual(
60            samples_to_ms(e.GetLevelFrameSize(), self._sample_rate),
61            e.GetLevelFrameSizeMs())
62        self.assertEqual(samples_to_ms(e.GetVadFrameSize(), self._sample_rate),
63                         e.GetVadFrameSizeMs())
64
65    def testVoiceActivityDetectors(self):
66        for vad_type_value in range(0, self._ALL_VAD_TYPES + 1):
67            vad_type = self._VAD_TYPE_CLASS(vad_type_value)
68            e = annotations.AudioAnnotationsExtractor(vad_type=vad_type_value)
69            e.Extract(self._wav_file_path)
70            if vad_type.Contains(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD):
71                # pylint: disable=unpacking-non-sequence
72                vad_output = e.GetVadOutput(
73                    self._VAD_TYPE_CLASS.ENERGY_THRESHOLD)
74                self.assertGreater(len(vad_output), 0)
75                self.assertGreaterEqual(
76                    float(np.sum(vad_output)) / len(vad_output), 0.95)
77
78            if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO):
79                # pylint: disable=unpacking-non-sequence
80                vad_output = e.GetVadOutput(
81                    self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO)
82                self.assertGreater(len(vad_output), 0)
83                self.assertGreaterEqual(
84                    float(np.sum(vad_output)) / len(vad_output), 0.95)
85
86            if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_APM):
87                # pylint: disable=unpacking-non-sequence
88                (vad_probs,
89                 vad_rms) = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)
90                self.assertGreater(len(vad_probs), 0)
91                self.assertGreater(len(vad_rms), 0)
92                self.assertGreaterEqual(
93                    float(np.sum(vad_probs)) / len(vad_probs), 0.5)
94                self.assertGreaterEqual(
95                    float(np.sum(vad_rms)) / len(vad_rms), 20000)
96
97            if self._DEBUG_PLOT_VAD:
98                frame_times_s = lambda num_frames, frame_size_ms: np.arange(
99                    num_frames).astype(np.float32) * frame_size_ms / 1000.0
100                level = e.GetLevel()
101                t_level = frame_times_s(num_frames=len(level),
102                                        frame_size_ms=e.GetLevelFrameSizeMs())
103                t_vad = frame_times_s(num_frames=len(vad_output),
104                                      frame_size_ms=e.GetVadFrameSizeMs())
105                import matplotlib.pyplot as plt
106                plt.figure()
107                plt.hold(True)
108                plt.plot(t_level, level)
109                plt.plot(t_vad, vad_output * np.max(level), '.')
110                plt.show()
111
112    def testSaveLoad(self):
113        e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
114        e.Extract(self._wav_file_path)
115        e.Save(self._tmp_path, "fake-annotation")
116
117        data = np.load(
118            os.path.join(
119                self._tmp_path,
120                e.GetOutputFileNameTemplate().format("fake-annotation")))
121        np.testing.assert_array_equal(e.GetLevel(), data['level'])
122        self.assertEqual(np.float32, data['level'].dtype)
123        np.testing.assert_array_equal(
124            e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD),
125            data['vad_energy_output'])
126        np.testing.assert_array_equal(
127            e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO),
128            data['vad_output'])
129        np.testing.assert_array_equal(
130            e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[0],
131            data['vad_probs'])
132        np.testing.assert_array_equal(
133            e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[1],
134            data['vad_rms'])
135        self.assertEqual(np.uint8, data['vad_energy_output'].dtype)
136        self.assertEqual(np.float64, data['vad_probs'].dtype)
137        self.assertEqual(np.float64, data['vad_rms'].dtype)
138
139    def testEmptyExternalShouldNotCrash(self):
140        for vad_type_value in range(0, self._ALL_VAD_TYPES + 1):
141            annotations.AudioAnnotationsExtractor(vad_type_value, {})
142
143    def testFakeExternalSaveLoad(self):
144        def FakeExternalFactory():
145            return external_vad.ExternalVad(
146                os.path.join(os.path.dirname(os.path.abspath(__file__)),
147                             'fake_external_vad.py'), 'fake')
148
149        for vad_type_value in range(0, self._ALL_VAD_TYPES + 1):
150            e = annotations.AudioAnnotationsExtractor(
151                vad_type_value, {'fake': FakeExternalFactory()})
152            e.Extract(self._wav_file_path)
153            e.Save(self._tmp_path, annotation_name="fake-annotation")
154            data = np.load(
155                os.path.join(
156                    self._tmp_path,
157                    e.GetOutputFileNameTemplate().format("fake-annotation")))
158            self.assertEqual(np.float32, data['extvad_conf-fake'].dtype)
159            np.testing.assert_almost_equal(np.arange(100, dtype=np.float32),
160                                           data['extvad_conf-fake'])
161