1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2#
3# Use of this source code is governed by a BSD-style license
4# that can be found in the LICENSE file in the root of the source
5# tree. An additional intellectual property rights grant can be found
6# in the file PATENTS.  All contributing project authors may
7# be found in the AUTHORS file in the root of the source tree.
8
9"""Unit tests for the simulation module.
10"""
11
12import logging
13import os
14import shutil
15import sys
16import tempfile
17import unittest
18
19SRC = os.path.abspath(os.path.join(
20    os.path.dirname((__file__)), os.pardir, os.pardir, os.pardir, os.pardir))
21sys.path.append(os.path.join(SRC, 'third_party', 'pymock'))
22
23import mock
24import pydub
25
26from . import audioproc_wrapper
27from . import eval_scores_factory
28from . import evaluation
29from . import external_vad
30from . import signal_processing
31from . import simulation
32from . import test_data_generation_factory
33
34
35class TestApmModuleSimulator(unittest.TestCase):
36  """Unit tests for the ApmModuleSimulator class.
37  """
38
39  def setUp(self):
40    """Create temporary folders and fake audio track."""
41    self._output_path = tempfile.mkdtemp()
42    self._tmp_path = tempfile.mkdtemp()
43
44    silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
45    fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
46        silence)
47    self._fake_audio_track_path = os.path.join(self._output_path, 'fake.wav')
48    signal_processing.SignalProcessingUtils.SaveWav(
49        self._fake_audio_track_path, fake_signal)
50
51  def tearDown(self):
52    """Recursively delete temporary folders."""
53    shutil.rmtree(self._output_path)
54    shutil.rmtree(self._tmp_path)
55
56  def testSimulation(self):
57    # Instance dependencies to mock and inject.
58    ap_wrapper = audioproc_wrapper.AudioProcWrapper(
59        audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH)
60    evaluator = evaluation.ApmModuleEvaluator()
61    ap_wrapper.Run = mock.MagicMock(name='Run')
62    evaluator.Run = mock.MagicMock(name='Run')
63
64    # Instance non-mocked dependencies.
65    test_data_generator_factory = (
66        test_data_generation_factory.TestDataGeneratorFactory(
67            aechen_ir_database_path='',
68            noise_tracks_path='',
69            copy_with_identity=False))
70    evaluation_score_factory = eval_scores_factory.EvaluationScoreWorkerFactory(
71        polqa_tool_bin_path=os.path.join(
72            os.path.dirname(__file__), 'fake_polqa'),
73        echo_metric_tool_bin_path=None
74    )
75
76    # Instance simulator.
77    simulator = simulation.ApmModuleSimulator(
78        test_data_generator_factory=test_data_generator_factory,
79        evaluation_score_factory=evaluation_score_factory,
80        ap_wrapper=ap_wrapper,
81        evaluator=evaluator,
82        external_vads={'fake': external_vad.ExternalVad(os.path.join(
83            os.path.dirname(__file__), 'fake_external_vad.py'), 'fake')}
84    )
85
86    # What to simulate.
87    config_files = ['apm_configs/default.json']
88    input_files = [self._fake_audio_track_path]
89    test_data_generators = ['identity', 'white_noise']
90    eval_scores = ['audio_level_mean', 'polqa']
91
92    # Run all simulations.
93    simulator.Run(
94        config_filepaths=config_files,
95        capture_input_filepaths=input_files,
96        test_data_generator_names=test_data_generators,
97        eval_score_names=eval_scores,
98        output_dir=self._output_path)
99
100    # Check.
101    # TODO(alessiob): Once the TestDataGenerator classes can be configured by
102    # the client code (e.g., number of SNR pairs for the white noise test data
103    # generator), the exact number of calls to ap_wrapper.Run and evaluator.Run
104    # is known; use that with assertEqual.
105    min_number_of_simulations = len(config_files) * len(input_files) * len(
106        test_data_generators)
107    self.assertGreaterEqual(len(ap_wrapper.Run.call_args_list),
108                            min_number_of_simulations)
109    self.assertGreaterEqual(len(evaluator.Run.call_args_list),
110                            min_number_of_simulations)
111
112  def testInputSignalCreation(self):
113    # Instance simulator.
114    simulator = simulation.ApmModuleSimulator(
115        test_data_generator_factory=(
116            test_data_generation_factory.TestDataGeneratorFactory(
117                aechen_ir_database_path='',
118                noise_tracks_path='',
119                copy_with_identity=False)),
120        evaluation_score_factory=(
121            eval_scores_factory.EvaluationScoreWorkerFactory(
122                polqa_tool_bin_path=os.path.join(
123                    os.path.dirname(__file__), 'fake_polqa'),
124                echo_metric_tool_bin_path=None
125            )),
126        ap_wrapper=audioproc_wrapper.AudioProcWrapper(
127            audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH),
128        evaluator=evaluation.ApmModuleEvaluator())
129
130    # Inexistent input files to be silently created.
131    input_files = [
132        os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
133        os.path.join(self._tmp_path, 'pure_tone-1000_500.wav'),
134    ]
135    self.assertFalse(any([os.path.exists(input_file) for input_file in (
136        input_files)]))
137
138    # The input files are created during the simulation.
139    simulator.Run(
140        config_filepaths=['apm_configs/default.json'],
141        capture_input_filepaths=input_files,
142        test_data_generator_names=['identity'],
143        eval_score_names=['audio_level_peak'],
144        output_dir=self._output_path)
145    self.assertTrue(all([os.path.exists(input_file) for input_file in (
146        input_files)]))
147
148  def testPureToneGenerationWithTotalHarmonicDistorsion(self):
149    logging.warning = mock.MagicMock(name='warning')
150
151    # Instance simulator.
152    simulator = simulation.ApmModuleSimulator(
153        test_data_generator_factory=(
154            test_data_generation_factory.TestDataGeneratorFactory(
155                aechen_ir_database_path='',
156                noise_tracks_path='',
157                copy_with_identity=False)),
158        evaluation_score_factory=(
159            eval_scores_factory.EvaluationScoreWorkerFactory(
160                polqa_tool_bin_path=os.path.join(
161                    os.path.dirname(__file__), 'fake_polqa'),
162                echo_metric_tool_bin_path=None
163            )),
164        ap_wrapper=audioproc_wrapper.AudioProcWrapper(
165            audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH),
166        evaluator=evaluation.ApmModuleEvaluator())
167
168    # What to simulate.
169    config_files = ['apm_configs/default.json']
170    input_files = [os.path.join(self._tmp_path, 'pure_tone-440_1000.wav')]
171    eval_scores = ['thd']
172
173    # Should work.
174    simulator.Run(
175        config_filepaths=config_files,
176        capture_input_filepaths=input_files,
177        test_data_generator_names=['identity'],
178        eval_score_names=eval_scores,
179        output_dir=self._output_path)
180    self.assertFalse(logging.warning.called)
181
182    # Warning expected.
183    simulator.Run(
184        config_filepaths=config_files,
185        capture_input_filepaths=input_files,
186        test_data_generator_names=['white_noise'],  # Not allowed with THD.
187        eval_score_names=eval_scores,
188        output_dir=self._output_path)
189    logging.warning.assert_called_with('the evaluation failed: %s', (
190        'The THD score cannot be used with any test data generator other than '
191        '"identity"'))
192
193  #   # Init.
194  #   generator = test_data_generation.IdentityTestDataGenerator('tmp')
195  #   input_signal_filepath = os.path.join(
196  #       self._test_data_cache_path, 'pure_tone-440_1000.wav')
197
198  #   # Check that the input signal is generated.
199  #   self.assertFalse(os.path.exists(input_signal_filepath))
200  #   generator.Generate(
201  #       input_signal_filepath=input_signal_filepath,
202  #       test_data_cache_path=self._test_data_cache_path,
203  #       base_output_path=self._base_output_path)
204  #   self.assertTrue(os.path.exists(input_signal_filepath))
205
206  #   # Check input signal properties.
207  #   input_signal = signal_processing.SignalProcessingUtils.LoadWav(
208  #       input_signal_filepath)
209  #   self.assertEqual(1000, len(input_signal))
210