1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2#
3# Use of this source code is governed by a BSD-style license
4# that can be found in the LICENSE file in the root of the source
5# tree. An additional intellectual property rights grant can be found
6# in the file PATENTS.  All contributing project authors may
7# be found in the AUTHORS file in the root of the source tree.
8"""Unit tests for the test_data_generation module.
9"""
10
11import os
12import shutil
13import tempfile
14import unittest
15
16import numpy as np
17import scipy.io
18
19from . import test_data_generation
20from . import test_data_generation_factory
21from . import signal_processing
22
23
24class TestTestDataGenerators(unittest.TestCase):
25    """Unit tests for the test_data_generation module.
26  """
27
28    def setUp(self):
29        """Create temporary folders."""
30        self._base_output_path = tempfile.mkdtemp()
31        self._test_data_cache_path = tempfile.mkdtemp()
32        self._fake_air_db_path = tempfile.mkdtemp()
33
34        # Fake AIR DB impulse responses.
35        # TODO(alessiob): ReverberationTestDataGenerator will change to allow custom
36        # impulse responses. When changed, the coupling below between
37        # impulse_response_mat_file_names and
38        # ReverberationTestDataGenerator._IMPULSE_RESPONSES can be removed.
39        impulse_response_mat_file_names = [
40            'air_binaural_lecture_0_0_1.mat',
41            'air_binaural_booth_0_0_1.mat',
42        ]
43        for impulse_response_mat_file_name in impulse_response_mat_file_names:
44            data = {'h_air': np.random.rand(1, 1000).astype('<f8')}
45            scipy.io.savemat(
46                os.path.join(self._fake_air_db_path,
47                             impulse_response_mat_file_name), data)
48
49    def tearDown(self):
50        """Recursively delete temporary folders."""
51        shutil.rmtree(self._base_output_path)
52        shutil.rmtree(self._test_data_cache_path)
53        shutil.rmtree(self._fake_air_db_path)
54
55    def testTestDataGenerators(self):
56        # Preliminary check.
57        self.assertTrue(os.path.exists(self._base_output_path))
58        self.assertTrue(os.path.exists(self._test_data_cache_path))
59
60        # Check that there is at least one registered test data generator.
61        registered_classes = (
62            test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
63        self.assertIsInstance(registered_classes, dict)
64        self.assertGreater(len(registered_classes), 0)
65
66        # Instance generators factory.
67        generators_factory = test_data_generation_factory.TestDataGeneratorFactory(
68            aechen_ir_database_path=self._fake_air_db_path,
69            noise_tracks_path=test_data_generation.  \
70                              AdditiveNoiseTestDataGenerator.  \
71                              DEFAULT_NOISE_TRACKS_PATH,
72            copy_with_identity=False)
73        generators_factory.SetOutputDirectoryPrefix('datagen-')
74
75        # Use a simple input file as clean input signal.
76        input_signal_filepath = os.path.join(os.getcwd(), 'probing_signals',
77                                             'tone-880.wav')
78        self.assertTrue(os.path.exists(input_signal_filepath))
79
80        # Load input signal.
81        input_signal = signal_processing.SignalProcessingUtils.LoadWav(
82            input_signal_filepath)
83
84        # Try each registered test data generator.
85        for generator_name in registered_classes:
86            # Instance test data generator.
87            generator = generators_factory.GetInstance(
88                registered_classes[generator_name])
89
90            # Generate the noisy input - reference pairs.
91            generator.Generate(input_signal_filepath=input_signal_filepath,
92                               test_data_cache_path=self._test_data_cache_path,
93                               base_output_path=self._base_output_path)
94
95            # Perform checks.
96            self._CheckGeneratedPairsListSizes(generator)
97            self._CheckGeneratedPairsSignalDurations(generator, input_signal)
98            self._CheckGeneratedPairsOutputPaths(generator)
99
100    def testTestidentityDataGenerator(self):
101        # Preliminary check.
102        self.assertTrue(os.path.exists(self._base_output_path))
103        self.assertTrue(os.path.exists(self._test_data_cache_path))
104
105        # Use a simple input file as clean input signal.
106        input_signal_filepath = os.path.join(os.getcwd(), 'probing_signals',
107                                             'tone-880.wav')
108        self.assertTrue(os.path.exists(input_signal_filepath))
109
110        def GetNoiseReferenceFilePaths(identity_generator):
111            noisy_signal_filepaths = identity_generator.noisy_signal_filepaths
112            reference_signal_filepaths = identity_generator.reference_signal_filepaths
113            assert noisy_signal_filepaths.keys(
114            ) == reference_signal_filepaths.keys()
115            assert len(noisy_signal_filepaths.keys()) == 1
116            key = noisy_signal_filepaths.keys()[0]
117            return noisy_signal_filepaths[key], reference_signal_filepaths[key]
118
119        # Test the |copy_with_identity| flag.
120        for copy_with_identity in [False, True]:
121            # Instance the generator through the factory.
122            factory = test_data_generation_factory.TestDataGeneratorFactory(
123                aechen_ir_database_path='',
124                noise_tracks_path='',
125                copy_with_identity=copy_with_identity)
126            factory.SetOutputDirectoryPrefix('datagen-')
127            generator = factory.GetInstance(
128                test_data_generation.IdentityTestDataGenerator)
129            # Check |copy_with_identity| is set correctly.
130            self.assertEqual(copy_with_identity, generator.copy_with_identity)
131
132            # Generate test data and extract the paths to the noise and the reference
133            # files.
134            generator.Generate(input_signal_filepath=input_signal_filepath,
135                               test_data_cache_path=self._test_data_cache_path,
136                               base_output_path=self._base_output_path)
137            noisy_signal_filepath, reference_signal_filepath = (
138                GetNoiseReferenceFilePaths(generator))
139
140            # Check that a copy is made if and only if |copy_with_identity| is True.
141            if copy_with_identity:
142                self.assertNotEqual(noisy_signal_filepath,
143                                    input_signal_filepath)
144                self.assertNotEqual(reference_signal_filepath,
145                                    input_signal_filepath)
146            else:
147                self.assertEqual(noisy_signal_filepath, input_signal_filepath)
148                self.assertEqual(reference_signal_filepath,
149                                 input_signal_filepath)
150
151    def _CheckGeneratedPairsListSizes(self, generator):
152        config_names = generator.config_names
153        number_of_pairs = len(config_names)
154        self.assertEqual(number_of_pairs,
155                         len(generator.noisy_signal_filepaths))
156        self.assertEqual(number_of_pairs, len(generator.apm_output_paths))
157        self.assertEqual(number_of_pairs,
158                         len(generator.reference_signal_filepaths))
159
160    def _CheckGeneratedPairsSignalDurations(self, generator, input_signal):
161        """Checks duration of the generated signals.
162
163    Checks that the noisy input and the reference tracks are audio files
164    with duration equal to or greater than that of the input signal.
165
166    Args:
167      generator: TestDataGenerator instance.
168      input_signal: AudioSegment instance.
169    """
170        input_signal_length = (
171            signal_processing.SignalProcessingUtils.CountSamples(input_signal))
172
173        # Iterate over the noisy signal - reference pairs.
174        for config_name in generator.config_names:
175            # Load the noisy input file.
176            noisy_signal_filepath = generator.noisy_signal_filepaths[
177                config_name]
178            noisy_signal = signal_processing.SignalProcessingUtils.LoadWav(
179                noisy_signal_filepath)
180
181            # Check noisy input signal length.
182            noisy_signal_length = (signal_processing.SignalProcessingUtils.
183                                   CountSamples(noisy_signal))
184            self.assertGreaterEqual(noisy_signal_length, input_signal_length)
185
186            # Load the reference file.
187            reference_signal_filepath = generator.reference_signal_filepaths[
188                config_name]
189            reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
190                reference_signal_filepath)
191
192            # Check noisy input signal length.
193            reference_signal_length = (signal_processing.SignalProcessingUtils.
194                                       CountSamples(reference_signal))
195            self.assertGreaterEqual(reference_signal_length,
196                                    input_signal_length)
197
198    def _CheckGeneratedPairsOutputPaths(self, generator):
199        """Checks that the output path created by the generator exists.
200
201    Args:
202      generator: TestDataGenerator instance.
203    """
204        # Iterate over the noisy signal - reference pairs.
205        for config_name in generator.config_names:
206            output_path = generator.apm_output_paths[config_name]
207            self.assertTrue(os.path.exists(output_path))
208