1 /* -*- c-basic-offset: 4; indent-tabs-mode: nil -*- */ 2 /* ==================================================================== 3 * Copyright (c) 2010 Carnegie Mellon University. All rights 4 * reserved. 5 * 6 * Redistribution and use in source and binary forms, with or without 7 * modification, are permitted provided that the following conditions 8 * are met: 9 * 10 * 1. Redistributions of source code must retain the above copyright 11 * notice, this list of conditions and the following disclaimer. 12 * 13 * 2. Redistributions in binary form must reproduce the above copyright 14 * notice, this list of conditions and the following disclaimer in 15 * the documentation and/or other materials provided with the 16 * distribution. 17 * 18 * This work was supported in part by funding from the Defense Advanced 19 * Research Projects Agency and the National Science Foundation of the 20 * United States of America, and the CMU Sphinx Speech Consortium. 21 * 22 * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND 23 * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 24 * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 25 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY 26 * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 27 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 28 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 30 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 31 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 * 34 * ==================================================================== 35 * 36 */ 37 38 /** 39 * @file state_align_search.c State (and phone and word) alignment search. 40 */ 41 42 #include "state_align_search.h" 43 44 static int 45 state_align_search_start(ps_search_t *search) 46 { 47 state_align_search_t *sas = (state_align_search_t *)search; 48 49 /* Activate the initial state. */ 50 hmm_enter(sas->hmms, 0, 0, 0); 51 52 return 0; 53 } 54 55 static void 56 renormalize_hmms(state_align_search_t *sas, int frame_idx, int32 norm) 57 { 58 int i; 59 for (i = 0; i < sas->n_phones; ++i) 60 hmm_normalize(sas->hmms + i, norm); 61 } 62 63 static int32 64 evaluate_hmms(state_align_search_t *sas, int16 const *senscr, int frame_idx) 65 { 66 int32 bs = WORST_SCORE; 67 int i; 68 69 hmm_context_set_senscore(sas->hmmctx, senscr); 70 71 for (i = 0; i < sas->n_phones; ++i) { 72 hmm_t *hmm = sas->hmms + i; 73 int32 score; 74 75 if (hmm_frame(hmm) < frame_idx) 76 continue; 77 score = hmm_vit_eval(hmm); 78 if (score BETTER_THAN bs) { 79 bs = score; 80 } 81 } 82 return bs; 83 } 84 85 static void 86 prune_hmms(state_align_search_t *sas, int frame_idx) 87 { 88 int nf = frame_idx + 1; 89 int i; 90 91 /* Check all phones to see if they remain active in the next frame. */ 92 for (i = 0; i < sas->n_phones; ++i) { 93 hmm_t *hmm = sas->hmms + i; 94 if (hmm_frame(hmm) < frame_idx) 95 continue; 96 hmm_frame(hmm) = nf; 97 } 98 } 99 100 static void 101 phone_transition(state_align_search_t *sas, int frame_idx) 102 { 103 int nf = frame_idx + 1; 104 int i; 105 106 for (i = 0; i < sas->n_phones - 1; ++i) { 107 hmm_t *hmm, *nhmm; 108 int32 newphone_score; 109 110 hmm = sas->hmms + i; 111 if (hmm_frame(hmm) != nf) 112 continue; 113 114 newphone_score = hmm_out_score(hmm); 115 /* Transition into next phone using the usual Viterbi rule. */ 116 nhmm = hmm + 1; 117 if (hmm_frame(nhmm) < frame_idx 118 || newphone_score BETTER_THAN hmm_in_score(nhmm)) { 119 hmm_enter(nhmm, newphone_score, hmm_out_history(hmm), nf); 120 } 121 } 122 } 123 124 #define TOKEN_STEP 20 125 static void 126 extend_tokenstack(state_align_search_t *sas, int frame_idx) 127 { 128 if (frame_idx >= sas->n_fr_alloc) { 129 sas->n_fr_alloc = frame_idx + TOKEN_STEP + 1; 130 sas->tokens = ckd_realloc(sas->tokens, 131 sas->n_emit_state * sas->n_fr_alloc 132 * sizeof(*sas->tokens)); 133 } 134 memset(sas->tokens + frame_idx * sas->n_emit_state, 0xff, 135 sas->n_emit_state * sizeof(*sas->tokens)); 136 } 137 138 static void 139 record_transitions(state_align_search_t *sas, int frame_idx) 140 { 141 uint16 *tokens; 142 int i; 143 144 /* Push another frame of tokens on the stack. */ 145 extend_tokenstack(sas, frame_idx); 146 tokens = sas->tokens + frame_idx * sas->n_emit_state; 147 148 /* Scan all active HMMs */ 149 for (i = 0; i < sas->n_phones; ++i) { 150 hmm_t *hmm = sas->hmms + i; 151 int j; 152 153 if (hmm_frame(hmm) < frame_idx) 154 continue; 155 for (j = 0; j < sas->hmmctx->n_emit_state; ++j) { 156 int state_idx = i * sas->hmmctx->n_emit_state + j; 157 /* Record their backpointers on the token stack. */ 158 tokens[state_idx] = hmm_history(hmm, j); 159 /* Update backpointer fields with state index. */ 160 hmm_history(hmm, j) = state_idx; 161 } 162 } 163 } 164 165 static int 166 state_align_search_step(ps_search_t *search, int frame_idx) 167 { 168 state_align_search_t *sas = (state_align_search_t *)search; 169 acmod_t *acmod = ps_search_acmod(search); 170 int16 const *senscr; 171 int i; 172 173 /* Calculate senone scores. */ 174 for (i = 0; i < sas->n_phones; ++i) 175 acmod_activate_hmm(acmod, sas->hmms + i); 176 senscr = acmod_score(acmod, &frame_idx); 177 178 /* Renormalize here if needed. */ 179 /* FIXME: Make sure to (unit-)test this!!! */ 180 if ((sas->best_score - 0x300000) WORSE_THAN WORST_SCORE) { 181 E_INFO("Renormalizing Scores at frame %d, best score %d\n", 182 frame_idx, sas->best_score); 183 renormalize_hmms(sas, frame_idx, sas->best_score); 184 } 185 186 /* Viterbi step. */ 187 sas->best_score = evaluate_hmms(sas, senscr, frame_idx); 188 prune_hmms(sas, frame_idx); 189 190 /* Transition out of non-emitting states. */ 191 phone_transition(sas, frame_idx); 192 193 /* Generate new tokens from best path results. */ 194 record_transitions(sas, frame_idx); 195 196 /* Update frame counter */ 197 sas->frame = frame_idx; 198 199 return 0; 200 } 201 202 static int 203 state_align_search_finish(ps_search_t *search) 204 { 205 state_align_search_t *sas = (state_align_search_t *)search; 206 hmm_t *final_phone = sas->hmms + sas->n_phones - 1; 207 ps_alignment_iter_t *itor; 208 ps_alignment_entry_t *ent; 209 int next_state, next_start, state, frame; 210 211 /* Best state exiting the last frame. */ 212 next_state = state = hmm_out_history(final_phone); 213 if (state == 0xffff) { 214 E_ERROR("Failed to reach final state in alignment\n"); 215 return -1; 216 } 217 itor = ps_alignment_states(sas->al); 218 next_start = sas->frame + 1; 219 for (frame = sas->frame - 1; frame >= 0; --frame) { 220 state = sas->tokens[frame * sas->n_emit_state + state]; 221 /* State boundary, update alignment entry for next state. */ 222 if (state != next_state) { 223 itor = ps_alignment_iter_goto(itor, next_state); 224 assert(itor != NULL); 225 ent = ps_alignment_iter_get(itor); 226 ent->start = frame + 1; 227 ent->duration = next_start - ent->start; 228 E_DEBUG(1,("state %d start %d end %d\n", next_state, 229 ent->start, next_start)); 230 next_state = state; 231 next_start = frame + 1; 232 } 233 } 234 /* Update alignment entry for initial state. */ 235 itor = ps_alignment_iter_goto(itor, 0); 236 assert(itor != NULL); 237 ent = ps_alignment_iter_get(itor); 238 ent->start = 0; 239 ent->duration = next_start; 240 E_DEBUG(1,("state %d start %d end %d\n", 0, 241 ent->start, next_start)); 242 ps_alignment_iter_free(itor); 243 ps_alignment_propagate(sas->al); 244 245 return 0; 246 } 247 248 static int 249 state_align_search_reinit(ps_search_t *search, dict_t *dict, dict2pid_t *d2p) 250 { 251 /* This does nothing. */ 252 return 0; 253 } 254 255 static void 256 state_align_search_free(ps_search_t *search) 257 { 258 state_align_search_t *sas = (state_align_search_t *)search; 259 ps_search_deinit(search); 260 ckd_free(sas->hmms); 261 ckd_free(sas->tokens); 262 hmm_context_free(sas->hmmctx); 263 ckd_free(sas); 264 } 265 266 static ps_searchfuncs_t state_align_search_funcs = { 267 /* name: */ "state_align", 268 /* start: */ state_align_search_start, 269 /* step: */ state_align_search_step, 270 /* finish: */ state_align_search_finish, 271 /* reinit: */ state_align_search_reinit, 272 /* free: */ state_align_search_free, 273 /* lattice: */ NULL, 274 /* hyp: */ NULL, 275 /* prob: */ NULL, 276 /* seg_iter: */ NULL, 277 }; 278 279 ps_search_t * 280 state_align_search_init(cmd_ln_t *config, 281 acmod_t *acmod, 282 ps_alignment_t *al) 283 { 284 state_align_search_t *sas; 285 ps_alignment_iter_t *itor; 286 hmm_t *hmm; 287 288 sas = ckd_calloc(1, sizeof(*sas)); 289 ps_search_init(ps_search_base(sas), &state_align_search_funcs, 290 config, acmod, al->d2p->dict, al->d2p); 291 sas->hmmctx = hmm_context_init(bin_mdef_n_emit_state(acmod->mdef), 292 acmod->tmat->tp, NULL, acmod->mdef->sseq); 293 if (sas->hmmctx == NULL) { 294 ckd_free(sas); 295 return NULL; 296 } 297 sas->al = al; 298 299 /* Generate HMM vector from phone level of alignment. */ 300 sas->n_phones = ps_alignment_n_phones(al); 301 sas->n_emit_state = ps_alignment_n_states(al); 302 sas->hmms = ckd_calloc(sas->n_phones, sizeof(*sas->hmms)); 303 for (hmm = sas->hmms, itor = ps_alignment_phones(al); itor; 304 ++hmm, itor = ps_alignment_iter_next(itor)) { 305 ps_alignment_entry_t *ent = ps_alignment_iter_get(itor); 306 hmm_init(sas->hmmctx, hmm, FALSE, 307 ent->id.pid.ssid, ent->id.pid.tmatid); 308 } 309 return ps_search_base(sas); 310 } 311