1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 2# 3# Use of this source code is governed by a BSD-style license 4# that can be found in the LICENSE file in the root of the source 5# tree. An additional intellectual property rights grant can be found 6# in the file PATENTS. All contributing project authors may 7# be found in the AUTHORS file in the root of the source tree. 8 9"""Unit tests for the simulation module. 10""" 11 12import logging 13import os 14import shutil 15import sys 16import tempfile 17import unittest 18 19SRC = os.path.abspath(os.path.join( 20 os.path.dirname((__file__)), os.pardir, os.pardir, os.pardir, os.pardir)) 21sys.path.append(os.path.join(SRC, 'third_party', 'pymock')) 22 23import mock 24import pydub 25 26from . import audioproc_wrapper 27from . import eval_scores_factory 28from . import evaluation 29from . import external_vad 30from . import signal_processing 31from . import simulation 32from . import test_data_generation_factory 33 34 35class TestApmModuleSimulator(unittest.TestCase): 36 """Unit tests for the ApmModuleSimulator class. 37 """ 38 39 def setUp(self): 40 """Create temporary folders and fake audio track.""" 41 self._output_path = tempfile.mkdtemp() 42 self._tmp_path = tempfile.mkdtemp() 43 44 silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000) 45 fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise( 46 silence) 47 self._fake_audio_track_path = os.path.join(self._output_path, 'fake.wav') 48 signal_processing.SignalProcessingUtils.SaveWav( 49 self._fake_audio_track_path, fake_signal) 50 51 def tearDown(self): 52 """Recursively delete temporary folders.""" 53 shutil.rmtree(self._output_path) 54 shutil.rmtree(self._tmp_path) 55 56 def testSimulation(self): 57 # Instance dependencies to mock and inject. 58 ap_wrapper = audioproc_wrapper.AudioProcWrapper( 59 audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH) 60 evaluator = evaluation.ApmModuleEvaluator() 61 ap_wrapper.Run = mock.MagicMock(name='Run') 62 evaluator.Run = mock.MagicMock(name='Run') 63 64 # Instance non-mocked dependencies. 65 test_data_generator_factory = ( 66 test_data_generation_factory.TestDataGeneratorFactory( 67 aechen_ir_database_path='', 68 noise_tracks_path='', 69 copy_with_identity=False)) 70 evaluation_score_factory = eval_scores_factory.EvaluationScoreWorkerFactory( 71 polqa_tool_bin_path=os.path.join( 72 os.path.dirname(__file__), 'fake_polqa'), 73 echo_metric_tool_bin_path=None 74 ) 75 76 # Instance simulator. 77 simulator = simulation.ApmModuleSimulator( 78 test_data_generator_factory=test_data_generator_factory, 79 evaluation_score_factory=evaluation_score_factory, 80 ap_wrapper=ap_wrapper, 81 evaluator=evaluator, 82 external_vads={'fake': external_vad.ExternalVad(os.path.join( 83 os.path.dirname(__file__), 'fake_external_vad.py'), 'fake')} 84 ) 85 86 # What to simulate. 87 config_files = ['apm_configs/default.json'] 88 input_files = [self._fake_audio_track_path] 89 test_data_generators = ['identity', 'white_noise'] 90 eval_scores = ['audio_level_mean', 'polqa'] 91 92 # Run all simulations. 93 simulator.Run( 94 config_filepaths=config_files, 95 capture_input_filepaths=input_files, 96 test_data_generator_names=test_data_generators, 97 eval_score_names=eval_scores, 98 output_dir=self._output_path) 99 100 # Check. 101 # TODO(alessiob): Once the TestDataGenerator classes can be configured by 102 # the client code (e.g., number of SNR pairs for the white noise test data 103 # generator), the exact number of calls to ap_wrapper.Run and evaluator.Run 104 # is known; use that with assertEqual. 105 min_number_of_simulations = len(config_files) * len(input_files) * len( 106 test_data_generators) 107 self.assertGreaterEqual(len(ap_wrapper.Run.call_args_list), 108 min_number_of_simulations) 109 self.assertGreaterEqual(len(evaluator.Run.call_args_list), 110 min_number_of_simulations) 111 112 def testInputSignalCreation(self): 113 # Instance simulator. 114 simulator = simulation.ApmModuleSimulator( 115 test_data_generator_factory=( 116 test_data_generation_factory.TestDataGeneratorFactory( 117 aechen_ir_database_path='', 118 noise_tracks_path='', 119 copy_with_identity=False)), 120 evaluation_score_factory=( 121 eval_scores_factory.EvaluationScoreWorkerFactory( 122 polqa_tool_bin_path=os.path.join( 123 os.path.dirname(__file__), 'fake_polqa'), 124 echo_metric_tool_bin_path=None 125 )), 126 ap_wrapper=audioproc_wrapper.AudioProcWrapper( 127 audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH), 128 evaluator=evaluation.ApmModuleEvaluator()) 129 130 # Inexistent input files to be silently created. 131 input_files = [ 132 os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'), 133 os.path.join(self._tmp_path, 'pure_tone-1000_500.wav'), 134 ] 135 self.assertFalse(any([os.path.exists(input_file) for input_file in ( 136 input_files)])) 137 138 # The input files are created during the simulation. 139 simulator.Run( 140 config_filepaths=['apm_configs/default.json'], 141 capture_input_filepaths=input_files, 142 test_data_generator_names=['identity'], 143 eval_score_names=['audio_level_peak'], 144 output_dir=self._output_path) 145 self.assertTrue(all([os.path.exists(input_file) for input_file in ( 146 input_files)])) 147 148 def testPureToneGenerationWithTotalHarmonicDistorsion(self): 149 logging.warning = mock.MagicMock(name='warning') 150 151 # Instance simulator. 152 simulator = simulation.ApmModuleSimulator( 153 test_data_generator_factory=( 154 test_data_generation_factory.TestDataGeneratorFactory( 155 aechen_ir_database_path='', 156 noise_tracks_path='', 157 copy_with_identity=False)), 158 evaluation_score_factory=( 159 eval_scores_factory.EvaluationScoreWorkerFactory( 160 polqa_tool_bin_path=os.path.join( 161 os.path.dirname(__file__), 'fake_polqa'), 162 echo_metric_tool_bin_path=None 163 )), 164 ap_wrapper=audioproc_wrapper.AudioProcWrapper( 165 audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH), 166 evaluator=evaluation.ApmModuleEvaluator()) 167 168 # What to simulate. 169 config_files = ['apm_configs/default.json'] 170 input_files = [os.path.join(self._tmp_path, 'pure_tone-440_1000.wav')] 171 eval_scores = ['thd'] 172 173 # Should work. 174 simulator.Run( 175 config_filepaths=config_files, 176 capture_input_filepaths=input_files, 177 test_data_generator_names=['identity'], 178 eval_score_names=eval_scores, 179 output_dir=self._output_path) 180 self.assertFalse(logging.warning.called) 181 182 # Warning expected. 183 simulator.Run( 184 config_filepaths=config_files, 185 capture_input_filepaths=input_files, 186 test_data_generator_names=['white_noise'], # Not allowed with THD. 187 eval_score_names=eval_scores, 188 output_dir=self._output_path) 189 logging.warning.assert_called_with('the evaluation failed: %s', ( 190 'The THD score cannot be used with any test data generator other than ' 191 '"identity"')) 192 193 # # Init. 194 # generator = test_data_generation.IdentityTestDataGenerator('tmp') 195 # input_signal_filepath = os.path.join( 196 # self._test_data_cache_path, 'pure_tone-440_1000.wav') 197 198 # # Check that the input signal is generated. 199 # self.assertFalse(os.path.exists(input_signal_filepath)) 200 # generator.Generate( 201 # input_signal_filepath=input_signal_filepath, 202 # test_data_cache_path=self._test_data_cache_path, 203 # base_output_path=self._base_output_path) 204 # self.assertTrue(os.path.exists(input_signal_filepath)) 205 206 # # Check input signal properties. 207 # input_signal = signal_processing.SignalProcessingUtils.LoadWav( 208 # input_signal_filepath) 209 # self.assertEqual(1000, len(input_signal)) 210