1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2#
3# Use of this source code is governed by a BSD-style license
4# that can be found in the LICENSE file in the root of the source
5# tree. An additional intellectual property rights grant can be found
6# in the file PATENTS.  All contributing project authors may
7# be found in the AUTHORS file in the root of the source tree.
8
9from __future__ import division
10
11import logging
12import os
13import subprocess
14import shutil
15import sys
16import tempfile
17
18try:
19    import numpy as np
20except ImportError:
21    logging.critical('Cannot import the third-party Python package numpy')
22    sys.exit(1)
23
24from . import signal_processing
25
26
27class ExternalVad(object):
28    def __init__(self, path_to_binary, name):
29        """Args:
30       path_to_binary: path to binary that accepts '-i <wav>', '-o
31          <float probabilities>'. There must be one float value per
32          10ms audio
33       name: a name to identify the external VAD. Used for saving
34          the output as extvad_output-<name>.
35    """
36        self._path_to_binary = path_to_binary
37        self.name = name
38        assert os.path.exists(self._path_to_binary), (self._path_to_binary)
39        self._vad_output = None
40
41    def Run(self, wav_file_path):
42        _signal = signal_processing.SignalProcessingUtils.LoadWav(
43            wav_file_path)
44        if _signal.channels != 1:
45            raise NotImplementedError('Multiple-channel'
46                                      ' annotations not implemented')
47        if _signal.frame_rate != 48000:
48            raise NotImplementedError('Frame rates '
49                                      'other than 48000 not implemented')
50
51        tmp_path = tempfile.mkdtemp()
52        try:
53            output_file_path = os.path.join(tmp_path, self.name + '_vad.tmp')
54            subprocess.call([
55                self._path_to_binary, '-i', wav_file_path, '-o',
56                output_file_path
57            ])
58            self._vad_output = np.fromfile(output_file_path, np.float32)
59        except Exception as e:
60            logging.error('Error while running the ' + self.name + ' VAD (' +
61                          e.message + ')')
62        finally:
63            if os.path.exists(tmp_path):
64                shutil.rmtree(tmp_path)
65
66    def GetVadOutput(self):
67        assert self._vad_output is not None
68        return self._vad_output
69
70    @classmethod
71    def ConstructVadDict(cls, vad_paths, vad_names):
72        external_vads = {}
73        for path, name in zip(vad_paths, vad_names):
74            external_vads[name] = ExternalVad(path, name)
75        return external_vads
76