1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 2# 3# Use of this source code is governed by a BSD-style license 4# that can be found in the LICENSE file in the root of the source 5# tree. An additional intellectual property rights grant can be found 6# in the file PATENTS. All contributing project authors may 7# be found in the AUTHORS file in the root of the source tree. 8"""Unit tests for the test_data_generation module. 9""" 10 11import os 12import shutil 13import tempfile 14import unittest 15 16import numpy as np 17import scipy.io 18 19from . import test_data_generation 20from . import test_data_generation_factory 21from . import signal_processing 22 23 24class TestTestDataGenerators(unittest.TestCase): 25 """Unit tests for the test_data_generation module. 26 """ 27 28 def setUp(self): 29 """Create temporary folders.""" 30 self._base_output_path = tempfile.mkdtemp() 31 self._test_data_cache_path = tempfile.mkdtemp() 32 self._fake_air_db_path = tempfile.mkdtemp() 33 34 # Fake AIR DB impulse responses. 35 # TODO(alessiob): ReverberationTestDataGenerator will change to allow custom 36 # impulse responses. When changed, the coupling below between 37 # impulse_response_mat_file_names and 38 # ReverberationTestDataGenerator._IMPULSE_RESPONSES can be removed. 39 impulse_response_mat_file_names = [ 40 'air_binaural_lecture_0_0_1.mat', 41 'air_binaural_booth_0_0_1.mat', 42 ] 43 for impulse_response_mat_file_name in impulse_response_mat_file_names: 44 data = {'h_air': np.random.rand(1, 1000).astype('<f8')} 45 scipy.io.savemat( 46 os.path.join(self._fake_air_db_path, 47 impulse_response_mat_file_name), data) 48 49 def tearDown(self): 50 """Recursively delete temporary folders.""" 51 shutil.rmtree(self._base_output_path) 52 shutil.rmtree(self._test_data_cache_path) 53 shutil.rmtree(self._fake_air_db_path) 54 55 def testTestDataGenerators(self): 56 # Preliminary check. 57 self.assertTrue(os.path.exists(self._base_output_path)) 58 self.assertTrue(os.path.exists(self._test_data_cache_path)) 59 60 # Check that there is at least one registered test data generator. 61 registered_classes = ( 62 test_data_generation.TestDataGenerator.REGISTERED_CLASSES) 63 self.assertIsInstance(registered_classes, dict) 64 self.assertGreater(len(registered_classes), 0) 65 66 # Instance generators factory. 67 generators_factory = test_data_generation_factory.TestDataGeneratorFactory( 68 aechen_ir_database_path=self._fake_air_db_path, 69 noise_tracks_path=test_data_generation. \ 70 AdditiveNoiseTestDataGenerator. \ 71 DEFAULT_NOISE_TRACKS_PATH, 72 copy_with_identity=False) 73 generators_factory.SetOutputDirectoryPrefix('datagen-') 74 75 # Use a simple input file as clean input signal. 76 input_signal_filepath = os.path.join(os.getcwd(), 'probing_signals', 77 'tone-880.wav') 78 self.assertTrue(os.path.exists(input_signal_filepath)) 79 80 # Load input signal. 81 input_signal = signal_processing.SignalProcessingUtils.LoadWav( 82 input_signal_filepath) 83 84 # Try each registered test data generator. 85 for generator_name in registered_classes: 86 # Instance test data generator. 87 generator = generators_factory.GetInstance( 88 registered_classes[generator_name]) 89 90 # Generate the noisy input - reference pairs. 91 generator.Generate(input_signal_filepath=input_signal_filepath, 92 test_data_cache_path=self._test_data_cache_path, 93 base_output_path=self._base_output_path) 94 95 # Perform checks. 96 self._CheckGeneratedPairsListSizes(generator) 97 self._CheckGeneratedPairsSignalDurations(generator, input_signal) 98 self._CheckGeneratedPairsOutputPaths(generator) 99 100 def testTestidentityDataGenerator(self): 101 # Preliminary check. 102 self.assertTrue(os.path.exists(self._base_output_path)) 103 self.assertTrue(os.path.exists(self._test_data_cache_path)) 104 105 # Use a simple input file as clean input signal. 106 input_signal_filepath = os.path.join(os.getcwd(), 'probing_signals', 107 'tone-880.wav') 108 self.assertTrue(os.path.exists(input_signal_filepath)) 109 110 def GetNoiseReferenceFilePaths(identity_generator): 111 noisy_signal_filepaths = identity_generator.noisy_signal_filepaths 112 reference_signal_filepaths = identity_generator.reference_signal_filepaths 113 assert noisy_signal_filepaths.keys( 114 ) == reference_signal_filepaths.keys() 115 assert len(noisy_signal_filepaths.keys()) == 1 116 key = noisy_signal_filepaths.keys()[0] 117 return noisy_signal_filepaths[key], reference_signal_filepaths[key] 118 119 # Test the |copy_with_identity| flag. 120 for copy_with_identity in [False, True]: 121 # Instance the generator through the factory. 122 factory = test_data_generation_factory.TestDataGeneratorFactory( 123 aechen_ir_database_path='', 124 noise_tracks_path='', 125 copy_with_identity=copy_with_identity) 126 factory.SetOutputDirectoryPrefix('datagen-') 127 generator = factory.GetInstance( 128 test_data_generation.IdentityTestDataGenerator) 129 # Check |copy_with_identity| is set correctly. 130 self.assertEqual(copy_with_identity, generator.copy_with_identity) 131 132 # Generate test data and extract the paths to the noise and the reference 133 # files. 134 generator.Generate(input_signal_filepath=input_signal_filepath, 135 test_data_cache_path=self._test_data_cache_path, 136 base_output_path=self._base_output_path) 137 noisy_signal_filepath, reference_signal_filepath = ( 138 GetNoiseReferenceFilePaths(generator)) 139 140 # Check that a copy is made if and only if |copy_with_identity| is True. 141 if copy_with_identity: 142 self.assertNotEqual(noisy_signal_filepath, 143 input_signal_filepath) 144 self.assertNotEqual(reference_signal_filepath, 145 input_signal_filepath) 146 else: 147 self.assertEqual(noisy_signal_filepath, input_signal_filepath) 148 self.assertEqual(reference_signal_filepath, 149 input_signal_filepath) 150 151 def _CheckGeneratedPairsListSizes(self, generator): 152 config_names = generator.config_names 153 number_of_pairs = len(config_names) 154 self.assertEqual(number_of_pairs, 155 len(generator.noisy_signal_filepaths)) 156 self.assertEqual(number_of_pairs, len(generator.apm_output_paths)) 157 self.assertEqual(number_of_pairs, 158 len(generator.reference_signal_filepaths)) 159 160 def _CheckGeneratedPairsSignalDurations(self, generator, input_signal): 161 """Checks duration of the generated signals. 162 163 Checks that the noisy input and the reference tracks are audio files 164 with duration equal to or greater than that of the input signal. 165 166 Args: 167 generator: TestDataGenerator instance. 168 input_signal: AudioSegment instance. 169 """ 170 input_signal_length = ( 171 signal_processing.SignalProcessingUtils.CountSamples(input_signal)) 172 173 # Iterate over the noisy signal - reference pairs. 174 for config_name in generator.config_names: 175 # Load the noisy input file. 176 noisy_signal_filepath = generator.noisy_signal_filepaths[ 177 config_name] 178 noisy_signal = signal_processing.SignalProcessingUtils.LoadWav( 179 noisy_signal_filepath) 180 181 # Check noisy input signal length. 182 noisy_signal_length = (signal_processing.SignalProcessingUtils. 183 CountSamples(noisy_signal)) 184 self.assertGreaterEqual(noisy_signal_length, input_signal_length) 185 186 # Load the reference file. 187 reference_signal_filepath = generator.reference_signal_filepaths[ 188 config_name] 189 reference_signal = signal_processing.SignalProcessingUtils.LoadWav( 190 reference_signal_filepath) 191 192 # Check noisy input signal length. 193 reference_signal_length = (signal_processing.SignalProcessingUtils. 194 CountSamples(reference_signal)) 195 self.assertGreaterEqual(reference_signal_length, 196 input_signal_length) 197 198 def _CheckGeneratedPairsOutputPaths(self, generator): 199 """Checks that the output path created by the generator exists. 200 201 Args: 202 generator: TestDataGenerator instance. 203 """ 204 # Iterate over the noisy signal - reference pairs. 205 for config_name in generator.config_names: 206 output_path = generator.apm_output_paths[config_name] 207 self.assertTrue(os.path.exists(output_path)) 208