1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 2# 3# Use of this source code is governed by a BSD-style license 4# that can be found in the LICENSE file in the root of the source 5# tree. An additional intellectual property rights grant can be found 6# in the file PATENTS. All contributing project authors may 7# be found in the AUTHORS file in the root of the source tree. 8 9"""Unit tests for the annotations module. 10""" 11 12from __future__ import division 13import logging 14import os 15import shutil 16import tempfile 17import unittest 18 19import numpy as np 20 21from . import annotations 22from . import external_vad 23from . import input_signal_creator 24from . import signal_processing 25 26 27class TestAnnotationsExtraction(unittest.TestCase): 28 """Unit tests for the annotations module. 29 """ 30 31 _CLEAN_TMP_OUTPUT = True 32 _DEBUG_PLOT_VAD = False 33 _VAD_TYPE_CLASS = annotations.AudioAnnotationsExtractor.VadType 34 _ALL_VAD_TYPES = (_VAD_TYPE_CLASS.ENERGY_THRESHOLD | 35 _VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO | 36 _VAD_TYPE_CLASS.WEBRTC_APM) 37 38 39 def setUp(self): 40 """Create temporary folder.""" 41 self._tmp_path = tempfile.mkdtemp() 42 self._wav_file_path = os.path.join(self._tmp_path, 'tone.wav') 43 pure_tone, _ = input_signal_creator.InputSignalCreator.Create( 44 'pure_tone', [440, 1000]) 45 signal_processing.SignalProcessingUtils.SaveWav( 46 self._wav_file_path, pure_tone) 47 self._sample_rate = pure_tone.frame_rate 48 49 def tearDown(self): 50 """Recursively delete temporary folder.""" 51 if self._CLEAN_TMP_OUTPUT: 52 shutil.rmtree(self._tmp_path) 53 else: 54 logging.warning(self.id() + ' did not clean the temporary path ' + ( 55 self._tmp_path)) 56 57 def testFrameSizes(self): 58 e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES) 59 e.Extract(self._wav_file_path) 60 samples_to_ms = lambda n, sr: 1000 * n // sr 61 self.assertEqual(samples_to_ms(e.GetLevelFrameSize(), self._sample_rate), 62 e.GetLevelFrameSizeMs()) 63 self.assertEqual(samples_to_ms(e.GetVadFrameSize(), self._sample_rate), 64 e.GetVadFrameSizeMs()) 65 66 def testVoiceActivityDetectors(self): 67 for vad_type_value in range(0, self._ALL_VAD_TYPES+1): 68 vad_type = self._VAD_TYPE_CLASS(vad_type_value) 69 e = annotations.AudioAnnotationsExtractor(vad_type=vad_type_value) 70 e.Extract(self._wav_file_path) 71 if vad_type.Contains(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD): 72 # pylint: disable=unpacking-non-sequence 73 vad_output = e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD) 74 self.assertGreater(len(vad_output), 0) 75 self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output), 76 0.95) 77 78 if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO): 79 # pylint: disable=unpacking-non-sequence 80 vad_output = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO) 81 self.assertGreater(len(vad_output), 0) 82 self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output), 83 0.95) 84 85 if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_APM): 86 # pylint: disable=unpacking-non-sequence 87 (vad_probs, vad_rms) = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM) 88 self.assertGreater(len(vad_probs), 0) 89 self.assertGreater(len(vad_rms), 0) 90 self.assertGreaterEqual(float(np.sum(vad_probs)) / len(vad_probs), 91 0.5) 92 self.assertGreaterEqual(float(np.sum(vad_rms)) / len(vad_rms), 20000) 93 94 if self._DEBUG_PLOT_VAD: 95 frame_times_s = lambda num_frames, frame_size_ms: np.arange( 96 num_frames).astype(np.float32) * frame_size_ms / 1000.0 97 level = e.GetLevel() 98 t_level = frame_times_s( 99 num_frames=len(level), 100 frame_size_ms=e.GetLevelFrameSizeMs()) 101 t_vad = frame_times_s( 102 num_frames=len(vad_output), 103 frame_size_ms=e.GetVadFrameSizeMs()) 104 import matplotlib.pyplot as plt 105 plt.figure() 106 plt.hold(True) 107 plt.plot(t_level, level) 108 plt.plot(t_vad, vad_output * np.max(level), '.') 109 plt.show() 110 111 def testSaveLoad(self): 112 e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES) 113 e.Extract(self._wav_file_path) 114 e.Save(self._tmp_path, "fake-annotation") 115 116 data = np.load(os.path.join( 117 self._tmp_path, 118 e.GetOutputFileNameTemplate().format("fake-annotation"))) 119 np.testing.assert_array_equal(e.GetLevel(), data['level']) 120 self.assertEqual(np.float32, data['level'].dtype) 121 np.testing.assert_array_equal( 122 e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD), 123 data['vad_energy_output']) 124 np.testing.assert_array_equal( 125 e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO), 126 data['vad_output']) 127 np.testing.assert_array_equal( 128 e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[0], data['vad_probs']) 129 np.testing.assert_array_equal( 130 e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[1], data['vad_rms']) 131 self.assertEqual(np.uint8, data['vad_energy_output'].dtype) 132 self.assertEqual(np.float64, data['vad_probs'].dtype) 133 self.assertEqual(np.float64, data['vad_rms'].dtype) 134 135 def testEmptyExternalShouldNotCrash(self): 136 for vad_type_value in range(0, self._ALL_VAD_TYPES+1): 137 annotations.AudioAnnotationsExtractor(vad_type_value, {}) 138 139 def testFakeExternalSaveLoad(self): 140 def FakeExternalFactory(): 141 return external_vad.ExternalVad( 142 os.path.join( 143 os.path.dirname(os.path.abspath(__file__)), 'fake_external_vad.py'), 144 'fake' 145 ) 146 for vad_type_value in range(0, self._ALL_VAD_TYPES+1): 147 e = annotations.AudioAnnotationsExtractor( 148 vad_type_value, 149 {'fake': FakeExternalFactory()}) 150 e.Extract(self._wav_file_path) 151 e.Save(self._tmp_path, annotation_name="fake-annotation") 152 data = np.load(os.path.join( 153 self._tmp_path, 154 e.GetOutputFileNameTemplate().format("fake-annotation"))) 155 self.assertEqual(np.float32, data['extvad_conf-fake'].dtype) 156 np.testing.assert_almost_equal(np.arange(100, dtype=np.float32), 157 data['extvad_conf-fake']) 158