1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2#
3# Use of this source code is governed by a BSD-style license
4# that can be found in the LICENSE file in the root of the source
5# tree. An additional intellectual property rights grant can be found
6# in the file PATENTS.  All contributing project authors may
7# be found in the AUTHORS file in the root of the source tree.
8
9"""Evaluation score abstract class and implementations.
10"""
11
12from __future__ import division
13import logging
14import os
15import re
16import subprocess
17import sys
18
19try:
20  import numpy as np
21except ImportError:
22  logging.critical('Cannot import the third-party Python package numpy')
23  sys.exit(1)
24
25from . import data_access
26from . import exceptions
27from . import signal_processing
28
29
30class EvaluationScore(object):
31
32  NAME = None
33  REGISTERED_CLASSES = {}
34
35  def __init__(self, score_filename_prefix):
36    self._score_filename_prefix = score_filename_prefix
37    self._input_signal_metadata = None
38    self._reference_signal = None
39    self._reference_signal_filepath = None
40    self._tested_signal = None
41    self._tested_signal_filepath = None
42    self._output_filepath = None
43    self._score = None
44    self._render_signal_filepath = None
45
46  @classmethod
47  def RegisterClass(cls, class_to_register):
48    """Registers an EvaluationScore implementation.
49
50    Decorator to automatically register the classes that extend EvaluationScore.
51    Example usage:
52
53    @EvaluationScore.RegisterClass
54    class AudioLevelScore(EvaluationScore):
55      pass
56    """
57    cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
58    return class_to_register
59
60  @property
61  def output_filepath(self):
62    return self._output_filepath
63
64  @property
65  def score(self):
66    return self._score
67
68  def SetInputSignalMetadata(self, metadata):
69    """Sets input signal metadata.
70
71    Args:
72      metadata: dict instance.
73    """
74    self._input_signal_metadata = metadata
75
76  def SetReferenceSignalFilepath(self, filepath):
77    """Sets the path to the audio track used as reference signal.
78
79    Args:
80      filepath: path to the reference audio track.
81    """
82    self._reference_signal_filepath = filepath
83
84  def SetTestedSignalFilepath(self, filepath):
85    """Sets the path to the audio track used as test signal.
86
87    Args:
88      filepath: path to the test audio track.
89    """
90    self._tested_signal_filepath = filepath
91
92  def SetRenderSignalFilepath(self, filepath):
93    """Sets the path to the audio track used as render signal.
94
95    Args:
96      filepath: path to the test audio track.
97    """
98    self._render_signal_filepath = filepath
99
100  def Run(self, output_path):
101    """Extracts the score for the set test data pair.
102
103    Args:
104      output_path: path to the directory where the output is written.
105    """
106    self._output_filepath = os.path.join(
107        output_path, self._score_filename_prefix + self.NAME + '.txt')
108    try:
109      # If the score has already been computed, load.
110      self._LoadScore()
111      logging.debug('score found and loaded')
112    except IOError:
113      # Compute the score.
114      logging.debug('score not found, compute')
115      self._Run(output_path)
116
117  def _Run(self, output_path):
118    # Abstract method.
119    raise NotImplementedError()
120
121  def _LoadReferenceSignal(self):
122    assert self._reference_signal_filepath is not None
123    self._reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
124        self._reference_signal_filepath)
125
126  def _LoadTestedSignal(self):
127    assert self._tested_signal_filepath is not None
128    self._tested_signal = signal_processing.SignalProcessingUtils.LoadWav(
129        self._tested_signal_filepath)
130
131
132  def _LoadScore(self):
133    return data_access.ScoreFile.Load(self._output_filepath)
134
135  def _SaveScore(self):
136    return data_access.ScoreFile.Save(self._output_filepath, self._score)
137
138
139@EvaluationScore.RegisterClass
140class AudioLevelPeakScore(EvaluationScore):
141  """Peak audio level score.
142
143  Defined as the difference between the peak audio level of the tested and
144  the reference signals.
145
146  Unit: dB
147  Ideal: 0 dB
148  Worst case: +/-inf dB
149  """
150
151  NAME = 'audio_level_peak'
152
153  def __init__(self, score_filename_prefix):
154    EvaluationScore.__init__(self, score_filename_prefix)
155
156  def _Run(self, output_path):
157    self._LoadReferenceSignal()
158    self._LoadTestedSignal()
159    self._score = self._tested_signal.dBFS - self._reference_signal.dBFS
160    self._SaveScore()
161
162
163@EvaluationScore.RegisterClass
164class MeanAudioLevelScore(EvaluationScore):
165  """Mean audio level score.
166
167  Defined as the difference between the mean audio level of the tested and
168  the reference signals.
169
170  Unit: dB
171  Ideal: 0 dB
172  Worst case: +/-inf dB
173  """
174
175  NAME = 'audio_level_mean'
176
177  def __init__(self, score_filename_prefix):
178    EvaluationScore.__init__(self, score_filename_prefix)
179
180  def _Run(self, output_path):
181    self._LoadReferenceSignal()
182    self._LoadTestedSignal()
183
184    dbfs_diffs_sum = 0.0
185    seconds = min(len(self._tested_signal), len(self._reference_signal)) // 1000
186    for t in range(seconds):
187      t0 = t * seconds
188      t1 = t0 + seconds
189      dbfs_diffs_sum += (
190        self._tested_signal[t0:t1].dBFS - self._reference_signal[t0:t1].dBFS)
191    self._score = dbfs_diffs_sum / float(seconds)
192    self._SaveScore()
193
194
195@EvaluationScore.RegisterClass
196class EchoMetric(EvaluationScore):
197  """Echo score.
198
199  Proportion of detected echo.
200
201  Unit: ratio
202  Ideal: 0
203  Worst case: 1
204  """
205
206  NAME = 'echo_metric'
207
208  def __init__(self, score_filename_prefix, echo_detector_bin_filepath):
209    EvaluationScore.__init__(self, score_filename_prefix)
210
211    # POLQA binary file path.
212    self._echo_detector_bin_filepath = echo_detector_bin_filepath
213    if not os.path.exists(self._echo_detector_bin_filepath):
214      logging.error('cannot find EchoMetric tool binary file')
215      raise exceptions.FileNotFoundError()
216
217    self._echo_detector_bin_path, _ = os.path.split(
218        self._echo_detector_bin_filepath)
219
220  def _Run(self, output_path):
221    echo_detector_out_filepath = os.path.join(output_path, 'echo_detector.out')
222    if os.path.exists(echo_detector_out_filepath):
223      os.unlink(echo_detector_out_filepath)
224
225    logging.debug("Render signal filepath: %s", self._render_signal_filepath)
226    if not os.path.exists(self._render_signal_filepath):
227      logging.error("Render input required for evaluating the echo metric.")
228
229    args = [
230        self._echo_detector_bin_filepath,
231        '--output_file', echo_detector_out_filepath,
232        '--',
233        '-i', self._tested_signal_filepath,
234        '-ri', self._render_signal_filepath
235    ]
236    logging.debug(' '.join(args))
237    subprocess.call(args, cwd=self._echo_detector_bin_path)
238
239    # Parse Echo detector tool output and extract the score.
240    self._score = self._ParseOutputFile(echo_detector_out_filepath)
241    self._SaveScore()
242
243  @classmethod
244  def _ParseOutputFile(cls, echo_metric_file_path):
245    """
246    Parses the POLQA tool output formatted as a table ('-t' option).
247
248    Args:
249      polqa_out_filepath: path to the POLQA tool output file.
250
251    Returns:
252      The score as a number in [0, 1].
253    """
254    with open(echo_metric_file_path) as f:
255      return float(f.read())
256
257@EvaluationScore.RegisterClass
258class PolqaScore(EvaluationScore):
259  """POLQA score.
260
261  See http://www.polqa.info/.
262
263  Unit: MOS
264  Ideal: 4.5
265  Worst case: 1.0
266  """
267
268  NAME = 'polqa'
269
270  def __init__(self, score_filename_prefix, polqa_bin_filepath):
271    EvaluationScore.__init__(self, score_filename_prefix)
272
273    # POLQA binary file path.
274    self._polqa_bin_filepath = polqa_bin_filepath
275    if not os.path.exists(self._polqa_bin_filepath):
276      logging.error('cannot find POLQA tool binary file')
277      raise exceptions.FileNotFoundError()
278
279    # Path to the POLQA directory with binary and license files.
280    self._polqa_tool_path, _ = os.path.split(self._polqa_bin_filepath)
281
282  def _Run(self, output_path):
283    polqa_out_filepath = os.path.join(output_path, 'polqa.out')
284    if os.path.exists(polqa_out_filepath):
285      os.unlink(polqa_out_filepath)
286
287    args = [
288        self._polqa_bin_filepath, '-t', '-q', '-Overwrite',
289        '-Ref', self._reference_signal_filepath,
290        '-Test', self._tested_signal_filepath,
291        '-LC', 'NB',
292        '-Out', polqa_out_filepath,
293    ]
294    logging.debug(' '.join(args))
295    subprocess.call(args, cwd=self._polqa_tool_path)
296
297    # Parse POLQA tool output and extract the score.
298    polqa_output = self._ParseOutputFile(polqa_out_filepath)
299    self._score = float(polqa_output['PolqaScore'])
300
301    self._SaveScore()
302
303  @classmethod
304  def _ParseOutputFile(cls, polqa_out_filepath):
305    """
306    Parses the POLQA tool output formatted as a table ('-t' option).
307
308    Args:
309      polqa_out_filepath: path to the POLQA tool output file.
310
311    Returns:
312      A dict.
313    """
314    data = []
315    with open(polqa_out_filepath) as f:
316      for line in f:
317        line = line.strip()
318        if len(line) == 0 or line.startswith('*'):
319          # Ignore comments.
320          continue
321        # Read fields.
322        data.append(re.split(r'\t+', line))
323
324    # Two rows expected (header and values).
325    assert len(data) == 2, 'Cannot parse POLQA output'
326    number_of_fields = len(data[0])
327    assert number_of_fields == len(data[1])
328
329    # Build and return a dictionary with field names (header) as keys and the
330    # corresponding field values as values.
331    return {data[0][index]: data[1][index] for index in range(number_of_fields)}
332
333
334@EvaluationScore.RegisterClass
335class TotalHarmonicDistorsionScore(EvaluationScore):
336  """Total harmonic distorsion plus noise score.
337
338  Total harmonic distorsion plus noise score.
339  See "https://en.wikipedia.org/wiki/Total_harmonic_distortion#THD.2BN".
340
341  Unit: -.
342  Ideal: 0.
343  Worst case: +inf
344  """
345
346  NAME = 'thd'
347
348  def __init__(self, score_filename_prefix):
349    EvaluationScore.__init__(self, score_filename_prefix)
350    self._input_frequency = None
351
352  def _Run(self, output_path):
353    self._CheckInputSignal()
354
355    self._LoadTestedSignal()
356    if self._tested_signal.channels != 1:
357      raise exceptions.EvaluationScoreException(
358          'unsupported number of channels')
359    samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
360        self._tested_signal)
361
362    # Init.
363    num_samples = len(samples)
364    duration = len(self._tested_signal) / 1000.0
365    scaling = 2.0 / num_samples
366    max_freq = self._tested_signal.frame_rate / 2
367    f0_freq = float(self._input_frequency)
368    t = np.linspace(0, duration, num_samples)
369
370    # Analyze harmonics.
371    b_terms = []
372    n = 1
373    while f0_freq * n < max_freq:
374      x_n = np.sum(samples * np.sin(2.0 * np.pi * n * f0_freq * t)) * scaling
375      y_n = np.sum(samples * np.cos(2.0 * np.pi * n * f0_freq * t)) * scaling
376      b_terms.append(np.sqrt(x_n**2 + y_n**2))
377      n += 1
378
379    output_without_fundamental = samples - b_terms[0] * np.sin(
380        2.0 * np.pi * f0_freq * t)
381    distortion_and_noise = np.sqrt(np.sum(
382        output_without_fundamental**2) * np.pi * scaling)
383
384    # TODO(alessiob): Fix or remove if not needed.
385    # thd = np.sqrt(np.sum(b_terms[1:]**2)) / b_terms[0]
386
387    # TODO(alessiob): Check the range of |thd_plus_noise| and update the class
388    # docstring above if accordingly.
389    thd_plus_noise = distortion_and_noise / b_terms[0]
390
391    self._score = thd_plus_noise
392    self._SaveScore()
393
394  def _CheckInputSignal(self):
395    # Check input signal and get properties.
396    try:
397      if self._input_signal_metadata['signal'] != 'pure_tone':
398        raise exceptions.EvaluationScoreException(
399            'The THD score requires a pure tone as input signal')
400      self._input_frequency = self._input_signal_metadata['frequency']
401      if self._input_signal_metadata['test_data_gen_name'] != 'identity' or (
402          self._input_signal_metadata['test_data_gen_config'] != 'default'):
403        raise exceptions.EvaluationScoreException(
404            'The THD score cannot be used with any test data generator other '
405            'than "identity"')
406    except TypeError:
407      raise exceptions.EvaluationScoreException(
408          'The THD score requires an input signal with associated metadata')
409    except KeyError:
410      raise exceptions.EvaluationScoreException(
411          'Invalid input signal metadata to compute the THD score')
412