1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2#
3# Use of this source code is governed by a BSD-style license
4# that can be found in the LICENSE file in the root of the source
5# tree. An additional intellectual property rights grant can be found
6# in the file PATENTS.  All contributing project authors may
7# be found in the AUTHORS file in the root of the source tree.
8
9"""Unit tests for the input mixer module.
10"""
11
12import logging
13import os
14import shutil
15import sys
16import tempfile
17import unittest
18
19SRC = os.path.abspath(os.path.join(
20    os.path.dirname((__file__)), os.pardir, os.pardir, os.pardir, os.pardir))
21sys.path.append(os.path.join(SRC, 'third_party', 'pymock'))
22
23import mock
24
25from . import exceptions
26from . import input_mixer
27from . import signal_processing
28
29
30class TestApmInputMixer(unittest.TestCase):
31  """Unit tests for the ApmInputMixer class.
32  """
33
34  # Audio track file names created in setUp().
35  _FILENAMES = ['capture', 'echo_1', 'echo_2', 'shorter', 'longer']
36
37  # Target peak power level (dBFS) of each audio track file created in setUp().
38  # These values are hand-crafted in order to make saturation happen when
39  # capture and echo_2 are mixed and the contrary for capture and echo_1.
40  # None means that the power is not changed.
41  _MAX_PEAK_POWER_LEVELS = [-10.0, -5.0, 0.0, None, None]
42
43  # Audio track file durations in milliseconds.
44  _DURATIONS = [1000, 1000, 1000, 800, 1200]
45
46  _SAMPLE_RATE = 48000
47
48  def setUp(self):
49    """Creates temporary data."""
50    self._tmp_path = tempfile.mkdtemp()
51
52    # Create audio track files.
53    self._audio_tracks = {}
54    for filename, peak_power, duration in zip(
55        self._FILENAMES, self._MAX_PEAK_POWER_LEVELS, self._DURATIONS):
56      audio_track_filepath = os.path.join(self._tmp_path, '{}.wav'.format(
57          filename))
58
59      # Create a pure tone with the target peak power level.
60      template = signal_processing.SignalProcessingUtils.GenerateSilence(
61          duration=duration, sample_rate=self._SAMPLE_RATE)
62      signal = signal_processing.SignalProcessingUtils.GeneratePureTone(
63          template)
64      if peak_power is not None:
65        signal = signal.apply_gain(-signal.max_dBFS + peak_power)
66
67      signal_processing.SignalProcessingUtils.SaveWav(
68        audio_track_filepath, signal)
69      self._audio_tracks[filename] = {
70          'filepath': audio_track_filepath,
71          'num_samples': signal_processing.SignalProcessingUtils.CountSamples(
72              signal)
73      }
74
75  def tearDown(self):
76    """Recursively deletes temporary folders."""
77    shutil.rmtree(self._tmp_path)
78
79  def testCheckMixSameDuration(self):
80    """Checks the duration when mixing capture and echo with same duration."""
81    mix_filepath = input_mixer.ApmInputMixer.Mix(
82        self._tmp_path,
83        self._audio_tracks['capture']['filepath'],
84        self._audio_tracks['echo_1']['filepath'])
85    self.assertTrue(os.path.exists(mix_filepath))
86
87    mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
88    self.assertEqual(self._audio_tracks['capture']['num_samples'],
89                     signal_processing.SignalProcessingUtils.CountSamples(mix))
90
91  def testRejectShorterEcho(self):
92    """Rejects echo signals that are shorter than the capture signal."""
93    try:
94      _ = input_mixer.ApmInputMixer.Mix(
95          self._tmp_path,
96          self._audio_tracks['capture']['filepath'],
97          self._audio_tracks['shorter']['filepath'])
98      self.fail('no exception raised')
99    except exceptions.InputMixerException:
100      pass
101
102  def testCheckMixDurationWithLongerEcho(self):
103    """Checks the duration when mixing an echo longer than the capture."""
104    mix_filepath = input_mixer.ApmInputMixer.Mix(
105        self._tmp_path,
106        self._audio_tracks['capture']['filepath'],
107        self._audio_tracks['longer']['filepath'])
108    self.assertTrue(os.path.exists(mix_filepath))
109
110    mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
111    self.assertEqual(self._audio_tracks['capture']['num_samples'],
112                     signal_processing.SignalProcessingUtils.CountSamples(mix))
113
114  def testCheckOutputFileNamesConflict(self):
115    """Checks that different echo files lead to different output file names."""
116    mix1_filepath = input_mixer.ApmInputMixer.Mix(
117        self._tmp_path,
118        self._audio_tracks['capture']['filepath'],
119        self._audio_tracks['echo_1']['filepath'])
120    self.assertTrue(os.path.exists(mix1_filepath))
121
122    mix2_filepath = input_mixer.ApmInputMixer.Mix(
123        self._tmp_path,
124        self._audio_tracks['capture']['filepath'],
125        self._audio_tracks['echo_2']['filepath'])
126    self.assertTrue(os.path.exists(mix2_filepath))
127
128    self.assertNotEqual(mix1_filepath, mix2_filepath)
129
130  def testHardClippingLogExpected(self):
131    """Checks that hard clipping warning is raised when occurring."""
132    logging.warning = mock.MagicMock(name='warning')
133    _ = input_mixer.ApmInputMixer.Mix(
134        self._tmp_path,
135        self._audio_tracks['capture']['filepath'],
136        self._audio_tracks['echo_2']['filepath'])
137    logging.warning.assert_called_once_with(
138        input_mixer.ApmInputMixer.HardClippingLogMessage())
139
140  def testHardClippingLogNotExpected(self):
141    """Checks that hard clipping warning is not raised when not occurring."""
142    logging.warning = mock.MagicMock(name='warning')
143    _ = input_mixer.ApmInputMixer.Mix(
144        self._tmp_path,
145        self._audio_tracks['capture']['filepath'],
146        self._audio_tracks['echo_1']['filepath'])
147    self.assertNotIn(
148        mock.call(input_mixer.ApmInputMixer.HardClippingLogMessage()),
149        logging.warning.call_args_list)
150