1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2#
3# Use of this source code is governed by a BSD-style license
4# that can be found in the LICENSE file in the root of the source
5# tree. An additional intellectual property rights grant can be found
6# in the file PATENTS.  All contributing project authors may
7# be found in the AUTHORS file in the root of the source tree.
8
9"""Unit tests for the annotations module.
10"""
11
12from __future__ import division
13import logging
14import os
15import shutil
16import tempfile
17import unittest
18
19import numpy as np
20
21from . import annotations
22from . import external_vad
23from . import input_signal_creator
24from . import signal_processing
25
26
27class TestAnnotationsExtraction(unittest.TestCase):
28  """Unit tests for the annotations module.
29  """
30
31  _CLEAN_TMP_OUTPUT = True
32  _DEBUG_PLOT_VAD = False
33  _VAD_TYPE_CLASS = annotations.AudioAnnotationsExtractor.VadType
34  _ALL_VAD_TYPES = (_VAD_TYPE_CLASS.ENERGY_THRESHOLD |
35                   _VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO |
36                   _VAD_TYPE_CLASS.WEBRTC_APM)
37
38
39  def setUp(self):
40    """Create temporary folder."""
41    self._tmp_path = tempfile.mkdtemp()
42    self._wav_file_path = os.path.join(self._tmp_path, 'tone.wav')
43    pure_tone, _ = input_signal_creator.InputSignalCreator.Create(
44        'pure_tone', [440, 1000])
45    signal_processing.SignalProcessingUtils.SaveWav(
46        self._wav_file_path, pure_tone)
47    self._sample_rate = pure_tone.frame_rate
48
49  def tearDown(self):
50    """Recursively delete temporary folder."""
51    if self._CLEAN_TMP_OUTPUT:
52      shutil.rmtree(self._tmp_path)
53    else:
54      logging.warning(self.id() + ' did not clean the temporary path ' + (
55          self._tmp_path))
56
57  def testFrameSizes(self):
58    e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
59    e.Extract(self._wav_file_path)
60    samples_to_ms = lambda n, sr: 1000 * n // sr
61    self.assertEqual(samples_to_ms(e.GetLevelFrameSize(), self._sample_rate),
62                     e.GetLevelFrameSizeMs())
63    self.assertEqual(samples_to_ms(e.GetVadFrameSize(), self._sample_rate),
64                     e.GetVadFrameSizeMs())
65
66  def testVoiceActivityDetectors(self):
67    for vad_type_value in range(0, self._ALL_VAD_TYPES+1):
68      vad_type = self._VAD_TYPE_CLASS(vad_type_value)
69      e = annotations.AudioAnnotationsExtractor(vad_type=vad_type_value)
70      e.Extract(self._wav_file_path)
71      if vad_type.Contains(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD):
72        # pylint: disable=unpacking-non-sequence
73        vad_output = e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD)
74        self.assertGreater(len(vad_output), 0)
75        self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output),
76                                0.95)
77
78      if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO):
79        # pylint: disable=unpacking-non-sequence
80        vad_output = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO)
81        self.assertGreater(len(vad_output), 0)
82        self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output),
83                                0.95)
84
85      if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_APM):
86        # pylint: disable=unpacking-non-sequence
87        (vad_probs, vad_rms) = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)
88        self.assertGreater(len(vad_probs), 0)
89        self.assertGreater(len(vad_rms), 0)
90        self.assertGreaterEqual(float(np.sum(vad_probs)) / len(vad_probs),
91                                0.5)
92        self.assertGreaterEqual(float(np.sum(vad_rms)) / len(vad_rms), 20000)
93
94      if self._DEBUG_PLOT_VAD:
95        frame_times_s = lambda num_frames, frame_size_ms: np.arange(
96            num_frames).astype(np.float32) * frame_size_ms / 1000.0
97        level = e.GetLevel()
98        t_level = frame_times_s(
99            num_frames=len(level),
100            frame_size_ms=e.GetLevelFrameSizeMs())
101        t_vad = frame_times_s(
102            num_frames=len(vad_output),
103            frame_size_ms=e.GetVadFrameSizeMs())
104        import matplotlib.pyplot as plt
105        plt.figure()
106        plt.hold(True)
107        plt.plot(t_level, level)
108        plt.plot(t_vad, vad_output * np.max(level), '.')
109        plt.show()
110
111  def testSaveLoad(self):
112    e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
113    e.Extract(self._wav_file_path)
114    e.Save(self._tmp_path, "fake-annotation")
115
116    data = np.load(os.path.join(
117        self._tmp_path,
118        e.GetOutputFileNameTemplate().format("fake-annotation")))
119    np.testing.assert_array_equal(e.GetLevel(), data['level'])
120    self.assertEqual(np.float32, data['level'].dtype)
121    np.testing.assert_array_equal(
122        e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD),
123        data['vad_energy_output'])
124    np.testing.assert_array_equal(
125        e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO),
126        data['vad_output'])
127    np.testing.assert_array_equal(
128        e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[0], data['vad_probs'])
129    np.testing.assert_array_equal(
130        e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[1], data['vad_rms'])
131    self.assertEqual(np.uint8, data['vad_energy_output'].dtype)
132    self.assertEqual(np.float64, data['vad_probs'].dtype)
133    self.assertEqual(np.float64, data['vad_rms'].dtype)
134
135  def testEmptyExternalShouldNotCrash(self):
136    for vad_type_value in range(0, self._ALL_VAD_TYPES+1):
137      annotations.AudioAnnotationsExtractor(vad_type_value, {})
138
139  def testFakeExternalSaveLoad(self):
140    def FakeExternalFactory():
141      return external_vad.ExternalVad(
142        os.path.join(
143            os.path.dirname(os.path.abspath(__file__)), 'fake_external_vad.py'),
144        'fake'
145      )
146    for vad_type_value in range(0, self._ALL_VAD_TYPES+1):
147      e = annotations.AudioAnnotationsExtractor(
148          vad_type_value,
149          {'fake': FakeExternalFactory()})
150      e.Extract(self._wav_file_path)
151      e.Save(self._tmp_path, annotation_name="fake-annotation")
152      data = np.load(os.path.join(
153          self._tmp_path,
154          e.GetOutputFileNameTemplate().format("fake-annotation")))
155      self.assertEqual(np.float32, data['extvad_conf-fake'].dtype)
156      np.testing.assert_almost_equal(np.arange(100, dtype=np.float32),
157                                     data['extvad_conf-fake'])
158