PocketSphinx
0.6
|
00001 /* -*- c-basic-offset: 4; indent-tabs-mode: nil -*- */ 00002 /* ==================================================================== 00003 * Copyright (c) 2010 Carnegie Mellon University. All rights 00004 * reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions 00008 * are met: 00009 * 00010 * 1. Redistributions of source code must retain the above copyright 00011 * notice, this list of conditions and the following disclaimer. 00012 * 00013 * 2. Redistributions in binary form must reproduce the above copyright 00014 * notice, this list of conditions and the following disclaimer in 00015 * the documentation and/or other materials provided with the 00016 * distribution. 00017 * 00018 * This work was supported in part by funding from the Defense Advanced 00019 * Research Projects Agency and the National Science Foundation of the 00020 * United States of America, and the CMU Sphinx Speech Consortium. 00021 * 00022 * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND 00023 * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 00024 * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 00025 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY 00026 * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 00027 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 00028 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 00029 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 00030 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00031 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 00032 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00033 * 00034 * ==================================================================== 00035 * 00036 */ 00037 00042 #include "state_align_search.h" 00043 00044 static int 00045 state_align_search_start(ps_search_t *search) 00046 { 00047 state_align_search_t *sas = (state_align_search_t *)search; 00048 00049 /* Activate the initial state. */ 00050 hmm_enter(sas->hmms, 0, 0, 0); 00051 00052 return 0; 00053 } 00054 00055 static void 00056 renormalize_hmms(state_align_search_t *sas, int frame_idx, int32 norm) 00057 { 00058 int i; 00059 for (i = 0; i < sas->n_phones; ++i) 00060 hmm_normalize(sas->hmms + i, norm); 00061 } 00062 00063 static int32 00064 evaluate_hmms(state_align_search_t *sas, int16 const *senscr, int frame_idx) 00065 { 00066 int32 bs = WORST_SCORE; 00067 int i, bi; 00068 00069 hmm_context_set_senscore(sas->hmmctx, senscr); 00070 00071 bi = 0; 00072 for (i = 0; i < sas->n_phones; ++i) { 00073 hmm_t *hmm = sas->hmms + i; 00074 int32 score; 00075 00076 if (hmm_frame(hmm) < frame_idx) 00077 continue; 00078 score = hmm_vit_eval(hmm); 00079 if (score BETTER_THAN bs) { 00080 bs = score; 00081 bi = i; 00082 } 00083 } 00084 return bs; 00085 } 00086 00087 static void 00088 prune_hmms(state_align_search_t *sas, int frame_idx) 00089 { 00090 int nf = frame_idx + 1; 00091 int i; 00092 00093 /* Check all phones to see if they remain active in the next frame. */ 00094 for (i = 0; i < sas->n_phones; ++i) { 00095 hmm_t *hmm = sas->hmms + i; 00096 if (hmm_frame(hmm) < frame_idx) 00097 continue; 00098 hmm_frame(hmm) = nf; 00099 } 00100 } 00101 00102 static void 00103 phone_transition(state_align_search_t *sas, int frame_idx) 00104 { 00105 int nf = frame_idx + 1; 00106 int i; 00107 00108 for (i = 0; i < sas->n_phones - 1; ++i) { 00109 hmm_t *hmm, *nhmm; 00110 int32 newphone_score; 00111 00112 hmm = sas->hmms + i; 00113 if (hmm_frame(hmm) != nf) 00114 continue; 00115 00116 newphone_score = hmm_out_score(hmm); 00117 /* Transition into next phone using the usual Viterbi rule. */ 00118 nhmm = hmm + 1; 00119 if (hmm_frame(nhmm) < frame_idx 00120 || newphone_score BETTER_THAN hmm_in_score(nhmm)) { 00121 hmm_enter(nhmm, newphone_score, hmm_out_history(hmm), nf); 00122 } 00123 } 00124 } 00125 00126 #define TOKEN_STEP 20 00127 static void 00128 extend_tokenstack(state_align_search_t *sas, int frame_idx) 00129 { 00130 if (frame_idx >= sas->n_fr_alloc) { 00131 sas->n_fr_alloc = frame_idx + TOKEN_STEP + 1; 00132 sas->tokens = ckd_realloc(sas->tokens, 00133 sas->n_emit_state * sas->n_fr_alloc 00134 * sizeof(*sas->tokens)); 00135 } 00136 memset(sas->tokens + frame_idx * sas->n_emit_state, 0xff, 00137 sas->n_emit_state * sizeof(*sas->tokens)); 00138 } 00139 00140 static void 00141 record_transitions(state_align_search_t *sas, int frame_idx) 00142 { 00143 uint16 *tokens; 00144 int i; 00145 00146 /* Push another frame of tokens on the stack. */ 00147 extend_tokenstack(sas, frame_idx); 00148 tokens = sas->tokens + frame_idx * sas->n_emit_state; 00149 00150 /* Scan all active HMMs */ 00151 for (i = 0; i < sas->n_phones; ++i) { 00152 hmm_t *hmm = sas->hmms + i; 00153 int j; 00154 00155 if (hmm_frame(hmm) < frame_idx) 00156 continue; 00157 for (j = 0; j < sas->hmmctx->n_emit_state; ++j) { 00158 int state_idx = i * sas->hmmctx->n_emit_state + j; 00159 /* Record their backpointers on the token stack. */ 00160 tokens[state_idx] = hmm_history(hmm, j); 00161 /* Update backpointer fields with state index. */ 00162 hmm_history(hmm, j) = state_idx; 00163 } 00164 } 00165 } 00166 00167 static int 00168 state_align_search_step(ps_search_t *search, int frame_idx) 00169 { 00170 state_align_search_t *sas = (state_align_search_t *)search; 00171 acmod_t *acmod = ps_search_acmod(search); 00172 int16 const *senscr; 00173 int i; 00174 00175 /* Calculate senone scores. */ 00176 for (i = 0; i < sas->n_phones; ++i) 00177 acmod_activate_hmm(acmod, sas->hmms + i); 00178 senscr = acmod_score(acmod, &frame_idx); 00179 00180 /* Renormalize here if needed. */ 00181 /* FIXME: Make sure to (unit-)test this!!! */ 00182 if ((sas->best_score - 0x300000) WORSE_THAN WORST_SCORE) { 00183 E_INFO("Renormalizing Scores at frame %d, best score %d\n", 00184 frame_idx, sas->best_score); 00185 renormalize_hmms(sas, frame_idx, sas->best_score); 00186 } 00187 00188 /* Viterbi step. */ 00189 sas->best_score = evaluate_hmms(sas, senscr, frame_idx); 00190 prune_hmms(sas, frame_idx); 00191 00192 /* Transition out of non-emitting states. */ 00193 phone_transition(sas, frame_idx); 00194 00195 /* Generate new tokens from best path results. */ 00196 record_transitions(sas, frame_idx); 00197 00198 /* Update frame counter */ 00199 sas->frame = frame_idx; 00200 00201 return 0; 00202 } 00203 00204 static int 00205 state_align_search_finish(ps_search_t *search) 00206 { 00207 state_align_search_t *sas = (state_align_search_t *)search; 00208 hmm_t *final_phone = sas->hmms + sas->n_phones - 1; 00209 ps_alignment_iter_t *itor; 00210 ps_alignment_entry_t *ent; 00211 int next_state, next_start, state, frame; 00212 00213 /* Best state exiting the last frame. */ 00214 next_state = state = hmm_out_history(final_phone); 00215 if (state == 0xffff) { 00216 E_ERROR("Failed to reach final state in alignment\n"); 00217 return -1; 00218 } 00219 itor = ps_alignment_states(sas->al); 00220 next_start = sas->frame + 1; 00221 for (frame = sas->frame - 1; frame >= 0; --frame) { 00222 state = sas->tokens[frame * sas->n_emit_state + state]; 00223 /* State boundary, update alignment entry for next state. */ 00224 if (state != next_state) { 00225 itor = ps_alignment_iter_goto(itor, next_state); 00226 assert(itor != NULL); 00227 ent = ps_alignment_iter_get(itor); 00228 ent->start = frame + 1; 00229 ent->duration = next_start - ent->start; 00230 E_DEBUG(1,("state %d start %d end %d\n", next_state, 00231 ent->start, next_start)); 00232 next_state = state; 00233 next_start = frame + 1; 00234 } 00235 } 00236 /* Update alignment entry for initial state. */ 00237 itor = ps_alignment_iter_goto(itor, 0); 00238 assert(itor != NULL); 00239 ent = ps_alignment_iter_get(itor); 00240 ent->start = 0; 00241 ent->duration = next_start; 00242 E_DEBUG(1,("state %d start %d end %d\n", 0, 00243 ent->start, next_start)); 00244 ps_alignment_iter_free(itor); 00245 ps_alignment_propagate(sas->al); 00246 00247 return 0; 00248 } 00249 00250 static int 00251 state_align_search_reinit(ps_search_t *search, dict_t *dict, dict2pid_t *d2p) 00252 { 00253 /* This does nothing. */ 00254 return 0; 00255 } 00256 00257 static void 00258 state_align_search_free(ps_search_t *search) 00259 { 00260 state_align_search_t *sas = (state_align_search_t *)search; 00261 ps_search_deinit(search); 00262 ckd_free(sas->hmms); 00263 ckd_free(sas->tokens); 00264 hmm_context_free(sas->hmmctx); 00265 ckd_free(sas); 00266 } 00267 00268 static ps_searchfuncs_t state_align_search_funcs = { 00269 /* name: */ "state_align", 00270 /* start: */ state_align_search_start, 00271 /* step: */ state_align_search_step, 00272 /* finish: */ state_align_search_finish, 00273 /* reinit: */ state_align_search_reinit, 00274 /* free: */ state_align_search_free, 00275 /* lattice: */ NULL, 00276 /* hyp: */ NULL, 00277 /* prob: */ NULL, 00278 /* seg_iter: */ NULL, 00279 }; 00280 00281 ps_search_t * 00282 state_align_search_init(cmd_ln_t *config, 00283 acmod_t *acmod, 00284 ps_alignment_t *al) 00285 { 00286 state_align_search_t *sas; 00287 ps_alignment_iter_t *itor; 00288 hmm_t *hmm; 00289 00290 sas = ckd_calloc(1, sizeof(*sas)); 00291 ps_search_init(ps_search_base(sas), &state_align_search_funcs, 00292 config, acmod, al->d2p->dict, al->d2p); 00293 sas->hmmctx = hmm_context_init(bin_mdef_n_emit_state(acmod->mdef), 00294 acmod->tmat->tp, NULL, acmod->mdef->sseq); 00295 if (sas->hmmctx == NULL) { 00296 ckd_free(sas); 00297 return NULL; 00298 } 00299 sas->al = al; 00300 00301 /* Generate HMM vector from phone level of alignment. */ 00302 sas->n_phones = ps_alignment_n_phones(al); 00303 sas->n_emit_state = ps_alignment_n_states(al); 00304 sas->hmms = ckd_calloc(sas->n_phones, sizeof(*sas->hmms)); 00305 for (hmm = sas->hmms, itor = ps_alignment_phones(al); itor; 00306 ++hmm, itor = ps_alignment_iter_next(itor)) { 00307 ps_alignment_entry_t *ent = ps_alignment_iter_get(itor); 00308 hmm_init(sas->hmmctx, hmm, FALSE, 00309 ent->id.pid.ssid, ent->id.pid.tmatid); 00310 } 00311 return ps_search_base(sas); 00312 }