1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2#
3# Use of this source code is governed by a BSD-style license
4# that can be found in the LICENSE file in the root of the source
5# tree. An additional intellectual property rights grant can be found
6# in the file PATENTS.  All contributing project authors may
7# be found in the AUTHORS file in the root of the source tree.
8"""Evaluation score abstract class and implementations.
9"""
10
11from __future__ import division
12import logging
13import os
14import re
15import subprocess
16import sys
17
18try:
19    import numpy as np
20except ImportError:
21    logging.critical('Cannot import the third-party Python package numpy')
22    sys.exit(1)
23
24from . import data_access
25from . import exceptions
26from . import signal_processing
27
28
29class EvaluationScore(object):
30
31    NAME = None
32    REGISTERED_CLASSES = {}
33
34    def __init__(self, score_filename_prefix):
35        self._score_filename_prefix = score_filename_prefix
36        self._input_signal_metadata = None
37        self._reference_signal = None
38        self._reference_signal_filepath = None
39        self._tested_signal = None
40        self._tested_signal_filepath = None
41        self._output_filepath = None
42        self._score = None
43        self._render_signal_filepath = None
44
45    @classmethod
46    def RegisterClass(cls, class_to_register):
47        """Registers an EvaluationScore implementation.
48
49    Decorator to automatically register the classes that extend EvaluationScore.
50    Example usage:
51
52    @EvaluationScore.RegisterClass
53    class AudioLevelScore(EvaluationScore):
54      pass
55    """
56        cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
57        return class_to_register
58
59    @property
60    def output_filepath(self):
61        return self._output_filepath
62
63    @property
64    def score(self):
65        return self._score
66
67    def SetInputSignalMetadata(self, metadata):
68        """Sets input signal metadata.
69
70    Args:
71      metadata: dict instance.
72    """
73        self._input_signal_metadata = metadata
74
75    def SetReferenceSignalFilepath(self, filepath):
76        """Sets the path to the audio track used as reference signal.
77
78    Args:
79      filepath: path to the reference audio track.
80    """
81        self._reference_signal_filepath = filepath
82
83    def SetTestedSignalFilepath(self, filepath):
84        """Sets the path to the audio track used as test signal.
85
86    Args:
87      filepath: path to the test audio track.
88    """
89        self._tested_signal_filepath = filepath
90
91    def SetRenderSignalFilepath(self, filepath):
92        """Sets the path to the audio track used as render signal.
93
94    Args:
95      filepath: path to the test audio track.
96    """
97        self._render_signal_filepath = filepath
98
99    def Run(self, output_path):
100        """Extracts the score for the set test data pair.
101
102    Args:
103      output_path: path to the directory where the output is written.
104    """
105        self._output_filepath = os.path.join(
106            output_path, self._score_filename_prefix + self.NAME + '.txt')
107        try:
108            # If the score has already been computed, load.
109            self._LoadScore()
110            logging.debug('score found and loaded')
111        except IOError:
112            # Compute the score.
113            logging.debug('score not found, compute')
114            self._Run(output_path)
115
116    def _Run(self, output_path):
117        # Abstract method.
118        raise NotImplementedError()
119
120    def _LoadReferenceSignal(self):
121        assert self._reference_signal_filepath is not None
122        self._reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
123            self._reference_signal_filepath)
124
125    def _LoadTestedSignal(self):
126        assert self._tested_signal_filepath is not None
127        self._tested_signal = signal_processing.SignalProcessingUtils.LoadWav(
128            self._tested_signal_filepath)
129
130    def _LoadScore(self):
131        return data_access.ScoreFile.Load(self._output_filepath)
132
133    def _SaveScore(self):
134        return data_access.ScoreFile.Save(self._output_filepath, self._score)
135
136
137@EvaluationScore.RegisterClass
138class AudioLevelPeakScore(EvaluationScore):
139    """Peak audio level score.
140
141  Defined as the difference between the peak audio level of the tested and
142  the reference signals.
143
144  Unit: dB
145  Ideal: 0 dB
146  Worst case: +/-inf dB
147  """
148
149    NAME = 'audio_level_peak'
150
151    def __init__(self, score_filename_prefix):
152        EvaluationScore.__init__(self, score_filename_prefix)
153
154    def _Run(self, output_path):
155        self._LoadReferenceSignal()
156        self._LoadTestedSignal()
157        self._score = self._tested_signal.dBFS - self._reference_signal.dBFS
158        self._SaveScore()
159
160
161@EvaluationScore.RegisterClass
162class MeanAudioLevelScore(EvaluationScore):
163    """Mean audio level score.
164
165  Defined as the difference between the mean audio level of the tested and
166  the reference signals.
167
168  Unit: dB
169  Ideal: 0 dB
170  Worst case: +/-inf dB
171  """
172
173    NAME = 'audio_level_mean'
174
175    def __init__(self, score_filename_prefix):
176        EvaluationScore.__init__(self, score_filename_prefix)
177
178    def _Run(self, output_path):
179        self._LoadReferenceSignal()
180        self._LoadTestedSignal()
181
182        dbfs_diffs_sum = 0.0
183        seconds = min(len(self._tested_signal), len(
184            self._reference_signal)) // 1000
185        for t in range(seconds):
186            t0 = t * seconds
187            t1 = t0 + seconds
188            dbfs_diffs_sum += (self._tested_signal[t0:t1].dBFS -
189                               self._reference_signal[t0:t1].dBFS)
190        self._score = dbfs_diffs_sum / float(seconds)
191        self._SaveScore()
192
193
194@EvaluationScore.RegisterClass
195class EchoMetric(EvaluationScore):
196    """Echo score.
197
198  Proportion of detected echo.
199
200  Unit: ratio
201  Ideal: 0
202  Worst case: 1
203  """
204
205    NAME = 'echo_metric'
206
207    def __init__(self, score_filename_prefix, echo_detector_bin_filepath):
208        EvaluationScore.__init__(self, score_filename_prefix)
209
210        # POLQA binary file path.
211        self._echo_detector_bin_filepath = echo_detector_bin_filepath
212        if not os.path.exists(self._echo_detector_bin_filepath):
213            logging.error('cannot find EchoMetric tool binary file')
214            raise exceptions.FileNotFoundError()
215
216        self._echo_detector_bin_path, _ = os.path.split(
217            self._echo_detector_bin_filepath)
218
219    def _Run(self, output_path):
220        echo_detector_out_filepath = os.path.join(output_path,
221                                                  'echo_detector.out')
222        if os.path.exists(echo_detector_out_filepath):
223            os.unlink(echo_detector_out_filepath)
224
225        logging.debug("Render signal filepath: %s",
226                      self._render_signal_filepath)
227        if not os.path.exists(self._render_signal_filepath):
228            logging.error(
229                "Render input required for evaluating the echo metric.")
230
231        args = [
232            self._echo_detector_bin_filepath, '--output_file',
233            echo_detector_out_filepath, '--', '-i',
234            self._tested_signal_filepath, '-ri', self._render_signal_filepath
235        ]
236        logging.debug(' '.join(args))
237        subprocess.call(args, cwd=self._echo_detector_bin_path)
238
239        # Parse Echo detector tool output and extract the score.
240        self._score = self._ParseOutputFile(echo_detector_out_filepath)
241        self._SaveScore()
242
243    @classmethod
244    def _ParseOutputFile(cls, echo_metric_file_path):
245        """
246    Parses the POLQA tool output formatted as a table ('-t' option).
247
248    Args:
249      polqa_out_filepath: path to the POLQA tool output file.
250
251    Returns:
252      The score as a number in [0, 1].
253    """
254        with open(echo_metric_file_path) as f:
255            return float(f.read())
256
257
258@EvaluationScore.RegisterClass
259class PolqaScore(EvaluationScore):
260    """POLQA score.
261
262  See http://www.polqa.info/.
263
264  Unit: MOS
265  Ideal: 4.5
266  Worst case: 1.0
267  """
268
269    NAME = 'polqa'
270
271    def __init__(self, score_filename_prefix, polqa_bin_filepath):
272        EvaluationScore.__init__(self, score_filename_prefix)
273
274        # POLQA binary file path.
275        self._polqa_bin_filepath = polqa_bin_filepath
276        if not os.path.exists(self._polqa_bin_filepath):
277            logging.error('cannot find POLQA tool binary file')
278            raise exceptions.FileNotFoundError()
279
280        # Path to the POLQA directory with binary and license files.
281        self._polqa_tool_path, _ = os.path.split(self._polqa_bin_filepath)
282
283    def _Run(self, output_path):
284        polqa_out_filepath = os.path.join(output_path, 'polqa.out')
285        if os.path.exists(polqa_out_filepath):
286            os.unlink(polqa_out_filepath)
287
288        args = [
289            self._polqa_bin_filepath,
290            '-t',
291            '-q',
292            '-Overwrite',
293            '-Ref',
294            self._reference_signal_filepath,
295            '-Test',
296            self._tested_signal_filepath,
297            '-LC',
298            'NB',
299            '-Out',
300            polqa_out_filepath,
301        ]
302        logging.debug(' '.join(args))
303        subprocess.call(args, cwd=self._polqa_tool_path)
304
305        # Parse POLQA tool output and extract the score.
306        polqa_output = self._ParseOutputFile(polqa_out_filepath)
307        self._score = float(polqa_output['PolqaScore'])
308
309        self._SaveScore()
310
311    @classmethod
312    def _ParseOutputFile(cls, polqa_out_filepath):
313        """
314    Parses the POLQA tool output formatted as a table ('-t' option).
315
316    Args:
317      polqa_out_filepath: path to the POLQA tool output file.
318
319    Returns:
320      A dict.
321    """
322        data = []
323        with open(polqa_out_filepath) as f:
324            for line in f:
325                line = line.strip()
326                if len(line) == 0 or line.startswith('*'):
327                    # Ignore comments.
328                    continue
329                # Read fields.
330                data.append(re.split(r'\t+', line))
331
332        # Two rows expected (header and values).
333        assert len(data) == 2, 'Cannot parse POLQA output'
334        number_of_fields = len(data[0])
335        assert number_of_fields == len(data[1])
336
337        # Build and return a dictionary with field names (header) as keys and the
338        # corresponding field values as values.
339        return {
340            data[0][index]: data[1][index]
341            for index in range(number_of_fields)
342        }
343
344
345@EvaluationScore.RegisterClass
346class TotalHarmonicDistorsionScore(EvaluationScore):
347    """Total harmonic distorsion plus noise score.
348
349  Total harmonic distorsion plus noise score.
350  See "https://en.wikipedia.org/wiki/Total_harmonic_distortion#THD.2BN".
351
352  Unit: -.
353  Ideal: 0.
354  Worst case: +inf
355  """
356
357    NAME = 'thd'
358
359    def __init__(self, score_filename_prefix):
360        EvaluationScore.__init__(self, score_filename_prefix)
361        self._input_frequency = None
362
363    def _Run(self, output_path):
364        self._CheckInputSignal()
365
366        self._LoadTestedSignal()
367        if self._tested_signal.channels != 1:
368            raise exceptions.EvaluationScoreException(
369                'unsupported number of channels')
370        samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
371            self._tested_signal)
372
373        # Init.
374        num_samples = len(samples)
375        duration = len(self._tested_signal) / 1000.0
376        scaling = 2.0 / num_samples
377        max_freq = self._tested_signal.frame_rate / 2
378        f0_freq = float(self._input_frequency)
379        t = np.linspace(0, duration, num_samples)
380
381        # Analyze harmonics.
382        b_terms = []
383        n = 1
384        while f0_freq * n < max_freq:
385            x_n = np.sum(
386                samples * np.sin(2.0 * np.pi * n * f0_freq * t)) * scaling
387            y_n = np.sum(
388                samples * np.cos(2.0 * np.pi * n * f0_freq * t)) * scaling
389            b_terms.append(np.sqrt(x_n**2 + y_n**2))
390            n += 1
391
392        output_without_fundamental = samples - b_terms[0] * np.sin(
393            2.0 * np.pi * f0_freq * t)
394        distortion_and_noise = np.sqrt(
395            np.sum(output_without_fundamental**2) * np.pi * scaling)
396
397        # TODO(alessiob): Fix or remove if not needed.
398        # thd = np.sqrt(np.sum(b_terms[1:]**2)) / b_terms[0]
399
400        # TODO(alessiob): Check the range of |thd_plus_noise| and update the class
401        # docstring above if accordingly.
402        thd_plus_noise = distortion_and_noise / b_terms[0]
403
404        self._score = thd_plus_noise
405        self._SaveScore()
406
407    def _CheckInputSignal(self):
408        # Check input signal and get properties.
409        try:
410            if self._input_signal_metadata['signal'] != 'pure_tone':
411                raise exceptions.EvaluationScoreException(
412                    'The THD score requires a pure tone as input signal')
413            self._input_frequency = self._input_signal_metadata['frequency']
414            if self._input_signal_metadata[
415                    'test_data_gen_name'] != 'identity' or (
416                        self._input_signal_metadata['test_data_gen_config'] !=
417                        'default'):
418                raise exceptions.EvaluationScoreException(
419                    'The THD score cannot be used with any test data generator other '
420                    'than "identity"')
421        except TypeError:
422            raise exceptions.EvaluationScoreException(
423                'The THD score requires an input signal with associated metadata'
424            )
425        except KeyError:
426            raise exceptions.EvaluationScoreException(
427                'Invalid input signal metadata to compute the THD score')
428