1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 2# 3# Use of this source code is governed by a BSD-style license 4# that can be found in the LICENSE file in the root of the source 5# tree. An additional intellectual property rights grant can be found 6# in the file PATENTS. All contributing project authors may 7# be found in the AUTHORS file in the root of the source tree. 8"""Unit tests for the annotations module. 9""" 10 11from __future__ import division 12import logging 13import os 14import shutil 15import tempfile 16import unittest 17 18import numpy as np 19 20from . import annotations 21from . import external_vad 22from . import input_signal_creator 23from . import signal_processing 24 25 26class TestAnnotationsExtraction(unittest.TestCase): 27 """Unit tests for the annotations module. 28 """ 29 30 _CLEAN_TMP_OUTPUT = True 31 _DEBUG_PLOT_VAD = False 32 _VAD_TYPE_CLASS = annotations.AudioAnnotationsExtractor.VadType 33 _ALL_VAD_TYPES = (_VAD_TYPE_CLASS.ENERGY_THRESHOLD 34 | _VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO 35 | _VAD_TYPE_CLASS.WEBRTC_APM) 36 37 def setUp(self): 38 """Create temporary folder.""" 39 self._tmp_path = tempfile.mkdtemp() 40 self._wav_file_path = os.path.join(self._tmp_path, 'tone.wav') 41 pure_tone, _ = input_signal_creator.InputSignalCreator.Create( 42 'pure_tone', [440, 1000]) 43 signal_processing.SignalProcessingUtils.SaveWav( 44 self._wav_file_path, pure_tone) 45 self._sample_rate = pure_tone.frame_rate 46 47 def tearDown(self): 48 """Recursively delete temporary folder.""" 49 if self._CLEAN_TMP_OUTPUT: 50 shutil.rmtree(self._tmp_path) 51 else: 52 logging.warning(self.id() + ' did not clean the temporary path ' + 53 (self._tmp_path)) 54 55 def testFrameSizes(self): 56 e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES) 57 e.Extract(self._wav_file_path) 58 samples_to_ms = lambda n, sr: 1000 * n // sr 59 self.assertEqual( 60 samples_to_ms(e.GetLevelFrameSize(), self._sample_rate), 61 e.GetLevelFrameSizeMs()) 62 self.assertEqual(samples_to_ms(e.GetVadFrameSize(), self._sample_rate), 63 e.GetVadFrameSizeMs()) 64 65 def testVoiceActivityDetectors(self): 66 for vad_type_value in range(0, self._ALL_VAD_TYPES + 1): 67 vad_type = self._VAD_TYPE_CLASS(vad_type_value) 68 e = annotations.AudioAnnotationsExtractor(vad_type=vad_type_value) 69 e.Extract(self._wav_file_path) 70 if vad_type.Contains(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD): 71 # pylint: disable=unpacking-non-sequence 72 vad_output = e.GetVadOutput( 73 self._VAD_TYPE_CLASS.ENERGY_THRESHOLD) 74 self.assertGreater(len(vad_output), 0) 75 self.assertGreaterEqual( 76 float(np.sum(vad_output)) / len(vad_output), 0.95) 77 78 if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO): 79 # pylint: disable=unpacking-non-sequence 80 vad_output = e.GetVadOutput( 81 self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO) 82 self.assertGreater(len(vad_output), 0) 83 self.assertGreaterEqual( 84 float(np.sum(vad_output)) / len(vad_output), 0.95) 85 86 if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_APM): 87 # pylint: disable=unpacking-non-sequence 88 (vad_probs, 89 vad_rms) = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM) 90 self.assertGreater(len(vad_probs), 0) 91 self.assertGreater(len(vad_rms), 0) 92 self.assertGreaterEqual( 93 float(np.sum(vad_probs)) / len(vad_probs), 0.5) 94 self.assertGreaterEqual( 95 float(np.sum(vad_rms)) / len(vad_rms), 20000) 96 97 if self._DEBUG_PLOT_VAD: 98 frame_times_s = lambda num_frames, frame_size_ms: np.arange( 99 num_frames).astype(np.float32) * frame_size_ms / 1000.0 100 level = e.GetLevel() 101 t_level = frame_times_s(num_frames=len(level), 102 frame_size_ms=e.GetLevelFrameSizeMs()) 103 t_vad = frame_times_s(num_frames=len(vad_output), 104 frame_size_ms=e.GetVadFrameSizeMs()) 105 import matplotlib.pyplot as plt 106 plt.figure() 107 plt.hold(True) 108 plt.plot(t_level, level) 109 plt.plot(t_vad, vad_output * np.max(level), '.') 110 plt.show() 111 112 def testSaveLoad(self): 113 e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES) 114 e.Extract(self._wav_file_path) 115 e.Save(self._tmp_path, "fake-annotation") 116 117 data = np.load( 118 os.path.join( 119 self._tmp_path, 120 e.GetOutputFileNameTemplate().format("fake-annotation"))) 121 np.testing.assert_array_equal(e.GetLevel(), data['level']) 122 self.assertEqual(np.float32, data['level'].dtype) 123 np.testing.assert_array_equal( 124 e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD), 125 data['vad_energy_output']) 126 np.testing.assert_array_equal( 127 e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO), 128 data['vad_output']) 129 np.testing.assert_array_equal( 130 e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[0], 131 data['vad_probs']) 132 np.testing.assert_array_equal( 133 e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[1], 134 data['vad_rms']) 135 self.assertEqual(np.uint8, data['vad_energy_output'].dtype) 136 self.assertEqual(np.float64, data['vad_probs'].dtype) 137 self.assertEqual(np.float64, data['vad_rms'].dtype) 138 139 def testEmptyExternalShouldNotCrash(self): 140 for vad_type_value in range(0, self._ALL_VAD_TYPES + 1): 141 annotations.AudioAnnotationsExtractor(vad_type_value, {}) 142 143 def testFakeExternalSaveLoad(self): 144 def FakeExternalFactory(): 145 return external_vad.ExternalVad( 146 os.path.join(os.path.dirname(os.path.abspath(__file__)), 147 'fake_external_vad.py'), 'fake') 148 149 for vad_type_value in range(0, self._ALL_VAD_TYPES + 1): 150 e = annotations.AudioAnnotationsExtractor( 151 vad_type_value, {'fake': FakeExternalFactory()}) 152 e.Extract(self._wav_file_path) 153 e.Save(self._tmp_path, annotation_name="fake-annotation") 154 data = np.load( 155 os.path.join( 156 self._tmp_path, 157 e.GetOutputFileNameTemplate().format("fake-annotation"))) 158 self.assertEqual(np.float32, data['extvad_conf-fake'].dtype) 159 np.testing.assert_almost_equal(np.arange(100, dtype=np.float32), 160 data['extvad_conf-fake']) 161