1 /* -*- c-basic-offset: 4; indent-tabs-mode: nil -*- */
2 /* ====================================================================
3  * Copyright (c) 2010 Carnegie Mellon University.  All rights
4  * reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  *
10  * 1. Redistributions of source code must retain the above copyright
11  *    notice, this list of conditions and the following disclaimer.
12  *
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in
15  *    the documentation and/or other materials provided with the
16  *    distribution.
17  *
18  * This work was supported in part by funding from the Defense Advanced
19  * Research Projects Agency and the National Science Foundation of the
20  * United States of America, and the CMU Sphinx Speech Consortium.
21  *
22  * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND
23  * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
24  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
25  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY
26  * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33  *
34  * ====================================================================
35  *
36  */
37 
38 /**
39  * @file state_align_search.c State (and phone and word) alignment search.
40  */
41 
42 #include "state_align_search.h"
43 
44 static int
45 state_align_search_start(ps_search_t *search)
46 {
47     state_align_search_t *sas = (state_align_search_t *)search;
48 
49     /* Activate the initial state. */
50     hmm_enter(sas->hmms, 0, 0, 0);
51 
52     return 0;
53 }
54 
55 static void
56 renormalize_hmms(state_align_search_t *sas, int frame_idx, int32 norm)
57 {
58     int i;
59     for (i = 0; i < sas->n_phones; ++i)
60         hmm_normalize(sas->hmms + i, norm);
61 }
62 
63 static int32
64 evaluate_hmms(state_align_search_t *sas, int16 const *senscr, int frame_idx)
65 {
66     int32 bs = WORST_SCORE;
67     int i;
68 
69     hmm_context_set_senscore(sas->hmmctx, senscr);
70 
71     for (i = 0; i < sas->n_phones; ++i) {
72         hmm_t *hmm = sas->hmms + i;
73         int32 score;
74 
75         if (hmm_frame(hmm) < frame_idx)
76             continue;
77         score = hmm_vit_eval(hmm);
78         if (score BETTER_THAN bs) {
79             bs = score;
80         }
81     }
82     return bs;
83 }
84 
85 static void
86 prune_hmms(state_align_search_t *sas, int frame_idx)
87 {
88     int nf = frame_idx + 1;
89     int i;
90 
91     /* Check all phones to see if they remain active in the next frame. */
92     for (i = 0; i < sas->n_phones; ++i) {
93         hmm_t *hmm = sas->hmms + i;
94         if (hmm_frame(hmm) < frame_idx)
95             continue;
96         hmm_frame(hmm) = nf;
97     }
98 }
99 
100 static void
101 phone_transition(state_align_search_t *sas, int frame_idx)
102 {
103     int nf = frame_idx + 1;
104     int i;
105 
106     for (i = 0; i < sas->n_phones - 1; ++i) {
107         hmm_t *hmm, *nhmm;
108         int32 newphone_score;
109 
110         hmm = sas->hmms + i;
111         if (hmm_frame(hmm) != nf)
112             continue;
113 
114         newphone_score = hmm_out_score(hmm);
115         /* Transition into next phone using the usual Viterbi rule. */
116         nhmm = hmm + 1;
117         if (hmm_frame(nhmm) < frame_idx
118             || newphone_score BETTER_THAN hmm_in_score(nhmm)) {
119             hmm_enter(nhmm, newphone_score, hmm_out_history(hmm), nf);
120         }
121     }
122 }
123 
124 #define TOKEN_STEP 20
125 static void
126 extend_tokenstack(state_align_search_t *sas, int frame_idx)
127 {
128     if (frame_idx >= sas->n_fr_alloc) {
129         sas->n_fr_alloc = frame_idx + TOKEN_STEP + 1;
130         sas->tokens = ckd_realloc(sas->tokens,
131                                   sas->n_emit_state * sas->n_fr_alloc
132                                   * sizeof(*sas->tokens));
133     }
134     memset(sas->tokens + frame_idx * sas->n_emit_state, 0xff,
135            sas->n_emit_state * sizeof(*sas->tokens));
136 }
137 
138 static void
139 record_transitions(state_align_search_t *sas, int frame_idx)
140 {
141     uint16 *tokens;
142     int i;
143 
144     /* Push another frame of tokens on the stack. */
145     extend_tokenstack(sas, frame_idx);
146     tokens = sas->tokens + frame_idx * sas->n_emit_state;
147 
148     /* Scan all active HMMs */
149     for (i = 0; i < sas->n_phones; ++i) {
150         hmm_t *hmm = sas->hmms + i;
151         int j;
152 
153         if (hmm_frame(hmm) < frame_idx)
154             continue;
155         for (j = 0; j < sas->hmmctx->n_emit_state; ++j) {
156             int state_idx = i * sas->hmmctx->n_emit_state + j;
157             /* Record their backpointers on the token stack. */
158             tokens[state_idx] = hmm_history(hmm, j);
159             /* Update backpointer fields with state index. */
160             hmm_history(hmm, j) = state_idx;
161         }
162     }
163 }
164 
165 static int
166 state_align_search_step(ps_search_t *search, int frame_idx)
167 {
168     state_align_search_t *sas = (state_align_search_t *)search;
169     acmod_t *acmod = ps_search_acmod(search);
170     int16 const *senscr;
171     int i;
172 
173     /* Calculate senone scores. */
174     for (i = 0; i < sas->n_phones; ++i)
175         acmod_activate_hmm(acmod, sas->hmms + i);
176     senscr = acmod_score(acmod, &frame_idx);
177 
178     /* Renormalize here if needed. */
179     /* FIXME: Make sure to (unit-)test this!!! */
180     if ((sas->best_score - 0x300000) WORSE_THAN WORST_SCORE) {
181         E_INFO("Renormalizing Scores at frame %d, best score %d\n",
182                frame_idx, sas->best_score);
183         renormalize_hmms(sas, frame_idx, sas->best_score);
184     }
185 
186     /* Viterbi step. */
187     sas->best_score = evaluate_hmms(sas, senscr, frame_idx);
188     prune_hmms(sas, frame_idx);
189 
190     /* Transition out of non-emitting states. */
191     phone_transition(sas, frame_idx);
192 
193     /* Generate new tokens from best path results. */
194     record_transitions(sas, frame_idx);
195 
196     /* Update frame counter */
197     sas->frame = frame_idx;
198 
199     return 0;
200 }
201 
202 static int
203 state_align_search_finish(ps_search_t *search)
204 {
205     state_align_search_t *sas = (state_align_search_t *)search;
206     hmm_t *final_phone = sas->hmms + sas->n_phones - 1;
207     ps_alignment_iter_t *itor;
208     ps_alignment_entry_t *ent;
209     int next_state, next_start, state, frame;
210 
211     /* Best state exiting the last frame. */
212     next_state = state = hmm_out_history(final_phone);
213     if (state == 0xffff) {
214         E_ERROR("Failed to reach final state in alignment\n");
215         return -1;
216     }
217     itor = ps_alignment_states(sas->al);
218     next_start = sas->frame + 1;
219     for (frame = sas->frame - 1; frame >= 0; --frame) {
220         state = sas->tokens[frame * sas->n_emit_state + state];
221         /* State boundary, update alignment entry for next state. */
222         if (state != next_state) {
223             itor = ps_alignment_iter_goto(itor, next_state);
224             assert(itor != NULL);
225             ent = ps_alignment_iter_get(itor);
226             ent->start = frame + 1;
227             ent->duration = next_start - ent->start;
228             E_DEBUG(1,("state %d start %d end %d\n", next_state,
229                        ent->start, next_start));
230             next_state = state;
231             next_start = frame + 1;
232         }
233     }
234     /* Update alignment entry for initial state. */
235     itor = ps_alignment_iter_goto(itor, 0);
236     assert(itor != NULL);
237     ent = ps_alignment_iter_get(itor);
238     ent->start = 0;
239     ent->duration = next_start;
240     E_DEBUG(1,("state %d start %d end %d\n", 0,
241                ent->start, next_start));
242     ps_alignment_iter_free(itor);
243     ps_alignment_propagate(sas->al);
244 
245     return 0;
246 }
247 
248 static int
249 state_align_search_reinit(ps_search_t *search, dict_t *dict, dict2pid_t *d2p)
250 {
251     /* This does nothing. */
252     return 0;
253 }
254 
255 static void
256 state_align_search_free(ps_search_t *search)
257 {
258     state_align_search_t *sas = (state_align_search_t *)search;
259     ps_search_deinit(search);
260     ckd_free(sas->hmms);
261     ckd_free(sas->tokens);
262     hmm_context_free(sas->hmmctx);
263     ckd_free(sas);
264 }
265 
266 static ps_searchfuncs_t state_align_search_funcs = {
267     /* name: */   "state_align",
268     /* start: */  state_align_search_start,
269     /* step: */   state_align_search_step,
270     /* finish: */ state_align_search_finish,
271     /* reinit: */ state_align_search_reinit,
272     /* free: */   state_align_search_free,
273     /* lattice: */  NULL,
274     /* hyp: */      NULL,
275     /* prob: */     NULL,
276     /* seg_iter: */ NULL,
277 };
278 
279 ps_search_t *
280 state_align_search_init(cmd_ln_t *config,
281                         acmod_t *acmod,
282                         ps_alignment_t *al)
283 {
284     state_align_search_t *sas;
285     ps_alignment_iter_t *itor;
286     hmm_t *hmm;
287 
288     sas = ckd_calloc(1, sizeof(*sas));
289     ps_search_init(ps_search_base(sas), &state_align_search_funcs,
290                    config, acmod, al->d2p->dict, al->d2p);
291     sas->hmmctx = hmm_context_init(bin_mdef_n_emit_state(acmod->mdef),
292                                    acmod->tmat->tp, NULL, acmod->mdef->sseq);
293     if (sas->hmmctx == NULL) {
294         ckd_free(sas);
295         return NULL;
296     }
297     sas->al = al;
298 
299     /* Generate HMM vector from phone level of alignment. */
300     sas->n_phones = ps_alignment_n_phones(al);
301     sas->n_emit_state = ps_alignment_n_states(al);
302     sas->hmms = ckd_calloc(sas->n_phones, sizeof(*sas->hmms));
303     for (hmm = sas->hmms, itor = ps_alignment_phones(al); itor;
304          ++hmm, itor = ps_alignment_iter_next(itor)) {
305         ps_alignment_entry_t *ent = ps_alignment_iter_get(itor);
306         hmm_init(sas->hmmctx, hmm, FALSE,
307                  ent->id.pid.ssid, ent->id.pid.tmatid);
308     }
309     return ps_search_base(sas);
310 }
311