1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 2# 3# Use of this source code is governed by a BSD-style license 4# that can be found in the LICENSE file in the root of the source 5# tree. An additional intellectual property rights grant can be found 6# in the file PATENTS. All contributing project authors may 7# be found in the AUTHORS file in the root of the source tree. 8"""Unit tests for the simulation module. 9""" 10 11import logging 12import os 13import shutil 14import tempfile 15import unittest 16 17import mock 18import pydub 19 20from . import audioproc_wrapper 21from . import eval_scores_factory 22from . import evaluation 23from . import external_vad 24from . import signal_processing 25from . import simulation 26from . import test_data_generation_factory 27 28 29class TestApmModuleSimulator(unittest.TestCase): 30 """Unit tests for the ApmModuleSimulator class. 31 """ 32 33 def setUp(self): 34 """Create temporary folders and fake audio track.""" 35 self._output_path = tempfile.mkdtemp() 36 self._tmp_path = tempfile.mkdtemp() 37 38 silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000) 39 fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise( 40 silence) 41 self._fake_audio_track_path = os.path.join(self._output_path, 42 'fake.wav') 43 signal_processing.SignalProcessingUtils.SaveWav( 44 self._fake_audio_track_path, fake_signal) 45 46 def tearDown(self): 47 """Recursively delete temporary folders.""" 48 shutil.rmtree(self._output_path) 49 shutil.rmtree(self._tmp_path) 50 51 def testSimulation(self): 52 # Instance dependencies to mock and inject. 53 ap_wrapper = audioproc_wrapper.AudioProcWrapper( 54 audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH) 55 evaluator = evaluation.ApmModuleEvaluator() 56 ap_wrapper.Run = mock.MagicMock(name='Run') 57 evaluator.Run = mock.MagicMock(name='Run') 58 59 # Instance non-mocked dependencies. 60 test_data_generator_factory = ( 61 test_data_generation_factory.TestDataGeneratorFactory( 62 aechen_ir_database_path='', 63 noise_tracks_path='', 64 copy_with_identity=False)) 65 evaluation_score_factory = eval_scores_factory.EvaluationScoreWorkerFactory( 66 polqa_tool_bin_path=os.path.join(os.path.dirname(__file__), 67 'fake_polqa'), 68 echo_metric_tool_bin_path=None) 69 70 # Instance simulator. 71 simulator = simulation.ApmModuleSimulator( 72 test_data_generator_factory=test_data_generator_factory, 73 evaluation_score_factory=evaluation_score_factory, 74 ap_wrapper=ap_wrapper, 75 evaluator=evaluator, 76 external_vads={ 77 'fake': 78 external_vad.ExternalVad( 79 os.path.join(os.path.dirname(__file__), 80 'fake_external_vad.py'), 'fake') 81 }) 82 83 # What to simulate. 84 config_files = ['apm_configs/default.json'] 85 input_files = [self._fake_audio_track_path] 86 test_data_generators = ['identity', 'white_noise'] 87 eval_scores = ['audio_level_mean', 'polqa'] 88 89 # Run all simulations. 90 simulator.Run(config_filepaths=config_files, 91 capture_input_filepaths=input_files, 92 test_data_generator_names=test_data_generators, 93 eval_score_names=eval_scores, 94 output_dir=self._output_path) 95 96 # Check. 97 # TODO(alessiob): Once the TestDataGenerator classes can be configured by 98 # the client code (e.g., number of SNR pairs for the white noise test data 99 # generator), the exact number of calls to ap_wrapper.Run and evaluator.Run 100 # is known; use that with assertEqual. 101 min_number_of_simulations = len(config_files) * len(input_files) * len( 102 test_data_generators) 103 self.assertGreaterEqual(len(ap_wrapper.Run.call_args_list), 104 min_number_of_simulations) 105 self.assertGreaterEqual(len(evaluator.Run.call_args_list), 106 min_number_of_simulations) 107 108 def testInputSignalCreation(self): 109 # Instance simulator. 110 simulator = simulation.ApmModuleSimulator( 111 test_data_generator_factory=( 112 test_data_generation_factory.TestDataGeneratorFactory( 113 aechen_ir_database_path='', 114 noise_tracks_path='', 115 copy_with_identity=False)), 116 evaluation_score_factory=( 117 eval_scores_factory.EvaluationScoreWorkerFactory( 118 polqa_tool_bin_path=os.path.join(os.path.dirname(__file__), 119 'fake_polqa'), 120 echo_metric_tool_bin_path=None)), 121 ap_wrapper=audioproc_wrapper.AudioProcWrapper( 122 audioproc_wrapper.AudioProcWrapper. 123 DEFAULT_APM_SIMULATOR_BIN_PATH), 124 evaluator=evaluation.ApmModuleEvaluator()) 125 126 # Inexistent input files to be silently created. 127 input_files = [ 128 os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'), 129 os.path.join(self._tmp_path, 'pure_tone-1000_500.wav'), 130 ] 131 self.assertFalse( 132 any([os.path.exists(input_file) for input_file in (input_files)])) 133 134 # The input files are created during the simulation. 135 simulator.Run(config_filepaths=['apm_configs/default.json'], 136 capture_input_filepaths=input_files, 137 test_data_generator_names=['identity'], 138 eval_score_names=['audio_level_peak'], 139 output_dir=self._output_path) 140 self.assertTrue( 141 all([os.path.exists(input_file) for input_file in (input_files)])) 142 143 def testPureToneGenerationWithTotalHarmonicDistorsion(self): 144 logging.warning = mock.MagicMock(name='warning') 145 146 # Instance simulator. 147 simulator = simulation.ApmModuleSimulator( 148 test_data_generator_factory=( 149 test_data_generation_factory.TestDataGeneratorFactory( 150 aechen_ir_database_path='', 151 noise_tracks_path='', 152 copy_with_identity=False)), 153 evaluation_score_factory=( 154 eval_scores_factory.EvaluationScoreWorkerFactory( 155 polqa_tool_bin_path=os.path.join(os.path.dirname(__file__), 156 'fake_polqa'), 157 echo_metric_tool_bin_path=None)), 158 ap_wrapper=audioproc_wrapper.AudioProcWrapper( 159 audioproc_wrapper.AudioProcWrapper. 160 DEFAULT_APM_SIMULATOR_BIN_PATH), 161 evaluator=evaluation.ApmModuleEvaluator()) 162 163 # What to simulate. 164 config_files = ['apm_configs/default.json'] 165 input_files = [os.path.join(self._tmp_path, 'pure_tone-440_1000.wav')] 166 eval_scores = ['thd'] 167 168 # Should work. 169 simulator.Run(config_filepaths=config_files, 170 capture_input_filepaths=input_files, 171 test_data_generator_names=['identity'], 172 eval_score_names=eval_scores, 173 output_dir=self._output_path) 174 self.assertFalse(logging.warning.called) 175 176 # Warning expected. 177 simulator.Run( 178 config_filepaths=config_files, 179 capture_input_filepaths=input_files, 180 test_data_generator_names=['white_noise'], # Not allowed with THD. 181 eval_score_names=eval_scores, 182 output_dir=self._output_path) 183 logging.warning.assert_called_with('the evaluation failed: %s', ( 184 'The THD score cannot be used with any test data generator other than ' 185 '"identity"')) 186 187 # # Init. 188 # generator = test_data_generation.IdentityTestDataGenerator('tmp') 189 # input_signal_filepath = os.path.join( 190 # self._test_data_cache_path, 'pure_tone-440_1000.wav') 191 192 # # Check that the input signal is generated. 193 # self.assertFalse(os.path.exists(input_signal_filepath)) 194 # generator.Generate( 195 # input_signal_filepath=input_signal_filepath, 196 # test_data_cache_path=self._test_data_cache_path, 197 # base_output_path=self._base_output_path) 198 # self.assertTrue(os.path.exists(input_signal_filepath)) 199 200 # # Check input signal properties. 201 # input_signal = signal_processing.SignalProcessingUtils.LoadWav( 202 # input_signal_filepath) 203 # self.assertEqual(1000, len(input_signal)) 204