PocketSphinx  0.6
src/libpocketsphinx/state_align_search.c
Go to the documentation of this file.
00001 /* -*- c-basic-offset: 4; indent-tabs-mode: nil -*- */
00002 /* ====================================================================
00003  * Copyright (c) 2010 Carnegie Mellon University.  All rights
00004  * reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions
00008  * are met:
00009  *
00010  * 1. Redistributions of source code must retain the above copyright
00011  *    notice, this list of conditions and the following disclaimer. 
00012  *
00013  * 2. Redistributions in binary form must reproduce the above copyright
00014  *    notice, this list of conditions and the following disclaimer in
00015  *    the documentation and/or other materials provided with the
00016  *    distribution.
00017  *
00018  * This work was supported in part by funding from the Defense Advanced 
00019  * Research Projects Agency and the National Science Foundation of the 
00020  * United States of America, and the CMU Sphinx Speech Consortium.
00021  *
00022  * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND 
00023  * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 
00024  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00025  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY
00026  * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
00027  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
00028  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 
00029  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 
00030  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
00031  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
00032  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00033  *
00034  * ====================================================================
00035  *
00036  */
00037 
00042 #include "state_align_search.h"
00043 
00044 static int
00045 state_align_search_start(ps_search_t *search)
00046 {
00047     state_align_search_t *sas = (state_align_search_t *)search;
00048 
00049     /* Activate the initial state. */
00050     hmm_enter(sas->hmms, 0, 0, 0);
00051 
00052     return 0;
00053 }
00054 
00055 static void
00056 renormalize_hmms(state_align_search_t *sas, int frame_idx, int32 norm)
00057 {
00058     int i;
00059     for (i = 0; i < sas->n_phones; ++i)
00060         hmm_normalize(sas->hmms + i, norm);
00061 }
00062 
00063 static int32
00064 evaluate_hmms(state_align_search_t *sas, int16 const *senscr, int frame_idx)
00065 {
00066     int32 bs = WORST_SCORE;
00067     int i, bi;
00068 
00069     hmm_context_set_senscore(sas->hmmctx, senscr);
00070 
00071     bi = 0;
00072     for (i = 0; i < sas->n_phones; ++i) {
00073         hmm_t *hmm = sas->hmms + i;
00074         int32 score;
00075 
00076         if (hmm_frame(hmm) < frame_idx)
00077             continue;
00078         score = hmm_vit_eval(hmm);
00079         if (score BETTER_THAN bs) {
00080             bs = score;
00081             bi = i;
00082         }
00083     }
00084     return bs;
00085 }
00086 
00087 static void
00088 prune_hmms(state_align_search_t *sas, int frame_idx)
00089 {
00090     int nf = frame_idx + 1;
00091     int i;
00092 
00093     /* Check all phones to see if they remain active in the next frame. */
00094     for (i = 0; i < sas->n_phones; ++i) {
00095         hmm_t *hmm = sas->hmms + i;
00096         if (hmm_frame(hmm) < frame_idx)
00097             continue;
00098         hmm_frame(hmm) = nf;
00099     }
00100 }
00101 
00102 static void
00103 phone_transition(state_align_search_t *sas, int frame_idx)
00104 {
00105     int nf = frame_idx + 1;
00106     int i;
00107 
00108     for (i = 0; i < sas->n_phones - 1; ++i) {
00109         hmm_t *hmm, *nhmm;
00110         int32 newphone_score;
00111 
00112         hmm = sas->hmms + i;
00113         if (hmm_frame(hmm) != nf)
00114             continue;
00115 
00116         newphone_score = hmm_out_score(hmm);
00117         /* Transition into next phone using the usual Viterbi rule. */
00118         nhmm = hmm + 1;
00119         if (hmm_frame(nhmm) < frame_idx
00120             || newphone_score BETTER_THAN hmm_in_score(nhmm)) {
00121             hmm_enter(nhmm, newphone_score, hmm_out_history(hmm), nf);
00122         }
00123     }
00124 }
00125 
00126 #define TOKEN_STEP 20
00127 static void
00128 extend_tokenstack(state_align_search_t *sas, int frame_idx)
00129 {
00130     if (frame_idx >= sas->n_fr_alloc) {
00131         sas->n_fr_alloc = frame_idx + TOKEN_STEP + 1;
00132         sas->tokens = ckd_realloc(sas->tokens,
00133                                   sas->n_emit_state * sas->n_fr_alloc
00134                                   * sizeof(*sas->tokens));
00135     }
00136     memset(sas->tokens + frame_idx * sas->n_emit_state, 0xff,
00137            sas->n_emit_state * sizeof(*sas->tokens));
00138 }
00139 
00140 static void
00141 record_transitions(state_align_search_t *sas, int frame_idx)
00142 {
00143     uint16 *tokens;
00144     int i;
00145 
00146     /* Push another frame of tokens on the stack. */
00147     extend_tokenstack(sas, frame_idx);
00148     tokens = sas->tokens + frame_idx * sas->n_emit_state;
00149 
00150     /* Scan all active HMMs */
00151     for (i = 0; i < sas->n_phones; ++i) {
00152         hmm_t *hmm = sas->hmms + i;
00153         int j;
00154 
00155         if (hmm_frame(hmm) < frame_idx)
00156             continue;
00157         for (j = 0; j < sas->hmmctx->n_emit_state; ++j) {
00158             int state_idx = i * sas->hmmctx->n_emit_state + j;
00159             /* Record their backpointers on the token stack. */
00160             tokens[state_idx] = hmm_history(hmm, j);
00161             /* Update backpointer fields with state index. */
00162             hmm_history(hmm, j) = state_idx;
00163         }
00164     }
00165 }
00166 
00167 static int
00168 state_align_search_step(ps_search_t *search, int frame_idx)
00169 {
00170     state_align_search_t *sas = (state_align_search_t *)search;
00171     acmod_t *acmod = ps_search_acmod(search);
00172     int16 const *senscr;
00173     int i;
00174 
00175     /* Calculate senone scores. */
00176     for (i = 0; i < sas->n_phones; ++i)
00177         acmod_activate_hmm(acmod, sas->hmms + i);
00178     senscr = acmod_score(acmod, &frame_idx);
00179 
00180     /* Renormalize here if needed. */
00181     /* FIXME: Make sure to (unit-)test this!!! */
00182     if ((sas->best_score - 0x300000) WORSE_THAN WORST_SCORE) {
00183         E_INFO("Renormalizing Scores at frame %d, best score %d\n",
00184                frame_idx, sas->best_score);
00185         renormalize_hmms(sas, frame_idx, sas->best_score);
00186     }
00187     
00188     /* Viterbi step. */
00189     sas->best_score = evaluate_hmms(sas, senscr, frame_idx);
00190     prune_hmms(sas, frame_idx);
00191 
00192     /* Transition out of non-emitting states. */
00193     phone_transition(sas, frame_idx);
00194 
00195     /* Generate new tokens from best path results. */
00196     record_transitions(sas, frame_idx);
00197 
00198     /* Update frame counter */
00199     sas->frame = frame_idx;
00200 
00201     return 0;
00202 }
00203 
00204 static int
00205 state_align_search_finish(ps_search_t *search)
00206 {
00207     state_align_search_t *sas = (state_align_search_t *)search;
00208     hmm_t *final_phone = sas->hmms + sas->n_phones - 1;
00209     ps_alignment_iter_t *itor;
00210     ps_alignment_entry_t *ent;
00211     int next_state, next_start, state, frame;
00212 
00213     /* Best state exiting the last frame. */
00214     next_state = state = hmm_out_history(final_phone);
00215     if (state == 0xffff) {
00216         E_ERROR("Failed to reach final state in alignment\n");
00217         return -1;
00218     }
00219     itor = ps_alignment_states(sas->al);
00220     next_start = sas->frame + 1;
00221     for (frame = sas->frame - 1; frame >= 0; --frame) {
00222         state = sas->tokens[frame * sas->n_emit_state + state];
00223         /* State boundary, update alignment entry for next state. */
00224         if (state != next_state) {
00225             itor = ps_alignment_iter_goto(itor, next_state);
00226             assert(itor != NULL);
00227             ent = ps_alignment_iter_get(itor);
00228             ent->start = frame + 1;
00229             ent->duration = next_start - ent->start;
00230             E_DEBUG(1,("state %d start %d end %d\n", next_state,
00231                        ent->start, next_start));
00232             next_state = state;
00233             next_start = frame + 1;
00234         }
00235     }
00236     /* Update alignment entry for initial state. */
00237     itor = ps_alignment_iter_goto(itor, 0);
00238     assert(itor != NULL);
00239     ent = ps_alignment_iter_get(itor);
00240     ent->start = 0;
00241     ent->duration = next_start;
00242     E_DEBUG(1,("state %d start %d end %d\n", 0,
00243                ent->start, next_start));
00244     ps_alignment_iter_free(itor);
00245     ps_alignment_propagate(sas->al);
00246 
00247     return 0;
00248 }
00249 
00250 static int
00251 state_align_search_reinit(ps_search_t *search, dict_t *dict, dict2pid_t *d2p)
00252 {
00253     /* This does nothing. */
00254     return 0;
00255 }
00256 
00257 static void
00258 state_align_search_free(ps_search_t *search)
00259 {
00260     state_align_search_t *sas = (state_align_search_t *)search;
00261     ps_search_deinit(search);
00262     ckd_free(sas->hmms);
00263     ckd_free(sas->tokens);
00264     hmm_context_free(sas->hmmctx);
00265     ckd_free(sas);
00266 }
00267 
00268 static ps_searchfuncs_t state_align_search_funcs = {
00269     /* name: */   "state_align",
00270     /* start: */  state_align_search_start,
00271     /* step: */   state_align_search_step,
00272     /* finish: */ state_align_search_finish,
00273     /* reinit: */ state_align_search_reinit,
00274     /* free: */   state_align_search_free,
00275     /* lattice: */  NULL,
00276     /* hyp: */      NULL,
00277     /* prob: */     NULL,
00278     /* seg_iter: */ NULL,
00279 };
00280 
00281 ps_search_t *
00282 state_align_search_init(cmd_ln_t *config,
00283                         acmod_t *acmod,
00284                         ps_alignment_t *al)
00285 {
00286     state_align_search_t *sas;
00287     ps_alignment_iter_t *itor;
00288     hmm_t *hmm;
00289 
00290     sas = ckd_calloc(1, sizeof(*sas));
00291     ps_search_init(ps_search_base(sas), &state_align_search_funcs,
00292                    config, acmod, al->d2p->dict, al->d2p);
00293     sas->hmmctx = hmm_context_init(bin_mdef_n_emit_state(acmod->mdef),
00294                                    acmod->tmat->tp, NULL, acmod->mdef->sseq);
00295     if (sas->hmmctx == NULL) {
00296         ckd_free(sas);
00297         return NULL;
00298     }
00299     sas->al = al;
00300 
00301     /* Generate HMM vector from phone level of alignment. */
00302     sas->n_phones = ps_alignment_n_phones(al);
00303     sas->n_emit_state = ps_alignment_n_states(al);
00304     sas->hmms = ckd_calloc(sas->n_phones, sizeof(*sas->hmms));
00305     for (hmm = sas->hmms, itor = ps_alignment_phones(al); itor;
00306          ++hmm, itor = ps_alignment_iter_next(itor)) {
00307         ps_alignment_entry_t *ent = ps_alignment_iter_get(itor);
00308         hmm_init(sas->hmmctx, hmm, FALSE,
00309                  ent->id.pid.ssid, ent->id.pid.tmatid);
00310     }
00311     return ps_search_base(sas);
00312 }