1 /*************************************************************************/
2 /* */
3 /* Language Technologies Institute */
4 /* Carnegie Mellon University */
5 /* Copyright (c) 1999-2003 */
6 /* All Rights Reserved. */
7 /* */
8 /* Permission is hereby granted, free of charge, to use and distribute */
9 /* this software and its documentation without restriction, including */
10 /* without limitation the rights to use, copy, modify, merge, publish, */
11 /* distribute, sublicense, and/or sell copies of this work, and to */
12 /* permit persons to whom this work is furnished to do so, subject to */
13 /* the following conditions: */
14 /* 1. The code must retain the above copyright notice, this list of */
15 /* conditions and the following disclaimer. */
16 /* 2. Any modifications must be clearly marked as such. */
17 /* 3. Original authors' names are not deleted. */
18 /* 4. The authors' names are not used to endorse or promote products */
19 /* derived from this software without specific prior written */
20 /* permission. */
21 /* */
22 /* CARNEGIE MELLON UNIVERSITY AND THE CONTRIBUTORS TO THIS WORK */
23 /* DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING */
24 /* ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT */
25 /* SHALL CARNEGIE MELLON UNIVERSITY NOR THE CONTRIBUTORS BE LIABLE */
26 /* FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES */
27 /* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN */
28 /* AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, */
29 /* ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF */
30 /* THIS SOFTWARE. */
31 /* */
32 /*************************************************************************/
33 /* Author : Alan W Black */
34 /* Date : October 1999 */
35 /*-----------------------------------------------------------------------*/
36 /* */
37 /* Training method to split states of existing WFST based on data to */
38 /* optimize entropy */
39 /* */
40 /* Confusing as this has nothing to do with the modelling */
41 /* technique known as "maximum entropy" */
42 /* */
43 /*=======================================================================*/
44 #include <iostream>
45 #include <cstdlib>
46 #include "EST_WFST.h"
47 #include "wfst_aux.h"
48 #include "EST_Token.h"
49 #include "EST_simplestats.h"
50
51 VAL_REGISTER_TYPE_NODEL(trans,EST_WFST_Transition)
52 SIOD_REGISTER_CLASS(trans,EST_WFST_Transition)
53 VAL_REGISTER_CLASS(pdf,EST_DiscreteProbDistribution)
54 SIOD_REGISTER_CLASS(pdf,EST_DiscreteProbDistribution)
55
56 static LISP *find_state_usage(EST_WFST &wfst, LISP data);
57 static double entropy(const EST_WFST_State *s);
58 static LISP *find_state_entropies(const EST_WFST &wfst, LISP *data);
59 EST_WFST_Transition *find_best_trans_split(EST_WFST &wfst,
60 int split_state,
61 LISP *data);
62 static LISP find_best_split(EST_WFST &wfst,
63 int split_state_name,
64 LISP *data);
65 static double find_score_if_split(EST_WFST &wfst,
66 int fromstate,
67 EST_WFST_Transition *trans,
68 LISP *data);
69 static LISP find_split_pdfs(EST_WFST &wfst,
70 int split_state_name,
71 LISP *data,
72 EST_DiscreteProbDistribution &pdf_all);
73 static double score_pdf_combine(EST_DiscreteProbDistribution &a,
74 EST_DiscreteProbDistribution &b,
75 EST_DiscreteProbDistribution &all);
76 #if 0
77 static void split_state(EST_WFST &wfst, EST_WFST_Transition *trans);
78 #endif
79 static void split_state(EST_WFST &wfst, LISP trans_list, int ostate);
80
load_string_data(EST_WFST & wfst,EST_String & filename)81 LISP load_string_data(EST_WFST &wfst,EST_String &filename)
82 {
83 // Load in sentences into data table, assume sentence per line
84 EST_TokenStream ts;
85 LISP ss = NIL;
86 EST_String t;
87 int id;
88 int i,j;
89
90 if (ts.open(filename) == -1)
91 EST_error("wfst_train: failed to read data from \"%s\"",
92 (const char *)filename);
93
94 i = 0;
95 j = 0;
96 while (!ts.eof())
97 {
98 LISP s = NIL;
99 do
100 {
101 t = (EST_String)ts.get();
102 id = wfst.in_symbol(t);
103 if (id == -1)
104 {
105 cerr << "wfst_train: data contains unknown symbol \"" <<
106 t << "\"" << endl;
107 }
108 s = cons(flocons(id),s);
109 j++;
110 }
111 while (!ts.eoln() && !ts.eof());
112 i++;
113 ss = cons(reverse(s),ss);
114 }
115
116 printf("wfst_train: loaded %d lines of %d tokens\n",
117 i,j);
118
119 return reverse(ss);
120 }
121
find_state_usage(EST_WFST & wfst,LISP data)122 static LISP *find_state_usage(EST_WFST &wfst, LISP data)
123 {
124 // Builds list of states, and which data points the represent
125 LISP *state_data = new LISP[wfst.num_states()];
126 static LISP ddd = NIL;
127 int s,i,id;
128 LISP d,w;
129 EST_WFST_Transition *trans;
130 // EST_Litem *tp;
131
132 if (ddd == NIL)
133 gc_protect(&ddd);
134
135 ddd = NIL;
136
137 wfst.start_cumulate(); // zero existing weights
138
139 for (i=0; i < wfst.num_states(); i++)
140 {
141 state_data[i] = NIL;
142 ddd = cons(state_data[i],ddd);
143 // // smoothing
144 // for (tp=wfst.state(i)->transitions.head(); tp != 0; tp = tp->next())
145 // wfst.state(i)->transitions(tp)->set_weight(1);
146 }
147
148 for (i=0,d=data; d; d=cdr(d),i++)
149 {
150 s = wfst.start_state();
151 for (w=car(d); w; w=cdr(w))
152 {
153 state_data[s] = cons(w,state_data[s]);
154 id = get_c_int(car(w));
155 trans = wfst.find_transition(s,id,id);
156 if (!trans)
157 {
158 printf("sentence %d not in language, skipping\n",i);
159 continue;
160 }
161 else
162 {
163 trans->set_weight(trans->weight()+1);
164 s = trans->state();
165 }
166 }
167 }
168
169 wfst.stop_cumulate();
170 return state_data;
171 }
172
entropy(const EST_WFST_State * s)173 static double entropy(const EST_WFST_State *s)
174 {
175 double sentropy,w;
176 EST_Litem *tp;
177 for (sentropy=0,tp=s->transitions.head(); tp != 0; tp = tp->next())
178 {
179 w = s->transitions(tp)->weight(); /* the probability */
180 if (w > 0)
181 sentropy += w * log(w);
182 }
183 return -1 * sentropy;
184 }
185
wfst_train(EST_WFST & wfst,LISP data)186 void wfst_train(EST_WFST &wfst, LISP data)
187 {
188 LISP *state_data;
189 LISP *state_entropies;
190 LISP best_trans_list = NIL;
191 int c=0,i, max_entropy_state;
192 gc_protect(&data);
193
194 while (1)
195 {
196 // Build table of state to points in data, and cumulate transitions
197 state_data = find_state_usage(wfst,data);
198
199 /* find entropy for each state (sorted) */
200 state_entropies = find_state_entropies(wfst,state_data);
201
202 max_entropy_state = -1;
203 for (i=0; i < wfst.num_states(); i++)
204 {
205 // double me = (double)get_c_float(car(state_entropies[i]));
206 max_entropy_state = get_c_int(cdr(state_entropies[i]));
207 // printf("trying %d %g\n",max_entropy_state,me);
208
209 // best_trans = find_best_trans_split(wfst,max_entropy_state,
210 // state_data);
211 best_trans_list = find_best_split(wfst,max_entropy_state,
212 state_data);
213 if (best_trans_list != NIL)
214 break;
215 // else
216 // printf("No best trans\n");
217 }
218 delete [] state_entropies;
219
220 if (max_entropy_state == -1)
221 {
222 printf("No new max_entropy state\n");
223 break;
224 }
225 if (best_trans_list == NIL)
226 {
227 printf("No best_trans in max_entropy state\n");
228 break;
229 }
230
231 /* for each transition *entering* max_entropy_state */
232 /* find entropy if it were split */
233 /* find best split */
234
235 /* print stats */
236 /* some sort of stop check */
237 c++;
238 printf("c is %d\n",c);
239 if (c > 5000)
240 {
241 printf("reached cycle end %d\n",c);
242 break;
243 }
244 /* split on best split */
245 split_state(wfst, best_trans_list, max_entropy_state);
246
247 if ((c % 100) == 0)
248 {
249 EST_String chkpntname = "chkpnt";
250 char bbb[7];
251 sprintf(bbb,"%03d",c);
252 wfst.save(chkpntname+bbb+".wfst");
253 }
254
255 delete [] state_data;
256 user_gc(NIL);
257 }
258 }
259
me_compare_function(const void * a,const void * b)260 static int me_compare_function(const void *a, const void *b)
261 {
262 LISP la;
263 LISP lb;
264 la = *(LISP *)a;
265 lb = *(LISP *)b;
266
267 float fa = get_c_float(car(la));
268 float fb = get_c_float(car(lb));
269
270 if (fa < fb)
271 return 1;
272 else if (fa == fb)
273 return 0;
274 else
275 return -1;
276 }
277
find_state_entropies(const EST_WFST & wfst,LISP * data)278 static LISP *find_state_entropies(const EST_WFST &wfst, LISP *data)
279 {
280 double all_entropy = 0;
281 int i;
282 double sentropy;
283 LISP *slist = new LISP[wfst.num_states()];
284 static LISP ddd = NIL;
285
286 if (ddd == NIL)
287 gc_protect(&ddd);
288 ddd = NIL;
289
290 for (i=0; i < wfst.num_states(); i++)
291 {
292 const EST_WFST_State *s = wfst.state(i);
293 sentropy = entropy(s);
294 // printf("dlength is %d %d\n",i,siod_llength(data[i]));
295 all_entropy += sentropy * siod_llength(data[i]);
296 slist[i] = cons(flocons(sentropy),flocons(i));
297 ddd = cons(slist[i],ddd);
298 }
299 printf("average entropy is %g\n",all_entropy/i);
300
301 qsort(slist,wfst.num_states(),sizeof(LISP),me_compare_function);
302
303 return slist;
304 }
305
find_best_split(EST_WFST & wfst,int split_state_name,LISP * data)306 static LISP find_best_split(EST_WFST &wfst,
307 int split_state_name,
308 LISP *data)
309 {
310 // Find the best partition of incoming translations that
311 // minimises entropy
312 EST_DiscreteProbDistribution pdf_all(&wfst.in_symbols());
313 EST_DiscreteProbDistribution *a_pdf, *b_pdf;
314 LISP splits,s,dd,r;
315 LISP *ssplits;
316 gc_protect(&splits);
317 EST_String sname;
318 int b,best_b;
319 EST_Litem *i;
320 int num_pdfs;
321 double best_score, score, sfreq;
322
323 for (dd = data[split_state_name]; dd; dd = cdr(dd))
324 pdf_all.cumulate(get_c_int(car(car(dd))));
325 splits = find_split_pdfs(wfst,split_state_name,data,pdf_all);
326 if (siod_llength(splits) < 2)
327 return NIL;
328 ssplits = new LISP[siod_llength(splits)];
329 for (num_pdfs=0,s=splits; s != NIL; s=cdr(s),num_pdfs++)
330 ssplits[num_pdfs] = car(s);
331
332 qsort(ssplits,num_pdfs,sizeof(LISP),me_compare_function);
333 // Combine trans pdfs in pdfs until more combination doesn't improve
334 while (1)
335 {
336
337 best_score = get_c_float(car(ssplits[0]));
338 best_b = -1;
339 a_pdf = pdf(car(cdr(cdr(ssplits[0]))));
340 for (b=1; b < num_pdfs; b++)
341 {
342 if (ssplits[b] == NIL)
343 continue;
344 score = score_pdf_combine(*a_pdf,*pdf(car(cdr(cdr(ssplits[b])))),
345 pdf_all);
346 if (score < best_score)
347 {
348 best_score = score;
349 best_b = b;
350 }
351 }
352
353 // combine a and b
354 if (best_b == -1)
355 break;
356 else
357 {
358 // combine a and b
359 // Add trans to 0
360 setcar(cdr(ssplits[0]),
361 append(car(cdr(ssplits[0])),
362 car(cdr(ssplits[best_b]))));
363 setcar(ssplits[0], flocons(best_score));
364 // Update 0's pdf with values from best_b's
365 b_pdf = pdf(car(cdr(cdr(ssplits[best_b]))));
366 for (i=b_pdf->item_start(); !b_pdf->item_end(i);
367 i = b_pdf->item_next(i))
368 {
369 b_pdf->item_freq(i,sname,sfreq);
370 a_pdf->cumulate(i,sfreq);
371 }
372 ssplits[best_b] = NIL;
373 }
374
375 }
376
377 printf("score %g ",(double)get_c_float(car(ssplits[0])));
378 for (dd=car(cdr(ssplits[0])); dd; dd=cdr(dd))
379 printf("%s ",(const char *)wfst.in_symbol(trans(car(dd))->in_symbol()));
380 printf("\n");
381 gc_unprotect(&splits);
382 r = car(cdr(ssplits[0]));
383 delete [] ssplits;
384 return r;
385 }
386
score_pdf_combine(EST_DiscreteProbDistribution & a,EST_DiscreteProbDistribution & b,EST_DiscreteProbDistribution & all)387 static double score_pdf_combine(EST_DiscreteProbDistribution &a,
388 EST_DiscreteProbDistribution &b,
389 EST_DiscreteProbDistribution &all)
390 {
391 // Find score of (a+b) vs (all-(a+b))
392 EST_DiscreteProbDistribution ab(a);
393 EST_DiscreteProbDistribution all_but_ab(all);
394 EST_Litem *i;
395 EST_String sname;
396 double sfreq, score;
397 for (i=b.item_start(); !b.item_end(i);
398 i = b.item_next(i))
399 {
400 b.item_freq(i,sname,sfreq);
401 ab.cumulate(i,sfreq);
402 }
403
404 for (i=ab.item_start(); !ab.item_end(i);
405 i = ab.item_next(i))
406 {
407 ab.item_freq(i,sname,sfreq);
408 all_but_ab.cumulate(i,-1*sfreq);
409 }
410
411 score = (ab.entropy() * ab.samples()) +
412 (all_but_ab.entropy() * all_but_ab.samples());
413
414 return score;
415
416 }
417
find_split_pdfs(EST_WFST & wfst,int split_state_name,LISP * data,EST_DiscreteProbDistribution & pdf_all)418 static LISP find_split_pdfs(EST_WFST &wfst,
419 int split_state_name,
420 LISP *data,
421 EST_DiscreteProbDistribution &pdf_all)
422 {
423 // Find following pdfs for each incoming transition as if they where
424 // split to a new state
425 int i,id, in;
426 EST_Litem *tp;
427 LISP pdfs = NIL,dd,ttt,p,t;
428 EST_DiscreteProbDistribution empty;
429 double value;
430
431 for (i=0; i < wfst.num_states(); i++)
432 {
433 const EST_WFST_State *s = wfst.state(i);
434 for (tp=s->transitions.head(); tp != 0; tp = tp->next())
435 {
436 if ((s->transitions(tp)->state() == split_state_name)
437 && (s->transitions(tp)->weight() > 0))
438 {
439 in = s->transitions(tp)->in_symbol();
440 EST_DiscreteProbDistribution *pdf =
441 new EST_DiscreteProbDistribution(&wfst.in_symbols());
442 for (dd = data[i]; dd; dd = cdr(dd))
443 {
444 id = get_c_int(car(car(dd)));
445 if (id == in)
446 { // This one would go to the new state so we count it
447 if (cdr(car(dd))) // not end of data string
448 pdf->cumulate(get_c_int(car(cdr(car(dd)))));
449 }
450 }
451 // value, list of trans, pdf
452 value = score_pdf_combine(*pdf,empty,pdf_all);
453 if ((value > 0) && // ignore transitions with no data
454 (pdf->samples() > 10))// and those with only a few data pnts
455 {
456 t = siod(s->transitions(tp));
457 p = siod(pdf);
458 ttt = cons(flocons(value),
459 cons(cons(t,NIL),
460 cons(p,NIL)));
461 pdfs = cons(ttt,pdfs);
462 }
463 else
464 delete pdf;
465 }
466 }
467 }
468 return pdfs;
469 }
470
find_best_trans_split(EST_WFST & wfst,int split_state_name,LISP * data)471 EST_WFST_Transition *find_best_trans_split(EST_WFST &wfst,
472 int split_state_name,
473 LISP *data)
474 {
475 EST_Litem *tp;
476 EST_WFST_Transition *best_trans = 0;
477 const EST_WFST_State *split_state = wfst.state(split_state_name);
478 double best_score,bb;
479 int i;
480
481 best_score = entropy(split_state)*siod_llength(data[split_state_name]);
482 // printf("unsplit score %g\n",best_score);
483
484 /* For each transition going to split_state */
485 for (i=1; i < wfst.num_states(); i++)
486 {
487 const EST_WFST_State *s = wfst.state(i);
488 for (tp=s->transitions.head(); tp != 0; tp = tp->next())
489 {
490 if ((wfst.state(s->transitions(tp)->state()) == split_state) &&
491 (s->transitions(tp)->weight() > 0))
492 {
493 bb = find_score_if_split(wfst,i,s->transitions(tp),data);
494 // cout << i << " "
495 // << wfst.in_symbol(s->transitions(tp)->in_symbol()) << " "
496 // << s->transitions(tp)->state() << " " << bb << endl;
497 if (bb == -1) /* didn't find a split */
498 continue;
499 if (bb < best_score)
500 {
501 best_score = bb;
502 best_trans = s->transitions(tp);
503 }
504 }
505 }
506 }
507
508 if (best_trans)
509 cout << "best " << wfst.in_symbol(best_trans->in_symbol()) << " "
510 << best_trans->weight() << " "
511 << best_trans->state() << " " << best_score << endl;
512 return best_trans;
513 }
514
find_score_if_split(EST_WFST & wfst,int fromstate,EST_WFST_Transition * trans,LISP * data)515 static double find_score_if_split(EST_WFST &wfst,
516 int fromstate,
517 EST_WFST_Transition *trans,
518 LISP *data)
519 {
520 double ent_split;
521 double ent_remain;
522 double score;
523 EST_DiscreteProbDistribution pdf_split(&wfst.in_symbols());
524 EST_DiscreteProbDistribution pdf_remain(&wfst.in_symbols());
525 int in, tostate, id;
526 EST_Litem *i;
527 double sfreq;
528 EST_String sname;
529
530 ent_split = ent_remain = 32*32*32*32;
531 LISP dd;
532
533 // printf("considering %d %s %g %d\n",
534 // fromstate,
535 // (const char *)wfst.in_symbol(trans->in_symbol()),
536 // trans->weight(),
537 // trans->state());
538
539 /* find entropy of possible new state */
540 /* for each data point through fromstate */
541 in = trans->in_symbol();
542 for (dd = data[fromstate]; dd; dd = cdr(dd))
543 {
544 id = get_c_int(car(car(dd)));
545 if (id == in)
546 { // This one would go to the new state so we count it
547 if (cdr(car(dd))) // not end of data string
548 pdf_split.cumulate(get_c_int(car(cdr(car(dd)))));
549 }
550 }
551 if (pdf_split.samples() > 0)
552 ent_split = pdf_split.entropy();
553 /* find entropy of old state minus trans into it */
554 tostate = trans->state();
555 // Actually only need to do this once per state
556 for (dd = data[tostate]; dd; dd = cdr(dd))
557 pdf_remain.cumulate(get_c_int(car(car(dd))));
558 // Subtract the bit thats split
559 for (i=pdf_split.item_start(); !pdf_split.item_end(i);
560 i = pdf_split.item_next(i))
561 {
562 pdf_split.item_freq(i,sname,sfreq);
563 pdf_remain.cumulate(i,-1*sfreq);
564 }
565 if (pdf_remain.samples() > 0)
566 ent_remain = pdf_remain.entropy();
567
568 if ((pdf_remain.samples() == 0) ||
569 (pdf_split.samples() == 0))
570 return -1;
571
572 score = (ent_remain * pdf_remain.samples()) +
573 (ent_split * pdf_split.samples());
574 // printf("tostate %d remain %g %d split %g %d score %g\n",
575 // tostate, ent_remain, (int)pdf_remain.samples(),
576 // ent_split, (int)pdf_split.samples(), score);
577
578 return score;
579 }
580
581 #if 0
582 static void split_state(EST_WFST &wfst, EST_WFST_Transition *trans)
583 {
584 /* Split off a new state for given trans. Add transitions */
585 /* to this new state for all transitions in (old) state trans */
586 /* goes to */
587 EST_Litem *tp;
588 int nstate = wfst.add_state(wfst_final);
589 int ostate = trans->state();
590
591 // printf("state %d entropy %g\n",ostate,entropy(wfst.state(ostate)));
592 /* must be done before adding the new transitions to nstate */
593 trans->set_state(nstate);
594
595 for (tp=wfst.state(ostate)->transitions.head(); tp != 0; tp = tp->next())
596 {
597 wfst.state_non_const(nstate)->
598 add_transition(0.0, /* weight will be filled in later*/
599 wfst.state(ostate)->transitions(tp)->state(),
600 wfst.state(ostate)->transitions(tp)->in_symbol(),
601 wfst.state(ostate)->transitions(tp)->out_symbol());
602
603 }
604 // printf(" nstate %d entropy %g\n",nstate,entropy(wfst.state(nstate)));
605 // printf(" ostate %d entropy %g\n",ostate,entropy(wfst.state(ostate)));
606
607 }
608 #endif
609
split_state(EST_WFST & wfst,LISP trans_list,int ostate)610 static void split_state(EST_WFST &wfst, LISP trans_list, int ostate)
611 {
612 /* Split off a new state for given trans. Add transitions */
613 /* to this new state for all transitions in (old) state trans */
614 /* goes to */
615 EST_Litem *tp;
616 int nstate = wfst.add_state(wfst_final);
617 LISP t;
618
619 /* must be done before adding the new transitions to nstate */
620 for (t=trans_list; t; t=cdr(t))
621 trans(car(t))->set_state(nstate);
622
623 for (tp=wfst.state(ostate)->transitions.head(); tp != 0; tp = tp->next())
624 {
625 wfst.state_non_const(nstate)->
626 add_transition(0.0, /* weight will be filled in later*/
627 wfst.state(ostate)->transitions(tp)->state(),
628 wfst.state(ostate)->transitions(tp)->in_symbol(),
629 wfst.state(ostate)->transitions(tp)->out_symbol());
630
631 }
632 }
633
634