1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2#
3# Use of this source code is governed by a BSD-style license
4# that can be found in the LICENSE file in the root of the source
5# tree. An additional intellectual property rights grant can be found
6# in the file PATENTS.  All contributing project authors may
7# be found in the AUTHORS file in the root of the source tree.
8
9"""Unit tests for the simulation module.
10"""
11
12import logging
13import os
14import shutil
15import sys
16import tempfile
17import unittest
18
19SRC = os.path.abspath(os.path.join(
20    os.path.dirname((__file__)), os.pardir, os.pardir, os.pardir, os.pardir))
21sys.path.append(os.path.join(SRC, 'third_party', 'pymock'))
22
23import mock
24import pydub
25
26from . import audioproc_wrapper
27from . import eval_scores_factory
28from . import evaluation
29from . import external_vad
30from . import signal_processing
31from . import simulation
32from . import test_data_generation_factory
33
34
35class TestApmModuleSimulator(unittest.TestCase):
36  """Unit tests for the ApmModuleSimulator class.
37  """
38
39  def setUp(self):
40    """Create temporary folders and fake audio track."""
41    self._output_path = tempfile.mkdtemp()
42    self._tmp_path = tempfile.mkdtemp()
43
44    silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
45    fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
46        silence)
47    self._fake_audio_track_path = os.path.join(self._output_path, 'fake.wav')
48    signal_processing.SignalProcessingUtils.SaveWav(
49        self._fake_audio_track_path, fake_signal)
50
51  def tearDown(self):
52    """Recursively delete temporary folders."""
53    shutil.rmtree(self._output_path)
54    shutil.rmtree(self._tmp_path)
55
56  def testSimulation(self):
57    # Instance dependencies to mock and inject.
58    ap_wrapper = audioproc_wrapper.AudioProcWrapper(
59        audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH)
60    evaluator = evaluation.ApmModuleEvaluator()
61    ap_wrapper.Run = mock.MagicMock(name='Run')
62    evaluator.Run = mock.MagicMock(name='Run')
63
64    # Instance non-mocked dependencies.
65    test_data_generator_factory = (
66        test_data_generation_factory.TestDataGeneratorFactory(
67            aechen_ir_database_path='',
68            noise_tracks_path='',
69            copy_with_identity=False))
70    evaluation_score_factory = eval_scores_factory.EvaluationScoreWorkerFactory(
71        polqa_tool_bin_path=os.path.join(
72            os.path.dirname(__file__), 'fake_polqa'))
73
74    # Instance simulator.
75    simulator = simulation.ApmModuleSimulator(
76        test_data_generator_factory=test_data_generator_factory,
77        evaluation_score_factory=evaluation_score_factory,
78        ap_wrapper=ap_wrapper,
79        evaluator=evaluator,
80        external_vads={'fake': external_vad.ExternalVad(os.path.join(
81            os.path.dirname(__file__), 'fake_external_vad.py'), 'fake')}
82    )
83
84    # What to simulate.
85    config_files = ['apm_configs/default.json']
86    input_files = [self._fake_audio_track_path]
87    test_data_generators = ['identity', 'white_noise']
88    eval_scores = ['audio_level_mean', 'polqa']
89
90    # Run all simulations.
91    simulator.Run(
92        config_filepaths=config_files,
93        capture_input_filepaths=input_files,
94        test_data_generator_names=test_data_generators,
95        eval_score_names=eval_scores,
96        output_dir=self._output_path)
97
98    # Check.
99    # TODO(alessiob): Once the TestDataGenerator classes can be configured by
100    # the client code (e.g., number of SNR pairs for the white noise test data
101    # generator), the exact number of calls to ap_wrapper.Run and evaluator.Run
102    # is known; use that with assertEqual.
103    min_number_of_simulations = len(config_files) * len(input_files) * len(
104        test_data_generators)
105    self.assertGreaterEqual(len(ap_wrapper.Run.call_args_list),
106                            min_number_of_simulations)
107    self.assertGreaterEqual(len(evaluator.Run.call_args_list),
108                            min_number_of_simulations)
109
110  def testInputSignalCreation(self):
111    # Instance simulator.
112    simulator = simulation.ApmModuleSimulator(
113        test_data_generator_factory=(
114            test_data_generation_factory.TestDataGeneratorFactory(
115                aechen_ir_database_path='',
116                noise_tracks_path='',
117                copy_with_identity=False)),
118        evaluation_score_factory=(
119            eval_scores_factory.EvaluationScoreWorkerFactory(
120                polqa_tool_bin_path=os.path.join(
121                    os.path.dirname(__file__), 'fake_polqa'))),
122        ap_wrapper=audioproc_wrapper.AudioProcWrapper(
123            audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH),
124        evaluator=evaluation.ApmModuleEvaluator())
125
126    # Inexistent input files to be silently created.
127    input_files = [
128        os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
129        os.path.join(self._tmp_path, 'pure_tone-1000_500.wav'),
130    ]
131    self.assertFalse(any([os.path.exists(input_file) for input_file in (
132        input_files)]))
133
134    # The input files are created during the simulation.
135    simulator.Run(
136        config_filepaths=['apm_configs/default.json'],
137        capture_input_filepaths=input_files,
138        test_data_generator_names=['identity'],
139        eval_score_names=['audio_level_peak'],
140        output_dir=self._output_path)
141    self.assertTrue(all([os.path.exists(input_file) for input_file in (
142        input_files)]))
143
144  def testPureToneGenerationWithTotalHarmonicDistorsion(self):
145    logging.warning = mock.MagicMock(name='warning')
146
147    # Instance simulator.
148    simulator = simulation.ApmModuleSimulator(
149        test_data_generator_factory=(
150            test_data_generation_factory.TestDataGeneratorFactory(
151                aechen_ir_database_path='',
152                noise_tracks_path='',
153                copy_with_identity=False)),
154        evaluation_score_factory=(
155            eval_scores_factory.EvaluationScoreWorkerFactory(
156                polqa_tool_bin_path=os.path.join(
157                    os.path.dirname(__file__), 'fake_polqa'))),
158        ap_wrapper=audioproc_wrapper.AudioProcWrapper(
159            audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH),
160        evaluator=evaluation.ApmModuleEvaluator())
161
162    # What to simulate.
163    config_files = ['apm_configs/default.json']
164    input_files = [os.path.join(self._tmp_path, 'pure_tone-440_1000.wav')]
165    eval_scores = ['thd']
166
167    # Should work.
168    simulator.Run(
169        config_filepaths=config_files,
170        capture_input_filepaths=input_files,
171        test_data_generator_names=['identity'],
172        eval_score_names=eval_scores,
173        output_dir=self._output_path)
174    self.assertFalse(logging.warning.called)
175
176    # Warning expected.
177    simulator.Run(
178        config_filepaths=config_files,
179        capture_input_filepaths=input_files,
180        test_data_generator_names=['white_noise'],  # Not allowed with THD.
181        eval_score_names=eval_scores,
182        output_dir=self._output_path)
183    logging.warning.assert_called_with('the evaluation failed: %s', (
184        'The THD score cannot be used with any test data generator other than '
185        '"identity"'))
186
187  #   # Init.
188  #   generator = test_data_generation.IdentityTestDataGenerator('tmp')
189  #   input_signal_filepath = os.path.join(
190  #       self._test_data_cache_path, 'pure_tone-440_1000.wav')
191
192  #   # Check that the input signal is generated.
193  #   self.assertFalse(os.path.exists(input_signal_filepath))
194  #   generator.Generate(
195  #       input_signal_filepath=input_signal_filepath,
196  #       test_data_cache_path=self._test_data_cache_path,
197  #       base_output_path=self._base_output_path)
198  #   self.assertTrue(os.path.exists(input_signal_filepath))
199
200  #   # Check input signal properties.
201  #   input_signal = signal_processing.SignalProcessingUtils.LoadWav(
202  #       input_signal_filepath)
203  #   self.assertEqual(1000, len(input_signal))
204