1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2#
3# Use of this source code is governed by a BSD-style license
4# that can be found in the LICENSE file in the root of the source
5# tree. An additional intellectual property rights grant can be found
6# in the file PATENTS.  All contributing project authors may
7# be found in the AUTHORS file in the root of the source tree.
8"""Unit tests for the simulation module.
9"""
10
11import logging
12import os
13import shutil
14import tempfile
15import unittest
16
17import mock
18import pydub
19
20from . import audioproc_wrapper
21from . import eval_scores_factory
22from . import evaluation
23from . import external_vad
24from . import signal_processing
25from . import simulation
26from . import test_data_generation_factory
27
28
29class TestApmModuleSimulator(unittest.TestCase):
30    """Unit tests for the ApmModuleSimulator class.
31  """
32
33    def setUp(self):
34        """Create temporary folders and fake audio track."""
35        self._output_path = tempfile.mkdtemp()
36        self._tmp_path = tempfile.mkdtemp()
37
38        silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
39        fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
40            silence)
41        self._fake_audio_track_path = os.path.join(self._output_path,
42                                                   'fake.wav')
43        signal_processing.SignalProcessingUtils.SaveWav(
44            self._fake_audio_track_path, fake_signal)
45
46    def tearDown(self):
47        """Recursively delete temporary folders."""
48        shutil.rmtree(self._output_path)
49        shutil.rmtree(self._tmp_path)
50
51    def testSimulation(self):
52        # Instance dependencies to mock and inject.
53        ap_wrapper = audioproc_wrapper.AudioProcWrapper(
54            audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH)
55        evaluator = evaluation.ApmModuleEvaluator()
56        ap_wrapper.Run = mock.MagicMock(name='Run')
57        evaluator.Run = mock.MagicMock(name='Run')
58
59        # Instance non-mocked dependencies.
60        test_data_generator_factory = (
61            test_data_generation_factory.TestDataGeneratorFactory(
62                aechen_ir_database_path='',
63                noise_tracks_path='',
64                copy_with_identity=False))
65        evaluation_score_factory = eval_scores_factory.EvaluationScoreWorkerFactory(
66            polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
67                                             'fake_polqa'),
68            echo_metric_tool_bin_path=None)
69
70        # Instance simulator.
71        simulator = simulation.ApmModuleSimulator(
72            test_data_generator_factory=test_data_generator_factory,
73            evaluation_score_factory=evaluation_score_factory,
74            ap_wrapper=ap_wrapper,
75            evaluator=evaluator,
76            external_vads={
77                'fake':
78                external_vad.ExternalVad(
79                    os.path.join(os.path.dirname(__file__),
80                                 'fake_external_vad.py'), 'fake')
81            })
82
83        # What to simulate.
84        config_files = ['apm_configs/default.json']
85        input_files = [self._fake_audio_track_path]
86        test_data_generators = ['identity', 'white_noise']
87        eval_scores = ['audio_level_mean', 'polqa']
88
89        # Run all simulations.
90        simulator.Run(config_filepaths=config_files,
91                      capture_input_filepaths=input_files,
92                      test_data_generator_names=test_data_generators,
93                      eval_score_names=eval_scores,
94                      output_dir=self._output_path)
95
96        # Check.
97        # TODO(alessiob): Once the TestDataGenerator classes can be configured by
98        # the client code (e.g., number of SNR pairs for the white noise test data
99        # generator), the exact number of calls to ap_wrapper.Run and evaluator.Run
100        # is known; use that with assertEqual.
101        min_number_of_simulations = len(config_files) * len(input_files) * len(
102            test_data_generators)
103        self.assertGreaterEqual(len(ap_wrapper.Run.call_args_list),
104                                min_number_of_simulations)
105        self.assertGreaterEqual(len(evaluator.Run.call_args_list),
106                                min_number_of_simulations)
107
108    def testInputSignalCreation(self):
109        # Instance simulator.
110        simulator = simulation.ApmModuleSimulator(
111            test_data_generator_factory=(
112                test_data_generation_factory.TestDataGeneratorFactory(
113                    aechen_ir_database_path='',
114                    noise_tracks_path='',
115                    copy_with_identity=False)),
116            evaluation_score_factory=(
117                eval_scores_factory.EvaluationScoreWorkerFactory(
118                    polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
119                                                     'fake_polqa'),
120                    echo_metric_tool_bin_path=None)),
121            ap_wrapper=audioproc_wrapper.AudioProcWrapper(
122                audioproc_wrapper.AudioProcWrapper.
123                DEFAULT_APM_SIMULATOR_BIN_PATH),
124            evaluator=evaluation.ApmModuleEvaluator())
125
126        # Inexistent input files to be silently created.
127        input_files = [
128            os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
129            os.path.join(self._tmp_path, 'pure_tone-1000_500.wav'),
130        ]
131        self.assertFalse(
132            any([os.path.exists(input_file) for input_file in (input_files)]))
133
134        # The input files are created during the simulation.
135        simulator.Run(config_filepaths=['apm_configs/default.json'],
136                      capture_input_filepaths=input_files,
137                      test_data_generator_names=['identity'],
138                      eval_score_names=['audio_level_peak'],
139                      output_dir=self._output_path)
140        self.assertTrue(
141            all([os.path.exists(input_file) for input_file in (input_files)]))
142
143    def testPureToneGenerationWithTotalHarmonicDistorsion(self):
144        logging.warning = mock.MagicMock(name='warning')
145
146        # Instance simulator.
147        simulator = simulation.ApmModuleSimulator(
148            test_data_generator_factory=(
149                test_data_generation_factory.TestDataGeneratorFactory(
150                    aechen_ir_database_path='',
151                    noise_tracks_path='',
152                    copy_with_identity=False)),
153            evaluation_score_factory=(
154                eval_scores_factory.EvaluationScoreWorkerFactory(
155                    polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
156                                                     'fake_polqa'),
157                    echo_metric_tool_bin_path=None)),
158            ap_wrapper=audioproc_wrapper.AudioProcWrapper(
159                audioproc_wrapper.AudioProcWrapper.
160                DEFAULT_APM_SIMULATOR_BIN_PATH),
161            evaluator=evaluation.ApmModuleEvaluator())
162
163        # What to simulate.
164        config_files = ['apm_configs/default.json']
165        input_files = [os.path.join(self._tmp_path, 'pure_tone-440_1000.wav')]
166        eval_scores = ['thd']
167
168        # Should work.
169        simulator.Run(config_filepaths=config_files,
170                      capture_input_filepaths=input_files,
171                      test_data_generator_names=['identity'],
172                      eval_score_names=eval_scores,
173                      output_dir=self._output_path)
174        self.assertFalse(logging.warning.called)
175
176        # Warning expected.
177        simulator.Run(
178            config_filepaths=config_files,
179            capture_input_filepaths=input_files,
180            test_data_generator_names=['white_noise'],  # Not allowed with THD.
181            eval_score_names=eval_scores,
182            output_dir=self._output_path)
183        logging.warning.assert_called_with('the evaluation failed: %s', (
184            'The THD score cannot be used with any test data generator other than '
185            '"identity"'))
186
187    #   # Init.
188    #   generator = test_data_generation.IdentityTestDataGenerator('tmp')
189    #   input_signal_filepath = os.path.join(
190    #       self._test_data_cache_path, 'pure_tone-440_1000.wav')
191
192    #   # Check that the input signal is generated.
193    #   self.assertFalse(os.path.exists(input_signal_filepath))
194    #   generator.Generate(
195    #       input_signal_filepath=input_signal_filepath,
196    #       test_data_cache_path=self._test_data_cache_path,
197    #       base_output_path=self._base_output_path)
198    #   self.assertTrue(os.path.exists(input_signal_filepath))
199
200    #   # Check input signal properties.
201    #   input_signal = signal_processing.SignalProcessingUtils.LoadWav(
202    #       input_signal_filepath)
203    #   self.assertEqual(1000, len(input_signal))
204