PocketSphinx
0.6
|
00001 /* -*- c-basic-offset: 4; indent-tabs-mode: nil -*- */ 00002 /* ==================================================================== 00003 * Copyright (c) 2008 Carnegie Mellon University. All rights 00004 * reserved. 00005 * 00006 * Redistribution and use in source and binary forms, with or without 00007 * modification, are permitted provided that the following conditions 00008 * are met: 00009 * 00010 * 1. Redistributions of source code must retain the above copyright 00011 * notice, this list of conditions and the following disclaimer. 00012 * 00013 * 2. Redistributions in binary form must reproduce the above copyright 00014 * notice, this list of conditions and the following disclaimer in 00015 * the documentation and/or other materials provided with the 00016 * distribution. 00017 * 00018 * This work was supported in part by funding from the Defense Advanced 00019 * Research Projects Agency and the National Science Foundation of the 00020 * United States of America, and the CMU Sphinx Speech Consortium. 00021 * 00022 * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND 00023 * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 00024 * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 00025 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY 00026 * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 00027 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 00028 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 00029 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 00030 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00031 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 00032 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00033 * 00034 * ==================================================================== 00035 * 00036 */ 00037 00042 /* System headers. */ 00043 #include <assert.h> 00044 #include <string.h> 00045 #include <math.h> 00046 00047 /* SphinxBase headers. */ 00048 #include <sphinxbase/ckd_alloc.h> 00049 #include <sphinxbase/listelem_alloc.h> 00050 #include <sphinxbase/strfuncs.h> 00051 #include <sphinxbase/err.h> 00052 #include <sphinxbase/pio.h> 00053 00054 /* Local headers. */ 00055 #include "pocketsphinx_internal.h" 00056 #include "ps_lattice_internal.h" 00057 #include "ngram_search.h" 00058 #include "dict.h" 00059 00060 /* 00061 * Create a directed link between "from" and "to" nodes, but if a link already exists, 00062 * choose one with the best ascr. 00063 */ 00064 void 00065 ps_lattice_link(ps_lattice_t *dag, ps_latnode_t *from, ps_latnode_t *to, 00066 int32 score, int32 ef) 00067 { 00068 latlink_list_t *fwdlink; 00069 00070 /* Look for an existing link between "from" and "to" nodes */ 00071 for (fwdlink = from->exits; fwdlink; fwdlink = fwdlink->next) 00072 if (fwdlink->link->to == to) 00073 break; 00074 00075 if (fwdlink == NULL) { 00076 latlink_list_t *revlink; 00077 ps_latlink_t *link; 00078 00079 /* No link between the two nodes; create a new one */ 00080 link = listelem_malloc(dag->latlink_alloc); 00081 fwdlink = listelem_malloc(dag->latlink_list_alloc); 00082 revlink = listelem_malloc(dag->latlink_list_alloc); 00083 00084 link->from = from; 00085 link->to = to; 00086 link->ascr = score; 00087 link->ef = ef; 00088 link->best_prev = NULL; 00089 00090 fwdlink->link = revlink->link = link; 00091 fwdlink->next = from->exits; 00092 from->exits = fwdlink; 00093 revlink->next = to->entries; 00094 to->entries = revlink; 00095 } 00096 else { 00097 /* Link already exists; just retain the best ascr */ 00098 if (score BETTER_THAN fwdlink->link->ascr) { 00099 fwdlink->link->ascr = score; 00100 fwdlink->link->ef = ef; 00101 } 00102 } 00103 } 00104 00105 void 00106 ps_lattice_bypass_fillers(ps_lattice_t *dag, int32 silpen, int32 fillpen) 00107 { 00108 ps_latnode_t *node; 00109 int32 score; 00110 00111 /* Bypass filler nodes */ 00112 for (node = dag->nodes; node; node = node->next) { 00113 latlink_list_t *revlink; 00114 if (node == dag->end || !dict_filler_word(dag->dict, node->basewid)) 00115 continue; 00116 00117 /* Replace each link entering filler node with links to all its successors */ 00118 for (revlink = node->entries; revlink; revlink = revlink->next) { 00119 latlink_list_t *forlink; 00120 ps_latlink_t *rlink = revlink->link; 00121 00122 score = (node->basewid == dag->silence) ? silpen : fillpen; 00123 score += rlink->ascr; 00124 /* 00125 * Make links from predecessor of filler (from) to successors of filler. 00126 * But if successor is a filler, it has already been eliminated since it 00127 * appears earlier in latnode_list (see build...). So it can be skipped. 00128 */ 00129 for (forlink = node->exits; forlink; forlink = forlink->next) { 00130 ps_latlink_t *flink = forlink->link; 00131 if (flink->to && rlink->from && 00132 !dict_filler_word(dag->dict, flink->to->basewid)) { 00133 ps_lattice_link(dag, rlink->from, flink->to, 00134 score + flink->ascr, flink->ef); 00135 } 00136 } 00137 } 00138 node->reachable = FALSE; 00139 } 00140 } 00141 00142 static void 00143 delete_node(ps_lattice_t *dag, ps_latnode_t *node) 00144 { 00145 latlink_list_t *x, *next_x; 00146 00147 for (x = node->exits; x; x = next_x) { 00148 next_x = x->next; 00149 x->link->from = NULL; 00150 listelem_free(dag->latlink_list_alloc, x); 00151 } 00152 for (x = node->entries; x; x = next_x) { 00153 next_x = x->next; 00154 x->link->to = NULL; 00155 listelem_free(dag->latlink_list_alloc, x); 00156 } 00157 listelem_free(dag->latnode_alloc, node); 00158 } 00159 00160 00161 static void 00162 remove_dangling_links(ps_lattice_t *dag, ps_latnode_t *node) 00163 { 00164 latlink_list_t *x, *prev_x, *next_x; 00165 00166 prev_x = NULL; 00167 for (x = node->exits; x; x = next_x) { 00168 next_x = x->next; 00169 if (x->link->to == NULL) { 00170 if (prev_x) 00171 prev_x->next = next_x; 00172 else 00173 node->exits = next_x; 00174 listelem_free(dag->latlink_alloc, x->link); 00175 listelem_free(dag->latlink_list_alloc, x); 00176 } 00177 else 00178 prev_x = x; 00179 } 00180 prev_x = NULL; 00181 for (x = node->entries; x; x = next_x) { 00182 next_x = x->next; 00183 if (x->link->from == NULL) { 00184 if (prev_x) 00185 prev_x->next = next_x; 00186 else 00187 node->entries = next_x; 00188 listelem_free(dag->latlink_alloc, x->link); 00189 listelem_free(dag->latlink_list_alloc, x); 00190 } 00191 else 00192 prev_x = x; 00193 } 00194 } 00195 00196 void 00197 ps_lattice_delete_unreachable(ps_lattice_t *dag) 00198 { 00199 ps_latnode_t *node, *prev_node, *next_node; 00200 int i; 00201 00202 /* Remove unreachable nodes from the list of nodes. */ 00203 prev_node = NULL; 00204 for (node = dag->nodes; node; node = next_node) { 00205 next_node = node->next; 00206 if (!node->reachable) { 00207 if (prev_node) 00208 prev_node->next = next_node; 00209 else 00210 dag->nodes = next_node; 00211 /* Delete this node and NULLify links to it. */ 00212 delete_node(dag, node); 00213 } 00214 else 00215 prev_node = node; 00216 } 00217 00218 /* Remove all links to and from unreachable nodes. */ 00219 i = 0; 00220 for (node = dag->nodes; node; node = node->next) { 00221 /* Assign sequence numbers. */ 00222 node->id = i++; 00223 00224 /* We should obviously not encounter unreachable nodes here! */ 00225 assert(node->reachable); 00226 00227 /* Remove all links that go nowhere. */ 00228 remove_dangling_links(dag, node); 00229 } 00230 } 00231 00232 int32 00233 ps_lattice_write(ps_lattice_t *dag, char const *filename) 00234 { 00235 FILE *fp; 00236 int32 i; 00237 ps_latnode_t *d, *initial, *final; 00238 00239 initial = dag->start; 00240 final = dag->end; 00241 00242 E_INFO("Writing lattice file: %s\n", filename); 00243 if ((fp = fopen(filename, "w")) == NULL) { 00244 E_ERROR("Failed to open lattice file '%s' for writing: %s\n", filename, strerror(errno)); 00245 return -1; 00246 } 00247 00248 /* Stupid Sphinx-III lattice code expects 'getcwd:' here */ 00249 fprintf(fp, "# getcwd: /this/is/bogus\n"); 00250 fprintf(fp, "# -logbase %e\n", logmath_get_base(dag->lmath)); 00251 fprintf(fp, "#\n"); 00252 00253 fprintf(fp, "Frames %d\n", dag->n_frames); 00254 fprintf(fp, "#\n"); 00255 00256 for (i = 0, d = dag->nodes; d; d = d->next, i++); 00257 fprintf(fp, 00258 "Nodes %d (NODEID WORD STARTFRAME FIRST-ENDFRAME LAST-ENDFRAME)\n", 00259 i); 00260 for (i = 0, d = dag->nodes; d; d = d->next, i++) { 00261 d->id = i; 00262 fprintf(fp, "%d %s %d %d %d\n", 00263 i, dict_wordstr(dag->dict, d->wid), 00264 d->sf, d->fef, d->lef); 00265 } 00266 fprintf(fp, "#\n"); 00267 00268 fprintf(fp, "Initial %d\nFinal %d\n", initial->id, final->id); 00269 fprintf(fp, "#\n"); 00270 00271 /* Don't bother with this, it's not used by anything. */ 00272 fprintf(fp, "BestSegAscr %d (NODEID ENDFRAME ASCORE)\n", 00273 0 /* #BPTable entries */ ); 00274 fprintf(fp, "#\n"); 00275 00276 fprintf(fp, "Edges (FROM-NODEID TO-NODEID ASCORE)\n"); 00277 for (d = dag->nodes; d; d = d->next) { 00278 latlink_list_t *l; 00279 for (l = d->exits; l; l = l->next) { 00280 if (l->link->ascr WORSE_THAN WORST_SCORE || l->link->ascr BETTER_THAN 0) 00281 continue; 00282 fprintf(fp, "%d %d %d\n", 00283 d->id, l->link->to->id, l->link->ascr << SENSCR_SHIFT); 00284 } 00285 } 00286 fprintf(fp, "End\n"); 00287 fclose(fp); 00288 00289 return 0; 00290 } 00291 00292 int32 00293 ps_lattice_write_htk(ps_lattice_t *dag, char const *filename) 00294 { 00295 FILE *fp; 00296 ps_latnode_t *d, *initial, *final; 00297 int32 i, j, n_links, n_nodes; 00298 00299 initial = dag->start; 00300 final = dag->end; 00301 00302 E_INFO("Writing lattice file: %s\n", filename); 00303 if ((fp = fopen(filename, "w")) == NULL) { 00304 E_ERROR("Failed to open lattice file '%s' for writing: %s\n", filename, strerror(errno)); 00305 return -1; 00306 } 00307 00308 for (n_links = n_nodes = 0, d = dag->nodes; d; d = d->next) { 00309 latlink_list_t *l; 00310 if (!d->reachable) 00311 continue; 00312 d->id = n_nodes; 00313 for (l = d->exits; l; l = l->next) { 00314 if (l->link->to == NULL || !l->link->to->reachable) 00315 continue; 00316 if (l->link->ascr WORSE_THAN WORST_SCORE || l->link->ascr BETTER_THAN 0) 00317 continue; 00318 ++n_links; 00319 } 00320 ++n_nodes; 00321 } 00322 00323 fprintf(fp, "# Lattice generated by PocketSphinx\n"); 00324 fprintf(fp, "#\n# Header\n#\n"); 00325 fprintf(fp, "VERSION=1.0\n"); 00326 fprintf(fp, "start=%d\n", initial->id); 00327 fprintf(fp, "end=%d\n", final->id); 00328 fprintf(fp, "#\n"); 00329 00330 fprintf(fp, "N=%d\tL=%d\n", n_nodes, n_links); 00331 fprintf(fp, "#\n# Node definitions\n#\n"); 00332 for (i = 0, d = dag->nodes; d; d = d->next) { 00333 char const *word = dict_wordstr(dag->dict, d->wid); 00334 char const *c = strrchr(word, '('); 00335 int altpron = 1; 00336 if (!d->reachable) 00337 continue; 00338 if (c) 00339 altpron = atoi(c + 1); 00340 word = dict_basestr(dag->dict, d->wid); 00341 if (d->wid == dict_startwid(dag->dict)) 00342 word = "!SENT_START"; 00343 else if (d->wid == dict_finishwid(dag->dict)) 00344 word = "!SENT_END"; 00345 else if (dict_filler_word(dag->dict, d->wid)) 00346 word = "!NULL"; 00347 fprintf(fp, "I=%d\tt=%.2f\tW=%s\tv=%d\n", 00348 d->id, (double)d->sf / dag->frate, 00349 word, altpron); 00350 } 00351 fprintf(fp, "#\n# Link definitions\n#\n"); 00352 for (j = 0, d = dag->nodes; d; d = d->next) { 00353 latlink_list_t *l; 00354 if (!d->reachable) 00355 continue; 00356 for (l = d->exits; l; l = l->next) { 00357 if (l->link->to == NULL || !l->link->to->reachable) 00358 continue; 00359 if (l->link->ascr WORSE_THAN WORST_SCORE || l->link->ascr BETTER_THAN 0) 00360 continue; 00361 fprintf(fp, "J=%d\tS=%d\tE=%d\ta=%f\tp=%g\n", j++, 00362 d->id, l->link->to->id, 00363 logmath_log_to_ln(dag->lmath, l->link->ascr << SENSCR_SHIFT), 00364 logmath_exp(dag->lmath, l->link->alpha + l->link->beta - dag->norm)); 00365 } 00366 } 00367 fclose(fp); 00368 00369 return 0; 00370 } 00371 00372 /* Read parameter from a lattice file*/ 00373 static int 00374 dag_param_read(lineiter_t *li, char *param) 00375 { 00376 int32 n; 00377 00378 while ((li = lineiter_next(li)) != NULL) { 00379 char *c; 00380 00381 /* Ignore comments. */ 00382 if (li->buf[0] == '#') 00383 continue; 00384 00385 /* Find the first space. */ 00386 c = strchr(li->buf, ' '); 00387 if (c == NULL) continue; 00388 00389 /* Check that the first field equals param and that there's a number after it. */ 00390 if (strncmp(li->buf, param, strlen(param)) == 0 00391 && sscanf(c + 1, "%d", &n) == 1) 00392 return n; 00393 } 00394 return -1; 00395 } 00396 00397 /* Mark every node that has a path to the argument dagnode as "reachable". */ 00398 static void 00399 dag_mark_reachable(ps_latnode_t * d) 00400 { 00401 latlink_list_t *l; 00402 00403 d->reachable = 1; 00404 for (l = d->entries; l; l = l->next) 00405 if (l->link->from && !l->link->from->reachable) 00406 dag_mark_reachable(l->link->from); 00407 } 00408 00409 ps_lattice_t * 00410 ps_lattice_read(ps_decoder_t *ps, 00411 char const *file) 00412 { 00413 FILE *fp; 00414 int32 ispipe; 00415 lineiter_t *line; 00416 float64 lb; 00417 float32 logratio; 00418 ps_latnode_t *tail; 00419 ps_latnode_t **darray; 00420 ps_lattice_t *dag; 00421 int i, k, n_nodes; 00422 int32 pip, silpen, fillpen; 00423 00424 dag = ckd_calloc(1, sizeof(*dag)); 00425 00426 if (ps) { 00427 dag->search = ps->search; 00428 dag->dict = dict_retain(ps->dict); 00429 dag->lmath = logmath_retain(ps->lmath); 00430 dag->frate = cmd_ln_int32_r(dag->search->config, "-frate"); 00431 } 00432 else { 00433 dag->dict = dict_init(NULL, NULL); 00434 dag->lmath = logmath_init(1.0001, 0, FALSE); 00435 dag->frate = 100; 00436 } 00437 dag->silence = dict_silwid(dag->dict); 00438 dag->latnode_alloc = listelem_alloc_init(sizeof(ps_latnode_t)); 00439 dag->latlink_alloc = listelem_alloc_init(sizeof(ps_latlink_t)); 00440 dag->latlink_list_alloc = listelem_alloc_init(sizeof(latlink_list_t)); 00441 dag->refcount = 1; 00442 00443 tail = NULL; 00444 darray = NULL; 00445 00446 E_INFO("Reading DAG file: %s\n", file); 00447 if ((fp = fopen_compchk(file, &ispipe)) == NULL) { 00448 E_ERROR("Failed to open DAG file '%s': %s\n", file, strerror(errno)); 00449 return NULL; 00450 } 00451 line = lineiter_start(fp); 00452 00453 /* Read and verify logbase (ONE BIG HACK!!) */ 00454 if (line == NULL) { 00455 E_ERROR("Premature EOF(%s)\n", file); 00456 goto load_error; 00457 } 00458 if (strncmp(line->buf, "# getcwd: ", 10) != 0) { 00459 E_ERROR("%s does not begin with '# getcwd: '\n%s", file, line->buf); 00460 goto load_error; 00461 } 00462 if ((line = lineiter_next(line)) == NULL) { 00463 E_ERROR("Premature EOF(%s)\n", file); 00464 goto load_error; 00465 } 00466 if ((strncmp(line->buf, "# -logbase ", 11) != 0) 00467 || (sscanf(line->buf + 11, "%lf", &lb) != 1)) { 00468 E_WARN("%s: Cannot find -logbase in header\n", file); 00469 lb = 1.0001; 00470 } 00471 logratio = 1.0f; 00472 if (dag->lmath == NULL) 00473 dag->lmath = logmath_init(lb, 0, TRUE); 00474 else { 00475 float32 pb = logmath_get_base(dag->lmath); 00476 if (fabs(lb - pb) >= 0.0001) { 00477 E_WARN("Inconsistent logbases: %f vs %f: will compensate\n", lb, pb); 00478 logratio = (float32)(log(lb) / log(pb)); 00479 E_INFO("Lattice log ratio: %f\n", logratio); 00480 } 00481 } 00482 /* Read Frames parameter */ 00483 dag->n_frames = dag_param_read(line, "Frames"); 00484 if (dag->n_frames <= 0) { 00485 E_ERROR("Frames parameter missing or invalid\n"); 00486 goto load_error; 00487 } 00488 /* Read Nodes parameter */ 00489 n_nodes = dag_param_read(line, "Nodes"); 00490 if (n_nodes <= 0) { 00491 E_ERROR("Nodes parameter missing or invalid\n"); 00492 goto load_error; 00493 } 00494 00495 /* Read nodes */ 00496 darray = ckd_calloc(n_nodes, sizeof(*darray)); 00497 for (i = 0; i < n_nodes; i++) { 00498 ps_latnode_t *d; 00499 int32 w; 00500 int seqid, sf, fef, lef; 00501 char wd[256]; 00502 00503 if ((line = lineiter_next(line)) == NULL) { 00504 E_ERROR("Premature EOF while loading Nodes(%s)\n", file); 00505 goto load_error; 00506 } 00507 00508 if ((k = 00509 sscanf(line->buf, "%d %255s %d %d %d", &seqid, wd, &sf, &fef, 00510 &lef)) != 5) { 00511 E_ERROR("Cannot parse line: %s, value of count %d\n", line->buf, k); 00512 goto load_error; 00513 } 00514 00515 w = dict_wordid(dag->dict, wd); 00516 if (w < 0) { 00517 if (dag->search == NULL) { 00518 char *ww = ckd_salloc(wd); 00519 if (dict_word2basestr(ww) != -1) { 00520 if (dict_wordid(dag->dict, ww) == BAD_S3WID) 00521 dict_add_word(dag->dict, ww, NULL, 0); 00522 } 00523 ckd_free(ww); 00524 w = dict_add_word(dag->dict, wd, NULL, 0); 00525 } 00526 if (w < 0) { 00527 E_ERROR("Unknown word in line: %s\n", line->buf); 00528 goto load_error; 00529 } 00530 } 00531 00532 if (seqid != i) { 00533 E_ERROR("Seqno error: %s\n", line->buf); 00534 goto load_error; 00535 } 00536 00537 d = listelem_malloc(dag->latnode_alloc); 00538 darray[i] = d; 00539 d->wid = w; 00540 d->basewid = dict_basewid(dag->dict, w); 00541 d->id = seqid; 00542 d->sf = sf; 00543 d->fef = fef; 00544 d->lef = lef; 00545 d->reachable = 0; 00546 d->exits = d->entries = NULL; 00547 d->next = NULL; 00548 00549 if (!dag->nodes) 00550 dag->nodes = d; 00551 else 00552 tail->next = d; 00553 tail = d; 00554 } 00555 00556 /* Read initial node ID */ 00557 k = dag_param_read(line, "Initial"); 00558 if ((k < 0) || (k >= n_nodes)) { 00559 E_ERROR("Initial node parameter missing or invalid\n"); 00560 goto load_error; 00561 } 00562 dag->start = darray[k]; 00563 00564 /* Read final node ID */ 00565 k = dag_param_read(line, "Final"); 00566 if ((k < 0) || (k >= n_nodes)) { 00567 E_ERROR("Final node parameter missing or invalid\n"); 00568 goto load_error; 00569 } 00570 dag->end = darray[k]; 00571 00572 /* Read bestsegscore entries and ignore them. */ 00573 if ((k = dag_param_read(line, "BestSegAscr")) < 0) { 00574 E_ERROR("BestSegAscr parameter missing\n"); 00575 goto load_error; 00576 } 00577 for (i = 0; i < k; i++) { 00578 if ((line = lineiter_next(line)) == NULL) { 00579 E_ERROR("Premature EOF while (%s) ignoring BestSegAscr\n", 00580 line); 00581 goto load_error; 00582 } 00583 } 00584 00585 /* Read in edges. */ 00586 while ((line = lineiter_next(line)) != NULL) { 00587 if (line->buf[0] == '#') 00588 continue; 00589 if (0 == strncmp(line->buf, "Edges", 5)) 00590 break; 00591 } 00592 if (line == NULL) { 00593 E_ERROR("Edges missing\n"); 00594 goto load_error; 00595 } 00596 while ((line = lineiter_next(line)) != NULL) { 00597 int from, to, ascr; 00598 ps_latnode_t *pd, *d; 00599 00600 if (sscanf(line->buf, "%d %d %d", &from, &to, &ascr) != 3) 00601 break; 00602 if (ascr WORSE_THAN WORST_SCORE) 00603 continue; 00604 pd = darray[from]; 00605 d = darray[to]; 00606 if (logratio != 1.0f) 00607 ascr = (int32)(ascr * logratio); 00608 ps_lattice_link(dag, pd, d, ascr, d->sf - 1); 00609 } 00610 if (strcmp(line->buf, "End\n") != 0) { 00611 E_ERROR("Terminating 'End' missing\n"); 00612 goto load_error; 00613 } 00614 lineiter_free(line); 00615 fclose_comp(fp, ispipe); 00616 ckd_free(darray); 00617 00618 /* Minor hack: If the final node is a filler word and not </s>, 00619 * then set its base word ID to </s>, so that the language model 00620 * scores won't be screwed up. */ 00621 if (dict_filler_word(dag->dict, dag->end->wid)) 00622 dag->end->basewid = dag->search 00623 ? ps_search_finish_wid(dag->search) 00624 : dict_wordid(dag->dict, S3_FINISH_WORD); 00625 00626 /* Mark reachable from dag->end */ 00627 dag_mark_reachable(dag->end); 00628 00629 /* Free nodes unreachable from dag->end and their links */ 00630 ps_lattice_delete_unreachable(dag); 00631 00632 if (ps) { 00633 /* Build links around silence and filler words, since they do 00634 * not exist in the language model. FIXME: This is 00635 * potentially buggy, as we already do this before outputing 00636 * lattices. */ 00637 pip = logmath_log(dag->lmath, cmd_ln_float32_r(ps->config, "-pip")); 00638 silpen = pip + logmath_log(dag->lmath, 00639 cmd_ln_float32_r(ps->config, "-silprob")); 00640 fillpen = pip + logmath_log(dag->lmath, 00641 cmd_ln_float32_r(ps->config, "-fillprob")); 00642 ps_lattice_bypass_fillers(dag, silpen, fillpen); 00643 } 00644 00645 return dag; 00646 00647 load_error: 00648 E_ERROR("Failed to load %s\n", file); 00649 lineiter_free(line); 00650 if (fp) fclose_comp(fp, ispipe); 00651 ckd_free(darray); 00652 return NULL; 00653 } 00654 00655 int 00656 ps_lattice_n_frames(ps_lattice_t *dag) 00657 { 00658 return dag->n_frames; 00659 } 00660 00661 ps_lattice_t * 00662 ps_lattice_init_search(ps_search_t *search, int n_frame) 00663 { 00664 ps_lattice_t *dag; 00665 00666 dag = ckd_calloc(1, sizeof(*dag)); 00667 dag->search = search; 00668 dag->dict = dict_retain(search->dict); 00669 dag->lmath = logmath_retain(search->acmod->lmath); 00670 dag->frate = cmd_ln_int32_r(dag->search->config, "-frate"); 00671 dag->silence = dict_silwid(dag->dict); 00672 dag->n_frames = n_frame; 00673 dag->latnode_alloc = listelem_alloc_init(sizeof(ps_latnode_t)); 00674 dag->latlink_alloc = listelem_alloc_init(sizeof(ps_latlink_t)); 00675 dag->latlink_list_alloc = listelem_alloc_init(sizeof(latlink_list_t)); 00676 dag->refcount = 1; 00677 return dag; 00678 } 00679 00680 ps_lattice_t * 00681 ps_lattice_retain(ps_lattice_t *dag) 00682 { 00683 ++dag->refcount; 00684 return dag; 00685 } 00686 00687 int 00688 ps_lattice_free(ps_lattice_t *dag) 00689 { 00690 if (dag == NULL) 00691 return 0; 00692 if (--dag->refcount > 0) 00693 return dag->refcount; 00694 logmath_free(dag->lmath); 00695 listelem_alloc_free(dag->latnode_alloc); 00696 listelem_alloc_free(dag->latlink_alloc); 00697 listelem_alloc_free(dag->latlink_list_alloc); 00698 ckd_free(dag->hyp_str); 00699 ckd_free(dag); 00700 return 0; 00701 } 00702 00703 logmath_t * 00704 ps_lattice_get_logmath(ps_lattice_t *dag) 00705 { 00706 return dag->lmath; 00707 } 00708 00709 ps_latnode_iter_t * 00710 ps_latnode_iter(ps_lattice_t *dag) 00711 { 00712 return dag->nodes; 00713 } 00714 00715 ps_latnode_iter_t * 00716 ps_latnode_iter_next(ps_latnode_iter_t *itor) 00717 { 00718 return itor->next; 00719 } 00720 00721 void 00722 ps_latnode_iter_free(ps_latnode_iter_t *itor) 00723 { 00724 /* Do absolutely nothing. */ 00725 } 00726 00727 ps_latnode_t * 00728 ps_latnode_iter_node(ps_latnode_iter_t *itor) 00729 { 00730 return itor; 00731 } 00732 00733 int 00734 ps_latnode_times(ps_latnode_t *node, int16 *out_fef, int16 *out_lef) 00735 { 00736 if (out_fef) *out_fef = (int16)node->fef; 00737 if (out_lef) *out_lef = (int16)node->lef; 00738 return node->sf; 00739 } 00740 00741 char const * 00742 ps_latnode_word(ps_lattice_t *dag, ps_latnode_t *node) 00743 { 00744 return dict_wordstr(dag->dict, node->wid); 00745 } 00746 00747 char const * 00748 ps_latnode_baseword(ps_lattice_t *dag, ps_latnode_t *node) 00749 { 00750 return dict_wordstr(dag->dict, node->basewid); 00751 } 00752 00753 int32 00754 ps_latnode_prob(ps_lattice_t *dag, ps_latnode_t *node, 00755 ps_latlink_t **out_link) 00756 { 00757 latlink_list_t *links; 00758 int32 bestpost = logmath_get_zero(dag->lmath); 00759 00760 for (links = node->exits; links; links = links->next) { 00761 int32 post = links->link->alpha + links->link->beta - dag->norm; 00762 if (post > bestpost) { 00763 if (out_link) *out_link = links->link; 00764 bestpost = post; 00765 } 00766 } 00767 return bestpost; 00768 } 00769 00770 ps_latlink_iter_t * 00771 ps_latnode_exits(ps_latnode_t *node) 00772 { 00773 return node->exits; 00774 } 00775 00776 ps_latlink_iter_t * 00777 ps_latnode_entries(ps_latnode_t *node) 00778 { 00779 return node->entries; 00780 } 00781 00782 ps_latlink_iter_t * 00783 ps_latlink_iter_next(ps_latlink_iter_t *itor) 00784 { 00785 return itor->next; 00786 } 00787 00788 void 00789 ps_latlink_iter_free(ps_latlink_iter_t *itor) 00790 { 00791 /* Do absolutely nothing. */ 00792 } 00793 00794 ps_latlink_t * 00795 ps_latlink_iter_link(ps_latlink_iter_t *itor) 00796 { 00797 return itor->link; 00798 } 00799 00800 int 00801 ps_latlink_times(ps_latlink_t *link, int16 *out_sf) 00802 { 00803 if (out_sf) { 00804 if (link->from) { 00805 *out_sf = link->from->sf; 00806 } 00807 else { 00808 *out_sf = 0; 00809 } 00810 } 00811 return link->ef; 00812 } 00813 00814 ps_latnode_t * 00815 ps_latlink_nodes(ps_latlink_t *link, ps_latnode_t **out_src) 00816 { 00817 if (out_src) *out_src = link->from; 00818 return link->to; 00819 } 00820 00821 char const * 00822 ps_latlink_word(ps_lattice_t *dag, ps_latlink_t *link) 00823 { 00824 if (link->from == NULL) 00825 return NULL; 00826 return dict_wordstr(dag->dict, link->from->wid); 00827 } 00828 00829 char const * 00830 ps_latlink_baseword(ps_lattice_t *dag, ps_latlink_t *link) 00831 { 00832 if (link->from == NULL) 00833 return NULL; 00834 return dict_wordstr(dag->dict, link->from->basewid); 00835 } 00836 00837 ps_latlink_t * 00838 ps_latlink_pred(ps_latlink_t *link) 00839 { 00840 return link->best_prev; 00841 } 00842 00843 int32 00844 ps_latlink_prob(ps_lattice_t *dag, ps_latlink_t *link, int32 *out_ascr) 00845 { 00846 int32 post = link->alpha + link->beta - dag->norm; 00847 if (out_ascr) *out_ascr = link->ascr << SENSCR_SHIFT; 00848 return post; 00849 } 00850 00851 char const * 00852 ps_lattice_hyp(ps_lattice_t *dag, ps_latlink_t *link) 00853 { 00854 ps_latlink_t *l; 00855 size_t len; 00856 char *c; 00857 00858 /* Backtrace once to get hypothesis length. */ 00859 len = 0; 00860 /* FIXME: There may not be a search, but actually there should be a dict. */ 00861 if (dict_real_word(dag->dict, link->to->basewid)) 00862 len += strlen(dict_wordstr(dag->dict, link->to->basewid)) + 1; 00863 for (l = link; l; l = l->best_prev) { 00864 if (dict_real_word(dag->dict, l->from->basewid)) 00865 len += strlen(dict_wordstr(dag->dict, l->from->basewid)) + 1; 00866 } 00867 00868 /* Backtrace again to construct hypothesis string. */ 00869 ckd_free(dag->hyp_str); 00870 dag->hyp_str = ckd_calloc(1, len+1); /* extra one incase the hyp is empty */ 00871 c = dag->hyp_str + len - 1; 00872 if (dict_real_word(dag->dict, link->to->basewid)) { 00873 len = strlen(dict_wordstr(dag->dict, link->to->basewid)); 00874 c -= len; 00875 memcpy(c, dict_wordstr(dag->dict, link->to->basewid), len); 00876 if (c > dag->hyp_str) { 00877 --c; 00878 *c = ' '; 00879 } 00880 } 00881 for (l = link; l; l = l->best_prev) { 00882 if (dict_real_word(dag->dict, l->from->basewid)) { 00883 len = strlen(dict_wordstr(dag->dict, l->from->basewid)); 00884 c -= len; 00885 memcpy(c, dict_wordstr(dag->dict, l->from->basewid), len); 00886 if (c > dag->hyp_str) { 00887 --c; 00888 *c = ' '; 00889 } 00890 } 00891 } 00892 00893 return dag->hyp_str; 00894 } 00895 00896 static void 00897 ps_lattice_compute_lscr(ps_seg_t *seg, ps_latlink_t *link, int to) 00898 { 00899 ngram_model_t *lmset; 00900 00901 /* Language model score is included in the link score for FSG 00902 * search. FIXME: Of course, this is sort of a hack :( */ 00903 if (0 != strcmp(ps_search_name(seg->search), "ngram")) { 00904 seg->lback = 1; /* Unigram... */ 00905 seg->lscr = 0; 00906 return; 00907 } 00908 00909 lmset = ((ngram_search_t *)seg->search)->lmset; 00910 00911 if (link->best_prev == NULL) { 00912 if (to) /* Sentence has only two words. */ 00913 seg->lscr = ngram_bg_score(lmset, link->to->basewid, 00914 link->from->basewid, &seg->lback) 00915 >> SENSCR_SHIFT; 00916 else {/* This is the start symbol, its lscr is always 0. */ 00917 seg->lscr = 0; 00918 seg->lback = 1; 00919 } 00920 } 00921 else { 00922 /* Find the two predecessor words. */ 00923 if (to) { 00924 seg->lscr = ngram_tg_score(lmset, link->to->basewid, 00925 link->from->basewid, 00926 link->best_prev->from->basewid, 00927 &seg->lback) >> SENSCR_SHIFT; 00928 } 00929 else { 00930 if (link->best_prev->best_prev) 00931 seg->lscr = ngram_tg_score(lmset, link->from->basewid, 00932 link->best_prev->from->basewid, 00933 link->best_prev->best_prev->from->basewid, 00934 &seg->lback) >> SENSCR_SHIFT; 00935 else 00936 seg->lscr = ngram_bg_score(lmset, link->from->basewid, 00937 link->best_prev->from->basewid, 00938 &seg->lback) >> SENSCR_SHIFT; 00939 } 00940 } 00941 } 00942 00943 static void 00944 ps_lattice_link2itor(ps_seg_t *seg, ps_latlink_t *link, int to) 00945 { 00946 dag_seg_t *itor = (dag_seg_t *)seg; 00947 ps_latnode_t *node; 00948 00949 if (to) { 00950 node = link->to; 00951 seg->ef = node->lef; 00952 seg->prob = 0; /* norm + beta - norm */ 00953 } 00954 else { 00955 latlink_list_t *x; 00956 ps_latnode_t *n; 00957 logmath_t *lmath = ps_search_acmod(seg->search)->lmath; 00958 00959 node = link->from; 00960 seg->ef = link->ef; 00961 seg->prob = link->alpha + link->beta - itor->norm; 00962 /* Sum over all exits for this word and any alternate 00963 pronunciations at the same frame. */ 00964 for (n = node; n; n = n->alt) { 00965 for (x = n->exits; x; x = x->next) { 00966 if (x->link == link) 00967 continue; 00968 seg->prob = logmath_add(lmath, seg->prob, 00969 x->link->alpha + x->link->beta - itor->norm); 00970 } 00971 } 00972 } 00973 seg->word = dict_wordstr(ps_search_dict(seg->search), node->wid); 00974 seg->sf = node->sf; 00975 seg->ascr = link->ascr << SENSCR_SHIFT; 00976 /* Compute language model score from best predecessors. */ 00977 ps_lattice_compute_lscr(seg, link, to); 00978 } 00979 00980 static void 00981 ps_lattice_seg_free(ps_seg_t *seg) 00982 { 00983 dag_seg_t *itor = (dag_seg_t *)seg; 00984 00985 ckd_free(itor->links); 00986 ckd_free(itor); 00987 } 00988 00989 static ps_seg_t * 00990 ps_lattice_seg_next(ps_seg_t *seg) 00991 { 00992 dag_seg_t *itor = (dag_seg_t *)seg; 00993 00994 ++itor->cur; 00995 if (itor->cur == itor->n_links + 1) { 00996 ps_lattice_seg_free(seg); 00997 return NULL; 00998 } 00999 else if (itor->cur == itor->n_links) { 01000 /* Re-use the last link but with the "to" node. */ 01001 ps_lattice_link2itor(seg, itor->links[itor->cur - 1], TRUE); 01002 } 01003 else { 01004 ps_lattice_link2itor(seg, itor->links[itor->cur], FALSE); 01005 } 01006 01007 return seg; 01008 } 01009 01010 static ps_segfuncs_t ps_lattice_segfuncs = { 01011 /* seg_next */ ps_lattice_seg_next, 01012 /* seg_free */ ps_lattice_seg_free 01013 }; 01014 01015 ps_seg_t * 01016 ps_lattice_seg_iter(ps_lattice_t *dag, ps_latlink_t *link, float32 lwf) 01017 { 01018 dag_seg_t *itor; 01019 ps_latlink_t *l; 01020 int cur; 01021 01022 /* Calling this an "iterator" is a bit of a misnomer since we have 01023 * to get the entire backtrace in order to produce it. 01024 */ 01025 itor = ckd_calloc(1, sizeof(*itor)); 01026 itor->base.vt = &ps_lattice_segfuncs; 01027 itor->base.search = dag->search; 01028 itor->base.lwf = lwf; 01029 itor->n_links = 0; 01030 itor->norm = dag->norm; 01031 01032 for (l = link; l; l = l->best_prev) { 01033 ++itor->n_links; 01034 } 01035 if (itor->n_links == 0) { 01036 ckd_free(itor); 01037 return NULL; 01038 } 01039 01040 itor->links = ckd_calloc(itor->n_links, sizeof(*itor->links)); 01041 cur = itor->n_links - 1; 01042 for (l = link; l; l = l->best_prev) { 01043 itor->links[cur] = l; 01044 --cur; 01045 } 01046 01047 ps_lattice_link2itor((ps_seg_t *)itor, itor->links[0], FALSE); 01048 return (ps_seg_t *)itor; 01049 } 01050 01051 latlink_list_t * 01052 latlink_list_new(ps_lattice_t *dag, ps_latlink_t *link, latlink_list_t *next) 01053 { 01054 latlink_list_t *ll; 01055 01056 ll = listelem_malloc(dag->latlink_list_alloc); 01057 ll->link = link; 01058 ll->next = next; 01059 01060 return ll; 01061 } 01062 01063 void 01064 ps_lattice_pushq(ps_lattice_t *dag, ps_latlink_t *link) 01065 { 01066 if (dag->q_head == NULL) 01067 dag->q_head = dag->q_tail = latlink_list_new(dag, link, NULL); 01068 else { 01069 dag->q_tail->next = latlink_list_new(dag, link, NULL); 01070 dag->q_tail = dag->q_tail->next; 01071 } 01072 01073 } 01074 01075 ps_latlink_t * 01076 ps_lattice_popq(ps_lattice_t *dag) 01077 { 01078 latlink_list_t *x; 01079 ps_latlink_t *link; 01080 01081 if (dag->q_head == NULL) 01082 return NULL; 01083 link = dag->q_head->link; 01084 x = dag->q_head->next; 01085 listelem_free(dag->latlink_list_alloc, dag->q_head); 01086 dag->q_head = x; 01087 if (dag->q_head == NULL) 01088 dag->q_tail = NULL; 01089 return link; 01090 } 01091 01092 void 01093 ps_lattice_delq(ps_lattice_t *dag) 01094 { 01095 while (ps_lattice_popq(dag)) { 01096 /* Do nothing. */ 01097 } 01098 } 01099 01100 ps_latlink_t * 01101 ps_lattice_traverse_edges(ps_lattice_t *dag, ps_latnode_t *start, ps_latnode_t *end) 01102 { 01103 ps_latnode_t *node; 01104 latlink_list_t *x; 01105 01106 /* Cancel any unfinished traversal. */ 01107 ps_lattice_delq(dag); 01108 01109 /* Initialize node fanin counts and path scores. */ 01110 for (node = dag->nodes; node; node = node->next) 01111 node->info.fanin = 0; 01112 for (node = dag->nodes; node; node = node->next) { 01113 for (x = node->exits; x; x = x->next) 01114 (x->link->to->info.fanin)++; 01115 } 01116 01117 /* Initialize agenda with all exits from start. */ 01118 if (start == NULL) start = dag->start; 01119 for (x = start->exits; x; x = x->next) 01120 ps_lattice_pushq(dag, x->link); 01121 01122 /* Pull the first edge off the queue. */ 01123 return ps_lattice_traverse_next(dag, end); 01124 } 01125 01126 ps_latlink_t * 01127 ps_lattice_traverse_next(ps_lattice_t *dag, ps_latnode_t *end) 01128 { 01129 ps_latlink_t *next; 01130 01131 next = ps_lattice_popq(dag); 01132 if (next == NULL) 01133 return NULL; 01134 01135 /* Decrease fanin count for destination node and expand outgoing 01136 * edges if all incoming edges have been seen. */ 01137 --next->to->info.fanin; 01138 if (next->to->info.fanin == 0) { 01139 latlink_list_t *x; 01140 01141 if (end == NULL) end = dag->end; 01142 if (next->to == end) { 01143 /* If we have traversed all links entering the end node, 01144 * clear the queue, causing future calls to this function 01145 * to return NULL. */ 01146 ps_lattice_delq(dag); 01147 return next; 01148 } 01149 01150 /* Extend all outgoing edges. */ 01151 for (x = next->to->exits; x; x = x->next) 01152 ps_lattice_pushq(dag, x->link); 01153 } 01154 return next; 01155 } 01156 01157 ps_latlink_t * 01158 ps_lattice_reverse_edges(ps_lattice_t *dag, ps_latnode_t *start, ps_latnode_t *end) 01159 { 01160 ps_latnode_t *node; 01161 latlink_list_t *x; 01162 01163 /* Cancel any unfinished traversal. */ 01164 ps_lattice_delq(dag); 01165 01166 /* Initialize node fanout counts and path scores. */ 01167 for (node = dag->nodes; node; node = node->next) { 01168 node->info.fanin = 0; 01169 for (x = node->exits; x; x = x->next) 01170 ++node->info.fanin; 01171 } 01172 01173 /* Initialize agenda with all entries from end. */ 01174 if (end == NULL) end = dag->end; 01175 for (x = end->entries; x; x = x->next) 01176 ps_lattice_pushq(dag, x->link); 01177 01178 /* Pull the first edge off the queue. */ 01179 return ps_lattice_reverse_next(dag, start); 01180 } 01181 01182 ps_latlink_t * 01183 ps_lattice_reverse_next(ps_lattice_t *dag, ps_latnode_t *start) 01184 { 01185 ps_latlink_t *next; 01186 01187 next = ps_lattice_popq(dag); 01188 if (next == NULL) 01189 return NULL; 01190 01191 /* Decrease fanout count for source node and expand incoming 01192 * edges if all incoming edges have been seen. */ 01193 --next->from->info.fanin; 01194 if (next->from->info.fanin == 0) { 01195 latlink_list_t *x; 01196 01197 if (start == NULL) start = dag->start; 01198 if (next->from == start) { 01199 /* If we have traversed all links entering the start node, 01200 * clear the queue, causing future calls to this function 01201 * to return NULL. */ 01202 ps_lattice_delq(dag); 01203 return next; 01204 } 01205 01206 /* Extend all outgoing edges. */ 01207 for (x = next->from->entries; x; x = x->next) 01208 ps_lattice_pushq(dag, x->link); 01209 } 01210 return next; 01211 } 01212 01213 /* 01214 * Find the best score from dag->start to end point of any link and 01215 * use it to update links further down the path. This is like 01216 * single-source shortest path search, except that it is done over 01217 * edges rather than nodes, which allows us to do exact trigram scoring. 01218 * 01219 * Helpfully enough, we get half of the posterior probability 01220 * calculation for free that way too. (interesting research topic: is 01221 * there a reliable Viterbi analogue to word-level Forward-Backward 01222 * like there is for state-level? Or, is it just lattice density?) 01223 */ 01224 ps_latlink_t * 01225 ps_lattice_bestpath(ps_lattice_t *dag, ngram_model_t *lmset, 01226 float32 lwf, float32 ascale) 01227 { 01228 ps_search_t *search; 01229 ps_latnode_t *node; 01230 ps_latlink_t *link; 01231 ps_latlink_t *bestend; 01232 latlink_list_t *x; 01233 logmath_t *lmath; 01234 int32 bestescr; 01235 01236 search = dag->search; 01237 lmath = dag->lmath; 01238 01239 /* Initialize path scores for all links exiting dag->start, and 01240 * set all other scores to the minimum. Also initialize alphas to 01241 * log-zero. */ 01242 for (node = dag->nodes; node; node = node->next) { 01243 for (x = node->exits; x; x = x->next) { 01244 x->link->path_scr = MAX_NEG_INT32; 01245 x->link->alpha = logmath_get_zero(lmath); 01246 } 01247 } 01248 for (x = dag->start->exits; x; x = x->next) { 01249 int32 n_used; 01250 01251 /* Ignore filler words. */ 01252 if (dict_filler_word(ps_search_dict(search), x->link->to->basewid) 01253 && x->link->to != dag->end) 01254 continue; 01255 01256 /* Best path points to dag->start, obviously. */ 01257 if (lmset) 01258 x->link->path_scr = x->link->ascr + 01259 (ngram_bg_score(lmset, x->link->to->basewid, 01260 ps_search_start_wid(search), &n_used) 01261 >> SENSCR_SHIFT) 01262 * lwf; 01263 else 01264 x->link->path_scr = x->link->ascr; 01265 x->link->best_prev = NULL; 01266 /* No predecessors for start links. */ 01267 x->link->alpha = 0; 01268 } 01269 01270 /* Traverse the edges in the graph, updating path scores. */ 01271 for (link = ps_lattice_traverse_edges(dag, NULL, NULL); 01272 link; link = ps_lattice_traverse_next(dag, NULL)) { 01273 int32 bprob, n_used; 01274 01275 /* Skip filler nodes in traversal. */ 01276 if (dict_filler_word(ps_search_dict(search), link->from->basewid) && link->from != dag->start) 01277 continue; 01278 if (dict_filler_word(ps_search_dict(search), link->to->basewid) && link->to != dag->end) 01279 continue; 01280 01281 /* Sanity check, we should not be traversing edges that 01282 * weren't previously updated, otherwise nasty overflows will result. */ 01283 assert(link->path_scr != MAX_NEG_INT32); 01284 01285 /* Calculate common bigram probability for all alphas. */ 01286 if (lmset) 01287 bprob = ngram_ng_prob(lmset, 01288 link->to->basewid, 01289 &link->from->basewid, 1, &n_used); 01290 else 01291 bprob = 0; 01292 /* Add in this link's acoustic score, which was a constant 01293 factor in previous computations (if any). */ 01294 link->alpha += (link->ascr << SENSCR_SHIFT) * ascale; 01295 01296 /* Update scores for all paths exiting link->to. */ 01297 for (x = link->to->exits; x; x = x->next) { 01298 int32 tscore, score; 01299 01300 /* Skip links to filler words in update. */ 01301 if (dict_filler_word(ps_search_dict(search), x->link->to->basewid) 01302 && x->link->to != dag->end) 01303 continue; 01304 01305 /* Update alpha with sum of previous alphas. */ 01306 x->link->alpha = logmath_add(lmath, x->link->alpha, link->alpha + bprob); 01307 /* Calculate trigram score for bestpath. */ 01308 if (lmset) 01309 tscore = (ngram_tg_score(lmset, x->link->to->basewid, 01310 link->to->basewid, 01311 link->from->basewid, &n_used) >> SENSCR_SHIFT) 01312 * lwf; 01313 else 01314 tscore = 0; 01315 /* Update link score with maximum link score. */ 01316 score = link->path_scr + tscore + x->link->ascr; 01317 if (score BETTER_THAN x->link->path_scr) { 01318 x->link->path_scr = score; 01319 x->link->best_prev = link; 01320 } 01321 } 01322 } 01323 01324 /* Find best link entering final node, and calculate normalizer 01325 * for posterior probabilities. */ 01326 bestend = NULL; 01327 bestescr = MAX_NEG_INT32; 01328 01329 /* Normalizer is the alpha for the imaginary link exiting the 01330 final node. */ 01331 dag->norm = logmath_get_zero(lmath); 01332 for (x = dag->end->entries; x; x = x->next) { 01333 int32 bprob, n_used; 01334 01335 if (dict_filler_word(ps_search_dict(search), x->link->from->basewid)) 01336 continue; 01337 if (lmset) 01338 bprob = ngram_ng_prob(lmset, 01339 x->link->to->basewid, 01340 &x->link->from->basewid, 1, &n_used); 01341 else 01342 bprob = 0; 01343 dag->norm = logmath_add(lmath, dag->norm, x->link->alpha + bprob); 01344 if (x->link->path_scr BETTER_THAN bestescr) { 01345 bestescr = x->link->path_scr; 01346 bestend = x->link; 01347 } 01348 } 01349 /* FIXME: floating point... */ 01350 dag->norm += (int32)(dag->final_node_ascr << SENSCR_SHIFT) * ascale; 01351 01352 E_INFO("Normalizer P(O) = alpha(%s:%d:%d) = %d\n", 01353 dict_wordstr(dag->search->dict, dag->end->wid), 01354 dag->end->sf, dag->end->lef, 01355 dag->norm); 01356 return bestend; 01357 } 01358 01359 static int32 01360 ps_lattice_joint(ps_lattice_t *dag, ps_latlink_t *link, float32 ascale) 01361 { 01362 ngram_model_t *lmset; 01363 int32 jprob; 01364 01365 /* Sort of a hack... */ 01366 if (dag->search && 0 == strcmp(ps_search_name(dag->search), "ngram")) 01367 lmset = ((ngram_search_t *)dag->search)->lmset; 01368 else 01369 lmset = NULL; 01370 01371 jprob = (dag->final_node_ascr << SENSCR_SHIFT) * ascale; 01372 while (link) { 01373 if (lmset) { 01374 int lback; 01375 /* Compute unscaled language model probability. Note that 01376 this is actually not the language model probability 01377 that corresponds to this link, but that is okay, 01378 because we are just taking the sum over all links in 01379 the best path. */ 01380 jprob += ngram_ng_prob(lmset, link->to->basewid, 01381 &link->from->basewid, 1, &lback); 01382 } 01383 /* If there is no language model, we assume that the language 01384 model probability (such as it is) has been included in the 01385 link score. */ 01386 jprob += (link->ascr << SENSCR_SHIFT) * ascale; 01387 link = link->best_prev; 01388 } 01389 01390 E_INFO("Joint P(O,S) = %d P(S|O) = %d\n", jprob, jprob - dag->norm); 01391 return jprob; 01392 } 01393 01394 int32 01395 ps_lattice_posterior(ps_lattice_t *dag, ngram_model_t *lmset, 01396 float32 ascale) 01397 { 01398 ps_search_t *search; 01399 logmath_t *lmath; 01400 ps_latnode_t *node; 01401 ps_latlink_t *link; 01402 latlink_list_t *x; 01403 ps_latlink_t *bestend; 01404 int32 bestescr; 01405 01406 search = dag->search; 01407 lmath = dag->lmath; 01408 01409 /* Reset all betas to zero. */ 01410 for (node = dag->nodes; node; node = node->next) { 01411 for (x = node->exits; x; x = x->next) { 01412 x->link->beta = logmath_get_zero(lmath); 01413 } 01414 } 01415 01416 bestend = NULL; 01417 bestescr = MAX_NEG_INT32; 01418 /* Accumulate backward probabilities for all links. */ 01419 for (link = ps_lattice_reverse_edges(dag, NULL, NULL); 01420 link; link = ps_lattice_reverse_next(dag, NULL)) { 01421 int32 bprob, n_used; 01422 01423 /* Skip filler nodes in traversal. */ 01424 if (dict_filler_word(ps_search_dict(search), link->from->basewid) && link->from != dag->start) 01425 continue; 01426 if (dict_filler_word(ps_search_dict(search), link->to->basewid) && link->to != dag->end) 01427 continue; 01428 01429 /* Calculate LM probability. */ 01430 if (lmset) 01431 bprob = ngram_ng_prob(lmset, link->to->basewid, 01432 &link->from->basewid, 1, &n_used); 01433 else 01434 bprob = 0; 01435 01436 if (link->to == dag->end) { 01437 /* Track the best path - we will backtrace in order to 01438 calculate the unscaled joint probability for sentence 01439 posterior. */ 01440 if (link->path_scr BETTER_THAN bestescr) { 01441 bestescr = link->path_scr; 01442 bestend = link; 01443 } 01444 /* Imaginary exit link from final node has beta = 1.0 */ 01445 link->beta = bprob + (dag->final_node_ascr << SENSCR_SHIFT) * ascale; 01446 } 01447 else { 01448 /* Update beta from all outgoing betas. */ 01449 for (x = link->to->exits; x; x = x->next) { 01450 if (dict_filler_word(ps_search_dict(search), x->link->to->basewid) && x->link->to != dag->end) 01451 continue; 01452 link->beta = logmath_add(lmath, link->beta, 01453 x->link->beta + bprob 01454 + (x->link->ascr << SENSCR_SHIFT) * ascale); 01455 } 01456 } 01457 } 01458 01459 /* Return P(S|O) = P(O,S)/P(O) */ 01460 return ps_lattice_joint(dag, bestend, ascale) - dag->norm; 01461 } 01462 01463 int32 01464 ps_lattice_posterior_prune(ps_lattice_t *dag, int32 beam) 01465 { 01466 ps_latlink_t *link; 01467 int npruned = 0; 01468 01469 for (link = ps_lattice_traverse_edges(dag, dag->start, dag->end); 01470 link; link = ps_lattice_traverse_next(dag, dag->end)) { 01471 link->from->reachable = FALSE; 01472 if (link->alpha + link->beta - dag->norm < beam) { 01473 latlink_list_t *x, *tmp, *next; 01474 tmp = NULL; 01475 for (x = link->from->exits; x; x = next) { 01476 next = x->next; 01477 if (x->link == link) { 01478 listelem_free(dag->latlink_list_alloc, x); 01479 } 01480 else { 01481 x->next = tmp; 01482 tmp = x; 01483 } 01484 } 01485 link->from->exits = tmp; 01486 tmp = NULL; 01487 for (x = link->to->entries; x; x = next) { 01488 next = x->next; 01489 if (x->link == link) { 01490 listelem_free(dag->latlink_list_alloc, x); 01491 } 01492 else { 01493 x->next = tmp; 01494 tmp = x; 01495 } 01496 } 01497 link->to->entries = tmp; 01498 listelem_free(dag->latlink_alloc, link); 01499 ++npruned; 01500 } 01501 } 01502 dag_mark_reachable(dag->end); 01503 ps_lattice_delete_unreachable(dag); 01504 return npruned; 01505 } 01506 01507 01508 /* Parameters to prune n-best alternatives search */ 01509 #define MAX_PATHS 500 /* Max allowed active paths at any time */ 01510 #define MAX_HYP_TRIES 10000 01511 01512 /* 01513 * For each node in any path between from and end of utt, find the 01514 * best score from "from".sf to end of utt. (NOTE: Uses bigram probs; 01515 * this is an estimate of the best score from "from".) (NOTE #2: yes, 01516 * this is the "heuristic score" used in A* search) 01517 */ 01518 static int32 01519 best_rem_score(ps_astar_t *nbest, ps_latnode_t * from) 01520 { 01521 ps_lattice_t *dag; 01522 latlink_list_t *x; 01523 int32 bestscore, score; 01524 01525 dag = nbest->dag; 01526 if (from->info.rem_score <= 0) 01527 return (from->info.rem_score); 01528 01529 /* Best score from "from" to end of utt not known; compute from successors */ 01530 bestscore = WORST_SCORE; 01531 for (x = from->exits; x; x = x->next) { 01532 int32 n_used; 01533 01534 score = best_rem_score(nbest, x->link->to); 01535 score += x->link->ascr; 01536 if (nbest->lmset) 01537 score += (ngram_bg_score(nbest->lmset, x->link->to->basewid, 01538 from->basewid, &n_used) >> SENSCR_SHIFT) 01539 * nbest->lwf; 01540 if (score BETTER_THAN bestscore) 01541 bestscore = score; 01542 } 01543 from->info.rem_score = bestscore; 01544 01545 return bestscore; 01546 } 01547 01548 /* 01549 * Insert newpath in sorted (by path score) list of paths. But if newpath is 01550 * too far down the list, drop it (FIXME: necessary?) 01551 * total_score = path score (newpath) + rem_score to end of utt. 01552 */ 01553 static void 01554 path_insert(ps_astar_t *nbest, ps_latpath_t *newpath, int32 total_score) 01555 { 01556 ps_lattice_t *dag; 01557 ps_latpath_t *prev, *p; 01558 int32 i; 01559 01560 dag = nbest->dag; 01561 prev = NULL; 01562 for (i = 0, p = nbest->path_list; (i < MAX_PATHS) && p; p = p->next, i++) { 01563 if ((p->score + p->node->info.rem_score) < total_score) 01564 break; 01565 prev = p; 01566 } 01567 01568 /* newpath should be inserted between prev and p */ 01569 if (i < MAX_PATHS) { 01570 /* Insert new partial hyp */ 01571 newpath->next = p; 01572 if (!prev) 01573 nbest->path_list = newpath; 01574 else 01575 prev->next = newpath; 01576 if (!p) 01577 nbest->path_tail = newpath; 01578 01579 nbest->n_path++; 01580 nbest->n_hyp_insert++; 01581 nbest->insert_depth += i; 01582 } 01583 else { 01584 /* newpath score too low; reject it and also prune paths beyond MAX_PATHS */ 01585 nbest->path_tail = prev; 01586 prev->next = NULL; 01587 nbest->n_path = MAX_PATHS; 01588 listelem_free(nbest->latpath_alloc, newpath); 01589 01590 nbest->n_hyp_reject++; 01591 for (; p; p = newpath) { 01592 newpath = p->next; 01593 listelem_free(nbest->latpath_alloc, p); 01594 nbest->n_hyp_reject++; 01595 } 01596 } 01597 } 01598 01599 /* Find all possible extensions to given partial path */ 01600 static void 01601 path_extend(ps_astar_t *nbest, ps_latpath_t * path) 01602 { 01603 latlink_list_t *x; 01604 ps_latpath_t *newpath; 01605 int32 total_score, tail_score; 01606 ps_lattice_t *dag; 01607 01608 dag = nbest->dag; 01609 01610 /* Consider all successors of path->node */ 01611 for (x = path->node->exits; x; x = x->next) { 01612 int32 n_used; 01613 01614 /* Skip successor if no path from it reaches the final node */ 01615 if (x->link->to->info.rem_score <= WORST_SCORE) 01616 continue; 01617 01618 /* Create path extension and compute exact score for this extension */ 01619 newpath = listelem_malloc(nbest->latpath_alloc); 01620 newpath->node = x->link->to; 01621 newpath->parent = path; 01622 newpath->score = path->score + x->link->ascr; 01623 if (nbest->lmset) { 01624 if (path->parent) { 01625 newpath->score += nbest->lwf 01626 * (ngram_tg_score(nbest->lmset, newpath->node->basewid, 01627 path->node->basewid, 01628 path->parent->node->basewid, &n_used) 01629 >> SENSCR_SHIFT); 01630 } 01631 else 01632 newpath->score += nbest->lwf 01633 * (ngram_bg_score(nbest->lmset, newpath->node->basewid, 01634 path->node->basewid, &n_used) 01635 >> SENSCR_SHIFT); 01636 } 01637 01638 /* Insert new partial path hypothesis into sorted path_list */ 01639 nbest->n_hyp_tried++; 01640 total_score = newpath->score + newpath->node->info.rem_score; 01641 01642 /* First see if hyp would be worse than the worst */ 01643 if (nbest->n_path >= MAX_PATHS) { 01644 tail_score = 01645 nbest->path_tail->score 01646 + nbest->path_tail->node->info.rem_score; 01647 if (total_score < tail_score) { 01648 listelem_free(nbest->latpath_alloc, newpath); 01649 nbest->n_hyp_reject++; 01650 continue; 01651 } 01652 } 01653 01654 path_insert(nbest, newpath, total_score); 01655 } 01656 } 01657 01658 ps_astar_t * 01659 ps_astar_start(ps_lattice_t *dag, 01660 ngram_model_t *lmset, 01661 float32 lwf, 01662 int sf, int ef, 01663 int w1, int w2) 01664 { 01665 ps_astar_t *nbest; 01666 ps_latnode_t *node; 01667 01668 nbest = ckd_calloc(1, sizeof(*nbest)); 01669 nbest->dag = dag; 01670 nbest->lmset = lmset; 01671 nbest->lwf = lwf; 01672 nbest->sf = sf; 01673 if (ef < 0) 01674 nbest->ef = dag->n_frames + 1; 01675 else 01676 nbest->ef = ef; 01677 nbest->w1 = w1; 01678 nbest->w2 = w2; 01679 nbest->latpath_alloc = listelem_alloc_init(sizeof(ps_latpath_t)); 01680 01681 /* Initialize rem_score (A* heuristic) to default values */ 01682 for (node = dag->nodes; node; node = node->next) { 01683 if (node == dag->end) 01684 node->info.rem_score = 0; 01685 else if (node->exits == NULL) 01686 node->info.rem_score = WORST_SCORE; 01687 else 01688 node->info.rem_score = 1; /* +ve => unknown value */ 01689 } 01690 01691 /* Create initial partial hypotheses list consisting of nodes starting at sf */ 01692 nbest->path_list = nbest->path_tail = NULL; 01693 for (node = dag->nodes; node; node = node->next) { 01694 if (node->sf == sf) { 01695 ps_latpath_t *path; 01696 int32 n_used; 01697 01698 best_rem_score(nbest, node); 01699 path = listelem_malloc(nbest->latpath_alloc); 01700 path->node = node; 01701 path->parent = NULL; 01702 if (nbest->lmset) 01703 path->score = nbest->lwf * 01704 (w1 < 0) 01705 ? ngram_bg_score(nbest->lmset, node->basewid, w2, &n_used) 01706 : ngram_tg_score(nbest->lmset, node->basewid, w2, w1, &n_used); 01707 else 01708 path->score = 0; 01709 path->score >>= SENSCR_SHIFT; 01710 path_insert(nbest, path, path->score + node->info.rem_score); 01711 } 01712 } 01713 01714 return nbest; 01715 } 01716 01717 ps_latpath_t * 01718 ps_astar_next(ps_astar_t *nbest) 01719 { 01720 ps_lattice_t *dag; 01721 01722 dag = nbest->dag; 01723 01724 /* Pop the top (best) partial hypothesis */ 01725 while ((nbest->top = nbest->path_list) != NULL) { 01726 nbest->path_list = nbest->path_list->next; 01727 if (nbest->top == nbest->path_tail) 01728 nbest->path_tail = NULL; 01729 nbest->n_path--; 01730 01731 /* Complete hypothesis? */ 01732 if ((nbest->top->node->sf >= nbest->ef) 01733 || ((nbest->top->node == dag->end) && 01734 (nbest->ef > dag->end->sf))) { 01735 /* FIXME: Verify that it is non-empty. Also we may want 01736 * to verify that it is actually distinct from other 01737 * paths, since often this is not the case*/ 01738 return nbest->top; 01739 } 01740 else { 01741 if (nbest->top->node->fef < nbest->ef) 01742 path_extend(nbest, nbest->top); 01743 } 01744 } 01745 01746 /* Did not find any more paths to extend. */ 01747 return NULL; 01748 } 01749 01750 char const * 01751 ps_astar_hyp(ps_astar_t *nbest, ps_latpath_t *path) 01752 { 01753 ps_search_t *search; 01754 ps_latpath_t *p; 01755 size_t len; 01756 char *c; 01757 char *hyp; 01758 01759 search = nbest->dag->search; 01760 01761 /* Backtrace once to get hypothesis length. */ 01762 len = 0; 01763 for (p = path; p; p = p->parent) { 01764 if (dict_real_word(ps_search_dict(search), p->node->basewid)) 01765 len += strlen(dict_wordstr(ps_search_dict(search), p->node->basewid)) + 1; 01766 } 01767 01768 if (len == 0) { 01769 return NULL; 01770 } 01771 01772 /* Backtrace again to construct hypothesis string. */ 01773 hyp = ckd_calloc(1, len); 01774 c = hyp + len - 1; 01775 for (p = path; p; p = p->parent) { 01776 if (dict_real_word(ps_search_dict(search), p->node->basewid)) { 01777 len = strlen(dict_wordstr(ps_search_dict(search), p->node->basewid)); 01778 c -= len; 01779 memcpy(c, dict_wordstr(ps_search_dict(search), p->node->basewid), len); 01780 if (c > hyp) { 01781 --c; 01782 *c = ' '; 01783 } 01784 } 01785 } 01786 01787 nbest->hyps = glist_add_ptr(nbest->hyps, hyp); 01788 return hyp; 01789 } 01790 01791 static void 01792 ps_astar_node2itor(astar_seg_t *itor) 01793 { 01794 ps_seg_t *seg = (ps_seg_t *)itor; 01795 ps_latnode_t *node; 01796 01797 assert(itor->cur < itor->n_nodes); 01798 node = itor->nodes[itor->cur]; 01799 if (itor->cur == itor->n_nodes - 1) 01800 seg->ef = node->lef; 01801 else 01802 seg->ef = itor->nodes[itor->cur + 1]->sf - 1; 01803 seg->word = dict_wordstr(ps_search_dict(seg->search), node->wid); 01804 seg->sf = node->sf; 01805 seg->prob = 0; /* FIXME: implement forward-backward */ 01806 } 01807 01808 static void 01809 ps_astar_seg_free(ps_seg_t *seg) 01810 { 01811 astar_seg_t *itor = (astar_seg_t *)seg; 01812 ckd_free(itor->nodes); 01813 ckd_free(itor); 01814 } 01815 01816 static ps_seg_t * 01817 ps_astar_seg_next(ps_seg_t *seg) 01818 { 01819 astar_seg_t *itor = (astar_seg_t *)seg; 01820 01821 ++itor->cur; 01822 if (itor->cur == itor->n_nodes) { 01823 ps_astar_seg_free(seg); 01824 return NULL; 01825 } 01826 else { 01827 ps_astar_node2itor(itor); 01828 } 01829 01830 return seg; 01831 } 01832 01833 static ps_segfuncs_t ps_astar_segfuncs = { 01834 /* seg_next */ ps_astar_seg_next, 01835 /* seg_free */ ps_astar_seg_free 01836 }; 01837 01838 ps_seg_t * 01839 ps_astar_seg_iter(ps_astar_t *astar, ps_latpath_t *path, float32 lwf) 01840 { 01841 astar_seg_t *itor; 01842 ps_latpath_t *p; 01843 int cur; 01844 01845 /* Backtrace and make an iterator, this should look familiar by now. */ 01846 itor = ckd_calloc(1, sizeof(*itor)); 01847 itor->base.vt = &ps_astar_segfuncs; 01848 itor->base.search = astar->dag->search; 01849 itor->base.lwf = lwf; 01850 itor->n_nodes = itor->cur = 0; 01851 for (p = path; p; p = p->parent) { 01852 ++itor->n_nodes; 01853 } 01854 itor->nodes = ckd_calloc(itor->n_nodes, sizeof(*itor->nodes)); 01855 cur = itor->n_nodes - 1; 01856 for (p = path; p; p = p->parent) { 01857 itor->nodes[cur] = p->node; 01858 --cur; 01859 } 01860 01861 ps_astar_node2itor(itor); 01862 return (ps_seg_t *)itor; 01863 } 01864 01865 void 01866 ps_astar_finish(ps_astar_t *nbest) 01867 { 01868 gnode_t *gn; 01869 01870 /* Free all hyps. */ 01871 for (gn = nbest->hyps; gn; gn = gnode_next(gn)) { 01872 ckd_free(gnode_ptr(gn)); 01873 } 01874 glist_free(nbest->hyps); 01875 /* Free all paths. */ 01876 listelem_alloc_free(nbest->latpath_alloc); 01877 /* Free the Henge. */ 01878 ckd_free(nbest); 01879 }