1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 2# 3# Use of this source code is governed by a BSD-style license 4# that can be found in the LICENSE file in the root of the source 5# tree. An additional intellectual property rights grant can be found 6# in the file PATENTS. All contributing project authors may 7# be found in the AUTHORS file in the root of the source tree. 8 9"""Extraction of annotations from audio files. 10""" 11 12from __future__ import division 13import logging 14import os 15import shutil 16import struct 17import subprocess 18import sys 19import tempfile 20 21try: 22 import numpy as np 23except ImportError: 24 logging.critical('Cannot import the third-party Python package numpy') 25 sys.exit(1) 26 27from . import external_vad 28from . import exceptions 29from . import signal_processing 30 31 32class AudioAnnotationsExtractor(object): 33 """Extracts annotations from audio files. 34 """ 35 36 # TODO(aleloi): change to enum.IntEnum when py 3.6 is available. 37 class VadType(object): 38 ENERGY_THRESHOLD = 1 # TODO(alessiob): Consider switching to P56 standard. 39 WEBRTC_COMMON_AUDIO = 2 # common_audio/vad/include/vad.h 40 WEBRTC_APM = 4 # modules/audio_processing/vad/vad.h 41 42 def __init__(self, value): 43 if (not isinstance(value, int)) or not 0 <= value <= 7: 44 raise exceptions.InitializationException( 45 'Invalid vad type: ' + value) 46 self._value = value 47 48 def Contains(self, vad_type): 49 return self._value | vad_type == self._value 50 51 def __str__(self): 52 vads = [] 53 if self.Contains(self.ENERGY_THRESHOLD): 54 vads.append("energy") 55 if self.Contains(self.WEBRTC_COMMON_AUDIO): 56 vads.append("common_audio") 57 if self.Contains(self.WEBRTC_APM): 58 vads.append("apm") 59 return "VadType({})".format(", ".join(vads)) 60 61 _OUTPUT_FILENAME_TEMPLATE = '{}annotations.npz' 62 63 # Level estimation params. 64 _ONE_DB_REDUCTION = np.power(10.0, -1.0 / 20.0) 65 _LEVEL_FRAME_SIZE_MS = 1.0 66 # The time constants in ms indicate the time it takes for the level estimate 67 # to go down/up by 1 db if the signal is zero. 68 _LEVEL_ATTACK_MS = 5.0 69 _LEVEL_DECAY_MS = 20.0 70 71 # VAD params. 72 _VAD_THRESHOLD = 1 73 _VAD_WEBRTC_PATH = os.path.join(os.path.dirname( 74 os.path.abspath(__file__)), os.pardir, os.pardir) 75 _VAD_WEBRTC_COMMON_AUDIO_PATH = os.path.join(_VAD_WEBRTC_PATH, 'vad') 76 77 _VAD_WEBRTC_APM_PATH = os.path.join( 78 _VAD_WEBRTC_PATH, 'apm_vad') 79 80 def __init__(self, vad_type, external_vads=None): 81 self._signal = None 82 self._level = None 83 self._level_frame_size = None 84 self._common_audio_vad = None 85 self._energy_vad = None 86 self._apm_vad_probs = None 87 self._apm_vad_rms = None 88 self._vad_frame_size = None 89 self._vad_frame_size_ms = None 90 self._c_attack = None 91 self._c_decay = None 92 93 self._vad_type = self.VadType(vad_type) 94 logging.info('VADs used for annotations: ' + str(self._vad_type)) 95 96 if external_vads is None: 97 external_vads = {} 98 self._external_vads = external_vads 99 100 assert len(self._external_vads) == len(external_vads), ( 101 'The external VAD names must be unique.') 102 for vad in external_vads.values(): 103 if not isinstance(vad, external_vad.ExternalVad): 104 raise exceptions.InitializationException( 105 'Invalid vad type: ' + str(type(vad))) 106 logging.info('External VAD used for annotation: ' + 107 str(vad.name)) 108 109 assert os.path.exists(self._VAD_WEBRTC_COMMON_AUDIO_PATH), \ 110 self._VAD_WEBRTC_COMMON_AUDIO_PATH 111 assert os.path.exists(self._VAD_WEBRTC_APM_PATH), \ 112 self._VAD_WEBRTC_APM_PATH 113 114 @classmethod 115 def GetOutputFileNameTemplate(cls): 116 return cls._OUTPUT_FILENAME_TEMPLATE 117 118 def GetLevel(self): 119 return self._level 120 121 def GetLevelFrameSize(self): 122 return self._level_frame_size 123 124 @classmethod 125 def GetLevelFrameSizeMs(cls): 126 return cls._LEVEL_FRAME_SIZE_MS 127 128 def GetVadOutput(self, vad_type): 129 if vad_type == self.VadType.ENERGY_THRESHOLD: 130 return self._energy_vad 131 elif vad_type == self.VadType.WEBRTC_COMMON_AUDIO: 132 return self._common_audio_vad 133 elif vad_type == self.VadType.WEBRTC_APM: 134 return (self._apm_vad_probs, self._apm_vad_rms) 135 else: 136 raise exceptions.InitializationException( 137 'Invalid vad type: ' + vad_type) 138 139 def GetVadFrameSize(self): 140 return self._vad_frame_size 141 142 def GetVadFrameSizeMs(self): 143 return self._vad_frame_size_ms 144 145 def Extract(self, filepath): 146 # Load signal. 147 self._signal = signal_processing.SignalProcessingUtils.LoadWav(filepath) 148 if self._signal.channels != 1: 149 raise NotImplementedError('Multiple-channel annotations not implemented') 150 151 # Level estimation params. 152 self._level_frame_size = int(self._signal.frame_rate / 1000 * ( 153 self._LEVEL_FRAME_SIZE_MS)) 154 self._c_attack = 0.0 if self._LEVEL_ATTACK_MS == 0 else ( 155 self._ONE_DB_REDUCTION ** ( 156 self._LEVEL_FRAME_SIZE_MS / self._LEVEL_ATTACK_MS)) 157 self._c_decay = 0.0 if self._LEVEL_DECAY_MS == 0 else ( 158 self._ONE_DB_REDUCTION ** ( 159 self._LEVEL_FRAME_SIZE_MS / self._LEVEL_DECAY_MS)) 160 161 # Compute level. 162 self._LevelEstimation() 163 164 # Ideal VAD output, it requires clean speech with high SNR as input. 165 if self._vad_type.Contains(self.VadType.ENERGY_THRESHOLD): 166 # Naive VAD based on level thresholding. 167 vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD) 168 self._energy_vad = np.uint8(self._level > vad_threshold) 169 self._vad_frame_size = self._level_frame_size 170 self._vad_frame_size_ms = self._LEVEL_FRAME_SIZE_MS 171 if self._vad_type.Contains(self.VadType.WEBRTC_COMMON_AUDIO): 172 # WebRTC common_audio/ VAD. 173 self._RunWebRtcCommonAudioVad(filepath, self._signal.frame_rate) 174 if self._vad_type.Contains(self.VadType.WEBRTC_APM): 175 # WebRTC modules/audio_processing/ VAD. 176 self._RunWebRtcApmVad(filepath) 177 for extvad_name in self._external_vads: 178 self._external_vads[extvad_name].Run(filepath) 179 180 def Save(self, output_path, annotation_name=""): 181 ext_kwargs = {'extvad_conf-' + ext_vad: 182 self._external_vads[ext_vad].GetVadOutput() 183 for ext_vad in self._external_vads} 184 # pylint: disable=star-args 185 np.savez_compressed( 186 file=os.path.join( 187 output_path, 188 self.GetOutputFileNameTemplate().format(annotation_name)), 189 level=self._level, 190 level_frame_size=self._level_frame_size, 191 level_frame_size_ms=self._LEVEL_FRAME_SIZE_MS, 192 vad_output=self._common_audio_vad, 193 vad_energy_output=self._energy_vad, 194 vad_frame_size=self._vad_frame_size, 195 vad_frame_size_ms=self._vad_frame_size_ms, 196 vad_probs=self._apm_vad_probs, 197 vad_rms=self._apm_vad_rms, 198 **ext_kwargs 199 ) 200 201 def _LevelEstimation(self): 202 # Read samples. 203 samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData( 204 self._signal).astype(np.float32) / 32768.0 205 num_frames = len(samples) // self._level_frame_size 206 num_samples = num_frames * self._level_frame_size 207 208 # Envelope. 209 self._level = np.max(np.reshape(np.abs(samples[:num_samples]), ( 210 num_frames, self._level_frame_size)), axis=1) 211 assert len(self._level) == num_frames 212 213 # Envelope smoothing. 214 smooth = lambda curr, prev, k: (1 - k) * curr + k * prev 215 self._level[0] = smooth(self._level[0], 0.0, self._c_attack) 216 for i in range(1, num_frames): 217 self._level[i] = smooth( 218 self._level[i], self._level[i - 1], self._c_attack if ( 219 self._level[i] > self._level[i - 1]) else self._c_decay) 220 221 def _RunWebRtcCommonAudioVad(self, wav_file_path, sample_rate): 222 self._common_audio_vad = None 223 self._vad_frame_size = None 224 225 # Create temporary output path. 226 tmp_path = tempfile.mkdtemp() 227 output_file_path = os.path.join( 228 tmp_path, os.path.split(wav_file_path)[1] + '_vad.tmp') 229 230 # Call WebRTC VAD. 231 try: 232 subprocess.call([ 233 self._VAD_WEBRTC_COMMON_AUDIO_PATH, 234 '-i', wav_file_path, 235 '-o', output_file_path 236 ], cwd=self._VAD_WEBRTC_PATH) 237 238 # Read bytes. 239 with open(output_file_path, 'rb') as f: 240 raw_data = f.read() 241 242 # Parse side information. 243 self._vad_frame_size_ms = struct.unpack('B', raw_data[0])[0] 244 self._vad_frame_size = self._vad_frame_size_ms * sample_rate / 1000 245 assert self._vad_frame_size_ms in [10, 20, 30] 246 extra_bits = struct.unpack('B', raw_data[-1])[0] 247 assert 0 <= extra_bits <= 8 248 249 # Init VAD vector. 250 num_bytes = len(raw_data) 251 num_frames = 8 * (num_bytes - 2) - extra_bits # 8 frames for each byte. 252 self._common_audio_vad = np.zeros(num_frames, np.uint8) 253 254 # Read VAD decisions. 255 for i, byte in enumerate(raw_data[1:-1]): 256 byte = struct.unpack('B', byte)[0] 257 for j in range(8 if i < num_bytes - 3 else (8 - extra_bits)): 258 self._common_audio_vad[i * 8 + j] = int(byte & 1) 259 byte = byte >> 1 260 except Exception as e: 261 logging.error('Error while running the WebRTC VAD (' + e.message + ')') 262 finally: 263 if os.path.exists(tmp_path): 264 shutil.rmtree(tmp_path) 265 266 def _RunWebRtcApmVad(self, wav_file_path): 267 # Create temporary output path. 268 tmp_path = tempfile.mkdtemp() 269 output_file_path_probs = os.path.join( 270 tmp_path, os.path.split(wav_file_path)[1] + '_vad_probs.tmp') 271 output_file_path_rms = os.path.join( 272 tmp_path, os.path.split(wav_file_path)[1] + '_vad_rms.tmp') 273 274 # Call WebRTC VAD. 275 try: 276 subprocess.call([ 277 self._VAD_WEBRTC_APM_PATH, 278 '-i', wav_file_path, 279 '-o_probs', output_file_path_probs, 280 '-o_rms', output_file_path_rms 281 ], cwd=self._VAD_WEBRTC_PATH) 282 283 # Parse annotations. 284 self._apm_vad_probs = np.fromfile(output_file_path_probs, np.double) 285 self._apm_vad_rms = np.fromfile(output_file_path_rms, np.double) 286 assert len(self._apm_vad_rms) == len(self._apm_vad_probs) 287 288 except Exception as e: 289 logging.error('Error while running the WebRTC APM VAD (' + 290 e.message + ')') 291 finally: 292 if os.path.exists(tmp_path): 293 shutil.rmtree(tmp_path) 294