• Main Page
  • Related Pages
  • Data Structures
  • Files
  • File List
  • Globals

src/libpocketsphinx/ps_lattice.c

Go to the documentation of this file.
00001 /* -*- c-basic-offset: 4; indent-tabs-mode: nil -*- */
00002 /* ====================================================================
00003  * Copyright (c) 2008 Carnegie Mellon University.  All rights
00004  * reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions
00008  * are met:
00009  *
00010  * 1. Redistributions of source code must retain the above copyright
00011  *    notice, this list of conditions and the following disclaimer. 
00012  *
00013  * 2. Redistributions in binary form must reproduce the above copyright
00014  *    notice, this list of conditions and the following disclaimer in
00015  *    the documentation and/or other materials provided with the
00016  *    distribution.
00017  *
00018  * This work was supported in part by funding from the Defense Advanced 
00019  * Research Projects Agency and the National Science Foundation of the 
00020  * United States of America, and the CMU Sphinx Speech Consortium.
00021  *
00022  * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND 
00023  * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 
00024  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00025  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY
00026  * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
00027  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
00028  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 
00029  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 
00030  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
00031  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
00032  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00033  *
00034  * ====================================================================
00035  *
00036  */
00037 
00042 /* System headers. */
00043 #include <assert.h>
00044 #include <string.h>
00045 #include <math.h>
00046 
00047 /* SphinxBase headers. */
00048 #include <ckd_alloc.h>
00049 #include <listelem_alloc.h>
00050 #include <strfuncs.h>
00051 #include <pio.h>
00052 
00053 /* Local headers. */
00054 #include "pocketsphinx_internal.h"
00055 #include "ps_lattice_internal.h"
00056 #include "ngram_search.h"
00057 
00058 /*
00059  * Create a directed link between "from" and "to" nodes, but if a link already exists,
00060  * choose one with the best ascr.
00061  */
00062 void
00063 ps_lattice_link(ps_lattice_t *dag, ps_latnode_t *from, ps_latnode_t *to, int32 score, int32 ef)
00064 {
00065     latlink_list_t *fwdlink;
00066 
00067     /* Look for an existing link between "from" and "to" nodes */
00068     for (fwdlink = from->exits; fwdlink; fwdlink = fwdlink->next)
00069         if (fwdlink->link->to == to)
00070             break;
00071 
00072     if (fwdlink == NULL) {
00073         latlink_list_t *revlink;
00074         ps_latlink_t *link;
00075 
00076         /* No link between the two nodes; create a new one */
00077         link = listelem_malloc(dag->latlink_alloc);
00078         fwdlink = listelem_malloc(dag->latlink_list_alloc);
00079         revlink = listelem_malloc(dag->latlink_list_alloc);
00080 
00081         link->from = from;
00082         link->to = to;
00083         link->ascr = score;
00084         link->ef = ef;
00085         link->best_prev = NULL;
00086 
00087         fwdlink->link = revlink->link = link;
00088         fwdlink->next = from->exits;
00089         from->exits = fwdlink;
00090         revlink->next = to->entries;
00091         to->entries = revlink;
00092     }
00093     else {
00094         /* Link already exists; just retain the best ascr */
00095         if (fwdlink->link->ascr < score) {
00096             fwdlink->link->ascr = score;
00097             fwdlink->link->ef = ef;
00098         }
00099     }
00100 }
00101 
00102 void
00103 ps_lattice_bypass_fillers(ps_lattice_t *dag, int32 silpen, int32 fillpen)
00104 {
00105     ps_latnode_t *node;
00106     int32 score;
00107 
00108     /* Bypass filler nodes */
00109     for (node = dag->nodes; node; node = node->next) {
00110         latlink_list_t *revlink;
00111         if (node == dag->end || !ISA_FILLER_WORD(dag->search, node->basewid))
00112             continue;
00113 
00114         /* Replace each link entering filler node with links to all its successors */
00115         for (revlink = node->entries; revlink; revlink = revlink->next) {
00116             latlink_list_t *forlink;
00117             ps_latlink_t *rlink = revlink->link;
00118 
00119             score = (node->basewid == ps_search_silence_wid(dag->search)) ? silpen : fillpen;
00120             score += rlink->ascr;
00121             /*
00122              * Make links from predecessor of filler (from) to successors of filler.
00123              * But if successor is a filler, it has already been eliminated since it
00124              * appears earlier in latnode_list (see build...).  So it can be skipped.
00125              */
00126             for (forlink = node->exits; forlink; forlink = forlink->next) {
00127                 ps_latlink_t *flink = forlink->link;
00128                 if (!ISA_FILLER_WORD(dag->search, flink->to->basewid)) {
00129                     ps_lattice_link(dag, rlink->from, flink->to,
00130                                     score + flink->ascr, flink->ef);
00131                 }
00132             }
00133         }
00134     }
00135 }
00136 
00137 static void
00138 delete_node(ps_lattice_t *dag, ps_latnode_t *node)
00139 {
00140     latlink_list_t *x, *next_x;
00141 
00142     for (x = node->exits; x; x = next_x) {
00143         next_x = x->next;
00144         x->link->from = NULL;
00145         listelem_free(dag->latlink_list_alloc, x);
00146     }
00147     for (x = node->entries; x; x = next_x) {
00148         next_x = x->next;
00149         x->link->to = NULL;
00150         listelem_free(dag->latlink_list_alloc, x);
00151     }
00152     listelem_free(dag->latnode_alloc, node);
00153 }
00154 
00155 static void
00156 remove_dangling_links(ps_lattice_t *dag, ps_latnode_t *node)
00157 {
00158     latlink_list_t *x, *prev_x, *next_x;
00159 
00160     prev_x = NULL;
00161     for (x = node->exits; x; x = next_x) {
00162         next_x = x->next;
00163         if (x->link->to == NULL) {
00164             if (prev_x)
00165                 prev_x->next = next_x;
00166             else
00167                 node->exits = next_x;
00168             listelem_free(dag->latlink_alloc, x->link);
00169             listelem_free(dag->latlink_list_alloc, x);
00170         }
00171         else
00172             prev_x = x;
00173     }
00174     prev_x = NULL;
00175     for (x = node->entries; x; x = next_x) {
00176         next_x = x->next;
00177         if (x->link->from == NULL) {
00178             if (prev_x)
00179                 prev_x->next = next_x;
00180             else
00181                 node->exits = next_x;
00182             listelem_free(dag->latlink_alloc, x->link);
00183             listelem_free(dag->latlink_list_alloc, x);
00184         }
00185         else
00186             prev_x = x;
00187     }
00188 }
00189 
00190 void
00191 ps_lattice_delete_unreachable(ps_lattice_t *dag)
00192 {
00193     ps_latnode_t *node, *prev_node, *next_node;
00194     int i;
00195 
00196     /* Remove unreachable nodes from the list of nodes. */
00197     prev_node = NULL;
00198     for (node = dag->nodes; node; node = next_node) {
00199         next_node = node->next;
00200         if (!node->reachable) {
00201             if (prev_node)
00202                 prev_node->next = next_node;
00203             else
00204                 dag->nodes = next_node;
00205             /* Delete this node and NULLify links to it. */
00206             delete_node(dag, node);
00207         }
00208         else
00209             prev_node = node;
00210     }
00211 
00212     /* Remove all links to and from unreachable nodes. */
00213     i = 0;
00214     for (node = dag->nodes; node; node = node->next) {
00215         /* Assign sequence numbers. */
00216         node->id = i++;
00217 
00218         /* We should obviously not encounter unreachable nodes here! */
00219         assert(node->reachable);
00220 
00221         /* Remove all links that go nowhere. */
00222         remove_dangling_links(dag, node);
00223     }
00224 }
00225 
00226 int32
00227 ps_lattice_write(ps_lattice_t *dag, char const *filename)
00228 {
00229     FILE *fp;
00230     int32 i;
00231     ps_latnode_t *d, *initial, *final;
00232 
00233     initial = dag->start;
00234     final = dag->end;
00235 
00236     E_INFO("Writing lattice file: %s\n", filename);
00237     if ((fp = fopen(filename, "w")) == NULL) {
00238         E_ERROR("fopen(%s,w) failed\n", filename);
00239         return -1;
00240     }
00241 
00242     /* Stupid Sphinx-III lattice code expects 'getcwd:' here */
00243     fprintf(fp, "# getcwd: /this/is/bogus\n");
00244     fprintf(fp, "# -logbase %e\n", logmath_get_base(dag->lmath));
00245     fprintf(fp, "#\n");
00246 
00247     fprintf(fp, "Frames %d\n", dag->n_frames);
00248     fprintf(fp, "#\n");
00249 
00250     for (i = 0, d = dag->nodes; d; d = d->next, i++);
00251     fprintf(fp,
00252             "Nodes %d (NODEID WORD STARTFRAME FIRST-ENDFRAME LAST-ENDFRAME)\n",
00253             i);
00254     for (i = 0, d = dag->nodes; d; d = d->next, i++) {
00255         d->id = i;
00256         fprintf(fp, "%d %s %d %d %d\n",
00257                 i, dict_word_str(ps_search_dict(dag->search), d->wid),
00258                 d->sf, d->fef, d->lef);
00259     }
00260     fprintf(fp, "#\n");
00261 
00262     fprintf(fp, "Initial %d\nFinal %d\n", initial->id, final->id);
00263     fprintf(fp, "#\n");
00264 
00265     /* Don't bother with this, it's not used by anything. */
00266     fprintf(fp, "BestSegAscr %d (NODEID ENDFRAME ASCORE)\n",
00267             0 /* #BPTable entries */ );
00268     fprintf(fp, "#\n");
00269 
00270     fprintf(fp, "Edges (FROM-NODEID TO-NODEID ASCORE)\n");
00271     for (d = dag->nodes; d; d = d->next) {
00272         latlink_list_t *l;
00273         for (l = d->exits; l; l = l->next)
00274             fprintf(fp, "%d %d %d\n",
00275                     d->id, l->link->to->id, l->link->ascr);
00276     }
00277     fprintf(fp, "End\n");
00278     fclose(fp);
00279 
00280     return 0;
00281 }
00282 
00283 /* Read parameter from a lattice file*/
00284 static int
00285 dag_param_read(lineiter_t *li, char *param)
00286 {
00287     int32 n;
00288 
00289     while ((li = lineiter_next(li)) != NULL) {
00290         char *c;
00291 
00292         /* Ignore comments. */
00293         if (li->buf[0] == '#')
00294             continue;
00295 
00296         /* Find the first space. */
00297         c = strchr(li->buf, ' ');
00298         if (c == NULL) continue;
00299 
00300         /* Check that the first field equals param and that there's a number after it. */
00301         if (strncmp(li->buf, param, strlen(param)) == 0
00302             && sscanf(c + 1, "%d", &n) == 1)
00303             return n;
00304     }
00305     return -1;
00306 }
00307 
00308 /* Mark every node that has a path to the argument dagnode as "reachable". */
00309 static void
00310 dag_mark_reachable(ps_latnode_t * d)
00311 {
00312     latlink_list_t *l;
00313 
00314     d->reachable = 1;
00315     for (l = d->entries; l; l = l->next)
00316         if (l->link->from && !l->link->from->reachable)
00317             dag_mark_reachable(l->link->from);
00318 }
00319 
00320 ps_lattice_t *
00321 ps_lattice_read(ps_decoder_t *ps,
00322                 char const *file)
00323 {
00324     FILE *fp;
00325     int32 ispipe;
00326     lineiter_t *line;
00327     float64 lb;
00328     float32 logratio;
00329     ps_latnode_t *tail;
00330     ps_latnode_t **darray;
00331     ps_lattice_t *dag;
00332     int i, k, n_nodes;
00333     int32 pip, silpen, fillpen;
00334 
00335     dag = ckd_calloc(1, sizeof(*dag));
00336     dag->search = ps->search;
00337     dag->lmath = logmath_retain(ps->lmath);
00338     dag->latnode_alloc = listelem_alloc_init(sizeof(ps_latnode_t));
00339     dag->latlink_alloc = listelem_alloc_init(sizeof(ps_latlink_t));
00340     dag->latlink_list_alloc = listelem_alloc_init(sizeof(latlink_list_t));
00341     dag->refcount = 1;
00342 
00343     tail = NULL;
00344     darray = NULL;
00345 
00346     E_INFO("Reading DAG file: %s\n", file);
00347     if ((fp = fopen_compchk(file, &ispipe)) == NULL) {
00348         E_ERROR("fopen_compchk(%s) failed\n", file);
00349         return NULL;
00350     }
00351     line = lineiter_start(fp);
00352 
00353     /* Read and verify logbase (ONE BIG HACK!!) */
00354     if (line == NULL) {
00355         E_ERROR("Premature EOF(%s)\n", file);
00356         goto load_error;
00357     }
00358     if (strncmp(line->buf, "# getcwd: ", 10) != 0) {
00359         E_ERROR("%s does not begin with '# getcwd: '\n%s", file, line->buf);
00360         goto load_error;
00361     }
00362     if ((line = lineiter_next(line)) == NULL) {
00363         E_ERROR("Premature EOF(%s)\n", file);
00364         goto load_error;
00365     }
00366     if ((strncmp(line->buf, "# -logbase ", 11) != 0)
00367         || (sscanf(line->buf + 11, "%lf", &lb) != 1)) {
00368         E_WARN("%s: Cannot find -logbase in header\n", file);
00369         lb = 1.0001;
00370     }
00371     logratio = 1.0f;
00372     if (dag->lmath == NULL)
00373         dag->lmath = logmath_init(lb, 0, TRUE);
00374     else {
00375         float32 pb = logmath_get_base(dag->lmath);
00376         if (fabs(lb - pb) >= 0.0001) {
00377             E_WARN("Inconsistent logbases: %f vs %f: will compensate\n", lb, pb);
00378             logratio = (float32)(log(lb) / log(pb));
00379             E_INFO("Lattice log ratio: %f\n", logratio);
00380         }
00381     }
00382     /* Read Frames parameter */
00383     dag->n_frames = dag_param_read(line, "Frames");
00384     if (dag->n_frames <= 0) {
00385         E_ERROR("Frames parameter missing or invalid\n");
00386         goto load_error;
00387     }
00388     /* Read Nodes parameter */
00389     n_nodes = dag_param_read(line, "Nodes");
00390     if (n_nodes <= 0) {
00391         E_ERROR("Nodes parameter missing or invalid\n");
00392         goto load_error;
00393     }
00394 
00395     /* Read nodes */
00396     darray = ckd_calloc(n_nodes, sizeof(*darray));
00397     for (i = 0; i < n_nodes; i++) {
00398         ps_latnode_t *d;
00399         int32 w;
00400         int seqid, sf, fef, lef;
00401         char wd[256];
00402 
00403         if ((line = lineiter_next(line)) == NULL) {
00404             E_ERROR("Premature EOF while loading Nodes(%s)\n", file);
00405             goto load_error;
00406         }
00407 
00408         if ((k =
00409              sscanf(line->buf, "%d %255s %d %d %d", &seqid, wd, &sf, &fef,
00410                     &lef)) != 5) {
00411             E_ERROR("Cannot parse line: %s, value of count %d\n", line->buf, k);
00412             goto load_error;
00413         }
00414 
00415         w = dict_to_id(ps->dict, wd);
00416         if (w < 0) {
00417             E_ERROR("Unknown word in line: %s\n", line->buf);
00418             goto load_error;
00419         }
00420 
00421         if (seqid != i) {
00422             E_ERROR("Seqno error: %s\n", line->buf);
00423             goto load_error;
00424         }
00425 
00426         d = listelem_malloc(dag->latnode_alloc);
00427         darray[i] = d;
00428         d->wid = w;
00429         d->basewid = dict_base_wid(ps->dict, w);
00430         d->id = seqid;
00431         d->sf = sf;
00432         d->fef = fef;
00433         d->lef = lef;
00434         d->reachable = 0;
00435         d->exits = d->entries = NULL;
00436         d->next = NULL;
00437 
00438         if (!dag->nodes)
00439             dag->nodes = d;
00440         else
00441             tail->next = d;
00442         tail = d;
00443     }
00444 
00445     /* Read initial node ID */
00446     k = dag_param_read(line, "Initial");
00447     if ((k < 0) || (k >= n_nodes)) {
00448         E_ERROR("Initial node parameter missing or invalid\n");
00449         goto load_error;
00450     }
00451     dag->start = darray[k];
00452 
00453     /* Read final node ID */
00454     k = dag_param_read(line, "Final");
00455     if ((k < 0) || (k >= n_nodes)) {
00456         E_ERROR("Final node parameter missing or invalid\n");
00457         goto load_error;
00458     }
00459     dag->end = darray[k];
00460 
00461     /* Read bestsegscore entries and ignore them. */
00462     if ((k = dag_param_read(line, "BestSegAscr")) < 0) {
00463         E_ERROR("BestSegAscr parameter missing\n");
00464         goto load_error;
00465     }
00466     for (i = 0; i < k; i++) {
00467         if ((line = lineiter_next(line)) == NULL) {
00468             E_ERROR("Premature EOF while (%s) ignoring BestSegAscr\n",
00469                     line);
00470             goto load_error;
00471         }
00472     }
00473 
00474     /* Read in edges. */
00475     while ((line = lineiter_next(line)) != NULL) {
00476         if (line->buf[0] == '#')
00477             continue;
00478         if (0 == strncmp(line->buf, "Edges", 5))
00479             break;
00480     }
00481     if (line == NULL) {
00482         E_ERROR("Edges missing\n");
00483         goto load_error;
00484     }
00485     while ((line = lineiter_next(line)) != NULL) {
00486         int from, to, ascr;
00487         ps_latnode_t *pd, *d;
00488 
00489         if (sscanf(line->buf, "%d %d %d", &from, &to, &ascr) != 3)
00490             break;
00491         pd = darray[from];
00492         d = darray[to];
00493         if (logratio != 1.0f)
00494             ascr = (int32)(ascr * logratio);
00495         ps_lattice_link(dag, pd, d, ascr, d->sf - 1);
00496     }
00497     if (strcmp(line->buf, "End\n") != 0) {
00498         E_ERROR("Terminating 'End' missing\n");
00499         goto load_error;
00500     }
00501     lineiter_free(line);
00502     fclose_comp(fp, ispipe);
00503     ckd_free(darray);
00504 
00505     /* Minor hack: If the final node is a filler word and not </s>,
00506      * then set its base word ID to </s>, so that the language model
00507      * scores won't be screwed up. */
00508     if (ISA_FILLER_WORD(dag->search, dag->end->wid))
00509         dag->end->basewid = ps_search_finish_wid(dag->search);
00510 
00511     /* Mark reachable from dag->end */
00512     dag_mark_reachable(dag->end);
00513 
00514     /* Free nodes unreachable from dag->end and their links */
00515     ps_lattice_delete_unreachable(dag);
00516 
00517     /* Build links around silence and filler words, since they do not
00518      * exist in the language model. */
00519     pip = logmath_log(dag->lmath, cmd_ln_float32_r(ps->config, "-pip"));
00520     silpen = pip + logmath_log(dag->lmath,
00521                                cmd_ln_float32_r(ps->config, "-silprob"));
00522     fillpen = pip + logmath_log(dag->lmath,
00523                                 cmd_ln_float32_r(ps->config, "-fillprob"));
00524     ps_lattice_bypass_fillers(dag, silpen, fillpen);
00525 
00526     return dag;
00527 
00528   load_error:
00529     E_ERROR("Failed to load %s\n", file);
00530     lineiter_free(line);
00531     if (fp) fclose_comp(fp, ispipe);
00532     ckd_free(darray);
00533     return NULL;
00534 }
00535 
00536 int
00537 ps_lattice_n_frames(ps_lattice_t *dag)
00538 {
00539     return dag->n_frames;
00540 }
00541 
00542 ps_lattice_t *
00543 ps_lattice_init_search(ps_search_t *search, int n_frame)
00544 {
00545     ps_lattice_t *dag;
00546 
00547     dag = ckd_calloc(1, sizeof(*dag));
00548     dag->search = search;
00549     dag->lmath = logmath_retain(search->acmod->lmath);
00550     dag->n_frames = n_frame;
00551     dag->latnode_alloc = listelem_alloc_init(sizeof(ps_latnode_t));
00552     dag->latlink_alloc = listelem_alloc_init(sizeof(ps_latlink_t));
00553     dag->latlink_list_alloc = listelem_alloc_init(sizeof(latlink_list_t));
00554     dag->refcount = 1;
00555     return dag;
00556 }
00557 
00558 ps_lattice_t *
00559 ps_lattice_retain(ps_lattice_t *dag)
00560 {
00561     ++dag->refcount;
00562     return dag;
00563 }
00564 
00565 int
00566 ps_lattice_free(ps_lattice_t *dag)
00567 {
00568     if (dag == NULL)
00569         return 0;
00570     if (--dag->refcount > 0)
00571         return dag->refcount;
00572     logmath_free(dag->lmath);
00573     listelem_alloc_free(dag->latnode_alloc);
00574     listelem_alloc_free(dag->latlink_alloc);
00575     listelem_alloc_free(dag->latlink_list_alloc);
00576     ckd_free(dag->hyp_str);
00577     ckd_free(dag);
00578     return 0;
00579 }
00580 
00581 logmath_t *
00582 ps_lattice_get_logmath(ps_lattice_t *dag)
00583 {
00584     return dag->lmath;
00585 }
00586 
00587 ps_latnode_iter_t *
00588 ps_latnode_iter(ps_lattice_t *dag)
00589 {
00590     return dag->nodes;
00591 }
00592 
00593 ps_latnode_iter_t *
00594 ps_latnode_iter_next(ps_latnode_iter_t *itor)
00595 {
00596     return itor->next;
00597 }
00598 
00599 void
00600 ps_latnode_iter_free(ps_latnode_iter_t *itor)
00601 {
00602     /* Do absolutely nothing. */
00603 }
00604 
00605 ps_latnode_t *
00606 ps_latnode_iter_node(ps_latnode_iter_t *itor)
00607 {
00608     return itor;
00609 }
00610 
00611 int
00612 ps_latnode_times(ps_latnode_t *node, int16 *out_fef, int16 *out_lef)
00613 {
00614     if (out_fef) *out_fef = (int16)node->fef;
00615     if (out_lef) *out_lef = (int16)node->lef;
00616     return node->sf;
00617 }
00618 
00619 char const *
00620 ps_latnode_word(ps_lattice_t *dag, ps_latnode_t *node)
00621 {
00622     return dict_word_str(ps_search_dict(dag->search), node->wid);
00623 }
00624 
00625 char const *
00626 ps_latnode_baseword(ps_lattice_t *dag, ps_latnode_t *node)
00627 {
00628     return dict_word_str(ps_search_dict(dag->search), node->basewid);
00629 }
00630 
00631 int32
00632 ps_latnode_prob(ps_lattice_t *dag, ps_latnode_t *node,
00633                 ps_latlink_t **out_link)
00634 {
00635     latlink_list_t *links;
00636     int32 bestpost = logmath_get_zero(dag->lmath);
00637 
00638     for (links = node->exits; links; links = links->next) {
00639         int32 post = links->link->alpha + links->link->beta - dag->norm;
00640         if (post > bestpost) {
00641             if (out_link) *out_link = links->link;
00642             bestpost = post;
00643         }
00644     }
00645     return bestpost;
00646 }
00647 
00648 ps_latlink_iter_t *
00649 ps_latnode_exits(ps_latnode_t *node)
00650 {
00651     return node->exits;
00652 }
00653 
00654 ps_latlink_iter_t *
00655 ps_latnode_entries(ps_latnode_t *node)
00656 {
00657     return node->entries;
00658 }
00659 
00660 ps_latlink_iter_t *
00661 ps_latlink_iter_next(ps_latlink_iter_t *itor)
00662 {
00663     return itor->next;
00664 }
00665 
00666 void
00667 ps_latlink_iter_free(ps_latlink_iter_t *itor)
00668 {
00669     /* Do absolutely nothing. */
00670 }
00671 
00672 ps_latlink_t *
00673 ps_latlink_iter_link(ps_latlink_iter_t *itor)
00674 {
00675     return itor->link;
00676 }
00677 
00678 int
00679 ps_latlink_times(ps_latlink_t *link, int16 *out_sf)
00680 {
00681     if (out_sf) {
00682         if (link->from) {
00683             *out_sf = link->from->sf;
00684         }
00685         else {
00686             *out_sf = 0;
00687         }
00688     }
00689     return link->ef;
00690 }
00691 
00692 ps_latnode_t *
00693 ps_latlink_nodes(ps_latlink_t *link, ps_latnode_t **out_src)
00694 {
00695     if (out_src) *out_src = link->from;
00696     return link->to;
00697 }
00698 
00699 char const *
00700 ps_latlink_word(ps_lattice_t *dag, ps_latlink_t *link)
00701 {
00702     if (link->from == NULL)
00703         return NULL;
00704     return dict_word_str(ps_search_dict(dag->search), link->from->wid);
00705 }
00706 
00707 char const *
00708 ps_latlink_baseword(ps_lattice_t *dag, ps_latlink_t *link)
00709 {
00710     if (link->from == NULL)
00711         return NULL;
00712     return dict_word_str(ps_search_dict(dag->search), link->from->basewid);
00713 }
00714 
00715 ps_latlink_t *
00716 ps_latlink_pred(ps_latlink_t *link)
00717 {
00718     return link->best_prev;
00719 }
00720 
00721 int32
00722 ps_latlink_prob(ps_lattice_t *dag, ps_latlink_t *link, int32 *out_ascr)
00723 {
00724     int32 post = link->alpha + link->beta - dag->norm;
00725     if (out_ascr) *out_ascr = link->ascr;
00726     return post;
00727 }
00728 
00729 char const *
00730 ps_lattice_hyp(ps_lattice_t *dag, ps_latlink_t *link)
00731 {
00732     ps_latlink_t *l;
00733     size_t len;
00734     char *c;
00735 
00736     /* Backtrace once to get hypothesis length. */
00737     len = 0;
00738     if (ISA_REAL_WORD(dag->search, link->to->basewid))
00739         len += strlen(dict_word_str(ps_search_dict(dag->search), link->to->basewid)) + 1;
00740     for (l = link; l; l = l->best_prev) {
00741         if (ISA_REAL_WORD(dag->search, l->from->basewid))
00742             len += strlen(dict_word_str(ps_search_dict(dag->search), l->from->basewid)) + 1;
00743     }
00744 
00745     /* Backtrace again to construct hypothesis string. */
00746     ckd_free(dag->hyp_str);
00747     dag->hyp_str = ckd_calloc(1, len);
00748     c = dag->hyp_str + len - 1;
00749     if (ISA_REAL_WORD(dag->search, link->to->basewid)) {
00750         len = strlen(dict_word_str(ps_search_dict(dag->search), link->to->basewid));
00751         c -= len;
00752         memcpy(c, dict_word_str(ps_search_dict(dag->search), link->to->basewid), len);
00753         if (c > dag->hyp_str) {
00754             --c;
00755             *c = ' ';
00756         }
00757     }
00758     for (l = link; l; l = l->best_prev) {
00759         if (ISA_REAL_WORD(dag->search, l->from->basewid)) {
00760             len = strlen(dict_word_str(ps_search_dict(dag->search), l->from->basewid));
00761             c -= len;
00762             memcpy(c, dict_word_str(ps_search_dict(dag->search), l->from->basewid), len);
00763             if (c > dag->hyp_str) {
00764                 --c;
00765                 *c = ' ';
00766             }
00767         }
00768     }
00769 
00770     return dag->hyp_str;
00771 }
00772 
00773 static void
00774 ps_lattice_compute_lscr(ps_seg_t *seg, ps_latlink_t *link, int to)
00775 {
00776     ngram_model_t *lmset;
00777 
00778     /* Language model score is included in the link score for FSG
00779      * search.  FIXME: Of course, this is sort of a hack :( */
00780     if (0 != strcmp(ps_search_name(seg->search), "ngram")) {
00781         seg->lback = 1; /* Unigram... */
00782         seg->lscr = 0;
00783         return;
00784     }
00785         
00786     lmset = ((ngram_search_t *)seg->search)->lmset;
00787 
00788     if (link->best_prev == NULL) {
00789         if (to) /* Sentence has only two words. */
00790             seg->lscr = ngram_bg_score(lmset, link->to->basewid,
00791                                        link->from->basewid, &seg->lback);
00792         else {/* This is the start symbol, its lscr is always 0. */
00793             seg->lscr = 0;
00794             seg->lback = 1;
00795         }
00796     }
00797     else {
00798         /* Find the two predecessor words. */
00799         if (to) {
00800             seg->lscr = ngram_tg_score(lmset, link->to->basewid,
00801                                        link->from->basewid,
00802                                        link->best_prev->from->basewid,
00803                                        &seg->lback);
00804         }
00805         else {
00806             if (link->best_prev->best_prev)
00807                 seg->lscr = ngram_tg_score(lmset, link->from->basewid,
00808                                            link->best_prev->from->basewid,
00809                                            link->best_prev->best_prev->from->basewid,
00810                                            &seg->lback);
00811             else
00812                 seg->lscr = ngram_bg_score(lmset, link->from->basewid,
00813                                            link->best_prev->from->basewid,
00814                                            &seg->lback);
00815         }
00816     }
00817 }
00818 
00819 static void
00820 ps_lattice_link2itor(ps_seg_t *seg, ps_latlink_t *link, int to)
00821 {
00822     dag_seg_t *itor = (dag_seg_t *)seg;
00823     ps_latnode_t *node;
00824 
00825     if (to) {
00826         node = link->to;
00827         seg->ef = node->lef;
00828         seg->prob = 0; /* norm + beta - norm */
00829     }
00830     else {
00831         latlink_list_t *x;
00832         ps_latnode_t *n;
00833         logmath_t *lmath = ps_search_acmod(seg->search)->lmath;
00834 
00835         node = link->from;
00836         seg->ef = link->ef;
00837         seg->prob = link->alpha + link->beta - itor->norm;
00838         /* Sum over all exits for this word and any alternate
00839            pronunciations at the same frame. */
00840         for (n = node; n; n = n->alt) {
00841             for (x = n->exits; x; x = x->next) {
00842                 if (x->link == link)
00843                     continue;
00844                 seg->prob = logmath_add(lmath, seg->prob,
00845                                         x->link->alpha + x->link->beta - itor->norm);
00846             }
00847         }
00848     }
00849     seg->word = dict_word_str(ps_search_dict(seg->search), node->wid);
00850     seg->sf = node->sf;
00851     seg->ascr = link->ascr;
00852     /* Compute language model score from best predecessors. */
00853     ps_lattice_compute_lscr(seg, link, to);
00854 }
00855 
00856 static void
00857 ps_lattice_seg_free(ps_seg_t *seg)
00858 {
00859     dag_seg_t *itor = (dag_seg_t *)seg;
00860     
00861     ckd_free(itor->links);
00862     ckd_free(itor);
00863 }
00864 
00865 static ps_seg_t *
00866 ps_lattice_seg_next(ps_seg_t *seg)
00867 {
00868     dag_seg_t *itor = (dag_seg_t *)seg;
00869 
00870     ++itor->cur;
00871     if (itor->cur == itor->n_links + 1) {
00872         ps_lattice_seg_free(seg);
00873         return NULL;
00874     }
00875     else if (itor->cur == itor->n_links) {
00876         /* Re-use the last link but with the "to" node. */
00877         ps_lattice_link2itor(seg, itor->links[itor->cur - 1], TRUE);
00878     }
00879     else {
00880         ps_lattice_link2itor(seg, itor->links[itor->cur], FALSE);
00881     }
00882 
00883     return seg;
00884 }
00885 
00886 static ps_segfuncs_t ps_lattice_segfuncs = {
00887     /* seg_next */ ps_lattice_seg_next,
00888     /* seg_free */ ps_lattice_seg_free
00889 };
00890 
00891 ps_seg_t *
00892 ps_lattice_seg_iter(ps_lattice_t *dag, ps_latlink_t *link, float32 lwf)
00893 {
00894     dag_seg_t *itor;
00895     ps_latlink_t *l;
00896     int cur;
00897 
00898     /* Calling this an "iterator" is a bit of a misnomer since we have
00899      * to get the entire backtrace in order to produce it.
00900      */
00901     itor = ckd_calloc(1, sizeof(*itor));
00902     itor->base.vt = &ps_lattice_segfuncs;
00903     itor->base.search = dag->search;
00904     itor->base.lwf = lwf;
00905     itor->n_links = 0;
00906     itor->norm = dag->norm;
00907 
00908     for (l = link; l; l = l->best_prev) {
00909         ++itor->n_links;
00910     }
00911     if (itor->n_links == 0) {
00912         ckd_free(itor);
00913         return NULL;
00914     }
00915 
00916     itor->links = ckd_calloc(itor->n_links, sizeof(*itor->links));
00917     cur = itor->n_links - 1;
00918     for (l = link; l; l = l->best_prev) {
00919         itor->links[cur] = l;
00920         --cur;
00921     }
00922 
00923     ps_lattice_link2itor((ps_seg_t *)itor, itor->links[0], FALSE);
00924     return (ps_seg_t *)itor;
00925 }
00926 
00927 latlink_list_t *
00928 latlink_list_new(ps_lattice_t *dag, ps_latlink_t *link, latlink_list_t *next)
00929 {
00930     latlink_list_t *ll;
00931 
00932     ll = listelem_malloc(dag->latlink_list_alloc);
00933     ll->link = link;
00934     ll->next = next;
00935 
00936     return ll;
00937 }
00938 
00939 void
00940 ps_lattice_pushq(ps_lattice_t *dag, ps_latlink_t *link)
00941 {
00942     if (dag->q_head == NULL)
00943         dag->q_head = dag->q_tail = latlink_list_new(dag, link, NULL);
00944     else {
00945         dag->q_tail->next = latlink_list_new(dag, link, NULL);
00946         dag->q_tail = dag->q_tail->next;
00947     }
00948 
00949 }
00950 
00951 ps_latlink_t *
00952 ps_lattice_popq(ps_lattice_t *dag)
00953 {
00954     latlink_list_t *x;
00955     ps_latlink_t *link;
00956 
00957     if (dag->q_head == NULL)
00958         return NULL;
00959     link = dag->q_head->link;
00960     x = dag->q_head->next;
00961     listelem_free(dag->latlink_list_alloc, dag->q_head);
00962     dag->q_head = x;
00963     if (dag->q_head == NULL)
00964         dag->q_tail = NULL;
00965     return link;
00966 }
00967 
00968 void
00969 ps_lattice_delq(ps_lattice_t *dag)
00970 {
00971     while (ps_lattice_popq(dag)) {
00972         /* Do nothing. */
00973     }
00974 }
00975 
00976 ps_latlink_t *
00977 ps_lattice_traverse_edges(ps_lattice_t *dag, ps_latnode_t *start, ps_latnode_t *end)
00978 {
00979     ps_latnode_t *node;
00980     latlink_list_t *x;
00981 
00982     /* Cancel any unfinished traversal. */
00983     ps_lattice_delq(dag);
00984 
00985     /* Initialize node fanin counts and path scores. */
00986     for (node = dag->nodes; node; node = node->next)
00987         node->info.fanin = 0;
00988     for (node = dag->nodes; node; node = node->next) {
00989         for (x = node->exits; x; x = x->next)
00990             (x->link->to->info.fanin)++;
00991     }
00992 
00993     /* Initialize agenda with all exits from start. */
00994     if (start == NULL) start = dag->start;
00995     for (x = start->exits; x; x = x->next)
00996         ps_lattice_pushq(dag, x->link);
00997 
00998     /* Pull the first edge off the queue. */
00999     return ps_lattice_traverse_next(dag, end);
01000 }
01001 
01002 ps_latlink_t *
01003 ps_lattice_traverse_next(ps_lattice_t *dag, ps_latnode_t *end)
01004 {
01005     ps_latlink_t *next;
01006 
01007     next = ps_lattice_popq(dag);
01008     if (next == NULL)
01009         return NULL;
01010 
01011     /* Decrease fanin count for destination node and expand outgoing
01012      * edges if all incoming edges have been seen. */
01013     --next->to->info.fanin;
01014     if (next->to->info.fanin == 0) {
01015         latlink_list_t *x;
01016 
01017         if (end == NULL) end = dag->end;
01018         if (next->to == end) {
01019             /* If we have traversed all links entering the end node,
01020              * clear the queue, causing future calls to this function
01021              * to return NULL. */
01022             ps_lattice_delq(dag);
01023             return next;
01024         }
01025 
01026         /* Extend all outgoing edges. */
01027         for (x = next->to->exits; x; x = x->next)
01028             ps_lattice_pushq(dag, x->link);
01029     }
01030     return next;
01031 }
01032 
01033 ps_latlink_t *
01034 ps_lattice_reverse_edges(ps_lattice_t *dag, ps_latnode_t *start, ps_latnode_t *end)
01035 {
01036     ps_latnode_t *node;
01037     latlink_list_t *x;
01038 
01039     /* Cancel any unfinished traversal. */
01040     ps_lattice_delq(dag);
01041 
01042     /* Initialize node fanout counts and path scores. */
01043     for (node = dag->nodes; node; node = node->next) {
01044         node->info.fanin = 0;
01045         for (x = node->exits; x; x = x->next)
01046             ++node->info.fanin;
01047     }
01048 
01049     /* Initialize agenda with all entries from end. */
01050     if (end == NULL) end = dag->end;
01051     for (x = end->entries; x; x = x->next)
01052         ps_lattice_pushq(dag, x->link);
01053 
01054     /* Pull the first edge off the queue. */
01055     return ps_lattice_reverse_next(dag, start);
01056 }
01057 
01058 ps_latlink_t *
01059 ps_lattice_reverse_next(ps_lattice_t *dag, ps_latnode_t *start)
01060 {
01061     ps_latlink_t *next;
01062 
01063     next = ps_lattice_popq(dag);
01064     if (next == NULL)
01065         return NULL;
01066 
01067     /* Decrease fanout count for source node and expand incoming
01068      * edges if all incoming edges have been seen. */
01069     --next->from->info.fanin;
01070     if (next->from->info.fanin == 0) {
01071         latlink_list_t *x;
01072 
01073         if (start == NULL) start = dag->start;
01074         if (next->from == start) {
01075             /* If we have traversed all links entering the start node,
01076              * clear the queue, causing future calls to this function
01077              * to return NULL. */
01078             ps_lattice_delq(dag);
01079             return next;
01080         }
01081 
01082         /* Extend all outgoing edges. */
01083         for (x = next->from->entries; x; x = x->next)
01084             ps_lattice_pushq(dag, x->link);
01085     }
01086     return next;
01087 }
01088 
01089 /*
01090  * Find the best score from dag->start to end point of any link and
01091  * use it to update links further down the path.  This is like
01092  * single-source shortest path search, except that it is done over
01093  * edges rather than nodes, which allows us to do exact trigram scoring.
01094  *
01095  * Helpfully enough, we get half of the posterior probability
01096  * calculation for free that way too.  (interesting research topic: is
01097  * there a reliable Viterbi analogue to word-level Forward-Backward
01098  * like there is for state-level?  Or, is it just lattice density?)
01099  */
01100 ps_latlink_t *
01101 ps_lattice_bestpath(ps_lattice_t *dag, ngram_model_t *lmset,
01102                     float32 lwf, float32 ascale)
01103 {
01104     ps_search_t *search;
01105     ps_latnode_t *node;
01106     ps_latlink_t *link;
01107     ps_latlink_t *bestend;
01108     latlink_list_t *x;
01109     logmath_t *lmath;
01110     int32 bestescr;
01111 
01112     search = dag->search;
01113     lmath = dag->lmath;
01114 
01115     /* Initialize path scores for all links exiting dag->start, and
01116      * set all other scores to the minimum.  Also initialize alphas to
01117      * log-zero. */
01118     for (node = dag->nodes; node; node = node->next) {
01119         for (x = node->exits; x; x = x->next) {
01120             x->link->path_scr = MAX_NEG_INT32;
01121             x->link->alpha = logmath_get_zero(lmath);
01122         }
01123     }
01124     for (x = dag->start->exits; x; x = x->next) {
01125         int32 n_used;
01126 
01127         /* Ignore filler words. */
01128         if (ISA_FILLER_WORD(search, x->link->to->basewid)
01129             && x->link->to != dag->end)
01130             continue;
01131 
01132         /* Best path points to dag->start, obviously. */
01133         if (lmset)
01134             x->link->path_scr = x->link->ascr +
01135                 ngram_bg_score(lmset, x->link->to->basewid,
01136                                ps_search_start_wid(search), &n_used) * lwf;
01137         else
01138             x->link->path_scr = x->link->ascr;
01139         x->link->best_prev = NULL;
01140         /* No predecessors for start links. */
01141         x->link->alpha = 0;
01142     }
01143 
01144     /* Traverse the edges in the graph, updating path scores. */
01145     for (link = ps_lattice_traverse_edges(dag, NULL, NULL);
01146          link; link = ps_lattice_traverse_next(dag, NULL)) {
01147         int32 bprob, n_used;
01148 
01149         /* Skip filler nodes in traversal. */
01150         if (ISA_FILLER_WORD(search, link->from->basewid) && link->from != dag->start)
01151             continue;
01152         if (ISA_FILLER_WORD(search, link->to->basewid) && link->to != dag->end)
01153             continue;
01154 
01155         /* Sanity check, we should not be traversing edges that
01156          * weren't previously updated, otherwise nasty overflows will result. */
01157         assert(link->path_scr != MAX_NEG_INT32);
01158 
01159         /* Calculate common bigram probability for all alphas. */
01160         if (lmset)
01161             bprob = ngram_ng_prob(lmset,
01162                                   link->to->basewid,
01163                                   &link->from->basewid, 1, &n_used);
01164         else
01165             bprob = 0;
01166         /* Add in this link's acoustic score, which was a constant
01167            factor in previous computations (if any). */
01168         link->alpha += link->ascr * ascale;
01169 
01170         /* Update scores for all paths exiting link->to. */
01171         for (x = link->to->exits; x; x = x->next) {
01172             int32 tscore, score;
01173 
01174             /* Skip links to filler words in update. */
01175             if (ISA_FILLER_WORD(search, x->link->to->basewid)
01176                 && x->link->to != dag->end)
01177                 continue;
01178 
01179             /* Update alpha with sum of previous alphas. */
01180             x->link->alpha = logmath_add(lmath, x->link->alpha, link->alpha + bprob);
01181             /* Calculate trigram score for bestpath. */
01182             if (lmset)
01183                 tscore = ngram_tg_score(lmset, x->link->to->basewid,
01184                                         link->to->basewid,
01185                                         link->from->basewid, &n_used) * lwf;
01186             else
01187                 tscore = 0;
01188             /* Update link score with maximum link score. */
01189             score = link->path_scr + tscore + x->link->ascr;
01190             if (score > x->link->path_scr) {
01191                 x->link->path_scr = score;
01192                 x->link->best_prev = link;
01193             }
01194         }
01195     }
01196 
01197     /* Find best link entering final node, and calculate normalizer
01198      * for posterior probabilities. */
01199     bestend = NULL;
01200     bestescr = MAX_NEG_INT32;
01201 
01202     /* Normalizer is the alpha for the imaginary link exiting the
01203        final node. */
01204     dag->norm = logmath_get_zero(lmath);
01205     for (x = dag->end->entries; x; x = x->next) {
01206         int32 bprob, n_used;
01207 
01208         if (ISA_FILLER_WORD(search, x->link->from->basewid))
01209             continue;
01210         if (lmset)
01211             bprob = ngram_ng_prob(lmset,
01212                                   x->link->to->basewid,
01213                                   &x->link->from->basewid, 1, &n_used);
01214         else
01215             bprob = 0;
01216         dag->norm = logmath_add(lmath, dag->norm, x->link->alpha + bprob);
01217         if (x->link->path_scr > bestescr) {
01218             bestescr = x->link->path_scr;
01219             bestend = x->link;
01220         }
01221     }
01222     /* FIXME: floating point... */
01223     dag->norm += (int32)dag->final_node_ascr * ascale;
01224 
01225     E_INFO("Normalizer P(O) = alpha(%s:%d:%d) = %d\n",
01226            dict_word_str(dag->search->dict, dag->end->wid),
01227            dag->end->sf, dag->end->lef,
01228            dag->norm);
01229     return bestend;
01230 }
01231 
01232 static int32
01233 ps_lattice_joint(ps_lattice_t *dag, ps_latlink_t *link, float32 ascale)
01234 {
01235     ngram_model_t *lmset;
01236     int32 jprob;
01237 
01238     /* Sort of a hack... */
01239     if (dag->search && 0 == strcmp(ps_search_name(dag->search), "ngram"))
01240         lmset = ((ngram_search_t *)dag->search)->lmset;
01241     else
01242         lmset = NULL;
01243 
01244     jprob = dag->final_node_ascr * ascale;
01245     while (link) {
01246         if (lmset) {
01247             int lback;
01248             /* Compute unscaled language model probability.  Note that
01249                this is actually not the language model probability
01250                that corresponds to this link, but that is okay,
01251                because we are just taking the sum over all links in
01252                the best path. */
01253             jprob += ngram_ng_prob(lmset, link->to->basewid,
01254                                    &link->from->basewid, 1, &lback);
01255         }
01256         /* If there is no language model, we assume that the language
01257            model probability (such as it is) has been included in the
01258            link score. */
01259         jprob += link->ascr * ascale;
01260         link = link->best_prev;
01261     }
01262 
01263     E_INFO("Joint P(O,S) = %d P(S|O) = %d\n", jprob, jprob - dag->norm);
01264     return jprob;
01265 }
01266 
01267 int32
01268 ps_lattice_posterior(ps_lattice_t *dag, ngram_model_t *lmset,
01269                      float32 ascale)
01270 {
01271     ps_search_t *search;
01272     logmath_t *lmath;
01273     ps_latnode_t *node;
01274     ps_latlink_t *link;
01275     latlink_list_t *x;
01276     ps_latlink_t *bestend;
01277     int32 bestescr;
01278 
01279     search = dag->search;
01280     lmath = dag->lmath;
01281 
01282     /* Reset all betas to zero. */
01283     for (node = dag->nodes; node; node = node->next) {
01284         for (x = node->exits; x; x = x->next) {
01285             x->link->beta = logmath_get_zero(lmath);
01286         }
01287     }
01288 
01289     bestend = NULL;
01290     bestescr = MAX_NEG_INT32;
01291     /* Accumulate backward probabilities for all links. */
01292     for (link = ps_lattice_reverse_edges(dag, NULL, NULL);
01293          link; link = ps_lattice_reverse_next(dag, NULL)) {
01294         int32 bprob, n_used;
01295 
01296         /* Skip filler nodes in traversal. */
01297         if (ISA_FILLER_WORD(search, link->from->basewid) && link->from != dag->start)
01298             continue;
01299         if (ISA_FILLER_WORD(search, link->to->basewid) && link->to != dag->end)
01300             continue;
01301 
01302         /* Calculate LM probability. */
01303         if (lmset)
01304             bprob = ngram_ng_prob(lmset, link->to->basewid,
01305                                   &link->from->basewid, 1, &n_used);
01306         else
01307             bprob = 0;
01308 
01309         if (link->to == dag->end) {
01310             /* Track the best path - we will backtrace in order to
01311                calculate the unscaled joint probability for sentence
01312                posterior. */
01313             if (link->path_scr > bestescr) {
01314                 bestescr = link->path_scr;
01315                 bestend = link;
01316             }
01317             /* Imaginary exit link from final node has beta = 1.0 */
01318             link->beta = bprob + dag->final_node_ascr * ascale;
01319         }
01320         else {
01321             /* Update beta from all outgoing betas. */
01322             for (x = link->to->exits; x; x = x->next) {
01323                 if (ISA_FILLER_WORD(search, x->link->to->basewid) && x->link->to != dag->end)
01324                     continue;
01325                 link->beta = logmath_add(lmath, link->beta,
01326                                          x->link->beta + bprob + x->link->ascr * ascale);
01327             }
01328         }
01329     }
01330 
01331     /* Return P(S|O) = P(O,S)/P(O) */
01332     return ps_lattice_joint(dag, bestend, ascale) - dag->norm;
01333 }
01334 
01335 
01336 /* Parameters to prune n-best alternatives search */
01337 #define MAX_PATHS       500     /* Max allowed active paths at any time */
01338 #define MAX_HYP_TRIES   10000
01339 
01340 /*
01341  * For each node in any path between from and end of utt, find the
01342  * best score from "from".sf to end of utt.  (NOTE: Uses bigram probs;
01343  * this is an estimate of the best score from "from".)  (NOTE #2: yes,
01344  * this is the "heuristic score" used in A* search)
01345  */
01346 static int32
01347 best_rem_score(ps_astar_t *nbest, ps_latnode_t * from)
01348 {
01349     ps_lattice_t *dag;
01350     latlink_list_t *x;
01351     int32 bestscore, score;
01352 
01353     dag = nbest->dag;
01354     if (from->info.rem_score <= 0)
01355         return (from->info.rem_score);
01356 
01357     /* Best score from "from" to end of utt not known; compute from successors */
01358     bestscore = WORST_SCORE;
01359     for (x = from->exits; x; x = x->next) {
01360         int32 n_used;
01361 
01362         score = best_rem_score(nbest, x->link->to);
01363         score += x->link->ascr;
01364         if (nbest->lmset)
01365             score += ngram_bg_score(nbest->lmset, x->link->to->basewid,
01366                                     from->basewid, &n_used) * nbest->lwf;
01367         if (score > bestscore)
01368             bestscore = score;
01369     }
01370     from->info.rem_score = bestscore;
01371 
01372     return bestscore;
01373 }
01374 
01375 /*
01376  * Insert newpath in sorted (by path score) list of paths.  But if newpath is
01377  * too far down the list, drop it (FIXME: necessary?)
01378  * total_score = path score (newpath) + rem_score to end of utt.
01379  */
01380 static void
01381 path_insert(ps_astar_t *nbest, ps_latpath_t *newpath, int32 total_score)
01382 {
01383     ps_lattice_t *dag;
01384     ps_latpath_t *prev, *p;
01385     int32 i;
01386 
01387     dag = nbest->dag;
01388     prev = NULL;
01389     for (i = 0, p = nbest->path_list; (i < MAX_PATHS) && p; p = p->next, i++) {
01390         if ((p->score + p->node->info.rem_score) < total_score)
01391             break;
01392         prev = p;
01393     }
01394 
01395     /* newpath should be inserted between prev and p */
01396     if (i < MAX_PATHS) {
01397         /* Insert new partial hyp */
01398         newpath->next = p;
01399         if (!prev)
01400             nbest->path_list = newpath;
01401         else
01402             prev->next = newpath;
01403         if (!p)
01404             nbest->path_tail = newpath;
01405 
01406         nbest->n_path++;
01407         nbest->n_hyp_insert++;
01408         nbest->insert_depth += i;
01409     }
01410     else {
01411         /* newpath score too low; reject it and also prune paths beyond MAX_PATHS */
01412         nbest->path_tail = prev;
01413         prev->next = NULL;
01414         nbest->n_path = MAX_PATHS;
01415         listelem_free(nbest->latpath_alloc, newpath);
01416 
01417         nbest->n_hyp_reject++;
01418         for (; p; p = newpath) {
01419             newpath = p->next;
01420             listelem_free(nbest->latpath_alloc, p);
01421             nbest->n_hyp_reject++;
01422         }
01423     }
01424 }
01425 
01426 /* Find all possible extensions to given partial path */
01427 static void
01428 path_extend(ps_astar_t *nbest, ps_latpath_t * path)
01429 {
01430     latlink_list_t *x;
01431     ps_latpath_t *newpath;
01432     int32 total_score, tail_score;
01433     ps_lattice_t *dag;
01434 
01435     dag = nbest->dag;
01436 
01437     /* Consider all successors of path->node */
01438     for (x = path->node->exits; x; x = x->next) {
01439         int32 n_used;
01440 
01441         /* Skip successor if no path from it reaches the final node */
01442         if (x->link->to->info.rem_score <= WORST_SCORE)
01443             continue;
01444 
01445         /* Create path extension and compute exact score for this extension */
01446         newpath = listelem_malloc(nbest->latpath_alloc);
01447         newpath->node = x->link->to;
01448         newpath->parent = path;
01449         newpath->score = path->score + x->link->ascr;
01450         if (nbest->lmset) {
01451             if (path->parent) {
01452                 newpath->score += nbest->lwf
01453                     * ngram_tg_score(nbest->lmset, newpath->node->basewid,
01454                                      path->node->basewid,
01455                                      path->parent->node->basewid, &n_used);
01456             }
01457             else 
01458                 newpath->score += nbest->lwf
01459                     * ngram_bg_score(nbest->lmset, newpath->node->basewid,
01460                                      path->node->basewid, &n_used);
01461         }
01462 
01463         /* Insert new partial path hypothesis into sorted path_list */
01464         nbest->n_hyp_tried++;
01465         total_score = newpath->score + newpath->node->info.rem_score;
01466 
01467         /* First see if hyp would be worse than the worst */
01468         if (nbest->n_path >= MAX_PATHS) {
01469             tail_score =
01470                 nbest->path_tail->score
01471                 + nbest->path_tail->node->info.rem_score;
01472             if (total_score < tail_score) {
01473                 listelem_free(nbest->latpath_alloc, newpath);
01474                 nbest->n_hyp_reject++;
01475                 continue;
01476             }
01477         }
01478 
01479         path_insert(nbest, newpath, total_score);
01480     }
01481 }
01482 
01483 ps_astar_t *
01484 ps_astar_start(ps_lattice_t *dag,
01485                   ngram_model_t *lmset,
01486                   float32 lwf,
01487                   int sf, int ef,
01488                   int w1, int w2)
01489 {
01490     ps_astar_t *nbest;
01491     ps_latnode_t *node;
01492 
01493     nbest = ckd_calloc(1, sizeof(*nbest));
01494     nbest->dag = dag;
01495     nbest->lmset = lmset;
01496     nbest->lwf = lwf;
01497     nbest->sf = sf;
01498     if (ef < 0)
01499         nbest->ef = dag->n_frames - ef;
01500     else
01501         nbest->ef = ef;
01502     nbest->w1 = w1;
01503     nbest->w2 = w2;
01504     nbest->latpath_alloc = listelem_alloc_init(sizeof(ps_latpath_t));
01505 
01506     /* Initialize rem_score (A* heuristic) to default values */
01507     for (node = dag->nodes; node; node = node->next) {
01508         if (node == dag->end)
01509             node->info.rem_score = 0;
01510         else if (node->exits == NULL)
01511             node->info.rem_score = WORST_SCORE;
01512         else
01513             node->info.rem_score = 1;   /* +ve => unknown value */
01514     }
01515 
01516     /* Create initial partial hypotheses list consisting of nodes starting at sf */
01517     nbest->path_list = nbest->path_tail = NULL;
01518     for (node = dag->nodes; node; node = node->next) {
01519         if (node->sf == sf) {
01520             ps_latpath_t *path;
01521             int32 n_used;
01522 
01523             best_rem_score(nbest, node);
01524             path = listelem_malloc(nbest->latpath_alloc);
01525             path->node = node;
01526             path->parent = NULL;
01527             if (nbest->lmset)
01528                 path->score = nbest->lwf *
01529                     (w1 < 0)
01530                     ? ngram_bg_score(nbest->lmset, node->basewid, w2, &n_used)
01531                     : ngram_tg_score(nbest->lmset, node->basewid, w2, w1, &n_used);
01532             else
01533                 path->score = 0;
01534             path_insert(nbest, path, path->score + node->info.rem_score);
01535         }
01536     }
01537 
01538     return nbest;
01539 }
01540 
01541 ps_latpath_t *
01542 ps_astar_next(ps_astar_t *nbest)
01543 {
01544     ps_latpath_t *top;
01545     ps_lattice_t *dag;
01546 
01547     dag = nbest->dag;
01548 
01549     /* Pop the top (best) partial hypothesis */
01550     while ((top = nbest->path_list) != NULL) {
01551         nbest->path_list = nbest->path_list->next;
01552         if (top == nbest->path_tail)
01553             nbest->path_tail = NULL;
01554         nbest->n_path--;
01555 
01556         /* Complete hypothesis? */
01557         if ((top->node->sf >= nbest->ef)
01558             || ((top->node == dag->end) &&
01559                 (nbest->ef > dag->end->sf))) {
01560             /* FIXME: Verify that it is non-empty. */
01561             return top;
01562         }
01563         else {
01564             if (top->node->fef < nbest->ef)
01565                 path_extend(nbest, top);
01566         }
01567 
01568         /*
01569          * Add top to paths already processed; cannot be freed because other paths
01570          * point to it.
01571          */
01572         top->next = nbest->paths_done;
01573         nbest->paths_done = top;
01574     }
01575 
01576     /* Did not find any more paths to extend. */
01577     return NULL;
01578 }
01579 
01580 char const *
01581 ps_astar_hyp(ps_astar_t *nbest, ps_latpath_t *path)
01582 {
01583     ps_search_t *search;
01584     ps_latpath_t *p;
01585     size_t len;
01586     char *c;
01587     char *hyp;
01588 
01589     search = nbest->dag->search;
01590 
01591     /* Backtrace once to get hypothesis length. */
01592     len = 0;
01593     for (p = path; p; p = p->parent) {
01594         if (ISA_REAL_WORD(search, p->node->basewid))
01595             len += strlen(dict_word_str(ps_search_dict(search), p->node->basewid)) + 1;
01596     }
01597 
01598     /* Backtrace again to construct hypothesis string. */
01599     hyp = ckd_calloc(1, len);
01600     c = hyp + len - 1;
01601     for (p = path; p; p = p->parent) {
01602         if (ISA_REAL_WORD(search, p->node->basewid)) {
01603             len = strlen(dict_word_str(ps_search_dict(search), p->node->basewid));
01604             c -= len;
01605             memcpy(c, dict_word_str(ps_search_dict(search), p->node->basewid), len);
01606             if (c > hyp) {
01607                 --c;
01608                 *c = ' ';
01609             }
01610         }
01611     }
01612 
01613     nbest->hyps = glist_add_ptr(nbest->hyps, hyp);
01614     return hyp;
01615 }
01616 
01617 static void
01618 ps_astar_node2itor(astar_seg_t *itor)
01619 {
01620     ps_seg_t *seg = (ps_seg_t *)itor;
01621     ps_latnode_t *node;
01622 
01623     assert(itor->cur < itor->n_nodes);
01624     node = itor->nodes[itor->cur];
01625     if (itor->cur == itor->n_nodes - 1)
01626         seg->ef = node->lef;
01627     else
01628         seg->ef = itor->nodes[itor->cur + 1]->sf - 1;
01629     seg->word = dict_word_str(ps_search_dict(seg->search), node->wid);
01630     seg->sf = node->sf;
01631     seg->prob = 0; /* FIXME: implement forward-backward */
01632 }
01633 
01634 static void
01635 ps_astar_seg_free(ps_seg_t *seg)
01636 {
01637     astar_seg_t *itor = (astar_seg_t *)seg;
01638     ckd_free(itor->nodes);
01639     ckd_free(itor);
01640 }
01641 
01642 static ps_seg_t *
01643 ps_astar_seg_next(ps_seg_t *seg)
01644 {
01645     astar_seg_t *itor = (astar_seg_t *)seg;
01646 
01647     ++itor->cur;
01648     if (itor->cur == itor->n_nodes) {
01649         ps_astar_seg_free(seg);
01650         return NULL;
01651     }
01652     else {
01653         ps_astar_node2itor(itor);
01654     }
01655 
01656     return seg;
01657 }
01658 
01659 static ps_segfuncs_t ps_astar_segfuncs = {
01660     /* seg_next */ ps_astar_seg_next,
01661     /* seg_free */ ps_astar_seg_free
01662 };
01663 
01664 ps_seg_t *
01665 ps_astar_seg_iter(ps_astar_t *astar, ps_latpath_t *path, float32 lwf)
01666 {
01667     astar_seg_t *itor;
01668     ps_latpath_t *p;
01669     int cur;
01670 
01671     /* Backtrace and make an iterator, this should look familiar by now. */
01672     itor = ckd_calloc(1, sizeof(*itor));
01673     itor->base.vt = &ps_astar_segfuncs;
01674     itor->base.search = astar->dag->search;
01675     itor->base.lwf = lwf;
01676     itor->n_nodes = itor->cur = 0;
01677     for (p = path; p; p = p->parent) {
01678         ++itor->n_nodes;
01679     }
01680     itor->nodes = ckd_calloc(itor->n_nodes, sizeof(*itor->nodes));
01681     cur = itor->n_nodes - 1;
01682     for (p = path; p; p = p->parent) {
01683         itor->nodes[cur] = p->node;
01684         --cur;
01685     }
01686 
01687     ps_astar_node2itor(itor);
01688     return (ps_seg_t *)itor;
01689 }
01690 
01691 void
01692 ps_astar_finish(ps_astar_t *nbest)
01693 {
01694     gnode_t *gn;
01695 
01696     /* Free all hyps. */
01697     for (gn = nbest->hyps; gn; gn = gnode_next(gn)) {
01698         ckd_free(gnode_ptr(gn));
01699     }
01700     glist_free(nbest->hyps);
01701     /* Free all paths. */
01702     listelem_alloc_free(nbest->latpath_alloc);
01703     /* Free the Henge. */
01704     ckd_free(nbest);
01705 }

Generated on Thu Jan 27 2011 for PocketSphinx by  doxygen 1.7.1