1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 2# 3# Use of this source code is governed by a BSD-style license 4# that can be found in the LICENSE file in the root of the source 5# tree. An additional intellectual property rights grant can be found 6# in the file PATENTS. All contributing project authors may 7# be found in the AUTHORS file in the root of the source tree. 8 9"""Unit tests for the input mixer module. 10""" 11 12import logging 13import os 14import shutil 15import sys 16import tempfile 17import unittest 18 19SRC = os.path.abspath(os.path.join( 20 os.path.dirname((__file__)), os.pardir, os.pardir, os.pardir, os.pardir)) 21sys.path.append(os.path.join(SRC, 'third_party', 'pymock')) 22 23import mock 24 25from . import exceptions 26from . import input_mixer 27from . import signal_processing 28 29 30class TestApmInputMixer(unittest.TestCase): 31 """Unit tests for the ApmInputMixer class. 32 """ 33 34 # Audio track file names created in setUp(). 35 _FILENAMES = ['capture', 'echo_1', 'echo_2', 'shorter', 'longer'] 36 37 # Target peak power level (dBFS) of each audio track file created in setUp(). 38 # These values are hand-crafted in order to make saturation happen when 39 # capture and echo_2 are mixed and the contrary for capture and echo_1. 40 # None means that the power is not changed. 41 _MAX_PEAK_POWER_LEVELS = [-10.0, -5.0, 0.0, None, None] 42 43 # Audio track file durations in milliseconds. 44 _DURATIONS = [1000, 1000, 1000, 800, 1200] 45 46 _SAMPLE_RATE = 48000 47 48 def setUp(self): 49 """Creates temporary data.""" 50 self._tmp_path = tempfile.mkdtemp() 51 52 # Create audio track files. 53 self._audio_tracks = {} 54 for filename, peak_power, duration in zip( 55 self._FILENAMES, self._MAX_PEAK_POWER_LEVELS, self._DURATIONS): 56 audio_track_filepath = os.path.join(self._tmp_path, '{}.wav'.format( 57 filename)) 58 59 # Create a pure tone with the target peak power level. 60 template = signal_processing.SignalProcessingUtils.GenerateSilence( 61 duration=duration, sample_rate=self._SAMPLE_RATE) 62 signal = signal_processing.SignalProcessingUtils.GeneratePureTone( 63 template) 64 if peak_power is not None: 65 signal = signal.apply_gain(-signal.max_dBFS + peak_power) 66 67 signal_processing.SignalProcessingUtils.SaveWav( 68 audio_track_filepath, signal) 69 self._audio_tracks[filename] = { 70 'filepath': audio_track_filepath, 71 'num_samples': signal_processing.SignalProcessingUtils.CountSamples( 72 signal) 73 } 74 75 def tearDown(self): 76 """Recursively deletes temporary folders.""" 77 shutil.rmtree(self._tmp_path) 78 79 def testCheckMixSameDuration(self): 80 """Checks the duration when mixing capture and echo with same duration.""" 81 mix_filepath = input_mixer.ApmInputMixer.Mix( 82 self._tmp_path, 83 self._audio_tracks['capture']['filepath'], 84 self._audio_tracks['echo_1']['filepath']) 85 self.assertTrue(os.path.exists(mix_filepath)) 86 87 mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath) 88 self.assertEqual(self._audio_tracks['capture']['num_samples'], 89 signal_processing.SignalProcessingUtils.CountSamples(mix)) 90 91 def testRejectShorterEcho(self): 92 """Rejects echo signals that are shorter than the capture signal.""" 93 try: 94 _ = input_mixer.ApmInputMixer.Mix( 95 self._tmp_path, 96 self._audio_tracks['capture']['filepath'], 97 self._audio_tracks['shorter']['filepath']) 98 self.fail('no exception raised') 99 except exceptions.InputMixerException: 100 pass 101 102 def testCheckMixDurationWithLongerEcho(self): 103 """Checks the duration when mixing an echo longer than the capture.""" 104 mix_filepath = input_mixer.ApmInputMixer.Mix( 105 self._tmp_path, 106 self._audio_tracks['capture']['filepath'], 107 self._audio_tracks['longer']['filepath']) 108 self.assertTrue(os.path.exists(mix_filepath)) 109 110 mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath) 111 self.assertEqual(self._audio_tracks['capture']['num_samples'], 112 signal_processing.SignalProcessingUtils.CountSamples(mix)) 113 114 def testCheckOutputFileNamesConflict(self): 115 """Checks that different echo files lead to different output file names.""" 116 mix1_filepath = input_mixer.ApmInputMixer.Mix( 117 self._tmp_path, 118 self._audio_tracks['capture']['filepath'], 119 self._audio_tracks['echo_1']['filepath']) 120 self.assertTrue(os.path.exists(mix1_filepath)) 121 122 mix2_filepath = input_mixer.ApmInputMixer.Mix( 123 self._tmp_path, 124 self._audio_tracks['capture']['filepath'], 125 self._audio_tracks['echo_2']['filepath']) 126 self.assertTrue(os.path.exists(mix2_filepath)) 127 128 self.assertNotEqual(mix1_filepath, mix2_filepath) 129 130 def testHardClippingLogExpected(self): 131 """Checks that hard clipping warning is raised when occurring.""" 132 logging.warning = mock.MagicMock(name='warning') 133 _ = input_mixer.ApmInputMixer.Mix( 134 self._tmp_path, 135 self._audio_tracks['capture']['filepath'], 136 self._audio_tracks['echo_2']['filepath']) 137 logging.warning.assert_called_once_with( 138 input_mixer.ApmInputMixer.HardClippingLogMessage()) 139 140 def testHardClippingLogNotExpected(self): 141 """Checks that hard clipping warning is not raised when not occurring.""" 142 logging.warning = mock.MagicMock(name='warning') 143 _ = input_mixer.ApmInputMixer.Mix( 144 self._tmp_path, 145 self._audio_tracks['capture']['filepath'], 146 self._audio_tracks['echo_1']['filepath']) 147 self.assertNotIn( 148 mock.call(input_mixer.ApmInputMixer.HardClippingLogMessage()), 149 logging.warning.call_args_list) 150