1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 2# 3# Use of this source code is governed by a BSD-style license 4# that can be found in the LICENSE file in the root of the source 5# tree. An additional intellectual property rights grant can be found 6# in the file PATENTS. All contributing project authors may 7# be found in the AUTHORS file in the root of the source tree. 8"""Extraction of annotations from audio files. 9""" 10 11from __future__ import division 12import logging 13import os 14import shutil 15import struct 16import subprocess 17import sys 18import tempfile 19 20try: 21 import numpy as np 22except ImportError: 23 logging.critical('Cannot import the third-party Python package numpy') 24 sys.exit(1) 25 26from . import external_vad 27from . import exceptions 28from . import signal_processing 29 30 31class AudioAnnotationsExtractor(object): 32 """Extracts annotations from audio files. 33 """ 34 35 class VadType(object): 36 ENERGY_THRESHOLD = 1 # TODO(alessiob): Consider switching to P56 standard. 37 WEBRTC_COMMON_AUDIO = 2 # common_audio/vad/include/vad.h 38 WEBRTC_APM = 4 # modules/audio_processing/vad/vad.h 39 40 def __init__(self, value): 41 if (not isinstance(value, int)) or not 0 <= value <= 7: 42 raise exceptions.InitializationException('Invalid vad type: ' + 43 value) 44 self._value = value 45 46 def Contains(self, vad_type): 47 return self._value | vad_type == self._value 48 49 def __str__(self): 50 vads = [] 51 if self.Contains(self.ENERGY_THRESHOLD): 52 vads.append("energy") 53 if self.Contains(self.WEBRTC_COMMON_AUDIO): 54 vads.append("common_audio") 55 if self.Contains(self.WEBRTC_APM): 56 vads.append("apm") 57 return "VadType({})".format(", ".join(vads)) 58 59 _OUTPUT_FILENAME_TEMPLATE = '{}annotations.npz' 60 61 # Level estimation params. 62 _ONE_DB_REDUCTION = np.power(10.0, -1.0 / 20.0) 63 _LEVEL_FRAME_SIZE_MS = 1.0 64 # The time constants in ms indicate the time it takes for the level estimate 65 # to go down/up by 1 db if the signal is zero. 66 _LEVEL_ATTACK_MS = 5.0 67 _LEVEL_DECAY_MS = 20.0 68 69 # VAD params. 70 _VAD_THRESHOLD = 1 71 _VAD_WEBRTC_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 72 os.pardir, os.pardir) 73 _VAD_WEBRTC_COMMON_AUDIO_PATH = os.path.join(_VAD_WEBRTC_PATH, 'vad') 74 75 _VAD_WEBRTC_APM_PATH = os.path.join(_VAD_WEBRTC_PATH, 'apm_vad') 76 77 def __init__(self, vad_type, external_vads=None): 78 self._signal = None 79 self._level = None 80 self._level_frame_size = None 81 self._common_audio_vad = None 82 self._energy_vad = None 83 self._apm_vad_probs = None 84 self._apm_vad_rms = None 85 self._vad_frame_size = None 86 self._vad_frame_size_ms = None 87 self._c_attack = None 88 self._c_decay = None 89 90 self._vad_type = self.VadType(vad_type) 91 logging.info('VADs used for annotations: ' + str(self._vad_type)) 92 93 if external_vads is None: 94 external_vads = {} 95 self._external_vads = external_vads 96 97 assert len(self._external_vads) == len(external_vads), ( 98 'The external VAD names must be unique.') 99 for vad in external_vads.values(): 100 if not isinstance(vad, external_vad.ExternalVad): 101 raise exceptions.InitializationException('Invalid vad type: ' + 102 str(type(vad))) 103 logging.info('External VAD used for annotation: ' + str(vad.name)) 104 105 assert os.path.exists(self._VAD_WEBRTC_COMMON_AUDIO_PATH), \ 106 self._VAD_WEBRTC_COMMON_AUDIO_PATH 107 assert os.path.exists(self._VAD_WEBRTC_APM_PATH), \ 108 self._VAD_WEBRTC_APM_PATH 109 110 @classmethod 111 def GetOutputFileNameTemplate(cls): 112 return cls._OUTPUT_FILENAME_TEMPLATE 113 114 def GetLevel(self): 115 return self._level 116 117 def GetLevelFrameSize(self): 118 return self._level_frame_size 119 120 @classmethod 121 def GetLevelFrameSizeMs(cls): 122 return cls._LEVEL_FRAME_SIZE_MS 123 124 def GetVadOutput(self, vad_type): 125 if vad_type == self.VadType.ENERGY_THRESHOLD: 126 return self._energy_vad 127 elif vad_type == self.VadType.WEBRTC_COMMON_AUDIO: 128 return self._common_audio_vad 129 elif vad_type == self.VadType.WEBRTC_APM: 130 return (self._apm_vad_probs, self._apm_vad_rms) 131 else: 132 raise exceptions.InitializationException('Invalid vad type: ' + 133 vad_type) 134 135 def GetVadFrameSize(self): 136 return self._vad_frame_size 137 138 def GetVadFrameSizeMs(self): 139 return self._vad_frame_size_ms 140 141 def Extract(self, filepath): 142 # Load signal. 143 self._signal = signal_processing.SignalProcessingUtils.LoadWav( 144 filepath) 145 if self._signal.channels != 1: 146 raise NotImplementedError( 147 'Multiple-channel annotations not implemented') 148 149 # Level estimation params. 150 self._level_frame_size = int(self._signal.frame_rate / 1000 * 151 (self._LEVEL_FRAME_SIZE_MS)) 152 self._c_attack = 0.0 if self._LEVEL_ATTACK_MS == 0 else ( 153 self._ONE_DB_REDUCTION**(self._LEVEL_FRAME_SIZE_MS / 154 self._LEVEL_ATTACK_MS)) 155 self._c_decay = 0.0 if self._LEVEL_DECAY_MS == 0 else ( 156 self._ONE_DB_REDUCTION**(self._LEVEL_FRAME_SIZE_MS / 157 self._LEVEL_DECAY_MS)) 158 159 # Compute level. 160 self._LevelEstimation() 161 162 # Ideal VAD output, it requires clean speech with high SNR as input. 163 if self._vad_type.Contains(self.VadType.ENERGY_THRESHOLD): 164 # Naive VAD based on level thresholding. 165 vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD) 166 self._energy_vad = np.uint8(self._level > vad_threshold) 167 self._vad_frame_size = self._level_frame_size 168 self._vad_frame_size_ms = self._LEVEL_FRAME_SIZE_MS 169 if self._vad_type.Contains(self.VadType.WEBRTC_COMMON_AUDIO): 170 # WebRTC common_audio/ VAD. 171 self._RunWebRtcCommonAudioVad(filepath, self._signal.frame_rate) 172 if self._vad_type.Contains(self.VadType.WEBRTC_APM): 173 # WebRTC modules/audio_processing/ VAD. 174 self._RunWebRtcApmVad(filepath) 175 for extvad_name in self._external_vads: 176 self._external_vads[extvad_name].Run(filepath) 177 178 def Save(self, output_path, annotation_name=""): 179 ext_kwargs = { 180 'extvad_conf-' + ext_vad: 181 self._external_vads[ext_vad].GetVadOutput() 182 for ext_vad in self._external_vads 183 } 184 np.savez_compressed(file=os.path.join( 185 output_path, 186 self.GetOutputFileNameTemplate().format(annotation_name)), 187 level=self._level, 188 level_frame_size=self._level_frame_size, 189 level_frame_size_ms=self._LEVEL_FRAME_SIZE_MS, 190 vad_output=self._common_audio_vad, 191 vad_energy_output=self._energy_vad, 192 vad_frame_size=self._vad_frame_size, 193 vad_frame_size_ms=self._vad_frame_size_ms, 194 vad_probs=self._apm_vad_probs, 195 vad_rms=self._apm_vad_rms, 196 **ext_kwargs) 197 198 def _LevelEstimation(self): 199 # Read samples. 200 samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData( 201 self._signal).astype(np.float32) / 32768.0 202 num_frames = len(samples) // self._level_frame_size 203 num_samples = num_frames * self._level_frame_size 204 205 # Envelope. 206 self._level = np.max(np.reshape(np.abs(samples[:num_samples]), 207 (num_frames, self._level_frame_size)), 208 axis=1) 209 assert len(self._level) == num_frames 210 211 # Envelope smoothing. 212 smooth = lambda curr, prev, k: (1 - k) * curr + k * prev 213 self._level[0] = smooth(self._level[0], 0.0, self._c_attack) 214 for i in range(1, num_frames): 215 self._level[i] = smooth( 216 self._level[i], self._level[i - 1], self._c_attack if 217 (self._level[i] > self._level[i - 1]) else self._c_decay) 218 219 def _RunWebRtcCommonAudioVad(self, wav_file_path, sample_rate): 220 self._common_audio_vad = None 221 self._vad_frame_size = None 222 223 # Create temporary output path. 224 tmp_path = tempfile.mkdtemp() 225 output_file_path = os.path.join( 226 tmp_path, 227 os.path.split(wav_file_path)[1] + '_vad.tmp') 228 229 # Call WebRTC VAD. 230 try: 231 subprocess.call([ 232 self._VAD_WEBRTC_COMMON_AUDIO_PATH, '-i', wav_file_path, '-o', 233 output_file_path 234 ], 235 cwd=self._VAD_WEBRTC_PATH) 236 237 # Read bytes. 238 with open(output_file_path, 'rb') as f: 239 raw_data = f.read() 240 241 # Parse side information. 242 self._vad_frame_size_ms = struct.unpack('B', raw_data[0])[0] 243 self._vad_frame_size = self._vad_frame_size_ms * sample_rate / 1000 244 assert self._vad_frame_size_ms in [10, 20, 30] 245 extra_bits = struct.unpack('B', raw_data[-1])[0] 246 assert 0 <= extra_bits <= 8 247 248 # Init VAD vector. 249 num_bytes = len(raw_data) 250 num_frames = 8 * (num_bytes - 251 2) - extra_bits # 8 frames for each byte. 252 self._common_audio_vad = np.zeros(num_frames, np.uint8) 253 254 # Read VAD decisions. 255 for i, byte in enumerate(raw_data[1:-1]): 256 byte = struct.unpack('B', byte)[0] 257 for j in range(8 if i < num_bytes - 3 else (8 - extra_bits)): 258 self._common_audio_vad[i * 8 + j] = int(byte & 1) 259 byte = byte >> 1 260 except Exception as e: 261 logging.error('Error while running the WebRTC VAD (' + e.message + 262 ')') 263 finally: 264 if os.path.exists(tmp_path): 265 shutil.rmtree(tmp_path) 266 267 def _RunWebRtcApmVad(self, wav_file_path): 268 # Create temporary output path. 269 tmp_path = tempfile.mkdtemp() 270 output_file_path_probs = os.path.join( 271 tmp_path, 272 os.path.split(wav_file_path)[1] + '_vad_probs.tmp') 273 output_file_path_rms = os.path.join( 274 tmp_path, 275 os.path.split(wav_file_path)[1] + '_vad_rms.tmp') 276 277 # Call WebRTC VAD. 278 try: 279 subprocess.call([ 280 self._VAD_WEBRTC_APM_PATH, '-i', wav_file_path, '-o_probs', 281 output_file_path_probs, '-o_rms', output_file_path_rms 282 ], 283 cwd=self._VAD_WEBRTC_PATH) 284 285 # Parse annotations. 286 self._apm_vad_probs = np.fromfile(output_file_path_probs, 287 np.double) 288 self._apm_vad_rms = np.fromfile(output_file_path_rms, np.double) 289 assert len(self._apm_vad_rms) == len(self._apm_vad_probs) 290 291 except Exception as e: 292 logging.error('Error while running the WebRTC APM VAD (' + 293 e.message + ')') 294 finally: 295 if os.path.exists(tmp_path): 296 shutil.rmtree(tmp_path) 297