1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2#
3# Use of this source code is governed by a BSD-style license
4# that can be found in the LICENSE file in the root of the source
5# tree. An additional intellectual property rights grant can be found
6# in the file PATENTS.  All contributing project authors may
7# be found in the AUTHORS file in the root of the source tree.
8
9"""Unit tests for the test_data_generation module.
10"""
11
12import os
13import shutil
14import tempfile
15import unittest
16
17import numpy as np
18import scipy.io
19
20from . import test_data_generation
21from . import test_data_generation_factory
22from . import signal_processing
23
24
25class TestTestDataGenerators(unittest.TestCase):
26  """Unit tests for the test_data_generation module.
27  """
28
29  def setUp(self):
30    """Create temporary folders."""
31    self._base_output_path = tempfile.mkdtemp()
32    self._test_data_cache_path = tempfile.mkdtemp()
33    self._fake_air_db_path = tempfile.mkdtemp()
34
35    # Fake AIR DB impulse responses.
36    # TODO(alessiob): ReverberationTestDataGenerator will change to allow custom
37    # impulse responses. When changed, the coupling below between
38    # impulse_response_mat_file_names and
39    # ReverberationTestDataGenerator._IMPULSE_RESPONSES can be removed.
40    impulse_response_mat_file_names = [
41        'air_binaural_lecture_0_0_1.mat',
42        'air_binaural_booth_0_0_1.mat',
43    ]
44    for impulse_response_mat_file_name in impulse_response_mat_file_names:
45      data = {'h_air': np.random.rand(1, 1000).astype('<f8')}
46      scipy.io.savemat(os.path.join(
47          self._fake_air_db_path, impulse_response_mat_file_name), data)
48
49  def tearDown(self):
50    """Recursively delete temporary folders."""
51    shutil.rmtree(self._base_output_path)
52    shutil.rmtree(self._test_data_cache_path)
53    shutil.rmtree(self._fake_air_db_path)
54
55  def testTestDataGenerators(self):
56    # Preliminary check.
57    self.assertTrue(os.path.exists(self._base_output_path))
58    self.assertTrue(os.path.exists(self._test_data_cache_path))
59
60    # Check that there is at least one registered test data generator.
61    registered_classes = (
62        test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
63    self.assertIsInstance(registered_classes, dict)
64    self.assertGreater(len(registered_classes), 0)
65
66    # Instance generators factory.
67    generators_factory = test_data_generation_factory.TestDataGeneratorFactory(
68        aechen_ir_database_path=self._fake_air_db_path,
69        noise_tracks_path=test_data_generation.  \
70                          AdditiveNoiseTestDataGenerator.  \
71                          DEFAULT_NOISE_TRACKS_PATH,
72        copy_with_identity=False)
73    generators_factory.SetOutputDirectoryPrefix('datagen-')
74
75    # Use a simple input file as clean input signal.
76    input_signal_filepath = os.path.join(
77        os.getcwd(), 'probing_signals', 'tone-880.wav')
78    self.assertTrue(os.path.exists(input_signal_filepath))
79
80    # Load input signal.
81    input_signal = signal_processing.SignalProcessingUtils.LoadWav(
82        input_signal_filepath)
83
84    # Try each registered test data generator.
85    for generator_name in registered_classes:
86      # Instance test data generator.
87      generator = generators_factory.GetInstance(
88          registered_classes[generator_name])
89
90      # Generate the noisy input - reference pairs.
91      generator.Generate(
92          input_signal_filepath=input_signal_filepath,
93          test_data_cache_path=self._test_data_cache_path,
94          base_output_path=self._base_output_path)
95
96      # Perform checks.
97      self._CheckGeneratedPairsListSizes(generator)
98      self._CheckGeneratedPairsSignalDurations(generator, input_signal)
99      self._CheckGeneratedPairsOutputPaths(generator)
100
101  def testTestidentityDataGenerator(self):
102    # Preliminary check.
103    self.assertTrue(os.path.exists(self._base_output_path))
104    self.assertTrue(os.path.exists(self._test_data_cache_path))
105
106    # Use a simple input file as clean input signal.
107    input_signal_filepath = os.path.join(
108        os.getcwd(), 'probing_signals', 'tone-880.wav')
109    self.assertTrue(os.path.exists(input_signal_filepath))
110
111    def GetNoiseReferenceFilePaths(identity_generator):
112      noisy_signal_filepaths = identity_generator.noisy_signal_filepaths
113      reference_signal_filepaths = identity_generator.reference_signal_filepaths
114      assert noisy_signal_filepaths.keys() == reference_signal_filepaths.keys()
115      assert len(noisy_signal_filepaths.keys()) == 1
116      key = noisy_signal_filepaths.keys()[0]
117      return noisy_signal_filepaths[key], reference_signal_filepaths[key]
118
119    # Test the |copy_with_identity| flag.
120    for copy_with_identity in [False, True]:
121      # Instance the generator through the factory.
122      factory = test_data_generation_factory.TestDataGeneratorFactory(
123        aechen_ir_database_path='', noise_tracks_path='',
124        copy_with_identity=copy_with_identity)
125      factory.SetOutputDirectoryPrefix('datagen-')
126      generator = factory.GetInstance(
127          test_data_generation.IdentityTestDataGenerator)
128      # Check |copy_with_identity| is set correctly.
129      self.assertEqual(copy_with_identity, generator.copy_with_identity)
130
131      # Generate test data and extract the paths to the noise and the reference
132      # files.
133      generator.Generate(
134          input_signal_filepath=input_signal_filepath,
135          test_data_cache_path=self._test_data_cache_path,
136          base_output_path=self._base_output_path)
137      noisy_signal_filepath, reference_signal_filepath = (
138          GetNoiseReferenceFilePaths(generator))
139
140      # Check that a copy is made if and only if |copy_with_identity| is True.
141      if copy_with_identity:
142        self.assertNotEqual(noisy_signal_filepath, input_signal_filepath)
143        self.assertNotEqual(reference_signal_filepath, input_signal_filepath)
144      else:
145        self.assertEqual(noisy_signal_filepath, input_signal_filepath)
146        self.assertEqual(reference_signal_filepath, input_signal_filepath)
147
148  def _CheckGeneratedPairsListSizes(self, generator):
149    config_names = generator.config_names
150    number_of_pairs = len(config_names)
151    self.assertEqual(number_of_pairs,
152                     len(generator.noisy_signal_filepaths))
153    self.assertEqual(number_of_pairs,
154                     len(generator.apm_output_paths))
155    self.assertEqual(number_of_pairs,
156                     len(generator.reference_signal_filepaths))
157
158  def _CheckGeneratedPairsSignalDurations(
159      self, generator, input_signal):
160    """Checks duration of the generated signals.
161
162    Checks that the noisy input and the reference tracks are audio files
163    with duration equal to or greater than that of the input signal.
164
165    Args:
166      generator: TestDataGenerator instance.
167      input_signal: AudioSegment instance.
168    """
169    input_signal_length = (
170        signal_processing.SignalProcessingUtils.CountSamples(input_signal))
171
172    # Iterate over the noisy signal - reference pairs.
173    for config_name in generator.config_names:
174      # Load the noisy input file.
175      noisy_signal_filepath = generator.noisy_signal_filepaths[
176          config_name]
177      noisy_signal = signal_processing.SignalProcessingUtils.LoadWav(
178          noisy_signal_filepath)
179
180      # Check noisy input signal length.
181      noisy_signal_length = (
182          signal_processing.SignalProcessingUtils.CountSamples(noisy_signal))
183      self.assertGreaterEqual(noisy_signal_length, input_signal_length)
184
185      # Load the reference file.
186      reference_signal_filepath = generator.reference_signal_filepaths[
187          config_name]
188      reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
189          reference_signal_filepath)
190
191      # Check noisy input signal length.
192      reference_signal_length = (
193          signal_processing.SignalProcessingUtils.CountSamples(
194              reference_signal))
195      self.assertGreaterEqual(reference_signal_length, input_signal_length)
196
197  def _CheckGeneratedPairsOutputPaths(self, generator):
198    """Checks that the output path created by the generator exists.
199
200    Args:
201      generator: TestDataGenerator instance.
202    """
203    # Iterate over the noisy signal - reference pairs.
204    for config_name in generator.config_names:
205      output_path = generator.apm_output_paths[config_name]
206      self.assertTrue(os.path.exists(output_path))
207