1 /*************************************************************************/
2 /*                                                                       */
3 /*                Centre for Speech Technology Research                  */
4 /*                     University of Edinburgh, UK                       */
5 /*                      Copyright (c) 1996,1997                          */
6 /*                        All Rights Reserved.                           */
7 /*                                                                       */
8 /*  Permission is hereby granted, free of charge, to use and distribute  */
9 /*  this software and its documentation without restriction, including   */
10 /*  without limitation the rights to use, copy, modify, merge, publish,  */
11 /*  distribute, sublicense, and/or sell copies of this work, and to      */
12 /*  permit persons to whom this work is furnished to do so, subject to   */
13 /*  the following conditions:                                            */
14 /*   1. The code must retain the above copyright notice, this list of    */
15 /*      conditions and the following disclaimer.                         */
16 /*   2. Any modifications must be clearly marked as such.                */
17 /*   3. Original authors' names are not deleted.                         */
18 /*   4. The authors' names are not used to endorse or promote products   */
19 /*      derived from this software without specific prior written        */
20 /*      permission.                                                      */
21 /*                                                                       */
22 /*  THE UNIVERSITY OF EDINBURGH AND THE CONTRIBUTORS TO THIS WORK        */
23 /*  DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING      */
24 /*  ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT   */
25 /*  SHALL THE UNIVERSITY OF EDINBURGH NOR THE CONTRIBUTORS BE LIABLE     */
26 /*  FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES    */
27 /*  WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN   */
28 /*  AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION,          */
29 /*  ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF       */
30 /*  THIS SOFTWARE.                                                       */
31 /*                                                                       */
32 /*************************************************************************/
33 /*                     Author :  Alan W Black                            */
34 /*                     Date   :  May 1996                                */
35 /*-----------------------------------------------------------------------*/
36 /*                                                                       */
37 /*  Various method functions                                             */
38 /*=======================================================================*/
39 
40 #include <cstdlib>
41 #include <iostream>
42 #include <cstring>
43 #include "EST_unix.h"
44 #include "EST_cutils.h"
45 #include "EST_Token.h"
46 #include "EST_Wagon.h"
47 #include "EST_math.h"
48 
49 
predict(const WVector & d)50 EST_Val WNode::predict(const WVector &d)
51 {
52     if (leaf())
53 	return impurity.value();
54     else if (question.ask(d))
55 	return left->predict(d);
56     else
57 	return right->predict(d);
58 }
59 
predict_node(const WVector & d)60 WNode *WNode::predict_node(const WVector &d)
61 {
62     if (leaf())
63 	return this;
64     else if (question.ask(d))
65 	return left->predict_node(d);
66     else
67 	return right->predict_node(d);
68 }
69 
pure(void)70 int WNode::pure(void)
71 {
72     //  A node is pure if it has no sub-nodes or its not of type class
73 
74     if ((left == 0) && (right == 0))
75 	return TRUE;
76     else if (get_impurity().type() != wnim_class)
77 	return TRUE;
78     else
79 	return FALSE;
80 }
81 
prune(void)82 void WNode::prune(void)
83 {
84     // Check all sub-nodes and if they are all of the same class
85     // delete their sub nodes.  Returns pureness of this node
86 
87     if (pure() == FALSE)
88     {
89 	// Ok lets try and make it pure
90 	if (left != 0) left->prune();
91 	if (right != 0) right->prune();
92 
93 	// Have to check purity as well as values to ensure left and right
94 	// don't further split
95 	if ((left->pure() == TRUE) && ((right->pure() == TRUE)) &&
96 	    (left->get_impurity().value() == right->get_impurity().value()))
97 	{
98 	     delete left; left = 0;
99 	     delete right; right = 0;
100 	}
101     }
102 
103 }
104 
held_out_prune()105 void WNode::held_out_prune()
106 {
107     // prune tree with held out data
108     // Check if node's questions differentiates for the held out data
109     // if not, prune all sub_nodes
110 
111     // Rescore with prune data
112     set_impurity(WImpurity(get_data()));  // for this new data
113 
114     if (left != 0)
115     {
116 	wgn_score_question(question,get_data());
117 	if (question.get_score() < get_impurity().measure())
118 	{  // its worth goint ot the next level
119 	    wgn_find_split(question,get_data(),
120 		       left->get_data(),
121 		       right->get_data());
122 	    left->held_out_prune();
123 	    right->held_out_prune();
124 	}
125 	else
126 	{  // not worth the split so prune both sub_nodes
127 	    delete left; left = 0;
128 	    delete right; right = 0;
129 	}
130     }
131 }
132 
print_out(ostream & s,int margin)133 void WNode::print_out(ostream &s, int margin)
134 {
135     int i;
136 
137     s << endl;
138     for (i=0;i<margin;i++) s << " ";
139     s << "(";
140     if (left==0) // base case
141 	s << impurity;
142     else
143     {
144 	s << question;
145 	left->print_out(s,margin+1);
146 	right->print_out(s,margin+1);
147     }
148     s << ")";
149 }
150 
operator <<(ostream & s,WNode & n)151 ostream & operator <<(ostream &s, WNode &n)
152 {
153     // Output this node and its sub-node
154 
155     n.print_out(s,0);
156     s << endl;
157     return s;
158 }
159 
ignore_non_numbers()160 void WDataSet::ignore_non_numbers()
161 {
162     /* For ols we want to ignore anything that is categorial */
163     int i;
164 
165     for (i=0; i<dlength; i++)
166     {
167         if ((p_type[i] == wndt_binary) ||
168             (p_type[i] == wndt_float))
169             continue;
170         else
171         {
172             p_ignore[i] = TRUE;
173         }
174     }
175 
176     return;
177 }
178 
load_description(const EST_String & fname,LISP ignores)179 void WDataSet::load_description(const EST_String &fname, LISP ignores)
180 {
181     // Initialise a dataset with sizes and types
182     EST_String tname;
183     int i;
184     LISP description,d;
185 
186     description = car(vload(fname,1));
187     dlength = siod_llength(description);
188 
189     p_type.resize(dlength);
190     p_ignore.resize(dlength);
191     p_name.resize(dlength);
192 
193     if (wgn_predictee_name == "")
194 	wgn_predictee = 0;  // default predictee is first field
195     else
196 	wgn_predictee = -1;
197 
198     for (i=0,d=description; d != NIL; d=cdr(d),i++)
199     {
200 	p_name[i] = get_c_string(car(car(d)));
201 	tname = get_c_string(car(cdr(car(d))));
202 	p_ignore[i] = FALSE;
203 	if ((wgn_predictee_name != "") && (wgn_predictee_name == p_name[i]))
204 	    wgn_predictee = i;
205 	if ((wgn_count_field_name != "") &&
206 	    (wgn_count_field_name == p_name[i]))
207 	    wgn_count_field = i;
208 	if ((tname == "count") || (i == wgn_count_field))
209 	{
210 	    // The count must be ignored, repeat it if you want it too
211 	    p_type[i] = wndt_ignore;  // the count must be ignored
212 	    p_ignore[i] = TRUE;
213 	    wgn_count_field = i;
214 	}
215 	else if ((tname == "ignore") || (siod_member_str(p_name[i],ignores)))
216 	{
217 	    p_type[i] = wndt_ignore;  // user specified ignore
218 	    p_ignore[i] = TRUE;
219 	    if (i == wgn_predictee)
220 		wagon_error(EST_String("predictee \"")+p_name[i]+
221 			    "\" can't be ignored \n");
222 	}
223 	else if (siod_llength(car(d)) > 2)
224 	{
225 	    LISP rest = cdr(car(d));
226 	    EST_StrList sl;
227 	    siod_list_to_strlist(rest,sl);
228 	    p_type[i] = wgn_discretes.def(sl);
229 	    if (streq(get_c_string(car(rest)),"_other_"))
230 		wgn_discretes[p_type[i]].def_val("_other_");
231 	}
232 	else if (tname == "binary")
233 	    p_type[i] = wndt_binary;
234 	else if (tname == "cluster")
235 	    p_type[i] = wndt_cluster;
236 	else if (tname == "vector")
237 	    p_type[i] = wndt_vector;
238 	else if (tname == "trajectory")
239 	    p_type[i] = wndt_trajectory;
240 	else if (tname == "ols")
241 	    p_type[i] = wndt_ols;
242 	else if (tname == "matrix")
243 	    p_type[i] = wndt_matrix;
244 	else if (tname == "float")
245 	    p_type[i] = wndt_float;
246 	else
247 	{
248 	    wagon_error(EST_String("Unknown type \"")+tname+
249 			"\" for field number "+itoString(i)+
250                         "/"+p_name[i]+" in description file \""+fname+"\"");
251 	}
252     }
253 
254     if (wgn_predictee == -1)
255     {
256 	wagon_error(EST_String("predictee field \"")+wgn_predictee_name+
257 			"\" not found in description ");
258     }
259 }
260 
ask(const WVector & w) const261 const int WQuestion::ask(const WVector &w) const
262 {
263     // Ask this question of the given vector
264     switch (op)
265     {
266       case wnop_equal:    // for numbers
267 	if (w.get_flt_val(feature_pos) == operand1.Float())
268 	    return TRUE;
269 	else
270 	    return FALSE;
271       case wnop_binary:    // for numbers
272 	if (w.get_int_val(feature_pos) == 1)
273 	    return TRUE;
274 	else
275 	    return FALSE;
276       case wnop_greaterthan:
277 	if (w.get_flt_val(feature_pos) > operand1.Float())
278 	    return TRUE;
279 	else
280 	    return FALSE;
281       case wnop_lessthan:
282 	if (w.get_flt_val(feature_pos) < operand1.Float())
283 	    return TRUE;
284 	else
285 	    return FALSE;
286       case wnop_is:       // for classes
287 	if (w.get_int_val(feature_pos) == operand1.Int())
288 	    return TRUE;
289 	else
290 	    return FALSE;
291       case wnop_in:       // for subsets -- note operand is list of ints
292 	if (ilist_member(operandl,w.get_int_val(feature_pos)))
293 	    return TRUE;
294 	else
295 	    return FALSE;
296       default:
297 	wagon_error("Unknown test operator");
298     }
299 
300     return FALSE;
301 }
302 
operator <<(ostream & s,const WQuestion & q)303 ostream& operator<<(ostream& s, const WQuestion &q)
304 {
305     EST_String name;
306     static EST_Regex needquotes(".*[()'\";., \t\n\r].*");
307 
308     s << "(" << wgn_dataset.feat_name(q.get_fp());
309     switch (q.get_op())
310     {
311       case wnop_equal:
312 	s << " = " << q.get_operand1().string();
313 	break;
314       case wnop_binary:
315 	break;
316       case wnop_greaterthan:
317 	s << " > " << q.get_operand1().Float();
318 	break;
319       case wnop_lessthan:
320 	s << " < " << q.get_operand1().Float();
321 	break;
322       case wnop_is:
323 	name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
324 	    name(q.get_operand1().Int());
325 	s << " is ";
326 	if (name.matches(needquotes))
327 	    s << quote_string(name,"\"","\\",1);
328 	else
329 	    s << name;
330 	break;
331       case wnop_matches:
332 	name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
333 	    name(q.get_operand1().Int());
334 	s << " matches " << quote_string(name,"\"","\\",1);
335 	break;
336       case wnop_in:
337 	s << " in (";
338 	for (int l=0; l < q.get_operandl().length(); l++)
339 	{
340 	    name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
341 		name(q.get_operandl().nth(l));
342 	    if (name.matches(needquotes))
343 		s << quote_string(name,"\"","\\",1);
344 	    else
345 		s << name;
346 	    s << " ";
347 	}
348 	s << ")";
349 	break;
350         // SunCC wont let me add this
351 //      default:
352 //	s << " unknown operation ";
353     }
354     s << ")";
355 
356     return s;
357 }
358 
value(void)359 EST_Val WImpurity::value(void)
360 {
361     // Returns the recommended value for this
362     EST_String s;
363     double prob;
364 
365     if (t==wnim_unset)
366     {
367 	cerr << "WImpurity: no value currently set\n";
368 	return EST_Val(0.0);
369     }
370     else if (t==wnim_class)
371 	return EST_Val(p.most_probable(&prob));
372     else if (t==wnim_cluster)
373 	return EST_Val(a.mean());
374     else if (t==wnim_ols)     /* OLS TBA */
375 	return EST_Val(a.mean());
376     else if (t==wnim_vector)
377 	return EST_Val(a.mean()); /* wnim_vector */
378     else if (t==wnim_trajectory)
379 	return EST_Val(a.mean()); /* NOT YET WRITTEN */
380     else
381 	return EST_Val(a.mean());
382 }
383 
samples(void)384 double WImpurity::samples(void)
385 {
386     if (t==wnim_float)
387 	return a.samples();
388     else if (t==wnim_class)
389 	return (int)p.samples();
390     else if (t==wnim_cluster)
391 	return members.length();
392     else if (t==wnim_ols)
393 	return members.length();
394     else if (t==wnim_vector)
395 	return members.length();
396     else if (t==wnim_trajectory)
397 	return members.length();
398     else
399 	return 0;
400 }
401 
WImpurity(const WVectorVector & ds)402 WImpurity::WImpurity(const WVectorVector &ds)
403 {
404     int i;
405 
406     t=wnim_unset;
407     a.reset(); trajectory=0; l=0; width=0;
408     data = &ds;  // for ols, model calculation
409     for (i=0; i < ds.n(); i++)
410     {
411         if (t == wnim_ols)
412             cumulate(i,1);
413         else if (wgn_count_field == -1)
414 	    cumulate((*(ds(i)))[wgn_predictee],1);
415         else
416 	    cumulate((*(ds(i)))[wgn_predictee],
417 		     (*(ds(i)))[wgn_count_field]);
418     }
419 }
420 
measure(void)421 float WImpurity::measure(void)
422 {
423     if (t == wnim_float)
424 	return a.variance()*a.samples();
425     else if (t == wnim_vector)
426 	return vector_impurity();
427     else if (t == wnim_trajectory)
428 	return trajectory_impurity();
429     else if (t == wnim_matrix)
430 	return a.variance()*a.samples();
431     else if (t == wnim_class)
432 	return p.entropy()*p.samples();
433     else if (t == wnim_cluster)
434 	return cluster_impurity();
435     else if (t == wnim_ols)
436 	return ols_impurity();  /* RMSE for OLS model */
437     else
438     {
439 	cerr << "WImpurity: can't measure unset object" << endl;
440 	return 0.0;
441     }
442 }
443 
vector_impurity()444 float WImpurity::vector_impurity()
445 {
446     // Find the mean/stddev for all values in all vectors
447     // sum the variances and multiply them by the number of members
448     EST_Litem *pp;
449     EST_Litem *countpp;
450     int i,j;
451     EST_SuffStats b;
452     double count = 1;
453 
454     a.reset();
455 #if 1
456     /* simple distance */
457     for (j=0; j<wgn_VertexFeats.num_channels(); j++)
458     {
459         if (wgn_VertexFeats.a(0,j) > 0.0)
460         {
461             b.reset();
462             for (pp=members.head(), countpp=member_counts.head(); pp != 0; pp=pp->next(), countpp=countpp->next())
463             {
464                 i = members.item(pp);
465 
466 		// Accumulate the value with count
467                 b.cumulate(wgn_VertexTrack.a(i,j), member_counts.item(countpp)) ;
468             }
469             a += b.stddev();
470             count = b.samples();
471         }
472     }
473 #endif
474 
475 #if 0
476     EST_SuffStats *c;
477     float x, lshift, rshift, ushift;
478     /* Find base mean, then measure do fshift to find best match */
479     c = new EST_SuffStats[wgn_VertexTrack.num_channels()+1];
480     for (j=0; j<wgn_VertexFeats.num_channels(); j++)
481     {
482         if (wgn_VertexFeats.a(0,j) > 0.0)
483         {
484             c[j].reset();
485             for (pp=members.head(), countpp=member_counts.head(); pp != 0;
486                  pp=pp->next(), countpp=countpp->next())
487             {
488                 i = members.item(pp);
489 		// Accumulate the value with count
490                 c[j].cumulate(wgn_VertexTrack.a(i,j),member_counts.item(countpp));
491             }
492             count = c[j].samples();
493         }
494     }
495 
496     /* Pass through again but vary the num_channels offset (hardcoded) */
497     for (pp=members.head(), countpp=member_counts.head(); pp != 0;
498          pp=pp->next(), countpp=countpp->next())
499     {
500         int q;
501         float bshift, qshift;
502         /* For each sample */
503         i = members.item(pp);
504         /* Find the value left shifted, unshifted, and right shifted */
505         lshift = 0; ushift = 0; rshift = 0;
506         bshift = 0;
507         for (q=-20; q<=20; q++)
508         {
509             qshift = 0;
510             for (j=67+q; j<147+q/*hardcoded*/; j++)
511             {
512                 x = c[j].mean() - wgn_VertexTrack(i,j);
513                 qshift += sqrt(x*x);
514                 if ((bshift > 0) && (qshift > bshift))
515                     break;
516             }
517             if ((bshift == 0) || (qshift < bshift))
518                 bshift = qshift;
519         }
520         a += bshift;
521     }
522 
523 #endif
524 
525 #if 0
526     /* full covariance */
527     /* worse in listening experiments */
528     EST_SuffStats **cs;
529     int mmm;
530     cs = new EST_SuffStats *[wgn_VertexTrack.num_channels()+1];
531     for (j=0; j<=wgn_VertexTrack.num_channels(); j++)
532         cs[j] = new EST_SuffStats[wgn_VertexTrack.num_channels()+1];
533     /* Find means for diagonal */
534     for (j=0; j<wgn_VertexFeats.num_channels(); j++)
535     {
536         if (wgn_VertexFeats.a(0,j) > 0.0)
537         {
538             for (pp=members.head(); pp != 0; pp=pp->next())
539                 cs[j][j] += wgn_VertexTrack.a(members.item(pp),j);
540         }
541     }
542     for (j=0; j<wgn_VertexFeats.num_channels(); j++)
543     {
544         for (i=j+1; i<wgn_VertexFeats.num_channels(); i++)
545             if (wgn_VertexFeats.a(0,j) > 0.0)
546             {
547                 for (pp=members.head(); pp != 0; pp=pp->next())
548                 {
549                     mmm = members.item(pp);
550                     cs[i][j] += (wgn_VertexTrack.a(mmm,i)-cs[j][j].mean())*
551                         (wgn_VertexTrack.a(mmm,j)-cs[j][j].mean());
552                 }
553             }
554     }
555     for (j=0; j<wgn_VertexFeats.num_channels(); j++)
556     {
557         for (i=j+1; i<wgn_VertexFeats.num_channels(); i++)
558             if (wgn_VertexFeats.a(0,j) > 0.0)
559                 a += cs[i][j].stddev();
560     }
561     count = cs[0][0].samples();
562 #endif
563 
564 #if 0
565     // look at mean euclidean distance between vectors
566     EST_Litem *qq;
567     int x,y;
568     double d,q;
569     count = 0;
570     for (pp=members.head(); pp != 0; pp=pp->next())
571     {
572         x = members.item(pp);
573         count++;
574         for (qq=pp->next(); qq != 0; qq=qq->next())
575         {
576             y = members.item(qq);
577             for (q=0.0,j=0; j<wgn_VertexFeats.num_channels(); j++)
578                 if (wgn_VertexFeats.a(0,j) > 0.0)
579                 {
580                     d = wgn_VertexTrack(x,j)-wgn_VertexTrack(y,j);
581                     q += d*d;
582                 }
583             a += sqrt(q);
584         }
585 
586     }
587 #endif
588 
589     // This is sum of stddev*samples
590     return a.mean() * count;
591 }
592 
~WImpurity()593 WImpurity::~WImpurity()
594 {
595     int j;
596 
597     if (trajectory != 0)
598     {
599         for (j=0; j<l; j++)
600             delete [] trajectory[j];
601         delete [] trajectory;
602         trajectory = 0;
603         l = 0;
604     }
605 }
606 
607 
trajectory_impurity()608 float WImpurity::trajectory_impurity()
609 {
610     // Find the mean length of all the units in the cluster
611     // Create that number of points
612     // Interpolate each unit to that number of points
613     // collect means and standard deviations for each point
614     // impurity is sum of the variance for each point and each coef
615     // multiplied by the number of units.
616     EST_Litem *pp;
617     int i, j;
618     int s, ti, ni, q;
619     int s1l, s2l;
620     double n, m, m1, m2, w;
621     EST_SuffStats lss, stdss;
622     EST_SuffStats l1ss, l2ss;
623     int l1, l2;
624     int ola=0;
625 
626     if (trajectory != 0)
627     {   /* already done this */
628         return score;
629     }
630 
631     lss.reset();
632     l = 0;
633     for (pp=members.head(); pp != 0; pp=pp->next())
634     {
635         i = members.item(pp);
636         for (q=0; q<wgn_UnitTrack.a(i,1); q++)
637         {
638             ni = (int)wgn_UnitTrack.a(i,0)+q;
639             if (wgn_VertexTrack.a(ni,0) == -1.0)
640             {
641                 l1ss += q;
642                 ola = 1;
643                 break;
644             }
645         }
646         if (q==wgn_UnitTrack.a(i,1))
647         {   /* can't find -1 center point, so put all in l2 */
648             l1ss += 0;
649             l2ss += q;
650         }
651         else
652             l2ss += wgn_UnitTrack.a(i,1) - (q+1) - 1;
653         lss += wgn_UnitTrack.a(i,1); /* length of each unit in the cluster */
654         if (wgn_UnitTrack.a(i,1) > l)
655             l = (int)wgn_UnitTrack.a(i,1);
656     }
657 
658     if (ola==0)  /* no -1's so its not an ola type cluster */
659     {
660         l = ((int)lss.mean() < 7) ? 7 : (int)lss.mean();
661 
662         /* a list of SuffStats on for each point in the trajectory */
663         trajectory = new EST_SuffStats *[l];
664         width = wgn_VertexTrack.num_channels()+1;
665         for (j=0; j<l; j++)
666             trajectory[j] = new EST_SuffStats[width];
667 
668         for (pp=members.head(); pp != 0; pp=pp->next())
669         {   /* for each unit */
670             i = members.item(pp);
671             m = (float)wgn_UnitTrack.a(i,1)/(float)l; /* find interpolation */
672             s = (int)wgn_UnitTrack.a(i,0); /* start point */
673             for (ti=0,n=0.0; ti<l; ti++,n+=m)
674             {
675                 ni = (int)n;  // hmm floor or nint ??
676                 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
677                 {
678                     if (wgn_VertexFeats.a(0,j) > 0.0)
679                         trajectory[ti][j] += wgn_VertexTrack.a(s+ni,j);
680                 }
681             }
682         }
683 
684         /* find sum of sum of stddev for all coefs of all traj points */
685         stdss.reset();
686         for (ti=0; ti<l; ti++)
687             for (j=0; j<wgn_VertexFeats.num_channels(); j++)
688             {
689                 if (wgn_VertexFeats.a(0,j) > 0.0)
690                     stdss += trajectory[ti][j].stddev();
691             }
692 
693         // This is sum of all stddev * samples
694         score = stdss.mean() * members.length();
695     }
696     else
697     {   /* OLA model */
698         l1 = (l1ss.mean() < 10.0) ? 10 : (int)l1ss.mean();
699         l2 = (l2ss.mean() < 10.0) ? 10 : (int)l2ss.mean();
700         l = l1 + l2 + 1 + 1;
701 
702         /* a list of SuffStats on for each point in the trajectory */
703         trajectory = new EST_SuffStats *[l];
704         for (j=0; j<l; j++)
705             trajectory[j] = new EST_SuffStats[wgn_VertexTrack.num_channels()+1];
706 
707         for (pp=members.head(); pp != 0; pp=pp->next())
708         {   /* for each unit */
709             i = members.item(pp);
710             s1l = 0;
711             s = (int)wgn_UnitTrack.a(i,0); /* start point */
712             for (q=0; q<wgn_UnitTrack.a(i,1); q++)
713                 if (wgn_VertexTrack.a(s+q,0) == -1.0)
714                 {
715                     s1l = q; /* printf("awb q is -1 at %d\n",q); */
716                     break;
717                 }
718             s2l = (int)wgn_UnitTrack.a(i,1) - (s1l + 2);
719             m1 = (float)(s1l)/(float)l1; /* find interpolation step */
720             m2 = (float)(s2l)/(float)l2; /* find interpolation step */
721             /* First half */
722             for (ti=0,n=0.0; s1l > 0 && ti<l1; ti++,n+=m1)
723             {
724                 ni = s + (((int)n < s1l) ? (int)n : s1l - 1);
725                 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
726                     if (wgn_VertexFeats.a(0,j) > 0.0)
727                         trajectory[ti][j] += wgn_VertexTrack.a(ni,j);
728             }
729             ti = l1; /* do it explicitly in case s1l < 1 */
730             for (j=0; j<wgn_VertexFeats.num_channels(); j++)
731                 if (wgn_VertexFeats.a(0,j) > 0.0)
732                     trajectory[ti][j] += -1;
733             /* Second half */
734             s += s1l+1;
735             for (ti++,n=0.0; s2l > 0 && ti<l-1; ti++,n+=m2)
736             {
737                 ni = s + (((int)n < s2l) ? (int)n : s2l - 1);
738                 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
739                     if (wgn_VertexFeats.a(0,j) > 0.0)
740                         trajectory[ti][j] += wgn_VertexTrack.a(ni,j);
741             }
742             for (j=0; j<wgn_VertexFeats.num_channels(); j++)
743                 if (wgn_VertexFeats.a(0,j) > 0.0)
744                     trajectory[ti][j] += -2;
745         }
746 
747         /* find sum of sum of stddev for all coefs of all traj points */
748         /* windowing the sums with a triangular weight window         */
749         stdss.reset();
750         m = 1.0/(float)l1;
751         for (w=0.0,ti=0; ti<l1; ti++,w+=m)
752             for (j=0; j<wgn_VertexFeats.num_channels(); j++)
753                 if (wgn_VertexFeats.a(0,j) > 0.0)
754                 stdss += trajectory[ti][j].stddev() * w;
755         m = 1.0/(float)l2;
756         for (w=1.0,ti++; ti<l-1; ti++,w-=m)
757             for (j=0; j<wgn_VertexFeats.num_channels(); j++)
758                 if (wgn_VertexFeats.a(0,j) > 0.0)
759                     stdss += trajectory[ti][j].stddev() * w;
760 
761         // This is sum of all stddev * samples
762         score = stdss.mean() * members.length();
763     }
764     return score;
765 }
766 
part_to_ols_data(EST_FMatrix & X,EST_FMatrix & Y,EST_IVector & included,EST_StrList & feat_names,const EST_IList & members,const WVectorVector & d)767 static void part_to_ols_data(EST_FMatrix &X, EST_FMatrix &Y,
768                              EST_IVector &included,
769                              EST_StrList &feat_names,
770                              const EST_IList &members,
771                              const WVectorVector &d)
772 {
773     int m,n,p;
774     int w, xm=0;
775     EST_Litem *pp;
776     WVector *wv;
777 
778     w = wgn_dataset.width();
779     included.resize(w);
780     X.resize(members.length(),w);
781     Y.resize(members.length(),1);
782     feat_names.append("Intercept");
783     included[0] = TRUE;
784 
785     for (p=0,pp=members.head(); pp; p++,pp=pp->next())
786     {
787         n = members.item(pp);
788         if (n < 0)
789         {
790             p--;
791             continue;
792         }
793         wv = d(n);
794 	Y.a_no_check(p,0) = (*wv)[0];
795 	X.a_no_check(p,0) = 1;
796 	for (m=1,xm=1; m < w; m++)
797         {
798             if (wgn_dataset.ftype(m) == wndt_float)
799             {
800                 if (p == 0) // only do this once
801                 {
802                     feat_names.append(wgn_dataset.feat_name(m));
803                 }
804                 X.a_no_check(p,xm) = (*wv)[m];
805                 included.a_no_check(xm) = FALSE;
806                 included.a_no_check(xm) = TRUE;
807                 xm++;
808             }
809         }
810     }
811 
812     included.resize(xm);
813     X.resize(p,xm);
814     Y.resize(p,1);
815 }
816 
ols_impurity()817 float WImpurity::ols_impurity()
818 {
819     // Build an OLS model for the current data and measure it against
820     // the data itself and give a RMSE
821     EST_FMatrix X,Y;
822     EST_IVector included;
823     EST_FMatrix coeffs;
824     EST_StrList feat_names;
825     float best_score;
826     EST_FMatrix coeffsl;
827     EST_FMatrix pred;
828     float cor,rmse;
829 
830     // Load the sample members into matrices for ols
831     part_to_ols_data(X,Y,included,feat_names,members,*data);
832 
833     // Find the best ols model.
834     // Far too computationally expensive
835     //    if (!stepwise_ols(X,Y,feat_names,0.0,coeffs,
836     //                      X,Y,included,best_score))
837     //  return WGN_HUGE_VAL;  // couldn't find a model
838 
839     // Non stepwise model
840     if (!robust_ols(X,Y,included,coeffsl))
841     {
842         //        printf("no robust ols\n");
843         return WGN_HUGE_VAL;
844     }
845     ols_apply(X,coeffsl,pred);
846     ols_test(Y,pred,cor,rmse);
847     best_score = cor;
848 
849     printf("Impurity OLS X(%d,%d) Y(%d,%d) %f, %f, %f\n",
850              X.num_rows(),X.num_columns(),Y.num_rows(),Y.num_columns(),
851              rmse,cor,
852              1-best_score);
853     if (fabs(coeffsl[0]) > 10000)
854     {
855         // printf("weird sized Intercept %f\n",coeffsl[0]);
856         return WGN_HUGE_VAL;
857     }
858 
859     return (1-best_score) *members.length();
860 }
861 
cluster_impurity()862 float WImpurity::cluster_impurity()
863 {
864     // Find the mean distance between all members of the dataset
865     // Uses the global DistMatrix for distances between members of
866     // the cluster set.  Distances are assumed to be symmetric thus only
867     // the bottom half of the distance matrix is filled
868     EST_Litem *pp, *q;
869     int i,j;
870     double dist;
871 
872     a.reset();
873     for (pp=members.head(); pp != 0; pp=pp->next())
874     {
875 	i = members.item(pp);
876 	for (q=pp->next(); q != 0; q=q->next())
877 	{
878 	    j = members.item(q);
879 	    dist = (j < i ? wgn_DistMatrix.a_no_check(i,j) :
880  		            wgn_DistMatrix.a_no_check(j,i));
881 	    a+=dist;  // cumulate for whole cluster
882 	}
883     }
884 
885     // This is sum distance between cross product of members
886 //    return a.sum();
887     if (a.samples() > 1)
888         return a.stddev() * a.samples();
889     else
890         return 0.0;
891 }
892 
cluster_distance(int i)893 float WImpurity::cluster_distance(int i)
894 {
895     // Distance this unit is from all others in this cluster
896     // in absolute standard deviations from the the mean.
897     float dist = cluster_member_mean(i);
898     float mdist = dist-a.mean();
899 
900     if (mdist == 0.0)
901 	return 0.0;
902     else
903 	return fabs((dist-a.mean())/a.stddev());
904 
905 }
906 
in_cluster(int i)907 int WImpurity::in_cluster(int i)
908 {
909     // Would this be a member of this cluster?.  Returns 1 if
910     // its distance is less than at least one other
911     float dist = cluster_member_mean(i);
912     EST_Litem *pp;
913 
914     for (pp=members.head(); pp != 0; pp=pp->next())
915     {
916 	if (dist < cluster_member_mean(members.item(pp)))
917 	    return 1;
918     }
919     return 0;
920 }
921 
cluster_ranking(int i)922 float WImpurity::cluster_ranking(int i)
923 {
924     // Position in ranking closest to centre
925     float dist = cluster_distance(i);
926     EST_Litem *pp;
927     int ranking = 1;
928 
929     for (pp=members.head(); pp != 0; pp=pp->next())
930     {
931 	if (dist >= cluster_distance(members.item(pp)))
932 	    ranking++;
933     }
934 
935     return ranking;
936 }
937 
cluster_member_mean(int i)938 float WImpurity::cluster_member_mean(int i)
939 {
940     // Returns the mean difference between this member and all others
941     // in cluster
942     EST_Litem *q;
943     int j,n;
944     double dist,sum;
945 
946     for (sum=0.0,n=0,q=members.head(); q != 0; q=q->next())
947     {
948 	j = members.item(q);
949 	if (i != j)
950 	{
951 	    dist = (j < i ? wgn_DistMatrix(i,j) : wgn_DistMatrix(j,i));
952 	    sum += dist;
953 	    n++;
954 	}
955     }
956 
957     return ( n == 0 ? 0.0 : sum/n );
958 }
959 
cumulate(const float pv,double count)960 void WImpurity::cumulate(const float pv,double count)
961 {
962     // Cumulate data for impurity calculation
963 
964     if (wgn_dataset.ftype(wgn_predictee) == wndt_cluster)
965     {
966 	t = wnim_cluster;
967 	members.append((int)pv);
968     }
969     else if (wgn_dataset.ftype(wgn_predictee) == wndt_ols)
970     {
971 	t = wnim_ols;
972 	members.append((int)pv);
973     }
974     else if (wgn_dataset.ftype(wgn_predictee) == wndt_vector)
975     {
976 	t = wnim_vector;
977 
978 	// AUP: Implement counts in vectors
979 	members.append((int)pv);
980 	member_counts.append((float)count);
981     }
982     else if (wgn_dataset.ftype(wgn_predictee) == wndt_trajectory)
983     {
984 	t = wnim_trajectory;
985 	members.append((int)pv);
986     }
987     else if (wgn_dataset.ftype(wgn_predictee) >= wndt_class)
988     {
989 	if (t == wnim_unset)
990 	    p.init(&wgn_discretes[wgn_dataset.ftype(wgn_predictee)]);
991 	t = wnim_class;
992 	p.cumulate((int)pv,count);
993     }
994     else if (wgn_dataset.ftype(wgn_predictee) == wndt_binary)
995     {
996 	t = wnim_float;
997 	a.cumulate((int)pv,count);
998     }
999     else if (wgn_dataset.ftype(wgn_predictee) == wndt_float)
1000     {
1001 	t = wnim_float;
1002 	a.cumulate(pv,count);
1003     }
1004     else
1005     {
1006 	wagon_error("WImpurity: cannot cumulate EST_Val type");
1007     }
1008 }
1009 
operator <<(ostream & s,WImpurity & imp)1010 ostream & operator <<(ostream &s, WImpurity &imp)
1011 {
1012     int j,i;
1013     EST_SuffStats b;
1014 
1015     if (imp.t == wnim_float)
1016 	s << "(" << imp.a.stddev() << " " << imp.a.mean() << ")";
1017     else if (imp.t == wnim_vector)
1018     {
1019       EST_Litem *p, *countp;
1020 	s << "((";
1021         imp.vector_impurity();
1022         if (wgn_vertex_output == "mean")  //output means
1023         {
1024             for (j=0; j<wgn_VertexTrack.num_channels(); j++)
1025             {
1026                 b.reset();
1027                 for (p=imp.members.head(), countp=imp.member_counts.head(); p != 0; p=p->next(), countp=countp->next())
1028                 {
1029 		  // Accumulate the members with their counts
1030 		  b.cumulate(wgn_VertexTrack.a(imp.members.item(p),j), imp.member_counts.item(countp));
1031 		  //b += wgn_VertexTrack.a(imp.members.item(p),j);
1032                 }
1033                 s << "(" << b.mean() << " ";
1034                 if (isfinite(b.stddev()))
1035                     s << b.stddev() << ")";
1036                 else
1037                     s << "0.001" << ")";
1038                 if (j+1<wgn_VertexTrack.num_channels())
1039                     s << " ";
1040             }
1041         }
1042         else /* output best in the cluster */
1043         {
1044             /* print out vector closest to center, rather than average */
1045             double best = WGN_HUGE_VAL;
1046             double x,d;
1047             int bestp = 0;
1048             EST_SuffStats *cs;
1049 
1050             cs = new EST_SuffStats [wgn_VertexTrack.num_channels()+1];
1051 
1052             for (j=0; j<wgn_VertexFeats.num_channels(); j++)
1053                 if (wgn_VertexFeats.a(0,j) > 0.0)
1054                 {
1055                     cs[j].reset();
1056                     for (p=imp.members.head(); p != 0; p=p->next())
1057                     {
1058                         cs[j] += wgn_VertexTrack.a(imp.members.item(p),j);
1059                     }
1060                 }
1061 
1062             for (p=imp.members.head(); p != 0; p=p->next())
1063             {
1064                 for (x=0.0,j=0; j<wgn_VertexFeats.num_channels(); j++)
1065                     if (wgn_VertexFeats.a(0,j) > 0.0)
1066                     {
1067                         d = (wgn_VertexTrack.a(imp.members.item(p),j)-cs[j].mean())
1068                             /* / cs[j].stddev() */ ; /* seems worse 061218 */
1069                         x += d*d;
1070                     }
1071                 if (x < best)
1072                 {
1073                     bestp = imp.members.item(p);
1074                     best = x;
1075                 }
1076             }
1077             for (j=0; j<wgn_VertexTrack.num_channels(); j++)
1078             {
1079                 s << "( ";
1080                 s << wgn_VertexTrack.a(bestp,j);
1081                 //                s << " 0 "; // fake stddev
1082                 s << " ";
1083                 if (isfinite(cs[j].stddev()))
1084                     s << cs[j].stddev();
1085                 else
1086                     s << "0";
1087                 s << " ) ";
1088                 if (j+1<wgn_VertexTrack.num_channels())
1089                     s << " ";
1090             }
1091 
1092             delete [] cs;
1093         }
1094 	s << ") ";
1095 	s << imp.a.mean() << ")";
1096     }
1097     else if (imp.t == wnim_trajectory)
1098     {
1099 	s << "((";
1100         imp.trajectory_impurity();
1101         for (i=0; i<imp.l; i++)
1102         {
1103             s << "(";
1104             for (j=0; j<wgn_VertexTrack.num_channels(); j++)
1105             {
1106                 s << "(" << imp.trajectory[i][j].mean() << " "
1107                   << imp.trajectory[i][j].stddev() << " " << ")";
1108             }
1109             s << ")\n";
1110         }
1111 	s << ") ";
1112 	// Mean of cross product of distances (cluster score)
1113 	s << imp.a.mean() << ")";
1114     }
1115     else if (imp.t == wnim_cluster)
1116     {
1117 	EST_Litem *p;
1118 	s << "((";
1119 	for (p=imp.members.head(); p != 0; p=p->next())
1120 	{
1121 	    // Ouput cluster member and its mean distance to others
1122 	    s << "(" << imp.members.item(p) << " " <<
1123 		imp.cluster_member_mean(imp.members.item(p)) << ")";
1124 	    if (p->next() != 0)
1125 		s << " ";
1126 	}
1127 	s << ") ";
1128 	// Mean of cross product of distances (cluster score)
1129 	s << imp.a.mean() << ")";
1130     }
1131     else if (imp.t == wnim_ols)
1132     {
1133         /* Output intercept, feature names and coefficients for ols model */
1134         EST_FMatrix X,Y;
1135         EST_IVector included;
1136         EST_FMatrix coeffs;
1137         EST_StrList feat_names;
1138         EST_FMatrix coeffsl;
1139         EST_FMatrix pred;
1140         float cor=0.0,rmse;
1141 
1142         s << "((";
1143         // Load the sample members into matrices for ols
1144         part_to_ols_data(X,Y,included,feat_names,imp.members,*(imp.data));
1145         if (!robust_ols(X,Y,included,coeffsl))
1146         {
1147             printf("no robust ols\n");
1148             // shouldn't happen
1149         }
1150         else
1151         {
1152             ols_apply(X,coeffsl,pred);
1153             ols_test(Y,pred,cor,rmse);
1154             for (i=0; i<coeffsl.num_rows(); i++)
1155             {
1156                 s << "(";
1157                 s << feat_names.nth(i);
1158                 s << " ";
1159                 s << coeffsl[i];
1160                 s << ") ";
1161             }
1162         }
1163 
1164 	// Mean of cross product of distances (cluster score)
1165 	s << ") " << cor << ")";
1166     }
1167     else if (imp.t == wnim_class)
1168     {
1169 	EST_Litem *i;
1170 	EST_String name;
1171 	double prob;
1172 
1173 	s << "(";
1174 	for (i=imp.p.item_start(); !imp.p.item_end(i); i=imp.p.item_next(i))
1175 	{
1176 	    imp.p.item_prob(i,name,prob);
1177 	    s << "(" << name << " " << prob << ") ";
1178 	}
1179 	s << imp.p.most_probable(&prob) << ")";
1180     }
1181     else
1182 	s << "([WImpurity unset])";
1183 
1184     return s;
1185 }
1186 
1187 
1188 
1189 
1190