1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2#
3# Use of this source code is governed by a BSD-style license
4# that can be found in the LICENSE file in the root of the source
5# tree. An additional intellectual property rights grant can be found
6# in the file PATENTS.  All contributing project authors may
7# be found in the AUTHORS file in the root of the source tree.
8"""Extraction of annotations from audio files.
9"""
10
11from __future__ import division
12import logging
13import os
14import shutil
15import struct
16import subprocess
17import sys
18import tempfile
19
20try:
21    import numpy as np
22except ImportError:
23    logging.critical('Cannot import the third-party Python package numpy')
24    sys.exit(1)
25
26from . import external_vad
27from . import exceptions
28from . import signal_processing
29
30
31class AudioAnnotationsExtractor(object):
32    """Extracts annotations from audio files.
33  """
34
35    class VadType(object):
36        ENERGY_THRESHOLD = 1  # TODO(alessiob): Consider switching to P56 standard.
37        WEBRTC_COMMON_AUDIO = 2  # common_audio/vad/include/vad.h
38        WEBRTC_APM = 4  # modules/audio_processing/vad/vad.h
39
40        def __init__(self, value):
41            if (not isinstance(value, int)) or not 0 <= value <= 7:
42                raise exceptions.InitializationException('Invalid vad type: ' +
43                                                         value)
44            self._value = value
45
46        def Contains(self, vad_type):
47            return self._value | vad_type == self._value
48
49        def __str__(self):
50            vads = []
51            if self.Contains(self.ENERGY_THRESHOLD):
52                vads.append("energy")
53            if self.Contains(self.WEBRTC_COMMON_AUDIO):
54                vads.append("common_audio")
55            if self.Contains(self.WEBRTC_APM):
56                vads.append("apm")
57            return "VadType({})".format(", ".join(vads))
58
59    _OUTPUT_FILENAME_TEMPLATE = '{}annotations.npz'
60
61    # Level estimation params.
62    _ONE_DB_REDUCTION = np.power(10.0, -1.0 / 20.0)
63    _LEVEL_FRAME_SIZE_MS = 1.0
64    # The time constants in ms indicate the time it takes for the level estimate
65    # to go down/up by 1 db if the signal is zero.
66    _LEVEL_ATTACK_MS = 5.0
67    _LEVEL_DECAY_MS = 20.0
68
69    # VAD params.
70    _VAD_THRESHOLD = 1
71    _VAD_WEBRTC_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)),
72                                    os.pardir, os.pardir)
73    _VAD_WEBRTC_COMMON_AUDIO_PATH = os.path.join(_VAD_WEBRTC_PATH, 'vad')
74
75    _VAD_WEBRTC_APM_PATH = os.path.join(_VAD_WEBRTC_PATH, 'apm_vad')
76
77    def __init__(self, vad_type, external_vads=None):
78        self._signal = None
79        self._level = None
80        self._level_frame_size = None
81        self._common_audio_vad = None
82        self._energy_vad = None
83        self._apm_vad_probs = None
84        self._apm_vad_rms = None
85        self._vad_frame_size = None
86        self._vad_frame_size_ms = None
87        self._c_attack = None
88        self._c_decay = None
89
90        self._vad_type = self.VadType(vad_type)
91        logging.info('VADs used for annotations: ' + str(self._vad_type))
92
93        if external_vads is None:
94            external_vads = {}
95        self._external_vads = external_vads
96
97        assert len(self._external_vads) == len(external_vads), (
98            'The external VAD names must be unique.')
99        for vad in external_vads.values():
100            if not isinstance(vad, external_vad.ExternalVad):
101                raise exceptions.InitializationException('Invalid vad type: ' +
102                                                         str(type(vad)))
103            logging.info('External VAD used for annotation: ' + str(vad.name))
104
105        assert os.path.exists(self._VAD_WEBRTC_COMMON_AUDIO_PATH), \
106          self._VAD_WEBRTC_COMMON_AUDIO_PATH
107        assert os.path.exists(self._VAD_WEBRTC_APM_PATH), \
108          self._VAD_WEBRTC_APM_PATH
109
110    @classmethod
111    def GetOutputFileNameTemplate(cls):
112        return cls._OUTPUT_FILENAME_TEMPLATE
113
114    def GetLevel(self):
115        return self._level
116
117    def GetLevelFrameSize(self):
118        return self._level_frame_size
119
120    @classmethod
121    def GetLevelFrameSizeMs(cls):
122        return cls._LEVEL_FRAME_SIZE_MS
123
124    def GetVadOutput(self, vad_type):
125        if vad_type == self.VadType.ENERGY_THRESHOLD:
126            return self._energy_vad
127        elif vad_type == self.VadType.WEBRTC_COMMON_AUDIO:
128            return self._common_audio_vad
129        elif vad_type == self.VadType.WEBRTC_APM:
130            return (self._apm_vad_probs, self._apm_vad_rms)
131        else:
132            raise exceptions.InitializationException('Invalid vad type: ' +
133                                                     vad_type)
134
135    def GetVadFrameSize(self):
136        return self._vad_frame_size
137
138    def GetVadFrameSizeMs(self):
139        return self._vad_frame_size_ms
140
141    def Extract(self, filepath):
142        # Load signal.
143        self._signal = signal_processing.SignalProcessingUtils.LoadWav(
144            filepath)
145        if self._signal.channels != 1:
146            raise NotImplementedError(
147                'Multiple-channel annotations not implemented')
148
149        # Level estimation params.
150        self._level_frame_size = int(self._signal.frame_rate / 1000 *
151                                     (self._LEVEL_FRAME_SIZE_MS))
152        self._c_attack = 0.0 if self._LEVEL_ATTACK_MS == 0 else (
153            self._ONE_DB_REDUCTION**(self._LEVEL_FRAME_SIZE_MS /
154                                     self._LEVEL_ATTACK_MS))
155        self._c_decay = 0.0 if self._LEVEL_DECAY_MS == 0 else (
156            self._ONE_DB_REDUCTION**(self._LEVEL_FRAME_SIZE_MS /
157                                     self._LEVEL_DECAY_MS))
158
159        # Compute level.
160        self._LevelEstimation()
161
162        # Ideal VAD output, it requires clean speech with high SNR as input.
163        if self._vad_type.Contains(self.VadType.ENERGY_THRESHOLD):
164            # Naive VAD based on level thresholding.
165            vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD)
166            self._energy_vad = np.uint8(self._level > vad_threshold)
167            self._vad_frame_size = self._level_frame_size
168            self._vad_frame_size_ms = self._LEVEL_FRAME_SIZE_MS
169        if self._vad_type.Contains(self.VadType.WEBRTC_COMMON_AUDIO):
170            # WebRTC common_audio/ VAD.
171            self._RunWebRtcCommonAudioVad(filepath, self._signal.frame_rate)
172        if self._vad_type.Contains(self.VadType.WEBRTC_APM):
173            # WebRTC modules/audio_processing/ VAD.
174            self._RunWebRtcApmVad(filepath)
175        for extvad_name in self._external_vads:
176            self._external_vads[extvad_name].Run(filepath)
177
178    def Save(self, output_path, annotation_name=""):
179        ext_kwargs = {
180            'extvad_conf-' + ext_vad:
181            self._external_vads[ext_vad].GetVadOutput()
182            for ext_vad in self._external_vads
183        }
184        np.savez_compressed(file=os.path.join(
185            output_path,
186            self.GetOutputFileNameTemplate().format(annotation_name)),
187                            level=self._level,
188                            level_frame_size=self._level_frame_size,
189                            level_frame_size_ms=self._LEVEL_FRAME_SIZE_MS,
190                            vad_output=self._common_audio_vad,
191                            vad_energy_output=self._energy_vad,
192                            vad_frame_size=self._vad_frame_size,
193                            vad_frame_size_ms=self._vad_frame_size_ms,
194                            vad_probs=self._apm_vad_probs,
195                            vad_rms=self._apm_vad_rms,
196                            **ext_kwargs)
197
198    def _LevelEstimation(self):
199        # Read samples.
200        samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
201            self._signal).astype(np.float32) / 32768.0
202        num_frames = len(samples) // self._level_frame_size
203        num_samples = num_frames * self._level_frame_size
204
205        # Envelope.
206        self._level = np.max(np.reshape(np.abs(samples[:num_samples]),
207                                        (num_frames, self._level_frame_size)),
208                             axis=1)
209        assert len(self._level) == num_frames
210
211        # Envelope smoothing.
212        smooth = lambda curr, prev, k: (1 - k) * curr + k * prev
213        self._level[0] = smooth(self._level[0], 0.0, self._c_attack)
214        for i in range(1, num_frames):
215            self._level[i] = smooth(
216                self._level[i], self._level[i - 1], self._c_attack if
217                (self._level[i] > self._level[i - 1]) else self._c_decay)
218
219    def _RunWebRtcCommonAudioVad(self, wav_file_path, sample_rate):
220        self._common_audio_vad = None
221        self._vad_frame_size = None
222
223        # Create temporary output path.
224        tmp_path = tempfile.mkdtemp()
225        output_file_path = os.path.join(
226            tmp_path,
227            os.path.split(wav_file_path)[1] + '_vad.tmp')
228
229        # Call WebRTC VAD.
230        try:
231            subprocess.call([
232                self._VAD_WEBRTC_COMMON_AUDIO_PATH, '-i', wav_file_path, '-o',
233                output_file_path
234            ],
235                            cwd=self._VAD_WEBRTC_PATH)
236
237            # Read bytes.
238            with open(output_file_path, 'rb') as f:
239                raw_data = f.read()
240
241            # Parse side information.
242            self._vad_frame_size_ms = struct.unpack('B', raw_data[0])[0]
243            self._vad_frame_size = self._vad_frame_size_ms * sample_rate / 1000
244            assert self._vad_frame_size_ms in [10, 20, 30]
245            extra_bits = struct.unpack('B', raw_data[-1])[0]
246            assert 0 <= extra_bits <= 8
247
248            # Init VAD vector.
249            num_bytes = len(raw_data)
250            num_frames = 8 * (num_bytes -
251                              2) - extra_bits  # 8 frames for each byte.
252            self._common_audio_vad = np.zeros(num_frames, np.uint8)
253
254            # Read VAD decisions.
255            for i, byte in enumerate(raw_data[1:-1]):
256                byte = struct.unpack('B', byte)[0]
257                for j in range(8 if i < num_bytes - 3 else (8 - extra_bits)):
258                    self._common_audio_vad[i * 8 + j] = int(byte & 1)
259                    byte = byte >> 1
260        except Exception as e:
261            logging.error('Error while running the WebRTC VAD (' + e.message +
262                          ')')
263        finally:
264            if os.path.exists(tmp_path):
265                shutil.rmtree(tmp_path)
266
267    def _RunWebRtcApmVad(self, wav_file_path):
268        # Create temporary output path.
269        tmp_path = tempfile.mkdtemp()
270        output_file_path_probs = os.path.join(
271            tmp_path,
272            os.path.split(wav_file_path)[1] + '_vad_probs.tmp')
273        output_file_path_rms = os.path.join(
274            tmp_path,
275            os.path.split(wav_file_path)[1] + '_vad_rms.tmp')
276
277        # Call WebRTC VAD.
278        try:
279            subprocess.call([
280                self._VAD_WEBRTC_APM_PATH, '-i', wav_file_path, '-o_probs',
281                output_file_path_probs, '-o_rms', output_file_path_rms
282            ],
283                            cwd=self._VAD_WEBRTC_PATH)
284
285            # Parse annotations.
286            self._apm_vad_probs = np.fromfile(output_file_path_probs,
287                                              np.double)
288            self._apm_vad_rms = np.fromfile(output_file_path_rms, np.double)
289            assert len(self._apm_vad_rms) == len(self._apm_vad_probs)
290
291        except Exception as e:
292            logging.error('Error while running the WebRTC APM VAD (' +
293                          e.message + ')')
294        finally:
295            if os.path.exists(tmp_path):
296                shutil.rmtree(tmp_path)
297