1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2#
3# Use of this source code is governed by a BSD-style license
4# that can be found in the LICENSE file in the root of the source
5# tree. An additional intellectual property rights grant can be found
6# in the file PATENTS.  All contributing project authors may
7# be found in the AUTHORS file in the root of the source tree.
8
9"""Evaluation score abstract class and implementations.
10"""
11
12from __future__ import division
13import logging
14import os
15import re
16import subprocess
17import sys
18
19try:
20  import numpy as np
21except ImportError:
22  logging.critical('Cannot import the third-party Python package numpy')
23  sys.exit(1)
24
25from . import data_access
26from . import exceptions
27from . import signal_processing
28
29
30class EvaluationScore(object):
31
32  NAME = None
33  REGISTERED_CLASSES = {}
34
35  def __init__(self, score_filename_prefix):
36    self._score_filename_prefix = score_filename_prefix
37    self._input_signal_metadata = None
38    self._reference_signal = None
39    self._reference_signal_filepath = None
40    self._tested_signal = None
41    self._tested_signal_filepath = None
42    self._output_filepath = None
43    self._score = None
44
45  @classmethod
46  def RegisterClass(cls, class_to_register):
47    """Registers an EvaluationScore implementation.
48
49    Decorator to automatically register the classes that extend EvaluationScore.
50    Example usage:
51
52    @EvaluationScore.RegisterClass
53    class AudioLevelScore(EvaluationScore):
54      pass
55    """
56    cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
57    return class_to_register
58
59  @property
60  def output_filepath(self):
61    return self._output_filepath
62
63  @property
64  def score(self):
65    return self._score
66
67  def SetInputSignalMetadata(self, metadata):
68    """Sets input signal metadata.
69
70    Args:
71      metadata: dict instance.
72    """
73    self._input_signal_metadata = metadata
74
75  def SetReferenceSignalFilepath(self, filepath):
76    """Sets the path to the audio track used as reference signal.
77
78    Args:
79      filepath: path to the reference audio track.
80    """
81    self._reference_signal_filepath = filepath
82
83  def SetTestedSignalFilepath(self, filepath):
84    """Sets the path to the audio track used as test signal.
85
86    Args:
87      filepath: path to the test audio track.
88    """
89    self._tested_signal_filepath = filepath
90
91  def Run(self, output_path):
92    """Extracts the score for the set test data pair.
93
94    Args:
95      output_path: path to the directory where the output is written.
96    """
97    self._output_filepath = os.path.join(
98        output_path, self._score_filename_prefix + self.NAME + '.txt')
99    try:
100      # If the score has already been computed, load.
101      self._LoadScore()
102      logging.debug('score found and loaded')
103    except IOError:
104      # Compute the score.
105      logging.debug('score not found, compute')
106      self._Run(output_path)
107
108  def _Run(self, output_path):
109    # Abstract method.
110    raise NotImplementedError()
111
112  def _LoadReferenceSignal(self):
113    assert self._reference_signal_filepath is not None
114    self._reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
115        self._reference_signal_filepath)
116
117  def _LoadTestedSignal(self):
118    assert self._tested_signal_filepath is not None
119    self._tested_signal = signal_processing.SignalProcessingUtils.LoadWav(
120        self._tested_signal_filepath)
121
122
123  def _LoadScore(self):
124    return data_access.ScoreFile.Load(self._output_filepath)
125
126  def _SaveScore(self):
127    return data_access.ScoreFile.Save(self._output_filepath, self._score)
128
129
130@EvaluationScore.RegisterClass
131class AudioLevelPeakScore(EvaluationScore):
132  """Peak audio level score.
133
134  Defined as the difference between the peak audio level of the tested and
135  the reference signals.
136
137  Unit: dB
138  Ideal: 0 dB
139  Worst case: +/-inf dB
140  """
141
142  NAME = 'audio_level_peak'
143
144  def __init__(self, score_filename_prefix):
145    EvaluationScore.__init__(self, score_filename_prefix)
146
147  def _Run(self, output_path):
148    self._LoadReferenceSignal()
149    self._LoadTestedSignal()
150    self._score = self._tested_signal.dBFS - self._reference_signal.dBFS
151    self._SaveScore()
152
153
154@EvaluationScore.RegisterClass
155class MeanAudioLevelScore(EvaluationScore):
156  """Mean audio level score.
157
158  Defined as the difference between the mean audio level of the tested and
159  the reference signals.
160
161  Unit: dB
162  Ideal: 0 dB
163  Worst case: +/-inf dB
164  """
165
166  NAME = 'audio_level_mean'
167
168  def __init__(self, score_filename_prefix):
169    EvaluationScore.__init__(self, score_filename_prefix)
170
171  def _Run(self, output_path):
172    self._LoadReferenceSignal()
173    self._LoadTestedSignal()
174
175    dbfs_diffs_sum = 0.0
176    seconds = min(len(self._tested_signal), len(self._reference_signal)) // 1000
177    for t in range(seconds):
178      t0 = t * seconds
179      t1 = t0 + seconds
180      dbfs_diffs_sum += (
181        self._tested_signal[t0:t1].dBFS - self._reference_signal[t0:t1].dBFS)
182    self._score = dbfs_diffs_sum / float(seconds)
183    self._SaveScore()
184
185
186@EvaluationScore.RegisterClass
187class PolqaScore(EvaluationScore):
188  """POLQA score.
189
190  See http://www.polqa.info/.
191
192  Unit: MOS
193  Ideal: 4.5
194  Worst case: 1.0
195  """
196
197  NAME = 'polqa'
198
199  def __init__(self, score_filename_prefix, polqa_bin_filepath):
200    EvaluationScore.__init__(self, score_filename_prefix)
201
202    # POLQA binary file path.
203    self._polqa_bin_filepath = polqa_bin_filepath
204    if not os.path.exists(self._polqa_bin_filepath):
205      logging.error('cannot find POLQA tool binary file')
206      raise exceptions.FileNotFoundError()
207
208    # Path to the POLQA directory with binary and license files.
209    self._polqa_tool_path, _ = os.path.split(self._polqa_bin_filepath)
210
211  def _Run(self, output_path):
212    polqa_out_filepath = os.path.join(output_path, 'polqa.out')
213    if os.path.exists(polqa_out_filepath):
214      os.unlink(polqa_out_filepath)
215
216    args = [
217        self._polqa_bin_filepath, '-t', '-q', '-Overwrite',
218        '-Ref', self._reference_signal_filepath,
219        '-Test', self._tested_signal_filepath,
220        '-LC', 'NB',
221        '-Out', polqa_out_filepath,
222    ]
223    logging.debug(' '.join(args))
224    subprocess.call(args, cwd=self._polqa_tool_path)
225
226    # Parse POLQA tool output and extract the score.
227    polqa_output = self._ParseOutputFile(polqa_out_filepath)
228    self._score = float(polqa_output['PolqaScore'])
229
230    self._SaveScore()
231
232  @classmethod
233  def _ParseOutputFile(cls, polqa_out_filepath):
234    """
235    Parses the POLQA tool output formatted as a table ('-t' option).
236
237    Args:
238      polqa_out_filepath: path to the POLQA tool output file.
239
240    Returns:
241      A dict.
242    """
243    data = []
244    with open(polqa_out_filepath) as f:
245      for line in f:
246        line = line.strip()
247        if len(line) == 0 or line.startswith('*'):
248          # Ignore comments.
249          continue
250        # Read fields.
251        data.append(re.split(r'\t+', line))
252
253    # Two rows expected (header and values).
254    assert len(data) == 2, 'Cannot parse POLQA output'
255    number_of_fields = len(data[0])
256    assert number_of_fields == len(data[1])
257
258    # Build and return a dictionary with field names (header) as keys and the
259    # corresponding field values as values.
260    return {data[0][index]: data[1][index] for index in range(number_of_fields)}
261
262
263@EvaluationScore.RegisterClass
264class TotalHarmonicDistorsionScore(EvaluationScore):
265  """Total harmonic distorsion plus noise score.
266
267  Total harmonic distorsion plus noise score.
268  See "https://en.wikipedia.org/wiki/Total_harmonic_distortion#THD.2BN".
269
270  Unit: -.
271  Ideal: 0.
272  Worst case: +inf
273  """
274
275  NAME = 'thd'
276
277  def __init__(self, score_filename_prefix):
278    EvaluationScore.__init__(self, score_filename_prefix)
279    self._input_frequency = None
280
281  def _Run(self, output_path):
282    # TODO(aleloi): Integrate changes made locally.
283    self._CheckInputSignal()
284
285    self._LoadTestedSignal()
286    if self._tested_signal.channels != 1:
287      raise exceptions.EvaluationScoreException(
288          'unsupported number of channels')
289    samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
290        self._tested_signal)
291
292    # Init.
293    num_samples = len(samples)
294    duration = len(self._tested_signal) / 1000.0
295    scaling = 2.0 / num_samples
296    max_freq = self._tested_signal.frame_rate / 2
297    f0_freq = float(self._input_frequency)
298    t = np.linspace(0, duration, num_samples)
299
300    # Analyze harmonics.
301    b_terms = []
302    n = 1
303    while f0_freq * n < max_freq:
304      x_n = np.sum(samples * np.sin(2.0 * np.pi * n * f0_freq * t)) * scaling
305      y_n = np.sum(samples * np.cos(2.0 * np.pi * n * f0_freq * t)) * scaling
306      b_terms.append(np.sqrt(x_n**2 + y_n**2))
307      n += 1
308
309    output_without_fundamental = samples - b_terms[0] * np.sin(
310        2.0 * np.pi * f0_freq * t)
311    distortion_and_noise = np.sqrt(np.sum(
312        output_without_fundamental**2) * np.pi * scaling)
313
314    # TODO(alessiob): Fix or remove if not needed.
315    # thd = np.sqrt(np.sum(b_terms[1:]**2)) / b_terms[0]
316
317    # TODO(alessiob): Check the range of |thd_plus_noise| and update the class
318    # docstring above if accordingly.
319    thd_plus_noise = distortion_and_noise / b_terms[0]
320
321    self._score = thd_plus_noise
322    self._SaveScore()
323
324  def _CheckInputSignal(self):
325    # Check input signal and get properties.
326    try:
327      if self._input_signal_metadata['signal'] != 'pure_tone':
328        raise exceptions.EvaluationScoreException(
329            'The THD score requires a pure tone as input signal')
330      self._input_frequency = self._input_signal_metadata['frequency']
331      if self._input_signal_metadata['test_data_gen_name'] != 'identity' or (
332          self._input_signal_metadata['test_data_gen_config'] != 'default'):
333        raise exceptions.EvaluationScoreException(
334            'The THD score cannot be used with any test data generator other '
335            'than "identity"')
336    except TypeError:
337      raise exceptions.EvaluationScoreException(
338          'The THD score requires an input signal with associated metadata')
339    except KeyError:
340      raise exceptions.EvaluationScoreException(
341          'Invalid input signal metadata to compute the THD score')
342