1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2#
3# Use of this source code is governed by a BSD-style license
4# that can be found in the LICENSE file in the root of the source
5# tree. An additional intellectual property rights grant can be found
6# in the file PATENTS.  All contributing project authors may
7# be found in the AUTHORS file in the root of the source tree.
8
9from __future__ import division
10
11import logging
12import os
13import subprocess
14import shutil
15import sys
16import tempfile
17
18try:
19  import numpy as np
20except ImportError:
21  logging.critical('Cannot import the third-party Python package numpy')
22  sys.exit(1)
23
24from . import signal_processing
25
26class ExternalVad(object):
27
28  def __init__(self, path_to_binary, name):
29    """Args:
30       path_to_binary: path to binary that accepts '-i <wav>', '-o
31          <float probabilities>'. There must be one float value per
32          10ms audio
33       name: a name to identify the external VAD. Used for saving
34          the output as extvad_output-<name>.
35    """
36    self._path_to_binary = path_to_binary
37    self.name = name
38    assert os.path.exists(self._path_to_binary), (
39        self._path_to_binary)
40    self._vad_output = None
41
42  def Run(self, wav_file_path):
43    _signal = signal_processing.SignalProcessingUtils.LoadWav(wav_file_path)
44    if _signal.channels != 1:
45      raise NotImplementedError('Multiple-channel'
46                                ' annotations not implemented')
47    if _signal.frame_rate != 48000:
48      raise NotImplementedError('Frame rates '
49                                'other than 48000 not implemented')
50
51    tmp_path = tempfile.mkdtemp()
52    try:
53      output_file_path = os.path.join(
54          tmp_path, self.name + '_vad.tmp')
55      subprocess.call([
56          self._path_to_binary,
57          '-i', wav_file_path,
58          '-o', output_file_path
59      ])
60      self._vad_output = np.fromfile(output_file_path, np.float32)
61    except Exception as e:
62      logging.error('Error while running the ' + self.name +
63                    ' VAD (' + e.message + ')')
64    finally:
65      if os.path.exists(tmp_path):
66        shutil.rmtree(tmp_path)
67
68  def GetVadOutput(self):
69    assert self._vad_output is not None
70    return self._vad_output
71
72  @classmethod
73  def ConstructVadDict(cls, vad_paths, vad_names):
74    external_vads = {}
75    for path, name in zip(vad_paths, vad_names):
76      external_vads[name] = ExternalVad(path, name)
77    return external_vads
78