1 /*************************************************************************/
2 /*                                                                       */
3 /*                  Language Technologies Institute                      */
4 /*                     Carnegie Mellon University                        */
5 /*                      Copyright (c) 1999-2003                          */
6 /*                        All Rights Reserved.                           */
7 /*                                                                       */
8 /*  Permission is hereby granted, free of charge, to use and distribute  */
9 /*  this software and its documentation without restriction, including   */
10 /*  without limitation the rights to use, copy, modify, merge, publish,  */
11 /*  distribute, sublicense, and/or sell copies of this work, and to      */
12 /*  permit persons to whom this work is furnished to do so, subject to   */
13 /*  the following conditions:                                            */
14 /*   1. The code must retain the above copyright notice, this list of    */
15 /*      conditions and the following disclaimer.                         */
16 /*   2. Any modifications must be clearly marked as such.                */
17 /*   3. Original authors' names are not deleted.                         */
18 /*   4. The authors' names are not used to endorse or promote products   */
19 /*      derived from this software without specific prior written        */
20 /*      permission.                                                      */
21 /*                                                                       */
22 /*  CARNEGIE MELLON UNIVERSITY AND THE CONTRIBUTORS TO THIS WORK         */
23 /*  DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING      */
24 /*  ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT   */
25 /*  SHALL CARNEGIE MELLON UNIVERSITY NOR THE CONTRIBUTORS BE LIABLE      */
26 /*  FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES    */
27 /*  WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN   */
28 /*  AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION,          */
29 /*  ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF       */
30 /*  THIS SOFTWARE.                                                       */
31 /*                                                                       */
32 /*************************************************************************/
33 /*                     Author :  Alan W Black                            */
34 /*                     Date   :  October 1999                            */
35 /*-----------------------------------------------------------------------*/
36 /*                                                                       */
37 /* Training method to split states of existing WFST based on data to     */
38 /* optimize entropy                                                      */
39 /*                                                                       */
40 /* Confusing as this has nothing to do with the modelling                */
41 /* technique known as "maximum entropy"                                  */
42 /*                                                                       */
43 /*=======================================================================*/
44 #include <iostream>
45 #include <cstdlib>
46 #include "EST_WFST.h"
47 #include "wfst_aux.h"
48 #include "EST_Token.h"
49 #include "EST_simplestats.h"
50 
51 VAL_REGISTER_TYPE_NODEL(trans,EST_WFST_Transition)
52 SIOD_REGISTER_CLASS(trans,EST_WFST_Transition)
53 VAL_REGISTER_CLASS(pdf,EST_DiscreteProbDistribution)
54 SIOD_REGISTER_CLASS(pdf,EST_DiscreteProbDistribution)
55 
56 static LISP *find_state_usage(EST_WFST &wfst, LISP data);
57 static double entropy(const EST_WFST_State *s);
58 static LISP *find_state_entropies(const EST_WFST &wfst, LISP *data);
59 EST_WFST_Transition *find_best_trans_split(EST_WFST &wfst,
60 					   int split_state,
61 					   LISP *data);
62 static LISP find_best_split(EST_WFST &wfst,
63 			    int split_state_name,
64 			    LISP *data);
65 static double find_score_if_split(EST_WFST &wfst,
66 				  int fromstate,
67 				  EST_WFST_Transition *trans,
68 				  LISP *data);
69 static LISP find_split_pdfs(EST_WFST &wfst,
70 			    int split_state_name,
71 			    LISP *data,
72 			    EST_DiscreteProbDistribution &pdf_all);
73 static double score_pdf_combine(EST_DiscreteProbDistribution &a,
74 				EST_DiscreteProbDistribution &b,
75 				EST_DiscreteProbDistribution &all);
76 #if 0
77 static void split_state(EST_WFST &wfst, EST_WFST_Transition *trans);
78 #endif
79 static void split_state(EST_WFST &wfst, LISP trans_list, int ostate);
80 
load_string_data(EST_WFST & wfst,EST_String & filename)81 LISP load_string_data(EST_WFST &wfst,EST_String &filename)
82 {
83     // Load in sentences into data table, assume sentence per line
84     EST_TokenStream ts;
85     LISP ss = NIL;
86     EST_String t;
87     int id;
88     int i,j;
89 
90     if (ts.open(filename) == -1)
91 	EST_error("wfst_train: failed to read data from \"%s\"",
92 			  (const char *)filename);
93 
94     i = 0;
95     j = 0;
96     while (!ts.eof())
97     {
98 	LISP s = NIL;
99 	do
100 	{
101 	    t = (EST_String)ts.get();
102 	    id = wfst.in_symbol(t);
103 	    if (id == -1)
104 	    {
105 		cerr << "wfst_train: data contains unknown symbol \"" <<
106 		    t << "\"" << endl;
107 	    }
108 	    s = cons(flocons(id),s);
109 	    j++;
110 	}
111 	while (!ts.eoln() && !ts.eof());
112 	i++;
113 	ss = cons(reverse(s),ss);
114     }
115 
116     printf("wfst_train: loaded %d lines of %d tokens\n",
117 	   i,j);
118 
119     return reverse(ss);
120 }
121 
find_state_usage(EST_WFST & wfst,LISP data)122 static LISP *find_state_usage(EST_WFST &wfst, LISP data)
123 {
124     // Builds list of states, and which data points the represent
125     LISP *state_data = new LISP[wfst.num_states()];
126     static LISP ddd = NIL;
127     int s,i,id;
128     LISP d,w;
129     EST_WFST_Transition *trans;
130 //    EST_Litem *tp;
131 
132     if (ddd == NIL)
133 	gc_protect(&ddd);
134 
135     ddd = NIL;
136 
137     wfst.start_cumulate();   // zero existing weights
138 
139     for (i=0; i < wfst.num_states(); i++)
140     {
141 	state_data[i] = NIL;
142 	ddd = cons(state_data[i],ddd);
143 //	// smoothing
144 //	for (tp=wfst.state(i)->transitions.head(); tp != 0; tp = tp->next())
145 //	    wfst.state(i)->transitions(tp)->set_weight(1);
146     }
147 
148     for (i=0,d=data; d; d=cdr(d),i++)
149     {
150 	s = wfst.start_state();
151 	for (w=car(d); w; w=cdr(w))
152 	{
153 	    state_data[s] = cons(w,state_data[s]);
154 	    id = get_c_int(car(w));
155 	    trans = wfst.find_transition(s,id,id);
156 	    if (!trans)
157 	    {
158 		printf("sentence %d not in language, skipping\n",i);
159 		continue;
160 	    }
161 	    else
162 	    {
163 		trans->set_weight(trans->weight()+1);
164 		s = trans->state();
165 	    }
166 	}
167     }
168 
169     wfst.stop_cumulate();
170     return state_data;
171 }
172 
entropy(const EST_WFST_State * s)173 static double entropy(const EST_WFST_State *s)
174 {
175     double sentropy,w;
176     EST_Litem *tp;
177     for (sentropy=0,tp=s->transitions.head(); tp != 0; tp = tp->next())
178     {
179 	w = s->transitions(tp)->weight();  /* the probability */
180 	if (w > 0)
181 	    sentropy += w * log(w);
182     }
183     return -1 * sentropy;
184 }
185 
wfst_train(EST_WFST & wfst,LISP data)186 void wfst_train(EST_WFST &wfst, LISP data)
187 {
188     LISP *state_data;
189     LISP *state_entropies;
190     LISP best_trans_list = NIL;
191     int c=0,i, max_entropy_state;
192     gc_protect(&data);
193 
194     while (1)
195     {
196 	// Build table of state to points in data, and cumulate transitions
197 	state_data = find_state_usage(wfst,data);
198 
199 	/* find entropy for each state (sorted) */
200 	state_entropies = find_state_entropies(wfst,state_data);
201 
202 	max_entropy_state = -1;
203 	for (i=0; i < wfst.num_states(); i++)
204 	{
205 //	    double me = (double)get_c_float(car(state_entropies[i]));
206 	    max_entropy_state = get_c_int(cdr(state_entropies[i]));
207 //	    printf("trying %d %g\n",max_entropy_state,me);
208 
209 //	    best_trans = find_best_trans_split(wfst,max_entropy_state,
210 //					       state_data);
211 	    best_trans_list = find_best_split(wfst,max_entropy_state,
212 					      state_data);
213 	    if (best_trans_list != NIL)
214 		break;
215 //	    else
216 //		printf("No best trans\n");
217 	}
218 	delete [] state_entropies;
219 
220 	if (max_entropy_state == -1)
221 	{
222 	    printf("No new max_entropy state\n");
223 	    break;
224 	}
225 	if (best_trans_list == NIL)
226 	{
227 	    printf("No best_trans in max_entropy state\n");
228 	    break;
229 	}
230 
231         /* for each transition *entering* max_entropy_state */
232         /*     find entropy if it were split          */
233         /*     find best split                      */
234 
235         /* print stats */
236         /* some sort of stop check */
237 	c++;
238 	printf("c is %d\n",c);
239 	if (c > 5000)
240 	{
241 	    printf("reached cycle end %d\n",c);
242 	    break;
243 	}
244         /* split on best split                      */
245         split_state(wfst, best_trans_list, max_entropy_state);
246 
247 	if ((c % 100) == 0)
248 	{
249 	    EST_String chkpntname = "chkpnt";
250 	    char bbb[7];
251 	    sprintf(bbb,"%03d",c);
252 	    wfst.save(chkpntname+bbb+".wfst");
253 	}
254 
255 	delete [] state_data;
256 	user_gc(NIL);
257     }
258 }
259 
me_compare_function(const void * a,const void * b)260 static int me_compare_function(const void *a, const void *b)
261 {
262     LISP la;
263     LISP lb;
264     la = *(LISP *)a;
265     lb = *(LISP *)b;
266 
267     float fa = get_c_float(car(la));
268     float fb = get_c_float(car(lb));
269 
270     if (fa < fb)
271 	return 1;
272     else if (fa == fb)
273 	return 0;
274     else
275 	return -1;
276 }
277 
find_state_entropies(const EST_WFST & wfst,LISP * data)278 static LISP *find_state_entropies(const EST_WFST &wfst, LISP *data)
279 {
280     double all_entropy = 0;
281     int i;
282     double sentropy;
283     LISP *slist = new LISP[wfst.num_states()];
284     static LISP ddd = NIL;
285 
286     if (ddd == NIL)
287 	gc_protect(&ddd);
288     ddd = NIL;
289 
290     for (i=0; i < wfst.num_states(); i++)
291     {
292 	const EST_WFST_State *s = wfst.state(i);
293 	sentropy = entropy(s);
294 //	printf("dlength is %d %d\n",i,siod_llength(data[i]));
295 	all_entropy += sentropy * siod_llength(data[i]);
296 	slist[i] = cons(flocons(sentropy),flocons(i));
297 	ddd = cons(slist[i],ddd);
298     }
299     printf("average entropy is %g\n",all_entropy/i);
300 
301     qsort(slist,wfst.num_states(),sizeof(LISP),me_compare_function);
302 
303     return slist;
304 }
305 
find_best_split(EST_WFST & wfst,int split_state_name,LISP * data)306 static LISP find_best_split(EST_WFST &wfst,
307 			    int split_state_name,
308 			    LISP *data)
309 {
310     // Find the best partition of incoming translations that
311     // minimises entropy
312     EST_DiscreteProbDistribution pdf_all(&wfst.in_symbols());
313     EST_DiscreteProbDistribution *a_pdf, *b_pdf;
314     LISP splits,s,dd,r;
315     LISP *ssplits;
316     gc_protect(&splits);
317     EST_String sname;
318     int b,best_b;
319     EST_Litem *i;
320     int num_pdfs;
321     double best_score, score, sfreq;
322 
323     for (dd = data[split_state_name]; dd; dd = cdr(dd))
324 	pdf_all.cumulate(get_c_int(car(car(dd))));
325     splits = find_split_pdfs(wfst,split_state_name,data,pdf_all);
326     if (siod_llength(splits) < 2)
327 	return NIL;
328     ssplits = new LISP[siod_llength(splits)];
329     for (num_pdfs=0,s=splits; s != NIL; s=cdr(s),num_pdfs++)
330 	ssplits[num_pdfs] = car(s);
331 
332     qsort(ssplits,num_pdfs,sizeof(LISP),me_compare_function);
333     // Combine trans pdfs in pdfs until more combination doesn't improve
334     while (1)
335     {
336 
337 	best_score = get_c_float(car(ssplits[0]));
338 	best_b = -1;
339 	a_pdf = pdf(car(cdr(cdr(ssplits[0]))));
340         for (b=1; b < num_pdfs; b++)
341 	{
342 	    if (ssplits[b] == NIL)
343 		continue;
344 	    score = score_pdf_combine(*a_pdf,*pdf(car(cdr(cdr(ssplits[b])))),
345 				      pdf_all);
346 	    if (score < best_score)
347 	    {
348 		best_score = score;
349 		best_b = b;
350 	    }
351 	}
352 
353 	// combine a and b
354 	if (best_b == -1)
355 	    break;
356 	else
357 	{
358 	    // combine a and b
359 	    // Add trans to 0
360 	    setcar(cdr(ssplits[0]),
361 		   append(car(cdr(ssplits[0])),
362 			  car(cdr(ssplits[best_b]))));
363 	    setcar(ssplits[0], flocons(best_score));
364 	    // Update 0's pdf with values from best_b's
365 	    b_pdf = pdf(car(cdr(cdr(ssplits[best_b]))));
366 	    for (i=b_pdf->item_start(); !b_pdf->item_end(i);
367 		 i = b_pdf->item_next(i))
368 	    {
369 		b_pdf->item_freq(i,sname,sfreq);
370 		a_pdf->cumulate(i,sfreq);
371 	    }
372 	    ssplits[best_b] = NIL;
373 	}
374 
375     }
376 
377     printf("score %g ",(double)get_c_float(car(ssplits[0])));
378     for (dd=car(cdr(ssplits[0])); dd; dd=cdr(dd))
379 	printf("%s ",(const char *)wfst.in_symbol(trans(car(dd))->in_symbol()));
380     printf("\n");
381     gc_unprotect(&splits);
382     r = car(cdr(ssplits[0]));
383     delete [] ssplits;
384     return r;
385 }
386 
score_pdf_combine(EST_DiscreteProbDistribution & a,EST_DiscreteProbDistribution & b,EST_DiscreteProbDistribution & all)387 static double score_pdf_combine(EST_DiscreteProbDistribution &a,
388 				EST_DiscreteProbDistribution &b,
389 				EST_DiscreteProbDistribution &all)
390 {
391     // Find score of (a+b) vs (all-(a+b))
392     EST_DiscreteProbDistribution ab(a);
393     EST_DiscreteProbDistribution all_but_ab(all);
394     EST_Litem *i;
395     EST_String sname;
396     double sfreq, score;
397     for (i=b.item_start(); !b.item_end(i);
398 	 i = b.item_next(i))
399     {
400 	b.item_freq(i,sname,sfreq);
401 	ab.cumulate(i,sfreq);
402     }
403 
404     for (i=ab.item_start(); !ab.item_end(i);
405 	 i = ab.item_next(i))
406     {
407 	ab.item_freq(i,sname,sfreq);
408 	all_but_ab.cumulate(i,-1*sfreq);
409     }
410 
411     score = (ab.entropy() * ab.samples()) +
412 	(all_but_ab.entropy() * all_but_ab.samples());
413 
414     return score;
415 
416 }
417 
find_split_pdfs(EST_WFST & wfst,int split_state_name,LISP * data,EST_DiscreteProbDistribution & pdf_all)418 static LISP find_split_pdfs(EST_WFST &wfst,
419 			    int split_state_name,
420 			    LISP *data,
421 			    EST_DiscreteProbDistribution &pdf_all)
422 {
423     // Find following pdfs for each incoming transition as if they where
424     // split to a new state
425     int i,id, in;
426     EST_Litem *tp;
427     LISP pdfs = NIL,dd,ttt,p,t;
428     EST_DiscreteProbDistribution empty;
429     double value;
430 
431     for (i=0; i < wfst.num_states(); i++)
432     {
433 	const EST_WFST_State *s = wfst.state(i);
434 	for (tp=s->transitions.head(); tp != 0; tp = tp->next())
435 	{
436 	    if ((s->transitions(tp)->state() == split_state_name)
437 		&& (s->transitions(tp)->weight() > 0))
438 	    {
439 		in = s->transitions(tp)->in_symbol();
440 		EST_DiscreteProbDistribution *pdf =
441 		    new EST_DiscreteProbDistribution(&wfst.in_symbols());
442 		for (dd = data[i]; dd; dd = cdr(dd))
443 		{
444 		    id = get_c_int(car(car(dd)));
445 		    if (id == in)
446 		    {   // This one would go to the new state so we count it
447 			if (cdr(car(dd))) // not end of data string
448 			    pdf->cumulate(get_c_int(car(cdr(car(dd)))));
449 		    }
450 		}
451 		// value, list of trans, pdf
452 		value = score_pdf_combine(*pdf,empty,pdf_all);
453 		if ((value > 0) && // ignore transitions with no data
454 		    (pdf->samples() > 10))// and those with only a few data pnts
455 		{
456 		    t = siod(s->transitions(tp));
457 		    p = siod(pdf);
458 		    ttt = cons(flocons(value),
459 			       cons(cons(t,NIL),
460 				    cons(p,NIL)));
461 		    pdfs = cons(ttt,pdfs);
462 		}
463 		else
464 		    delete pdf;
465 	    }
466 	}
467     }
468     return pdfs;
469 }
470 
find_best_trans_split(EST_WFST & wfst,int split_state_name,LISP * data)471 EST_WFST_Transition *find_best_trans_split(EST_WFST &wfst,
472 					   int split_state_name,
473 					   LISP *data)
474 {
475     EST_Litem *tp;
476     EST_WFST_Transition *best_trans = 0;
477     const EST_WFST_State *split_state = wfst.state(split_state_name);
478     double best_score,bb;
479     int i;
480 
481     best_score = entropy(split_state)*siod_llength(data[split_state_name]);
482 //    printf("unsplit score %g\n",best_score);
483 
484     /* For each transition going to split_state */
485     for (i=1; i < wfst.num_states(); i++)
486     {
487 	const EST_WFST_State *s = wfst.state(i);
488 	for (tp=s->transitions.head(); tp != 0; tp = tp->next())
489 	{
490 	    if ((wfst.state(s->transitions(tp)->state()) == split_state) &&
491 		(s->transitions(tp)->weight() > 0))
492 	    {
493 		bb = find_score_if_split(wfst,i,s->transitions(tp),data);
494 //		cout << i << " "
495 //		     << wfst.in_symbol(s->transitions(tp)->in_symbol()) << " "
496 //		     << s->transitions(tp)->state() << " " << bb << endl;
497 		if (bb == -1)  /* didn't find a split */
498 		    continue;
499 		if (bb < best_score)
500 		{
501 		    best_score = bb;
502 		    best_trans = s->transitions(tp);
503 		}
504 	    }
505 	}
506     }
507 
508     if (best_trans)
509 	cout << "best " << wfst.in_symbol(best_trans->in_symbol()) << " "
510 	     << best_trans->weight() << " "
511  	     << best_trans->state() << " " << best_score << endl;
512     return best_trans;
513 }
514 
find_score_if_split(EST_WFST & wfst,int fromstate,EST_WFST_Transition * trans,LISP * data)515 static double find_score_if_split(EST_WFST &wfst,
516                                   int fromstate,
517 				  EST_WFST_Transition *trans,
518 				  LISP *data)
519 {
520     double ent_split;
521     double ent_remain;
522     double score;
523     EST_DiscreteProbDistribution pdf_split(&wfst.in_symbols());
524     EST_DiscreteProbDistribution pdf_remain(&wfst.in_symbols());
525     int in, tostate, id;
526     EST_Litem *i;
527     double sfreq;
528     EST_String sname;
529 
530     ent_split = ent_remain = 32*32*32*32;
531     LISP dd;
532 
533 //    printf("considering %d %s %g %d\n",
534 //	   fromstate,
535 //	   (const char *)wfst.in_symbol(trans->in_symbol()),
536 //	   trans->weight(),
537 //	   trans->state());
538 
539     /* find entropy of possible new state */
540     /* for each data point through fromstate */
541     in = trans->in_symbol();
542     for (dd = data[fromstate]; dd; dd = cdr(dd))
543     {
544 	id = get_c_int(car(car(dd)));
545 	if (id == in)
546 	{   // This one would go to the new state so we count it
547 	    if (cdr(car(dd))) // not end of data string
548 		pdf_split.cumulate(get_c_int(car(cdr(car(dd)))));
549 	}
550     }
551     if (pdf_split.samples() > 0)
552 	ent_split = pdf_split.entropy();
553     /* find entropy of old state minus trans into it */
554     tostate = trans->state();
555     // Actually only need to do this once per state
556     for (dd = data[tostate]; dd; dd = cdr(dd))
557 	pdf_remain.cumulate(get_c_int(car(car(dd))));
558     // Subtract the bit thats split
559     for (i=pdf_split.item_start(); !pdf_split.item_end(i);
560 	 i = pdf_split.item_next(i))
561     {
562 	pdf_split.item_freq(i,sname,sfreq);
563 	pdf_remain.cumulate(i,-1*sfreq);
564     }
565     if (pdf_remain.samples() > 0)
566 	ent_remain = pdf_remain.entropy();
567 
568     if ((pdf_remain.samples() == 0) ||
569 	(pdf_split.samples() == 0))
570 	return -1;
571 
572     score = (ent_remain * pdf_remain.samples()) +
573 	(ent_split * pdf_split.samples());
574 //    printf("tostate %d remain %g %d split %g %d score %g\n",
575 //	   tostate, ent_remain, (int)pdf_remain.samples(),
576 //	   ent_split, (int)pdf_split.samples(), score);
577 
578     return score;
579 }
580 
581 #if 0
582 static void split_state(EST_WFST &wfst, EST_WFST_Transition *trans)
583 {
584     /* Split off a new state for given trans.  Add transitions    */
585     /* to this new state for all transitions in (old) state trans */
586     /* goes to                                                    */
587     EST_Litem *tp;
588     int nstate = wfst.add_state(wfst_final);
589     int ostate = trans->state();
590 
591 //    printf("state %d entropy %g\n",ostate,entropy(wfst.state(ostate)));
592     /* must be done before adding the new transitions to nstate */
593     trans->set_state(nstate);
594 
595     for (tp=wfst.state(ostate)->transitions.head(); tp != 0; tp = tp->next())
596     {
597 	wfst.state_non_const(nstate)->
598 	    add_transition(0.0,  /* weight will be filled in later*/
599 			   wfst.state(ostate)->transitions(tp)->state(),
600 			   wfst.state(ostate)->transitions(tp)->in_symbol(),
601 			   wfst.state(ostate)->transitions(tp)->out_symbol());
602 
603     }
604 //    printf(" nstate %d entropy %g\n",nstate,entropy(wfst.state(nstate)));
605 //    printf(" ostate %d entropy %g\n",ostate,entropy(wfst.state(ostate)));
606 
607 }
608 #endif
609 
split_state(EST_WFST & wfst,LISP trans_list,int ostate)610 static void split_state(EST_WFST &wfst, LISP trans_list, int ostate)
611 {
612     /* Split off a new state for given trans.  Add transitions    */
613     /* to this new state for all transitions in (old) state trans */
614     /* goes to                                                    */
615     EST_Litem *tp;
616     int nstate = wfst.add_state(wfst_final);
617     LISP t;
618 
619     /* must be done before adding the new transitions to nstate */
620     for (t=trans_list; t; t=cdr(t))
621 	trans(car(t))->set_state(nstate);
622 
623     for (tp=wfst.state(ostate)->transitions.head(); tp != 0; tp = tp->next())
624     {
625 	wfst.state_non_const(nstate)->
626 	    add_transition(0.0,  /* weight will be filled in later*/
627 			   wfst.state(ostate)->transitions(tp)->state(),
628 			   wfst.state(ostate)->transitions(tp)->in_symbol(),
629 			   wfst.state(ostate)->transitions(tp)->out_symbol());
630 
631     }
632 }
633 
634