1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 2# 3# Use of this source code is governed by a BSD-style license 4# that can be found in the LICENSE file in the root of the source 5# tree. An additional intellectual property rights grant can be found 6# in the file PATENTS. All contributing project authors may 7# be found in the AUTHORS file in the root of the source tree. 8 9"""Unit tests for the test_data_generation module. 10""" 11 12import os 13import shutil 14import tempfile 15import unittest 16 17import numpy as np 18import scipy.io 19 20from . import test_data_generation 21from . import test_data_generation_factory 22from . import signal_processing 23 24 25class TestTestDataGenerators(unittest.TestCase): 26 """Unit tests for the test_data_generation module. 27 """ 28 29 def setUp(self): 30 """Create temporary folders.""" 31 self._base_output_path = tempfile.mkdtemp() 32 self._test_data_cache_path = tempfile.mkdtemp() 33 self._fake_air_db_path = tempfile.mkdtemp() 34 35 # Fake AIR DB impulse responses. 36 # TODO(alessiob): ReverberationTestDataGenerator will change to allow custom 37 # impulse responses. When changed, the coupling below between 38 # impulse_response_mat_file_names and 39 # ReverberationTestDataGenerator._IMPULSE_RESPONSES can be removed. 40 impulse_response_mat_file_names = [ 41 'air_binaural_lecture_0_0_1.mat', 42 'air_binaural_booth_0_0_1.mat', 43 ] 44 for impulse_response_mat_file_name in impulse_response_mat_file_names: 45 data = {'h_air': np.random.rand(1, 1000).astype('<f8')} 46 scipy.io.savemat(os.path.join( 47 self._fake_air_db_path, impulse_response_mat_file_name), data) 48 49 def tearDown(self): 50 """Recursively delete temporary folders.""" 51 shutil.rmtree(self._base_output_path) 52 shutil.rmtree(self._test_data_cache_path) 53 shutil.rmtree(self._fake_air_db_path) 54 55 def testTestDataGenerators(self): 56 # Preliminary check. 57 self.assertTrue(os.path.exists(self._base_output_path)) 58 self.assertTrue(os.path.exists(self._test_data_cache_path)) 59 60 # Check that there is at least one registered test data generator. 61 registered_classes = ( 62 test_data_generation.TestDataGenerator.REGISTERED_CLASSES) 63 self.assertIsInstance(registered_classes, dict) 64 self.assertGreater(len(registered_classes), 0) 65 66 # Instance generators factory. 67 generators_factory = test_data_generation_factory.TestDataGeneratorFactory( 68 aechen_ir_database_path=self._fake_air_db_path, 69 noise_tracks_path=test_data_generation. \ 70 AdditiveNoiseTestDataGenerator. \ 71 DEFAULT_NOISE_TRACKS_PATH, 72 copy_with_identity=False) 73 generators_factory.SetOutputDirectoryPrefix('datagen-') 74 75 # Use a simple input file as clean input signal. 76 input_signal_filepath = os.path.join( 77 os.getcwd(), 'probing_signals', 'tone-880.wav') 78 self.assertTrue(os.path.exists(input_signal_filepath)) 79 80 # Load input signal. 81 input_signal = signal_processing.SignalProcessingUtils.LoadWav( 82 input_signal_filepath) 83 84 # Try each registered test data generator. 85 for generator_name in registered_classes: 86 # Instance test data generator. 87 generator = generators_factory.GetInstance( 88 registered_classes[generator_name]) 89 90 # Generate the noisy input - reference pairs. 91 generator.Generate( 92 input_signal_filepath=input_signal_filepath, 93 test_data_cache_path=self._test_data_cache_path, 94 base_output_path=self._base_output_path) 95 96 # Perform checks. 97 self._CheckGeneratedPairsListSizes(generator) 98 self._CheckGeneratedPairsSignalDurations(generator, input_signal) 99 self._CheckGeneratedPairsOutputPaths(generator) 100 101 def testTestidentityDataGenerator(self): 102 # Preliminary check. 103 self.assertTrue(os.path.exists(self._base_output_path)) 104 self.assertTrue(os.path.exists(self._test_data_cache_path)) 105 106 # Use a simple input file as clean input signal. 107 input_signal_filepath = os.path.join( 108 os.getcwd(), 'probing_signals', 'tone-880.wav') 109 self.assertTrue(os.path.exists(input_signal_filepath)) 110 111 def GetNoiseReferenceFilePaths(identity_generator): 112 noisy_signal_filepaths = identity_generator.noisy_signal_filepaths 113 reference_signal_filepaths = identity_generator.reference_signal_filepaths 114 assert noisy_signal_filepaths.keys() == reference_signal_filepaths.keys() 115 assert len(noisy_signal_filepaths.keys()) == 1 116 key = noisy_signal_filepaths.keys()[0] 117 return noisy_signal_filepaths[key], reference_signal_filepaths[key] 118 119 # Test the |copy_with_identity| flag. 120 for copy_with_identity in [False, True]: 121 # Instance the generator through the factory. 122 factory = test_data_generation_factory.TestDataGeneratorFactory( 123 aechen_ir_database_path='', noise_tracks_path='', 124 copy_with_identity=copy_with_identity) 125 factory.SetOutputDirectoryPrefix('datagen-') 126 generator = factory.GetInstance( 127 test_data_generation.IdentityTestDataGenerator) 128 # Check |copy_with_identity| is set correctly. 129 self.assertEqual(copy_with_identity, generator.copy_with_identity) 130 131 # Generate test data and extract the paths to the noise and the reference 132 # files. 133 generator.Generate( 134 input_signal_filepath=input_signal_filepath, 135 test_data_cache_path=self._test_data_cache_path, 136 base_output_path=self._base_output_path) 137 noisy_signal_filepath, reference_signal_filepath = ( 138 GetNoiseReferenceFilePaths(generator)) 139 140 # Check that a copy is made if and only if |copy_with_identity| is True. 141 if copy_with_identity: 142 self.assertNotEqual(noisy_signal_filepath, input_signal_filepath) 143 self.assertNotEqual(reference_signal_filepath, input_signal_filepath) 144 else: 145 self.assertEqual(noisy_signal_filepath, input_signal_filepath) 146 self.assertEqual(reference_signal_filepath, input_signal_filepath) 147 148 def _CheckGeneratedPairsListSizes(self, generator): 149 config_names = generator.config_names 150 number_of_pairs = len(config_names) 151 self.assertEqual(number_of_pairs, 152 len(generator.noisy_signal_filepaths)) 153 self.assertEqual(number_of_pairs, 154 len(generator.apm_output_paths)) 155 self.assertEqual(number_of_pairs, 156 len(generator.reference_signal_filepaths)) 157 158 def _CheckGeneratedPairsSignalDurations( 159 self, generator, input_signal): 160 """Checks duration of the generated signals. 161 162 Checks that the noisy input and the reference tracks are audio files 163 with duration equal to or greater than that of the input signal. 164 165 Args: 166 generator: TestDataGenerator instance. 167 input_signal: AudioSegment instance. 168 """ 169 input_signal_length = ( 170 signal_processing.SignalProcessingUtils.CountSamples(input_signal)) 171 172 # Iterate over the noisy signal - reference pairs. 173 for config_name in generator.config_names: 174 # Load the noisy input file. 175 noisy_signal_filepath = generator.noisy_signal_filepaths[ 176 config_name] 177 noisy_signal = signal_processing.SignalProcessingUtils.LoadWav( 178 noisy_signal_filepath) 179 180 # Check noisy input signal length. 181 noisy_signal_length = ( 182 signal_processing.SignalProcessingUtils.CountSamples(noisy_signal)) 183 self.assertGreaterEqual(noisy_signal_length, input_signal_length) 184 185 # Load the reference file. 186 reference_signal_filepath = generator.reference_signal_filepaths[ 187 config_name] 188 reference_signal = signal_processing.SignalProcessingUtils.LoadWav( 189 reference_signal_filepath) 190 191 # Check noisy input signal length. 192 reference_signal_length = ( 193 signal_processing.SignalProcessingUtils.CountSamples( 194 reference_signal)) 195 self.assertGreaterEqual(reference_signal_length, input_signal_length) 196 197 def _CheckGeneratedPairsOutputPaths(self, generator): 198 """Checks that the output path created by the generator exists. 199 200 Args: 201 generator: TestDataGenerator instance. 202 """ 203 # Iterate over the noisy signal - reference pairs. 204 for config_name in generator.config_names: 205 output_path = generator.apm_output_paths[config_name] 206 self.assertTrue(os.path.exists(output_path)) 207