PocketSphinx  0.6
src/libpocketsphinx/ps_lattice.c
Go to the documentation of this file.
00001 /* -*- c-basic-offset: 4; indent-tabs-mode: nil -*- */
00002 /* ====================================================================
00003  * Copyright (c) 2008 Carnegie Mellon University.  All rights
00004  * reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions
00008  * are met:
00009  *
00010  * 1. Redistributions of source code must retain the above copyright
00011  *    notice, this list of conditions and the following disclaimer. 
00012  *
00013  * 2. Redistributions in binary form must reproduce the above copyright
00014  *    notice, this list of conditions and the following disclaimer in
00015  *    the documentation and/or other materials provided with the
00016  *    distribution.
00017  *
00018  * This work was supported in part by funding from the Defense Advanced 
00019  * Research Projects Agency and the National Science Foundation of the 
00020  * United States of America, and the CMU Sphinx Speech Consortium.
00021  *
00022  * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND 
00023  * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 
00024  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00025  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY
00026  * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
00027  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
00028  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 
00029  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 
00030  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
00031  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
00032  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00033  *
00034  * ====================================================================
00035  *
00036  */
00037 
00042 /* System headers. */
00043 #include <assert.h>
00044 #include <string.h>
00045 #include <math.h>
00046 
00047 /* SphinxBase headers. */
00048 #include <sphinxbase/ckd_alloc.h>
00049 #include <sphinxbase/listelem_alloc.h>
00050 #include <sphinxbase/strfuncs.h>
00051 #include <sphinxbase/err.h>
00052 #include <sphinxbase/pio.h>
00053 
00054 /* Local headers. */
00055 #include "pocketsphinx_internal.h"
00056 #include "ps_lattice_internal.h"
00057 #include "ngram_search.h"
00058 #include "dict.h"
00059 
00060 /*
00061  * Create a directed link between "from" and "to" nodes, but if a link already exists,
00062  * choose one with the best ascr.
00063  */
00064 void
00065 ps_lattice_link(ps_lattice_t *dag, ps_latnode_t *from, ps_latnode_t *to,
00066                 int32 score, int32 ef)
00067 {
00068     latlink_list_t *fwdlink;
00069 
00070     /* Look for an existing link between "from" and "to" nodes */
00071     for (fwdlink = from->exits; fwdlink; fwdlink = fwdlink->next)
00072         if (fwdlink->link->to == to)
00073             break;
00074 
00075     if (fwdlink == NULL) {
00076         latlink_list_t *revlink;
00077         ps_latlink_t *link;
00078 
00079         /* No link between the two nodes; create a new one */
00080         link = listelem_malloc(dag->latlink_alloc);
00081         fwdlink = listelem_malloc(dag->latlink_list_alloc);
00082         revlink = listelem_malloc(dag->latlink_list_alloc);
00083 
00084         link->from = from;
00085         link->to = to;
00086         link->ascr = score;
00087         link->ef = ef;
00088         link->best_prev = NULL;
00089 
00090         fwdlink->link = revlink->link = link;
00091         fwdlink->next = from->exits;
00092         from->exits = fwdlink;
00093         revlink->next = to->entries;
00094         to->entries = revlink;
00095     }
00096     else {
00097         /* Link already exists; just retain the best ascr */
00098         if (score BETTER_THAN fwdlink->link->ascr) {
00099             fwdlink->link->ascr = score;
00100             fwdlink->link->ef = ef;
00101         }
00102     }           
00103 }
00104 
00105 void
00106 ps_lattice_bypass_fillers(ps_lattice_t *dag, int32 silpen, int32 fillpen)
00107 {
00108     ps_latnode_t *node;
00109     int32 score;
00110 
00111     /* Bypass filler nodes */
00112     for (node = dag->nodes; node; node = node->next) {
00113         latlink_list_t *revlink;
00114         if (node == dag->end || !dict_filler_word(dag->dict, node->basewid))
00115             continue;
00116 
00117         /* Replace each link entering filler node with links to all its successors */
00118         for (revlink = node->entries; revlink; revlink = revlink->next) {
00119             latlink_list_t *forlink;
00120             ps_latlink_t *rlink = revlink->link;
00121 
00122             score = (node->basewid == dag->silence) ? silpen : fillpen;
00123             score += rlink->ascr;
00124             /*
00125              * Make links from predecessor of filler (from) to successors of filler.
00126              * But if successor is a filler, it has already been eliminated since it
00127              * appears earlier in latnode_list (see build...).  So it can be skipped.
00128              */
00129             for (forlink = node->exits; forlink; forlink = forlink->next) {
00130                 ps_latlink_t *flink = forlink->link;
00131                 if (flink->to && rlink->from &&
00132                     !dict_filler_word(dag->dict, flink->to->basewid)) {
00133                     ps_lattice_link(dag, rlink->from, flink->to,
00134                                     score + flink->ascr, flink->ef);
00135                 }
00136             }
00137         }
00138         node->reachable = FALSE;
00139     }
00140 }
00141 
00142 static void
00143 delete_node(ps_lattice_t *dag, ps_latnode_t *node)
00144 {
00145     latlink_list_t *x, *next_x;
00146 
00147     for (x = node->exits; x; x = next_x) {
00148         next_x = x->next;
00149         x->link->from = NULL;
00150         listelem_free(dag->latlink_list_alloc, x);
00151     }
00152     for (x = node->entries; x; x = next_x) {
00153         next_x = x->next;
00154         x->link->to = NULL;
00155         listelem_free(dag->latlink_list_alloc, x);
00156     }
00157     listelem_free(dag->latnode_alloc, node);
00158 }
00159 
00160 
00161 static void
00162 remove_dangling_links(ps_lattice_t *dag, ps_latnode_t *node)
00163 {
00164     latlink_list_t *x, *prev_x, *next_x;
00165 
00166     prev_x = NULL;
00167     for (x = node->exits; x; x = next_x) {
00168         next_x = x->next;
00169         if (x->link->to == NULL) {
00170             if (prev_x)
00171                 prev_x->next = next_x;
00172             else
00173                 node->exits = next_x;
00174             listelem_free(dag->latlink_alloc, x->link);
00175             listelem_free(dag->latlink_list_alloc, x);
00176         }
00177         else
00178             prev_x = x;
00179     }
00180     prev_x = NULL;
00181     for (x = node->entries; x; x = next_x) {
00182         next_x = x->next;
00183         if (x->link->from == NULL) {
00184             if (prev_x)
00185                 prev_x->next = next_x;
00186             else
00187                 node->entries = next_x;
00188             listelem_free(dag->latlink_alloc, x->link);
00189             listelem_free(dag->latlink_list_alloc, x);
00190         }
00191         else
00192             prev_x = x;
00193     }
00194 }
00195 
00196 void
00197 ps_lattice_delete_unreachable(ps_lattice_t *dag)
00198 {
00199     ps_latnode_t *node, *prev_node, *next_node;
00200     int i;
00201 
00202     /* Remove unreachable nodes from the list of nodes. */
00203     prev_node = NULL;
00204     for (node = dag->nodes; node; node = next_node) {
00205         next_node = node->next;
00206         if (!node->reachable) {
00207             if (prev_node)
00208                 prev_node->next = next_node;
00209             else
00210                 dag->nodes = next_node;
00211             /* Delete this node and NULLify links to it. */
00212             delete_node(dag, node);
00213         }
00214         else
00215             prev_node = node;
00216     }
00217 
00218     /* Remove all links to and from unreachable nodes. */
00219     i = 0;
00220     for (node = dag->nodes; node; node = node->next) {
00221         /* Assign sequence numbers. */
00222         node->id = i++;
00223 
00224         /* We should obviously not encounter unreachable nodes here! */
00225         assert(node->reachable);
00226 
00227         /* Remove all links that go nowhere. */
00228         remove_dangling_links(dag, node);
00229     }
00230 }
00231 
00232 int32
00233 ps_lattice_write(ps_lattice_t *dag, char const *filename)
00234 {
00235     FILE *fp;
00236     int32 i;
00237     ps_latnode_t *d, *initial, *final;
00238 
00239     initial = dag->start;
00240     final = dag->end;
00241 
00242     E_INFO("Writing lattice file: %s\n", filename);
00243     if ((fp = fopen(filename, "w")) == NULL) {
00244         E_ERROR("Failed to open lattice file '%s' for writing: %s\n", filename, strerror(errno));
00245         return -1;
00246     }
00247 
00248     /* Stupid Sphinx-III lattice code expects 'getcwd:' here */
00249     fprintf(fp, "# getcwd: /this/is/bogus\n");
00250     fprintf(fp, "# -logbase %e\n", logmath_get_base(dag->lmath));
00251     fprintf(fp, "#\n");
00252 
00253     fprintf(fp, "Frames %d\n", dag->n_frames);
00254     fprintf(fp, "#\n");
00255 
00256     for (i = 0, d = dag->nodes; d; d = d->next, i++);
00257     fprintf(fp,
00258             "Nodes %d (NODEID WORD STARTFRAME FIRST-ENDFRAME LAST-ENDFRAME)\n",
00259             i);
00260     for (i = 0, d = dag->nodes; d; d = d->next, i++) {
00261         d->id = i;
00262         fprintf(fp, "%d %s %d %d %d\n",
00263                 i, dict_wordstr(dag->dict, d->wid),
00264                 d->sf, d->fef, d->lef);
00265     }
00266     fprintf(fp, "#\n");
00267 
00268     fprintf(fp, "Initial %d\nFinal %d\n", initial->id, final->id);
00269     fprintf(fp, "#\n");
00270 
00271     /* Don't bother with this, it's not used by anything. */
00272     fprintf(fp, "BestSegAscr %d (NODEID ENDFRAME ASCORE)\n",
00273             0 /* #BPTable entries */ );
00274     fprintf(fp, "#\n");
00275 
00276     fprintf(fp, "Edges (FROM-NODEID TO-NODEID ASCORE)\n");
00277     for (d = dag->nodes; d; d = d->next) {
00278         latlink_list_t *l;
00279         for (l = d->exits; l; l = l->next) {
00280             if (l->link->ascr WORSE_THAN WORST_SCORE || l->link->ascr BETTER_THAN 0)
00281                 continue;
00282             fprintf(fp, "%d %d %d\n",
00283                     d->id, l->link->to->id, l->link->ascr << SENSCR_SHIFT);
00284         }
00285     }
00286     fprintf(fp, "End\n");
00287     fclose(fp);
00288 
00289     return 0;
00290 }
00291 
00292 int32
00293 ps_lattice_write_htk(ps_lattice_t *dag, char const *filename)
00294 {
00295     FILE *fp;
00296     ps_latnode_t *d, *initial, *final;
00297     int32 i, j, n_links, n_nodes;
00298 
00299     initial = dag->start;
00300     final = dag->end;
00301 
00302     E_INFO("Writing lattice file: %s\n", filename);
00303     if ((fp = fopen(filename, "w")) == NULL) {
00304         E_ERROR("Failed to open lattice file '%s' for writing: %s\n", filename, strerror(errno));
00305         return -1;
00306     }
00307 
00308     for (n_links = n_nodes = 0, d = dag->nodes; d; d = d->next) {
00309         latlink_list_t *l;
00310         if (!d->reachable)
00311             continue;
00312         d->id = n_nodes;
00313         for (l = d->exits; l; l = l->next) {
00314             if (l->link->to == NULL || !l->link->to->reachable)
00315                 continue;
00316             if (l->link->ascr WORSE_THAN WORST_SCORE || l->link->ascr BETTER_THAN 0)
00317                 continue;
00318             ++n_links;
00319         }
00320         ++n_nodes;
00321     }
00322 
00323     fprintf(fp, "# Lattice generated by PocketSphinx\n");
00324     fprintf(fp, "#\n# Header\n#\n");
00325     fprintf(fp, "VERSION=1.0\n");
00326     fprintf(fp, "start=%d\n", initial->id);
00327     fprintf(fp, "end=%d\n", final->id);
00328     fprintf(fp, "#\n");
00329 
00330     fprintf(fp, "N=%d\tL=%d\n", n_nodes, n_links);
00331     fprintf(fp, "#\n# Node definitions\n#\n");
00332     for (i = 0, d = dag->nodes; d; d = d->next) {
00333         char const *word = dict_wordstr(dag->dict, d->wid);
00334         char const *c = strrchr(word, '(');
00335         int altpron = 1;
00336         if (!d->reachable)
00337             continue;
00338         if (c)
00339             altpron = atoi(c + 1);
00340         word = dict_basestr(dag->dict, d->wid);
00341         if (d->wid == dict_startwid(dag->dict))
00342             word = "!SENT_START";
00343         else if (d->wid == dict_finishwid(dag->dict))
00344             word = "!SENT_END";
00345         else if (dict_filler_word(dag->dict, d->wid))
00346             word = "!NULL";
00347         fprintf(fp, "I=%d\tt=%.2f\tW=%s\tv=%d\n",
00348                 d->id, (double)d->sf / dag->frate,
00349                 word, altpron);
00350     }
00351     fprintf(fp, "#\n# Link definitions\n#\n");
00352     for (j = 0, d = dag->nodes; d; d = d->next) {
00353         latlink_list_t *l;
00354         if (!d->reachable)
00355             continue;
00356         for (l = d->exits; l; l = l->next) {
00357             if (l->link->to == NULL || !l->link->to->reachable)
00358                 continue;
00359             if (l->link->ascr WORSE_THAN WORST_SCORE || l->link->ascr BETTER_THAN 0)
00360                 continue;
00361             fprintf(fp, "J=%d\tS=%d\tE=%d\ta=%f\tp=%g\n", j++,
00362                     d->id, l->link->to->id,
00363                     logmath_log_to_ln(dag->lmath, l->link->ascr << SENSCR_SHIFT),
00364                     logmath_exp(dag->lmath, l->link->alpha + l->link->beta - dag->norm));
00365         }
00366     }
00367     fclose(fp);
00368 
00369     return 0;
00370 }
00371 
00372 /* Read parameter from a lattice file*/
00373 static int
00374 dag_param_read(lineiter_t *li, char *param)
00375 {
00376     int32 n;
00377 
00378     while ((li = lineiter_next(li)) != NULL) {
00379         char *c;
00380 
00381         /* Ignore comments. */
00382         if (li->buf[0] == '#')
00383             continue;
00384 
00385         /* Find the first space. */
00386         c = strchr(li->buf, ' ');
00387         if (c == NULL) continue;
00388 
00389         /* Check that the first field equals param and that there's a number after it. */
00390         if (strncmp(li->buf, param, strlen(param)) == 0
00391             && sscanf(c + 1, "%d", &n) == 1)
00392             return n;
00393     }
00394     return -1;
00395 }
00396 
00397 /* Mark every node that has a path to the argument dagnode as "reachable". */
00398 static void
00399 dag_mark_reachable(ps_latnode_t * d)
00400 {
00401     latlink_list_t *l;
00402 
00403     d->reachable = 1;
00404     for (l = d->entries; l; l = l->next)
00405         if (l->link->from && !l->link->from->reachable)
00406             dag_mark_reachable(l->link->from);
00407 }
00408 
00409 ps_lattice_t *
00410 ps_lattice_read(ps_decoder_t *ps,
00411                 char const *file)
00412 {
00413     FILE *fp;
00414     int32 ispipe;
00415     lineiter_t *line;
00416     float64 lb;
00417     float32 logratio;
00418     ps_latnode_t *tail;
00419     ps_latnode_t **darray;
00420     ps_lattice_t *dag;
00421     int i, k, n_nodes;
00422     int32 pip, silpen, fillpen;
00423 
00424     dag = ckd_calloc(1, sizeof(*dag));
00425 
00426     if (ps) {
00427         dag->search = ps->search;
00428         dag->dict = dict_retain(ps->dict);
00429         dag->lmath = logmath_retain(ps->lmath);
00430         dag->frate = cmd_ln_int32_r(dag->search->config, "-frate");
00431     }
00432     else {
00433         dag->dict = dict_init(NULL, NULL);
00434         dag->lmath = logmath_init(1.0001, 0, FALSE);
00435         dag->frate = 100;
00436     }
00437     dag->silence = dict_silwid(dag->dict);
00438     dag->latnode_alloc = listelem_alloc_init(sizeof(ps_latnode_t));
00439     dag->latlink_alloc = listelem_alloc_init(sizeof(ps_latlink_t));
00440     dag->latlink_list_alloc = listelem_alloc_init(sizeof(latlink_list_t));
00441     dag->refcount = 1;
00442 
00443     tail = NULL;
00444     darray = NULL;
00445 
00446     E_INFO("Reading DAG file: %s\n", file);
00447     if ((fp = fopen_compchk(file, &ispipe)) == NULL) {
00448         E_ERROR("Failed to open DAG file '%s': %s\n", file, strerror(errno));
00449         return NULL;
00450     }
00451     line = lineiter_start(fp);
00452 
00453     /* Read and verify logbase (ONE BIG HACK!!) */
00454     if (line == NULL) {
00455         E_ERROR("Premature EOF(%s)\n", file);
00456         goto load_error;
00457     }
00458     if (strncmp(line->buf, "# getcwd: ", 10) != 0) {
00459         E_ERROR("%s does not begin with '# getcwd: '\n%s", file, line->buf);
00460         goto load_error;
00461     }
00462     if ((line = lineiter_next(line)) == NULL) {
00463         E_ERROR("Premature EOF(%s)\n", file);
00464         goto load_error;
00465     }
00466     if ((strncmp(line->buf, "# -logbase ", 11) != 0)
00467         || (sscanf(line->buf + 11, "%lf", &lb) != 1)) {
00468         E_WARN("%s: Cannot find -logbase in header\n", file);
00469         lb = 1.0001;
00470     }
00471     logratio = 1.0f;
00472     if (dag->lmath == NULL)
00473         dag->lmath = logmath_init(lb, 0, TRUE);
00474     else {
00475         float32 pb = logmath_get_base(dag->lmath);
00476         if (fabs(lb - pb) >= 0.0001) {
00477             E_WARN("Inconsistent logbases: %f vs %f: will compensate\n", lb, pb);
00478             logratio = (float32)(log(lb) / log(pb));
00479             E_INFO("Lattice log ratio: %f\n", logratio);
00480         }
00481     }
00482     /* Read Frames parameter */
00483     dag->n_frames = dag_param_read(line, "Frames");
00484     if (dag->n_frames <= 0) {
00485         E_ERROR("Frames parameter missing or invalid\n");
00486         goto load_error;
00487     }
00488     /* Read Nodes parameter */
00489     n_nodes = dag_param_read(line, "Nodes");
00490     if (n_nodes <= 0) {
00491         E_ERROR("Nodes parameter missing or invalid\n");
00492         goto load_error;
00493     }
00494 
00495     /* Read nodes */
00496     darray = ckd_calloc(n_nodes, sizeof(*darray));
00497     for (i = 0; i < n_nodes; i++) {
00498         ps_latnode_t *d;
00499         int32 w;
00500         int seqid, sf, fef, lef;
00501         char wd[256];
00502 
00503         if ((line = lineiter_next(line)) == NULL) {
00504             E_ERROR("Premature EOF while loading Nodes(%s)\n", file);
00505             goto load_error;
00506         }
00507 
00508         if ((k =
00509              sscanf(line->buf, "%d %255s %d %d %d", &seqid, wd, &sf, &fef,
00510                     &lef)) != 5) {
00511             E_ERROR("Cannot parse line: %s, value of count %d\n", line->buf, k);
00512             goto load_error;
00513         }
00514 
00515         w = dict_wordid(dag->dict, wd);
00516         if (w < 0) {
00517             if (dag->search == NULL) {
00518                 char *ww = ckd_salloc(wd);
00519                 if (dict_word2basestr(ww) != -1) {
00520                     if (dict_wordid(dag->dict, ww) == BAD_S3WID)
00521                         dict_add_word(dag->dict, ww, NULL, 0);
00522                 }
00523                 ckd_free(ww);
00524                 w = dict_add_word(dag->dict, wd, NULL, 0);
00525             }
00526             if (w < 0) {
00527                 E_ERROR("Unknown word in line: %s\n", line->buf);
00528                 goto load_error;
00529             }
00530         }
00531 
00532         if (seqid != i) {
00533             E_ERROR("Seqno error: %s\n", line->buf);
00534             goto load_error;
00535         }
00536 
00537         d = listelem_malloc(dag->latnode_alloc);
00538         darray[i] = d;
00539         d->wid = w;
00540         d->basewid = dict_basewid(dag->dict, w);
00541         d->id = seqid;
00542         d->sf = sf;
00543         d->fef = fef;
00544         d->lef = lef;
00545         d->reachable = 0;
00546         d->exits = d->entries = NULL;
00547         d->next = NULL;
00548 
00549         if (!dag->nodes)
00550             dag->nodes = d;
00551         else
00552             tail->next = d;
00553         tail = d;
00554     }
00555 
00556     /* Read initial node ID */
00557     k = dag_param_read(line, "Initial");
00558     if ((k < 0) || (k >= n_nodes)) {
00559         E_ERROR("Initial node parameter missing or invalid\n");
00560         goto load_error;
00561     }
00562     dag->start = darray[k];
00563 
00564     /* Read final node ID */
00565     k = dag_param_read(line, "Final");
00566     if ((k < 0) || (k >= n_nodes)) {
00567         E_ERROR("Final node parameter missing or invalid\n");
00568         goto load_error;
00569     }
00570     dag->end = darray[k];
00571 
00572     /* Read bestsegscore entries and ignore them. */
00573     if ((k = dag_param_read(line, "BestSegAscr")) < 0) {
00574         E_ERROR("BestSegAscr parameter missing\n");
00575         goto load_error;
00576     }
00577     for (i = 0; i < k; i++) {
00578         if ((line = lineiter_next(line)) == NULL) {
00579             E_ERROR("Premature EOF while (%s) ignoring BestSegAscr\n",
00580                     line);
00581             goto load_error;
00582         }
00583     }
00584 
00585     /* Read in edges. */
00586     while ((line = lineiter_next(line)) != NULL) {
00587         if (line->buf[0] == '#')
00588             continue;
00589         if (0 == strncmp(line->buf, "Edges", 5))
00590             break;
00591     }
00592     if (line == NULL) {
00593         E_ERROR("Edges missing\n");
00594         goto load_error;
00595     }
00596     while ((line = lineiter_next(line)) != NULL) {
00597         int from, to, ascr;
00598         ps_latnode_t *pd, *d;
00599 
00600         if (sscanf(line->buf, "%d %d %d", &from, &to, &ascr) != 3)
00601             break;
00602         if (ascr WORSE_THAN WORST_SCORE)
00603             continue;
00604         pd = darray[from];
00605         d = darray[to];
00606         if (logratio != 1.0f)
00607             ascr = (int32)(ascr * logratio);
00608         ps_lattice_link(dag, pd, d, ascr, d->sf - 1);
00609     }
00610     if (strcmp(line->buf, "End\n") != 0) {
00611         E_ERROR("Terminating 'End' missing\n");
00612         goto load_error;
00613     }
00614     lineiter_free(line);
00615     fclose_comp(fp, ispipe);
00616     ckd_free(darray);
00617 
00618     /* Minor hack: If the final node is a filler word and not </s>,
00619      * then set its base word ID to </s>, so that the language model
00620      * scores won't be screwed up. */
00621     if (dict_filler_word(dag->dict, dag->end->wid))
00622         dag->end->basewid = dag->search
00623             ? ps_search_finish_wid(dag->search)
00624             : dict_wordid(dag->dict, S3_FINISH_WORD);
00625 
00626     /* Mark reachable from dag->end */
00627     dag_mark_reachable(dag->end);
00628 
00629     /* Free nodes unreachable from dag->end and their links */
00630     ps_lattice_delete_unreachable(dag);
00631 
00632     if (ps) {
00633         /* Build links around silence and filler words, since they do
00634          * not exist in the language model.  FIXME: This is
00635          * potentially buggy, as we already do this before outputing
00636          * lattices. */
00637         pip = logmath_log(dag->lmath, cmd_ln_float32_r(ps->config, "-pip"));
00638         silpen = pip + logmath_log(dag->lmath,
00639                                    cmd_ln_float32_r(ps->config, "-silprob"));
00640         fillpen = pip + logmath_log(dag->lmath,
00641                                     cmd_ln_float32_r(ps->config, "-fillprob"));
00642         ps_lattice_bypass_fillers(dag, silpen, fillpen);
00643     }
00644 
00645     return dag;
00646 
00647   load_error:
00648     E_ERROR("Failed to load %s\n", file);
00649     lineiter_free(line);
00650     if (fp) fclose_comp(fp, ispipe);
00651     ckd_free(darray);
00652     return NULL;
00653 }
00654 
00655 int
00656 ps_lattice_n_frames(ps_lattice_t *dag)
00657 {
00658     return dag->n_frames;
00659 }
00660 
00661 ps_lattice_t *
00662 ps_lattice_init_search(ps_search_t *search, int n_frame)
00663 {
00664     ps_lattice_t *dag;
00665 
00666     dag = ckd_calloc(1, sizeof(*dag));
00667     dag->search = search;
00668     dag->dict = dict_retain(search->dict);
00669     dag->lmath = logmath_retain(search->acmod->lmath);
00670     dag->frate = cmd_ln_int32_r(dag->search->config, "-frate");
00671     dag->silence = dict_silwid(dag->dict);
00672     dag->n_frames = n_frame;
00673     dag->latnode_alloc = listelem_alloc_init(sizeof(ps_latnode_t));
00674     dag->latlink_alloc = listelem_alloc_init(sizeof(ps_latlink_t));
00675     dag->latlink_list_alloc = listelem_alloc_init(sizeof(latlink_list_t));
00676     dag->refcount = 1;
00677     return dag;
00678 }
00679 
00680 ps_lattice_t *
00681 ps_lattice_retain(ps_lattice_t *dag)
00682 {
00683     ++dag->refcount;
00684     return dag;
00685 }
00686 
00687 int
00688 ps_lattice_free(ps_lattice_t *dag)
00689 {
00690     if (dag == NULL)
00691         return 0;
00692     if (--dag->refcount > 0)
00693         return dag->refcount;
00694     logmath_free(dag->lmath);
00695     listelem_alloc_free(dag->latnode_alloc);
00696     listelem_alloc_free(dag->latlink_alloc);
00697     listelem_alloc_free(dag->latlink_list_alloc);
00698     ckd_free(dag->hyp_str);
00699     ckd_free(dag);
00700     return 0;
00701 }
00702 
00703 logmath_t *
00704 ps_lattice_get_logmath(ps_lattice_t *dag)
00705 {
00706     return dag->lmath;
00707 }
00708 
00709 ps_latnode_iter_t *
00710 ps_latnode_iter(ps_lattice_t *dag)
00711 {
00712     return dag->nodes;
00713 }
00714 
00715 ps_latnode_iter_t *
00716 ps_latnode_iter_next(ps_latnode_iter_t *itor)
00717 {
00718     return itor->next;
00719 }
00720 
00721 void
00722 ps_latnode_iter_free(ps_latnode_iter_t *itor)
00723 {
00724     /* Do absolutely nothing. */
00725 }
00726 
00727 ps_latnode_t *
00728 ps_latnode_iter_node(ps_latnode_iter_t *itor)
00729 {
00730     return itor;
00731 }
00732 
00733 int
00734 ps_latnode_times(ps_latnode_t *node, int16 *out_fef, int16 *out_lef)
00735 {
00736     if (out_fef) *out_fef = (int16)node->fef;
00737     if (out_lef) *out_lef = (int16)node->lef;
00738     return node->sf;
00739 }
00740 
00741 char const *
00742 ps_latnode_word(ps_lattice_t *dag, ps_latnode_t *node)
00743 {
00744     return dict_wordstr(dag->dict, node->wid);
00745 }
00746 
00747 char const *
00748 ps_latnode_baseword(ps_lattice_t *dag, ps_latnode_t *node)
00749 {
00750     return dict_wordstr(dag->dict, node->basewid);
00751 }
00752 
00753 int32
00754 ps_latnode_prob(ps_lattice_t *dag, ps_latnode_t *node,
00755                 ps_latlink_t **out_link)
00756 {
00757     latlink_list_t *links;
00758     int32 bestpost = logmath_get_zero(dag->lmath);
00759 
00760     for (links = node->exits; links; links = links->next) {
00761         int32 post = links->link->alpha + links->link->beta - dag->norm;
00762         if (post > bestpost) {
00763             if (out_link) *out_link = links->link;
00764             bestpost = post;
00765         }
00766     }
00767     return bestpost;
00768 }
00769 
00770 ps_latlink_iter_t *
00771 ps_latnode_exits(ps_latnode_t *node)
00772 {
00773     return node->exits;
00774 }
00775 
00776 ps_latlink_iter_t *
00777 ps_latnode_entries(ps_latnode_t *node)
00778 {
00779     return node->entries;
00780 }
00781 
00782 ps_latlink_iter_t *
00783 ps_latlink_iter_next(ps_latlink_iter_t *itor)
00784 {
00785     return itor->next;
00786 }
00787 
00788 void
00789 ps_latlink_iter_free(ps_latlink_iter_t *itor)
00790 {
00791     /* Do absolutely nothing. */
00792 }
00793 
00794 ps_latlink_t *
00795 ps_latlink_iter_link(ps_latlink_iter_t *itor)
00796 {
00797     return itor->link;
00798 }
00799 
00800 int
00801 ps_latlink_times(ps_latlink_t *link, int16 *out_sf)
00802 {
00803     if (out_sf) {
00804         if (link->from) {
00805             *out_sf = link->from->sf;
00806         }
00807         else {
00808             *out_sf = 0;
00809         }
00810     }
00811     return link->ef;
00812 }
00813 
00814 ps_latnode_t *
00815 ps_latlink_nodes(ps_latlink_t *link, ps_latnode_t **out_src)
00816 {
00817     if (out_src) *out_src = link->from;
00818     return link->to;
00819 }
00820 
00821 char const *
00822 ps_latlink_word(ps_lattice_t *dag, ps_latlink_t *link)
00823 {
00824     if (link->from == NULL)
00825         return NULL;
00826     return dict_wordstr(dag->dict, link->from->wid);
00827 }
00828 
00829 char const *
00830 ps_latlink_baseword(ps_lattice_t *dag, ps_latlink_t *link)
00831 {
00832     if (link->from == NULL)
00833         return NULL;
00834     return dict_wordstr(dag->dict, link->from->basewid);
00835 }
00836 
00837 ps_latlink_t *
00838 ps_latlink_pred(ps_latlink_t *link)
00839 {
00840     return link->best_prev;
00841 }
00842 
00843 int32
00844 ps_latlink_prob(ps_lattice_t *dag, ps_latlink_t *link, int32 *out_ascr)
00845 {
00846     int32 post = link->alpha + link->beta - dag->norm;
00847     if (out_ascr) *out_ascr = link->ascr << SENSCR_SHIFT;
00848     return post;
00849 }
00850 
00851 char const *
00852 ps_lattice_hyp(ps_lattice_t *dag, ps_latlink_t *link)
00853 {
00854     ps_latlink_t *l;
00855     size_t len;
00856     char *c;
00857 
00858     /* Backtrace once to get hypothesis length. */
00859     len = 0;
00860     /* FIXME: There may not be a search, but actually there should be a dict. */
00861     if (dict_real_word(dag->dict, link->to->basewid))
00862         len += strlen(dict_wordstr(dag->dict, link->to->basewid)) + 1;
00863     for (l = link; l; l = l->best_prev) {
00864         if (dict_real_word(dag->dict, l->from->basewid))
00865             len += strlen(dict_wordstr(dag->dict, l->from->basewid)) + 1;
00866     }
00867 
00868     /* Backtrace again to construct hypothesis string. */
00869     ckd_free(dag->hyp_str);
00870     dag->hyp_str = ckd_calloc(1, len+1); /* extra one incase the hyp is empty */
00871     c = dag->hyp_str + len - 1;
00872     if (dict_real_word(dag->dict, link->to->basewid)) {
00873         len = strlen(dict_wordstr(dag->dict, link->to->basewid));
00874         c -= len;
00875         memcpy(c, dict_wordstr(dag->dict, link->to->basewid), len);
00876         if (c > dag->hyp_str) {
00877             --c;
00878             *c = ' ';
00879         }
00880     }
00881     for (l = link; l; l = l->best_prev) {
00882         if (dict_real_word(dag->dict, l->from->basewid)) {
00883             len = strlen(dict_wordstr(dag->dict, l->from->basewid));
00884             c -= len;
00885             memcpy(c, dict_wordstr(dag->dict, l->from->basewid), len);
00886             if (c > dag->hyp_str) {
00887                 --c;
00888                 *c = ' ';
00889             }
00890         }
00891     }
00892 
00893     return dag->hyp_str;
00894 }
00895 
00896 static void
00897 ps_lattice_compute_lscr(ps_seg_t *seg, ps_latlink_t *link, int to)
00898 {
00899     ngram_model_t *lmset;
00900 
00901     /* Language model score is included in the link score for FSG
00902      * search.  FIXME: Of course, this is sort of a hack :( */
00903     if (0 != strcmp(ps_search_name(seg->search), "ngram")) {
00904         seg->lback = 1; /* Unigram... */
00905         seg->lscr = 0;
00906         return;
00907     }
00908         
00909     lmset = ((ngram_search_t *)seg->search)->lmset;
00910 
00911     if (link->best_prev == NULL) {
00912         if (to) /* Sentence has only two words. */
00913             seg->lscr = ngram_bg_score(lmset, link->to->basewid,
00914                                        link->from->basewid, &seg->lback)
00915                 >> SENSCR_SHIFT;
00916         else {/* This is the start symbol, its lscr is always 0. */
00917             seg->lscr = 0;
00918             seg->lback = 1;
00919         }
00920     }
00921     else {
00922         /* Find the two predecessor words. */
00923         if (to) {
00924             seg->lscr = ngram_tg_score(lmset, link->to->basewid,
00925                                        link->from->basewid,
00926                                        link->best_prev->from->basewid,
00927                                        &seg->lback) >> SENSCR_SHIFT;
00928         }
00929         else {
00930             if (link->best_prev->best_prev)
00931                 seg->lscr = ngram_tg_score(lmset, link->from->basewid,
00932                                            link->best_prev->from->basewid,
00933                                            link->best_prev->best_prev->from->basewid,
00934                                            &seg->lback) >> SENSCR_SHIFT;
00935             else
00936                 seg->lscr = ngram_bg_score(lmset, link->from->basewid,
00937                                            link->best_prev->from->basewid,
00938                                            &seg->lback) >> SENSCR_SHIFT;
00939         }
00940     }
00941 }
00942 
00943 static void
00944 ps_lattice_link2itor(ps_seg_t *seg, ps_latlink_t *link, int to)
00945 {
00946     dag_seg_t *itor = (dag_seg_t *)seg;
00947     ps_latnode_t *node;
00948 
00949     if (to) {
00950         node = link->to;
00951         seg->ef = node->lef;
00952         seg->prob = 0; /* norm + beta - norm */
00953     }
00954     else {
00955         latlink_list_t *x;
00956         ps_latnode_t *n;
00957         logmath_t *lmath = ps_search_acmod(seg->search)->lmath;
00958 
00959         node = link->from;
00960         seg->ef = link->ef;
00961         seg->prob = link->alpha + link->beta - itor->norm;
00962         /* Sum over all exits for this word and any alternate
00963            pronunciations at the same frame. */
00964         for (n = node; n; n = n->alt) {
00965             for (x = n->exits; x; x = x->next) {
00966                 if (x->link == link)
00967                     continue;
00968                 seg->prob = logmath_add(lmath, seg->prob,
00969                                         x->link->alpha + x->link->beta - itor->norm);
00970             }
00971         }
00972     }
00973     seg->word = dict_wordstr(ps_search_dict(seg->search), node->wid);
00974     seg->sf = node->sf;
00975     seg->ascr = link->ascr << SENSCR_SHIFT;
00976     /* Compute language model score from best predecessors. */
00977     ps_lattice_compute_lscr(seg, link, to);
00978 }
00979 
00980 static void
00981 ps_lattice_seg_free(ps_seg_t *seg)
00982 {
00983     dag_seg_t *itor = (dag_seg_t *)seg;
00984     
00985     ckd_free(itor->links);
00986     ckd_free(itor);
00987 }
00988 
00989 static ps_seg_t *
00990 ps_lattice_seg_next(ps_seg_t *seg)
00991 {
00992     dag_seg_t *itor = (dag_seg_t *)seg;
00993 
00994     ++itor->cur;
00995     if (itor->cur == itor->n_links + 1) {
00996         ps_lattice_seg_free(seg);
00997         return NULL;
00998     }
00999     else if (itor->cur == itor->n_links) {
01000         /* Re-use the last link but with the "to" node. */
01001         ps_lattice_link2itor(seg, itor->links[itor->cur - 1], TRUE);
01002     }
01003     else {
01004         ps_lattice_link2itor(seg, itor->links[itor->cur], FALSE);
01005     }
01006 
01007     return seg;
01008 }
01009 
01010 static ps_segfuncs_t ps_lattice_segfuncs = {
01011     /* seg_next */ ps_lattice_seg_next,
01012     /* seg_free */ ps_lattice_seg_free
01013 };
01014 
01015 ps_seg_t *
01016 ps_lattice_seg_iter(ps_lattice_t *dag, ps_latlink_t *link, float32 lwf)
01017 {
01018     dag_seg_t *itor;
01019     ps_latlink_t *l;
01020     int cur;
01021 
01022     /* Calling this an "iterator" is a bit of a misnomer since we have
01023      * to get the entire backtrace in order to produce it.
01024      */
01025     itor = ckd_calloc(1, sizeof(*itor));
01026     itor->base.vt = &ps_lattice_segfuncs;
01027     itor->base.search = dag->search;
01028     itor->base.lwf = lwf;
01029     itor->n_links = 0;
01030     itor->norm = dag->norm;
01031 
01032     for (l = link; l; l = l->best_prev) {
01033         ++itor->n_links;
01034     }
01035     if (itor->n_links == 0) {
01036         ckd_free(itor);
01037         return NULL;
01038     }
01039 
01040     itor->links = ckd_calloc(itor->n_links, sizeof(*itor->links));
01041     cur = itor->n_links - 1;
01042     for (l = link; l; l = l->best_prev) {
01043         itor->links[cur] = l;
01044         --cur;
01045     }
01046 
01047     ps_lattice_link2itor((ps_seg_t *)itor, itor->links[0], FALSE);
01048     return (ps_seg_t *)itor;
01049 }
01050 
01051 latlink_list_t *
01052 latlink_list_new(ps_lattice_t *dag, ps_latlink_t *link, latlink_list_t *next)
01053 {
01054     latlink_list_t *ll;
01055 
01056     ll = listelem_malloc(dag->latlink_list_alloc);
01057     ll->link = link;
01058     ll->next = next;
01059 
01060     return ll;
01061 }
01062 
01063 void
01064 ps_lattice_pushq(ps_lattice_t *dag, ps_latlink_t *link)
01065 {
01066     if (dag->q_head == NULL)
01067         dag->q_head = dag->q_tail = latlink_list_new(dag, link, NULL);
01068     else {
01069         dag->q_tail->next = latlink_list_new(dag, link, NULL);
01070         dag->q_tail = dag->q_tail->next;
01071     }
01072 
01073 }
01074 
01075 ps_latlink_t *
01076 ps_lattice_popq(ps_lattice_t *dag)
01077 {
01078     latlink_list_t *x;
01079     ps_latlink_t *link;
01080 
01081     if (dag->q_head == NULL)
01082         return NULL;
01083     link = dag->q_head->link;
01084     x = dag->q_head->next;
01085     listelem_free(dag->latlink_list_alloc, dag->q_head);
01086     dag->q_head = x;
01087     if (dag->q_head == NULL)
01088         dag->q_tail = NULL;
01089     return link;
01090 }
01091 
01092 void
01093 ps_lattice_delq(ps_lattice_t *dag)
01094 {
01095     while (ps_lattice_popq(dag)) {
01096         /* Do nothing. */
01097     }
01098 }
01099 
01100 ps_latlink_t *
01101 ps_lattice_traverse_edges(ps_lattice_t *dag, ps_latnode_t *start, ps_latnode_t *end)
01102 {
01103     ps_latnode_t *node;
01104     latlink_list_t *x;
01105 
01106     /* Cancel any unfinished traversal. */
01107     ps_lattice_delq(dag);
01108 
01109     /* Initialize node fanin counts and path scores. */
01110     for (node = dag->nodes; node; node = node->next)
01111         node->info.fanin = 0;
01112     for (node = dag->nodes; node; node = node->next) {
01113         for (x = node->exits; x; x = x->next)
01114             (x->link->to->info.fanin)++;
01115     }
01116 
01117     /* Initialize agenda with all exits from start. */
01118     if (start == NULL) start = dag->start;
01119     for (x = start->exits; x; x = x->next)
01120         ps_lattice_pushq(dag, x->link);
01121 
01122     /* Pull the first edge off the queue. */
01123     return ps_lattice_traverse_next(dag, end);
01124 }
01125 
01126 ps_latlink_t *
01127 ps_lattice_traverse_next(ps_lattice_t *dag, ps_latnode_t *end)
01128 {
01129     ps_latlink_t *next;
01130 
01131     next = ps_lattice_popq(dag);
01132     if (next == NULL)
01133         return NULL;
01134 
01135     /* Decrease fanin count for destination node and expand outgoing
01136      * edges if all incoming edges have been seen. */
01137     --next->to->info.fanin;
01138     if (next->to->info.fanin == 0) {
01139         latlink_list_t *x;
01140 
01141         if (end == NULL) end = dag->end;
01142         if (next->to == end) {
01143             /* If we have traversed all links entering the end node,
01144              * clear the queue, causing future calls to this function
01145              * to return NULL. */
01146             ps_lattice_delq(dag);
01147             return next;
01148         }
01149 
01150         /* Extend all outgoing edges. */
01151         for (x = next->to->exits; x; x = x->next)
01152             ps_lattice_pushq(dag, x->link);
01153     }
01154     return next;
01155 }
01156 
01157 ps_latlink_t *
01158 ps_lattice_reverse_edges(ps_lattice_t *dag, ps_latnode_t *start, ps_latnode_t *end)
01159 {
01160     ps_latnode_t *node;
01161     latlink_list_t *x;
01162 
01163     /* Cancel any unfinished traversal. */
01164     ps_lattice_delq(dag);
01165 
01166     /* Initialize node fanout counts and path scores. */
01167     for (node = dag->nodes; node; node = node->next) {
01168         node->info.fanin = 0;
01169         for (x = node->exits; x; x = x->next)
01170             ++node->info.fanin;
01171     }
01172 
01173     /* Initialize agenda with all entries from end. */
01174     if (end == NULL) end = dag->end;
01175     for (x = end->entries; x; x = x->next)
01176         ps_lattice_pushq(dag, x->link);
01177 
01178     /* Pull the first edge off the queue. */
01179     return ps_lattice_reverse_next(dag, start);
01180 }
01181 
01182 ps_latlink_t *
01183 ps_lattice_reverse_next(ps_lattice_t *dag, ps_latnode_t *start)
01184 {
01185     ps_latlink_t *next;
01186 
01187     next = ps_lattice_popq(dag);
01188     if (next == NULL)
01189         return NULL;
01190 
01191     /* Decrease fanout count for source node and expand incoming
01192      * edges if all incoming edges have been seen. */
01193     --next->from->info.fanin;
01194     if (next->from->info.fanin == 0) {
01195         latlink_list_t *x;
01196 
01197         if (start == NULL) start = dag->start;
01198         if (next->from == start) {
01199             /* If we have traversed all links entering the start node,
01200              * clear the queue, causing future calls to this function
01201              * to return NULL. */
01202             ps_lattice_delq(dag);
01203             return next;
01204         }
01205 
01206         /* Extend all outgoing edges. */
01207         for (x = next->from->entries; x; x = x->next)
01208             ps_lattice_pushq(dag, x->link);
01209     }
01210     return next;
01211 }
01212 
01213 /*
01214  * Find the best score from dag->start to end point of any link and
01215  * use it to update links further down the path.  This is like
01216  * single-source shortest path search, except that it is done over
01217  * edges rather than nodes, which allows us to do exact trigram scoring.
01218  *
01219  * Helpfully enough, we get half of the posterior probability
01220  * calculation for free that way too.  (interesting research topic: is
01221  * there a reliable Viterbi analogue to word-level Forward-Backward
01222  * like there is for state-level?  Or, is it just lattice density?)
01223  */
01224 ps_latlink_t *
01225 ps_lattice_bestpath(ps_lattice_t *dag, ngram_model_t *lmset,
01226                     float32 lwf, float32 ascale)
01227 {
01228     ps_search_t *search;
01229     ps_latnode_t *node;
01230     ps_latlink_t *link;
01231     ps_latlink_t *bestend;
01232     latlink_list_t *x;
01233     logmath_t *lmath;
01234     int32 bestescr;
01235 
01236     search = dag->search;
01237     lmath = dag->lmath;
01238 
01239     /* Initialize path scores for all links exiting dag->start, and
01240      * set all other scores to the minimum.  Also initialize alphas to
01241      * log-zero. */
01242     for (node = dag->nodes; node; node = node->next) {
01243         for (x = node->exits; x; x = x->next) {
01244             x->link->path_scr = MAX_NEG_INT32;
01245             x->link->alpha = logmath_get_zero(lmath);
01246         }
01247     }
01248     for (x = dag->start->exits; x; x = x->next) {
01249         int32 n_used;
01250 
01251         /* Ignore filler words. */
01252         if (dict_filler_word(ps_search_dict(search), x->link->to->basewid)
01253             && x->link->to != dag->end)
01254             continue;
01255 
01256         /* Best path points to dag->start, obviously. */
01257         if (lmset)
01258             x->link->path_scr = x->link->ascr +
01259                 (ngram_bg_score(lmset, x->link->to->basewid,
01260                                 ps_search_start_wid(search), &n_used) 
01261                  >> SENSCR_SHIFT)
01262                  * lwf;
01263         else
01264             x->link->path_scr = x->link->ascr;
01265         x->link->best_prev = NULL;
01266         /* No predecessors for start links. */
01267         x->link->alpha = 0;
01268     }
01269 
01270     /* Traverse the edges in the graph, updating path scores. */
01271     for (link = ps_lattice_traverse_edges(dag, NULL, NULL);
01272          link; link = ps_lattice_traverse_next(dag, NULL)) {
01273         int32 bprob, n_used;
01274 
01275         /* Skip filler nodes in traversal. */
01276         if (dict_filler_word(ps_search_dict(search), link->from->basewid) && link->from != dag->start)
01277             continue;
01278         if (dict_filler_word(ps_search_dict(search), link->to->basewid) && link->to != dag->end)
01279             continue;
01280 
01281         /* Sanity check, we should not be traversing edges that
01282          * weren't previously updated, otherwise nasty overflows will result. */
01283         assert(link->path_scr != MAX_NEG_INT32);
01284 
01285         /* Calculate common bigram probability for all alphas. */
01286         if (lmset)
01287             bprob = ngram_ng_prob(lmset,
01288                                   link->to->basewid,
01289                                   &link->from->basewid, 1, &n_used);
01290         else
01291             bprob = 0;
01292         /* Add in this link's acoustic score, which was a constant
01293            factor in previous computations (if any). */
01294         link->alpha += (link->ascr << SENSCR_SHIFT) * ascale;
01295 
01296         /* Update scores for all paths exiting link->to. */
01297         for (x = link->to->exits; x; x = x->next) {
01298             int32 tscore, score;
01299 
01300             /* Skip links to filler words in update. */
01301             if (dict_filler_word(ps_search_dict(search), x->link->to->basewid)
01302                 && x->link->to != dag->end)
01303                 continue;
01304 
01305             /* Update alpha with sum of previous alphas. */
01306             x->link->alpha = logmath_add(lmath, x->link->alpha, link->alpha + bprob);
01307             /* Calculate trigram score for bestpath. */
01308             if (lmset)
01309                 tscore = (ngram_tg_score(lmset, x->link->to->basewid,
01310                                         link->to->basewid,
01311                                         link->from->basewid, &n_used) >> SENSCR_SHIFT)
01312                     * lwf;
01313             else
01314                 tscore = 0;
01315             /* Update link score with maximum link score. */
01316             score = link->path_scr + tscore + x->link->ascr;
01317             if (score BETTER_THAN x->link->path_scr) {
01318                 x->link->path_scr = score;
01319                 x->link->best_prev = link;
01320             }
01321         }
01322     }
01323 
01324     /* Find best link entering final node, and calculate normalizer
01325      * for posterior probabilities. */
01326     bestend = NULL;
01327     bestescr = MAX_NEG_INT32;
01328 
01329     /* Normalizer is the alpha for the imaginary link exiting the
01330        final node. */
01331     dag->norm = logmath_get_zero(lmath);
01332     for (x = dag->end->entries; x; x = x->next) {
01333         int32 bprob, n_used;
01334 
01335         if (dict_filler_word(ps_search_dict(search), x->link->from->basewid))
01336             continue;
01337         if (lmset)
01338             bprob = ngram_ng_prob(lmset,
01339                                   x->link->to->basewid,
01340                                   &x->link->from->basewid, 1, &n_used);
01341         else
01342             bprob = 0;
01343         dag->norm = logmath_add(lmath, dag->norm, x->link->alpha + bprob);
01344         if (x->link->path_scr BETTER_THAN bestescr) {
01345             bestescr = x->link->path_scr;
01346             bestend = x->link;
01347         }
01348     }
01349     /* FIXME: floating point... */
01350     dag->norm += (int32)(dag->final_node_ascr << SENSCR_SHIFT) * ascale;
01351 
01352     E_INFO("Normalizer P(O) = alpha(%s:%d:%d) = %d\n",
01353            dict_wordstr(dag->search->dict, dag->end->wid),
01354            dag->end->sf, dag->end->lef,
01355            dag->norm);
01356     return bestend;
01357 }
01358 
01359 static int32
01360 ps_lattice_joint(ps_lattice_t *dag, ps_latlink_t *link, float32 ascale)
01361 {
01362     ngram_model_t *lmset;
01363     int32 jprob;
01364 
01365     /* Sort of a hack... */
01366     if (dag->search && 0 == strcmp(ps_search_name(dag->search), "ngram"))
01367         lmset = ((ngram_search_t *)dag->search)->lmset;
01368     else
01369         lmset = NULL;
01370 
01371     jprob = (dag->final_node_ascr << SENSCR_SHIFT) * ascale;
01372     while (link) {
01373         if (lmset) {
01374             int lback;
01375             /* Compute unscaled language model probability.  Note that
01376                this is actually not the language model probability
01377                that corresponds to this link, but that is okay,
01378                because we are just taking the sum over all links in
01379                the best path. */
01380             jprob += ngram_ng_prob(lmset, link->to->basewid,
01381                                    &link->from->basewid, 1, &lback);
01382         }
01383         /* If there is no language model, we assume that the language
01384            model probability (such as it is) has been included in the
01385            link score. */
01386         jprob += (link->ascr << SENSCR_SHIFT) * ascale;
01387         link = link->best_prev;
01388     }
01389 
01390     E_INFO("Joint P(O,S) = %d P(S|O) = %d\n", jprob, jprob - dag->norm);
01391     return jprob;
01392 }
01393 
01394 int32
01395 ps_lattice_posterior(ps_lattice_t *dag, ngram_model_t *lmset,
01396                      float32 ascale)
01397 {
01398     ps_search_t *search;
01399     logmath_t *lmath;
01400     ps_latnode_t *node;
01401     ps_latlink_t *link;
01402     latlink_list_t *x;
01403     ps_latlink_t *bestend;
01404     int32 bestescr;
01405 
01406     search = dag->search;
01407     lmath = dag->lmath;
01408 
01409     /* Reset all betas to zero. */
01410     for (node = dag->nodes; node; node = node->next) {
01411         for (x = node->exits; x; x = x->next) {
01412             x->link->beta = logmath_get_zero(lmath);
01413         }
01414     }
01415 
01416     bestend = NULL;
01417     bestescr = MAX_NEG_INT32;
01418     /* Accumulate backward probabilities for all links. */
01419     for (link = ps_lattice_reverse_edges(dag, NULL, NULL);
01420          link; link = ps_lattice_reverse_next(dag, NULL)) {
01421         int32 bprob, n_used;
01422 
01423         /* Skip filler nodes in traversal. */
01424         if (dict_filler_word(ps_search_dict(search), link->from->basewid) && link->from != dag->start)
01425             continue;
01426         if (dict_filler_word(ps_search_dict(search), link->to->basewid) && link->to != dag->end)
01427             continue;
01428 
01429         /* Calculate LM probability. */
01430         if (lmset)
01431             bprob = ngram_ng_prob(lmset, link->to->basewid,
01432                                   &link->from->basewid, 1, &n_used);
01433         else
01434             bprob = 0;
01435 
01436         if (link->to == dag->end) {
01437             /* Track the best path - we will backtrace in order to
01438                calculate the unscaled joint probability for sentence
01439                posterior. */
01440             if (link->path_scr BETTER_THAN bestescr) {
01441                 bestescr = link->path_scr;
01442                 bestend = link;
01443             }
01444             /* Imaginary exit link from final node has beta = 1.0 */
01445             link->beta = bprob + (dag->final_node_ascr << SENSCR_SHIFT) * ascale;
01446         }
01447         else {
01448             /* Update beta from all outgoing betas. */
01449             for (x = link->to->exits; x; x = x->next) {
01450                 if (dict_filler_word(ps_search_dict(search), x->link->to->basewid) && x->link->to != dag->end)
01451                     continue;
01452                 link->beta = logmath_add(lmath, link->beta,
01453                                          x->link->beta + bprob
01454                                          + (x->link->ascr << SENSCR_SHIFT) * ascale);
01455             }
01456         }
01457     }
01458 
01459     /* Return P(S|O) = P(O,S)/P(O) */
01460     return ps_lattice_joint(dag, bestend, ascale) - dag->norm;
01461 }
01462 
01463 int32
01464 ps_lattice_posterior_prune(ps_lattice_t *dag, int32 beam)
01465 {
01466     ps_latlink_t *link;
01467     int npruned = 0;
01468 
01469     for (link = ps_lattice_traverse_edges(dag, dag->start, dag->end);
01470          link; link = ps_lattice_traverse_next(dag, dag->end)) {
01471         link->from->reachable = FALSE;
01472         if (link->alpha + link->beta - dag->norm < beam) {
01473             latlink_list_t *x, *tmp, *next;
01474             tmp = NULL;
01475             for (x = link->from->exits; x; x = next) {
01476                 next = x->next;
01477                 if (x->link == link) {
01478                     listelem_free(dag->latlink_list_alloc, x);
01479                 }
01480                 else {
01481                     x->next = tmp;
01482                     tmp = x;
01483                 }
01484             }
01485             link->from->exits = tmp;
01486             tmp = NULL;
01487             for (x = link->to->entries; x; x = next) {
01488                 next = x->next;
01489                 if (x->link == link) {
01490                     listelem_free(dag->latlink_list_alloc, x);
01491                 }
01492                 else {
01493                     x->next = tmp;
01494                     tmp = x;
01495                 }
01496             }
01497             link->to->entries = tmp;
01498             listelem_free(dag->latlink_alloc, link);
01499             ++npruned;
01500         }
01501     }
01502     dag_mark_reachable(dag->end);
01503     ps_lattice_delete_unreachable(dag);
01504     return npruned;
01505 }
01506 
01507 
01508 /* Parameters to prune n-best alternatives search */
01509 #define MAX_PATHS       500     /* Max allowed active paths at any time */
01510 #define MAX_HYP_TRIES   10000
01511 
01512 /*
01513  * For each node in any path between from and end of utt, find the
01514  * best score from "from".sf to end of utt.  (NOTE: Uses bigram probs;
01515  * this is an estimate of the best score from "from".)  (NOTE #2: yes,
01516  * this is the "heuristic score" used in A* search)
01517  */
01518 static int32
01519 best_rem_score(ps_astar_t *nbest, ps_latnode_t * from)
01520 {
01521     ps_lattice_t *dag;
01522     latlink_list_t *x;
01523     int32 bestscore, score;
01524 
01525     dag = nbest->dag;
01526     if (from->info.rem_score <= 0)
01527         return (from->info.rem_score);
01528 
01529     /* Best score from "from" to end of utt not known; compute from successors */
01530     bestscore = WORST_SCORE;
01531     for (x = from->exits; x; x = x->next) {
01532         int32 n_used;
01533 
01534         score = best_rem_score(nbest, x->link->to);
01535         score += x->link->ascr;
01536         if (nbest->lmset)
01537             score += (ngram_bg_score(nbest->lmset, x->link->to->basewid,
01538                                      from->basewid, &n_used) >> SENSCR_SHIFT)
01539                       * nbest->lwf;
01540         if (score BETTER_THAN bestscore)
01541             bestscore = score;
01542     }
01543     from->info.rem_score = bestscore;
01544 
01545     return bestscore;
01546 }
01547 
01548 /*
01549  * Insert newpath in sorted (by path score) list of paths.  But if newpath is
01550  * too far down the list, drop it (FIXME: necessary?)
01551  * total_score = path score (newpath) + rem_score to end of utt.
01552  */
01553 static void
01554 path_insert(ps_astar_t *nbest, ps_latpath_t *newpath, int32 total_score)
01555 {
01556     ps_lattice_t *dag;
01557     ps_latpath_t *prev, *p;
01558     int32 i;
01559 
01560     dag = nbest->dag;
01561     prev = NULL;
01562     for (i = 0, p = nbest->path_list; (i < MAX_PATHS) && p; p = p->next, i++) {
01563         if ((p->score + p->node->info.rem_score) < total_score)
01564             break;
01565         prev = p;
01566     }
01567 
01568     /* newpath should be inserted between prev and p */
01569     if (i < MAX_PATHS) {
01570         /* Insert new partial hyp */
01571         newpath->next = p;
01572         if (!prev)
01573             nbest->path_list = newpath;
01574         else
01575             prev->next = newpath;
01576         if (!p)
01577             nbest->path_tail = newpath;
01578 
01579         nbest->n_path++;
01580         nbest->n_hyp_insert++;
01581         nbest->insert_depth += i;
01582     }
01583     else {
01584         /* newpath score too low; reject it and also prune paths beyond MAX_PATHS */
01585         nbest->path_tail = prev;
01586         prev->next = NULL;
01587         nbest->n_path = MAX_PATHS;
01588         listelem_free(nbest->latpath_alloc, newpath);
01589 
01590         nbest->n_hyp_reject++;
01591         for (; p; p = newpath) {
01592             newpath = p->next;
01593             listelem_free(nbest->latpath_alloc, p);
01594             nbest->n_hyp_reject++;
01595         }
01596     }
01597 }
01598 
01599 /* Find all possible extensions to given partial path */
01600 static void
01601 path_extend(ps_astar_t *nbest, ps_latpath_t * path)
01602 {
01603     latlink_list_t *x;
01604     ps_latpath_t *newpath;
01605     int32 total_score, tail_score;
01606     ps_lattice_t *dag;
01607 
01608     dag = nbest->dag;
01609 
01610     /* Consider all successors of path->node */
01611     for (x = path->node->exits; x; x = x->next) {
01612         int32 n_used;
01613 
01614         /* Skip successor if no path from it reaches the final node */
01615         if (x->link->to->info.rem_score <= WORST_SCORE)
01616             continue;
01617 
01618         /* Create path extension and compute exact score for this extension */
01619         newpath = listelem_malloc(nbest->latpath_alloc);
01620         newpath->node = x->link->to;
01621         newpath->parent = path;
01622         newpath->score = path->score + x->link->ascr;
01623         if (nbest->lmset) {
01624             if (path->parent) {
01625                 newpath->score += nbest->lwf
01626                     * (ngram_tg_score(nbest->lmset, newpath->node->basewid,
01627                                       path->node->basewid,
01628                                       path->parent->node->basewid, &n_used)
01629                        >> SENSCR_SHIFT);
01630             }
01631             else 
01632                 newpath->score += nbest->lwf
01633                     * (ngram_bg_score(nbest->lmset, newpath->node->basewid,
01634                                       path->node->basewid, &n_used)
01635                        >> SENSCR_SHIFT);
01636         }
01637 
01638         /* Insert new partial path hypothesis into sorted path_list */
01639         nbest->n_hyp_tried++;
01640         total_score = newpath->score + newpath->node->info.rem_score;
01641 
01642         /* First see if hyp would be worse than the worst */
01643         if (nbest->n_path >= MAX_PATHS) {
01644             tail_score =
01645                 nbest->path_tail->score
01646                 + nbest->path_tail->node->info.rem_score;
01647             if (total_score < tail_score) {
01648                 listelem_free(nbest->latpath_alloc, newpath);
01649                 nbest->n_hyp_reject++;
01650                 continue;
01651             }
01652         }
01653 
01654         path_insert(nbest, newpath, total_score);
01655     }
01656 }
01657 
01658 ps_astar_t *
01659 ps_astar_start(ps_lattice_t *dag,
01660                   ngram_model_t *lmset,
01661                   float32 lwf,
01662                   int sf, int ef,
01663                   int w1, int w2)
01664 {
01665     ps_astar_t *nbest;
01666     ps_latnode_t *node;
01667 
01668     nbest = ckd_calloc(1, sizeof(*nbest));
01669     nbest->dag = dag;
01670     nbest->lmset = lmset;
01671     nbest->lwf = lwf;
01672     nbest->sf = sf;
01673     if (ef < 0)
01674         nbest->ef = dag->n_frames + 1;
01675     else
01676         nbest->ef = ef;
01677     nbest->w1 = w1;
01678     nbest->w2 = w2;
01679     nbest->latpath_alloc = listelem_alloc_init(sizeof(ps_latpath_t));
01680 
01681     /* Initialize rem_score (A* heuristic) to default values */
01682     for (node = dag->nodes; node; node = node->next) {
01683         if (node == dag->end)
01684             node->info.rem_score = 0;
01685         else if (node->exits == NULL)
01686             node->info.rem_score = WORST_SCORE;
01687         else
01688             node->info.rem_score = 1;   /* +ve => unknown value */
01689     }
01690 
01691     /* Create initial partial hypotheses list consisting of nodes starting at sf */
01692     nbest->path_list = nbest->path_tail = NULL;
01693     for (node = dag->nodes; node; node = node->next) {
01694         if (node->sf == sf) {
01695             ps_latpath_t *path;
01696             int32 n_used;
01697 
01698             best_rem_score(nbest, node);
01699             path = listelem_malloc(nbest->latpath_alloc);
01700             path->node = node;
01701             path->parent = NULL;
01702             if (nbest->lmset)
01703                 path->score = nbest->lwf *
01704                     (w1 < 0)
01705                     ? ngram_bg_score(nbest->lmset, node->basewid, w2, &n_used)
01706                     : ngram_tg_score(nbest->lmset, node->basewid, w2, w1, &n_used);
01707             else
01708                 path->score = 0;
01709             path->score >>= SENSCR_SHIFT;
01710             path_insert(nbest, path, path->score + node->info.rem_score);
01711         }
01712     }
01713 
01714     return nbest;
01715 }
01716 
01717 ps_latpath_t *
01718 ps_astar_next(ps_astar_t *nbest)
01719 {
01720     ps_lattice_t *dag;
01721 
01722     dag = nbest->dag;
01723 
01724     /* Pop the top (best) partial hypothesis */
01725     while ((nbest->top = nbest->path_list) != NULL) {
01726         nbest->path_list = nbest->path_list->next;
01727         if (nbest->top == nbest->path_tail)
01728             nbest->path_tail = NULL;
01729         nbest->n_path--;
01730 
01731         /* Complete hypothesis? */
01732         if ((nbest->top->node->sf >= nbest->ef)
01733             || ((nbest->top->node == dag->end) &&
01734                 (nbest->ef > dag->end->sf))) {
01735             /* FIXME: Verify that it is non-empty.  Also we may want
01736              * to verify that it is actually distinct from other
01737              * paths, since often this is not the case*/
01738             return nbest->top;
01739         }
01740         else {
01741             if (nbest->top->node->fef < nbest->ef)
01742                 path_extend(nbest, nbest->top);
01743         }
01744     }
01745 
01746     /* Did not find any more paths to extend. */
01747     return NULL;
01748 }
01749 
01750 char const *
01751 ps_astar_hyp(ps_astar_t *nbest, ps_latpath_t *path)
01752 {
01753     ps_search_t *search;
01754     ps_latpath_t *p;
01755     size_t len;
01756     char *c;
01757     char *hyp;
01758 
01759     search = nbest->dag->search;
01760 
01761     /* Backtrace once to get hypothesis length. */
01762     len = 0;
01763     for (p = path; p; p = p->parent) {
01764         if (dict_real_word(ps_search_dict(search), p->node->basewid))
01765             len += strlen(dict_wordstr(ps_search_dict(search), p->node->basewid)) + 1;
01766     }
01767 
01768     if (len == 0) {
01769         return NULL;
01770     }
01771 
01772     /* Backtrace again to construct hypothesis string. */
01773     hyp = ckd_calloc(1, len);
01774     c = hyp + len - 1;
01775     for (p = path; p; p = p->parent) {
01776         if (dict_real_word(ps_search_dict(search), p->node->basewid)) {
01777             len = strlen(dict_wordstr(ps_search_dict(search), p->node->basewid));
01778             c -= len;
01779             memcpy(c, dict_wordstr(ps_search_dict(search), p->node->basewid), len);
01780             if (c > hyp) {
01781                 --c;
01782                 *c = ' ';
01783             }
01784         }
01785     }
01786 
01787     nbest->hyps = glist_add_ptr(nbest->hyps, hyp);
01788     return hyp;
01789 }
01790 
01791 static void
01792 ps_astar_node2itor(astar_seg_t *itor)
01793 {
01794     ps_seg_t *seg = (ps_seg_t *)itor;
01795     ps_latnode_t *node;
01796 
01797     assert(itor->cur < itor->n_nodes);
01798     node = itor->nodes[itor->cur];
01799     if (itor->cur == itor->n_nodes - 1)
01800         seg->ef = node->lef;
01801     else
01802         seg->ef = itor->nodes[itor->cur + 1]->sf - 1;
01803     seg->word = dict_wordstr(ps_search_dict(seg->search), node->wid);
01804     seg->sf = node->sf;
01805     seg->prob = 0; /* FIXME: implement forward-backward */
01806 }
01807 
01808 static void
01809 ps_astar_seg_free(ps_seg_t *seg)
01810 {
01811     astar_seg_t *itor = (astar_seg_t *)seg;
01812     ckd_free(itor->nodes);
01813     ckd_free(itor);
01814 }
01815 
01816 static ps_seg_t *
01817 ps_astar_seg_next(ps_seg_t *seg)
01818 {
01819     astar_seg_t *itor = (astar_seg_t *)seg;
01820 
01821     ++itor->cur;
01822     if (itor->cur == itor->n_nodes) {
01823         ps_astar_seg_free(seg);
01824         return NULL;
01825     }
01826     else {
01827         ps_astar_node2itor(itor);
01828     }
01829 
01830     return seg;
01831 }
01832 
01833 static ps_segfuncs_t ps_astar_segfuncs = {
01834     /* seg_next */ ps_astar_seg_next,
01835     /* seg_free */ ps_astar_seg_free
01836 };
01837 
01838 ps_seg_t *
01839 ps_astar_seg_iter(ps_astar_t *astar, ps_latpath_t *path, float32 lwf)
01840 {
01841     astar_seg_t *itor;
01842     ps_latpath_t *p;
01843     int cur;
01844 
01845     /* Backtrace and make an iterator, this should look familiar by now. */
01846     itor = ckd_calloc(1, sizeof(*itor));
01847     itor->base.vt = &ps_astar_segfuncs;
01848     itor->base.search = astar->dag->search;
01849     itor->base.lwf = lwf;
01850     itor->n_nodes = itor->cur = 0;
01851     for (p = path; p; p = p->parent) {
01852         ++itor->n_nodes;
01853     }
01854     itor->nodes = ckd_calloc(itor->n_nodes, sizeof(*itor->nodes));
01855     cur = itor->n_nodes - 1;
01856     for (p = path; p; p = p->parent) {
01857         itor->nodes[cur] = p->node;
01858         --cur;
01859     }
01860 
01861     ps_astar_node2itor(itor);
01862     return (ps_seg_t *)itor;
01863 }
01864 
01865 void
01866 ps_astar_finish(ps_astar_t *nbest)
01867 {
01868     gnode_t *gn;
01869 
01870     /* Free all hyps. */
01871     for (gn = nbest->hyps; gn; gn = gnode_next(gn)) {
01872         ckd_free(gnode_ptr(gn));
01873     }
01874     glist_free(nbest->hyps);
01875     /* Free all paths. */
01876     listelem_alloc_free(nbest->latpath_alloc);
01877     /* Free the Henge. */
01878     ckd_free(nbest);
01879 }