1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2#
3# Use of this source code is governed by a BSD-style license
4# that can be found in the LICENSE file in the root of the source
5# tree. An additional intellectual property rights grant can be found
6# in the file PATENTS.  All contributing project authors may
7# be found in the AUTHORS file in the root of the source tree.
8
9"""Extraction of annotations from audio files.
10"""
11
12from __future__ import division
13import logging
14import os
15import shutil
16import struct
17import subprocess
18import sys
19import tempfile
20
21try:
22  import numpy as np
23except ImportError:
24  logging.critical('Cannot import the third-party Python package numpy')
25  sys.exit(1)
26
27from . import external_vad
28from . import exceptions
29from . import signal_processing
30
31
32class AudioAnnotationsExtractor(object):
33  """Extracts annotations from audio files.
34  """
35
36  # TODO(aleloi): change to enum.IntEnum when py 3.6 is available.
37  class VadType(object):
38    ENERGY_THRESHOLD = 1  # TODO(alessiob): Consider switching to P56 standard.
39    WEBRTC_COMMON_AUDIO = 2  # common_audio/vad/include/vad.h
40    WEBRTC_APM = 4  # modules/audio_processing/vad/vad.h
41
42    def __init__(self, value):
43      if (not isinstance(value, int)) or not 0 <= value <= 7:
44        raise exceptions.InitializationException(
45            'Invalid vad type: ' + value)
46      self._value = value
47
48    def Contains(self, vad_type):
49      return self._value | vad_type == self._value
50
51    def __str__(self):
52      vads = []
53      if self.Contains(self.ENERGY_THRESHOLD):
54        vads.append("energy")
55      if self.Contains(self.WEBRTC_COMMON_AUDIO):
56        vads.append("common_audio")
57      if self.Contains(self.WEBRTC_APM):
58        vads.append("apm")
59      return "VadType({})".format(", ".join(vads))
60
61  _OUTPUT_FILENAME_TEMPLATE = '{}annotations.npz'
62
63  # Level estimation params.
64  _ONE_DB_REDUCTION = np.power(10.0, -1.0 / 20.0)
65  _LEVEL_FRAME_SIZE_MS = 1.0
66  # The time constants in ms indicate the time it takes for the level estimate
67  # to go down/up by 1 db if the signal is zero.
68  _LEVEL_ATTACK_MS = 5.0
69  _LEVEL_DECAY_MS = 20.0
70
71  # VAD params.
72  _VAD_THRESHOLD = 1
73  _VAD_WEBRTC_PATH = os.path.join(os.path.dirname(
74      os.path.abspath(__file__)), os.pardir, os.pardir)
75  _VAD_WEBRTC_COMMON_AUDIO_PATH = os.path.join(_VAD_WEBRTC_PATH, 'vad')
76
77  _VAD_WEBRTC_APM_PATH = os.path.join(
78      _VAD_WEBRTC_PATH, 'apm_vad')
79
80  def __init__(self, vad_type, external_vads=None):
81    self._signal = None
82    self._level = None
83    self._level_frame_size = None
84    self._common_audio_vad = None
85    self._energy_vad = None
86    self._apm_vad_probs = None
87    self._apm_vad_rms = None
88    self._vad_frame_size = None
89    self._vad_frame_size_ms = None
90    self._c_attack = None
91    self._c_decay = None
92
93    self._vad_type = self.VadType(vad_type)
94    logging.info('VADs used for annotations: ' + str(self._vad_type))
95
96    if external_vads is None:
97      external_vads = {}
98    self._external_vads = external_vads
99
100    assert len(self._external_vads) == len(external_vads), (
101        'The external VAD names must be unique.')
102    for vad in external_vads.values():
103      if not isinstance(vad, external_vad.ExternalVad):
104        raise exceptions.InitializationException(
105            'Invalid vad type: ' + str(type(vad)))
106      logging.info('External VAD used for annotation: ' +
107                   str(vad.name))
108
109    assert os.path.exists(self._VAD_WEBRTC_COMMON_AUDIO_PATH), \
110      self._VAD_WEBRTC_COMMON_AUDIO_PATH
111    assert os.path.exists(self._VAD_WEBRTC_APM_PATH), \
112      self._VAD_WEBRTC_APM_PATH
113
114  @classmethod
115  def GetOutputFileNameTemplate(cls):
116    return cls._OUTPUT_FILENAME_TEMPLATE
117
118  def GetLevel(self):
119    return self._level
120
121  def GetLevelFrameSize(self):
122    return self._level_frame_size
123
124  @classmethod
125  def GetLevelFrameSizeMs(cls):
126    return cls._LEVEL_FRAME_SIZE_MS
127
128  def GetVadOutput(self, vad_type):
129    if vad_type == self.VadType.ENERGY_THRESHOLD:
130      return self._energy_vad
131    elif vad_type == self.VadType.WEBRTC_COMMON_AUDIO:
132      return self._common_audio_vad
133    elif vad_type == self.VadType.WEBRTC_APM:
134      return (self._apm_vad_probs, self._apm_vad_rms)
135    else:
136      raise exceptions.InitializationException(
137            'Invalid vad type: ' + vad_type)
138
139  def GetVadFrameSize(self):
140    return self._vad_frame_size
141
142  def GetVadFrameSizeMs(self):
143    return self._vad_frame_size_ms
144
145  def Extract(self, filepath):
146    # Load signal.
147    self._signal = signal_processing.SignalProcessingUtils.LoadWav(filepath)
148    if self._signal.channels != 1:
149      raise NotImplementedError('Multiple-channel annotations not implemented')
150
151    # Level estimation params.
152    self._level_frame_size = int(self._signal.frame_rate / 1000 * (
153        self._LEVEL_FRAME_SIZE_MS))
154    self._c_attack = 0.0 if self._LEVEL_ATTACK_MS == 0 else (
155        self._ONE_DB_REDUCTION ** (
156            self._LEVEL_FRAME_SIZE_MS / self._LEVEL_ATTACK_MS))
157    self._c_decay = 0.0 if self._LEVEL_DECAY_MS == 0 else (
158        self._ONE_DB_REDUCTION ** (
159            self._LEVEL_FRAME_SIZE_MS / self._LEVEL_DECAY_MS))
160
161    # Compute level.
162    self._LevelEstimation()
163
164    # Ideal VAD output, it requires clean speech with high SNR as input.
165    if self._vad_type.Contains(self.VadType.ENERGY_THRESHOLD):
166      # Naive VAD based on level thresholding.
167      vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD)
168      self._energy_vad = np.uint8(self._level > vad_threshold)
169      self._vad_frame_size = self._level_frame_size
170      self._vad_frame_size_ms = self._LEVEL_FRAME_SIZE_MS
171    if self._vad_type.Contains(self.VadType.WEBRTC_COMMON_AUDIO):
172      # WebRTC common_audio/ VAD.
173      self._RunWebRtcCommonAudioVad(filepath, self._signal.frame_rate)
174    if self._vad_type.Contains(self.VadType.WEBRTC_APM):
175      # WebRTC modules/audio_processing/ VAD.
176      self._RunWebRtcApmVad(filepath)
177    for extvad_name in self._external_vads:
178      self._external_vads[extvad_name].Run(filepath)
179
180  def Save(self, output_path, annotation_name=""):
181    ext_kwargs = {'extvad_conf-' + ext_vad:
182                  self._external_vads[ext_vad].GetVadOutput()
183                  for ext_vad in self._external_vads}
184    # pylint: disable=star-args
185    np.savez_compressed(
186        file=os.path.join(
187            output_path,
188            self.GetOutputFileNameTemplate().format(annotation_name)),
189        level=self._level,
190        level_frame_size=self._level_frame_size,
191        level_frame_size_ms=self._LEVEL_FRAME_SIZE_MS,
192        vad_output=self._common_audio_vad,
193        vad_energy_output=self._energy_vad,
194        vad_frame_size=self._vad_frame_size,
195        vad_frame_size_ms=self._vad_frame_size_ms,
196        vad_probs=self._apm_vad_probs,
197        vad_rms=self._apm_vad_rms,
198        **ext_kwargs
199    )
200
201  def _LevelEstimation(self):
202    # Read samples.
203    samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
204        self._signal).astype(np.float32) / 32768.0
205    num_frames = len(samples) // self._level_frame_size
206    num_samples = num_frames * self._level_frame_size
207
208    # Envelope.
209    self._level = np.max(np.reshape(np.abs(samples[:num_samples]), (
210        num_frames, self._level_frame_size)), axis=1)
211    assert len(self._level) == num_frames
212
213    # Envelope smoothing.
214    smooth = lambda curr, prev, k: (1 - k) * curr  + k * prev
215    self._level[0] = smooth(self._level[0], 0.0, self._c_attack)
216    for i in range(1, num_frames):
217      self._level[i] = smooth(
218          self._level[i], self._level[i - 1], self._c_attack if (
219              self._level[i] > self._level[i - 1]) else self._c_decay)
220
221  def _RunWebRtcCommonAudioVad(self, wav_file_path, sample_rate):
222    self._common_audio_vad = None
223    self._vad_frame_size = None
224
225    # Create temporary output path.
226    tmp_path = tempfile.mkdtemp()
227    output_file_path = os.path.join(
228        tmp_path, os.path.split(wav_file_path)[1] + '_vad.tmp')
229
230    # Call WebRTC VAD.
231    try:
232      subprocess.call([
233          self._VAD_WEBRTC_COMMON_AUDIO_PATH,
234          '-i', wav_file_path,
235          '-o', output_file_path
236      ], cwd=self._VAD_WEBRTC_PATH)
237
238      # Read bytes.
239      with open(output_file_path, 'rb') as f:
240        raw_data = f.read()
241
242      # Parse side information.
243      self._vad_frame_size_ms = struct.unpack('B', raw_data[0])[0]
244      self._vad_frame_size = self._vad_frame_size_ms * sample_rate / 1000
245      assert self._vad_frame_size_ms in [10, 20, 30]
246      extra_bits = struct.unpack('B', raw_data[-1])[0]
247      assert 0 <= extra_bits <= 8
248
249      # Init VAD vector.
250      num_bytes = len(raw_data)
251      num_frames = 8 * (num_bytes - 2) - extra_bits  # 8 frames for each byte.
252      self._common_audio_vad = np.zeros(num_frames, np.uint8)
253
254      # Read VAD decisions.
255      for i, byte in enumerate(raw_data[1:-1]):
256        byte = struct.unpack('B', byte)[0]
257        for j in range(8 if i < num_bytes - 3 else (8 - extra_bits)):
258          self._common_audio_vad[i * 8 + j] = int(byte & 1)
259          byte = byte >> 1
260    except Exception as e:
261      logging.error('Error while running the WebRTC VAD (' + e.message + ')')
262    finally:
263      if os.path.exists(tmp_path):
264        shutil.rmtree(tmp_path)
265
266  def _RunWebRtcApmVad(self, wav_file_path):
267    # Create temporary output path.
268    tmp_path = tempfile.mkdtemp()
269    output_file_path_probs = os.path.join(
270        tmp_path, os.path.split(wav_file_path)[1] + '_vad_probs.tmp')
271    output_file_path_rms = os.path.join(
272        tmp_path, os.path.split(wav_file_path)[1] + '_vad_rms.tmp')
273
274    # Call WebRTC VAD.
275    try:
276      subprocess.call([
277          self._VAD_WEBRTC_APM_PATH,
278          '-i', wav_file_path,
279          '-o_probs', output_file_path_probs,
280          '-o_rms', output_file_path_rms
281      ], cwd=self._VAD_WEBRTC_PATH)
282
283      # Parse annotations.
284      self._apm_vad_probs = np.fromfile(output_file_path_probs, np.double)
285      self._apm_vad_rms = np.fromfile(output_file_path_rms, np.double)
286      assert len(self._apm_vad_rms) == len(self._apm_vad_probs)
287
288    except Exception as e:
289      logging.error('Error while running the WebRTC APM VAD (' +
290                    e.message + ')')
291    finally:
292      if os.path.exists(tmp_path):
293        shutil.rmtree(tmp_path)
294