1 /*************************************************************************/
2 /* */
3 /* Centre for Speech Technology Research */
4 /* University of Edinburgh, UK */
5 /* Copyright (c) 1996,1997 */
6 /* All Rights Reserved. */
7 /* */
8 /* Permission is hereby granted, free of charge, to use and distribute */
9 /* this software and its documentation without restriction, including */
10 /* without limitation the rights to use, copy, modify, merge, publish, */
11 /* distribute, sublicense, and/or sell copies of this work, and to */
12 /* permit persons to whom this work is furnished to do so, subject to */
13 /* the following conditions: */
14 /* 1. The code must retain the above copyright notice, this list of */
15 /* conditions and the following disclaimer. */
16 /* 2. Any modifications must be clearly marked as such. */
17 /* 3. Original authors' names are not deleted. */
18 /* 4. The authors' names are not used to endorse or promote products */
19 /* derived from this software without specific prior written */
20 /* permission. */
21 /* */
22 /* THE UNIVERSITY OF EDINBURGH AND THE CONTRIBUTORS TO THIS WORK */
23 /* DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING */
24 /* ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT */
25 /* SHALL THE UNIVERSITY OF EDINBURGH NOR THE CONTRIBUTORS BE LIABLE */
26 /* FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES */
27 /* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN */
28 /* AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, */
29 /* ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF */
30 /* THIS SOFTWARE. */
31 /* */
32 /*************************************************************************/
33 /* Author : Alan W Black */
34 /* Date : May 1996 */
35 /*-----------------------------------------------------------------------*/
36 /* */
37 /* Various method functions */
38 /*=======================================================================*/
39
40 #include <cstdlib>
41 #include <iostream>
42 #include <cstring>
43 #include "EST_unix.h"
44 #include "EST_cutils.h"
45 #include "EST_Token.h"
46 #include "EST_Wagon.h"
47 #include "EST_math.h"
48
49
predict(const WVector & d)50 EST_Val WNode::predict(const WVector &d)
51 {
52 if (leaf())
53 return impurity.value();
54 else if (question.ask(d))
55 return left->predict(d);
56 else
57 return right->predict(d);
58 }
59
predict_node(const WVector & d)60 WNode *WNode::predict_node(const WVector &d)
61 {
62 if (leaf())
63 return this;
64 else if (question.ask(d))
65 return left->predict_node(d);
66 else
67 return right->predict_node(d);
68 }
69
pure(void)70 int WNode::pure(void)
71 {
72 // A node is pure if it has no sub-nodes or its not of type class
73
74 if ((left == 0) && (right == 0))
75 return TRUE;
76 else if (get_impurity().type() != wnim_class)
77 return TRUE;
78 else
79 return FALSE;
80 }
81
prune(void)82 void WNode::prune(void)
83 {
84 // Check all sub-nodes and if they are all of the same class
85 // delete their sub nodes. Returns pureness of this node
86
87 if (pure() == FALSE)
88 {
89 // Ok lets try and make it pure
90 if (left != 0) left->prune();
91 if (right != 0) right->prune();
92
93 // Have to check purity as well as values to ensure left and right
94 // don't further split
95 if ((left->pure() == TRUE) && ((right->pure() == TRUE)) &&
96 (left->get_impurity().value() == right->get_impurity().value()))
97 {
98 delete left; left = 0;
99 delete right; right = 0;
100 }
101 }
102
103 }
104
held_out_prune()105 void WNode::held_out_prune()
106 {
107 // prune tree with held out data
108 // Check if node's questions differentiates for the held out data
109 // if not, prune all sub_nodes
110
111 // Rescore with prune data
112 set_impurity(WImpurity(get_data())); // for this new data
113
114 if (left != 0)
115 {
116 wgn_score_question(question,get_data());
117 if (question.get_score() < get_impurity().measure())
118 { // its worth goint ot the next level
119 wgn_find_split(question,get_data(),
120 left->get_data(),
121 right->get_data());
122 left->held_out_prune();
123 right->held_out_prune();
124 }
125 else
126 { // not worth the split so prune both sub_nodes
127 delete left; left = 0;
128 delete right; right = 0;
129 }
130 }
131 }
132
print_out(ostream & s,int margin)133 void WNode::print_out(ostream &s, int margin)
134 {
135 int i;
136
137 s << endl;
138 for (i=0;i<margin;i++) s << " ";
139 s << "(";
140 if (left==0) // base case
141 s << impurity;
142 else
143 {
144 s << question;
145 left->print_out(s,margin+1);
146 right->print_out(s,margin+1);
147 }
148 s << ")";
149 }
150
operator <<(ostream & s,WNode & n)151 ostream & operator <<(ostream &s, WNode &n)
152 {
153 // Output this node and its sub-node
154
155 n.print_out(s,0);
156 s << endl;
157 return s;
158 }
159
ignore_non_numbers()160 void WDataSet::ignore_non_numbers()
161 {
162 /* For ols we want to ignore anything that is categorial */
163 int i;
164
165 for (i=0; i<dlength; i++)
166 {
167 if ((p_type[i] == wndt_binary) ||
168 (p_type[i] == wndt_float))
169 continue;
170 else
171 {
172 p_ignore[i] = TRUE;
173 }
174 }
175
176 return;
177 }
178
load_description(const EST_String & fname,LISP ignores)179 void WDataSet::load_description(const EST_String &fname, LISP ignores)
180 {
181 // Initialise a dataset with sizes and types
182 EST_String tname;
183 int i;
184 LISP description,d;
185
186 description = car(vload(fname,1));
187 dlength = siod_llength(description);
188
189 p_type.resize(dlength);
190 p_ignore.resize(dlength);
191 p_name.resize(dlength);
192
193 if (wgn_predictee_name == "")
194 wgn_predictee = 0; // default predictee is first field
195 else
196 wgn_predictee = -1;
197
198 for (i=0,d=description; d != NIL; d=cdr(d),i++)
199 {
200 p_name[i] = get_c_string(car(car(d)));
201 tname = get_c_string(car(cdr(car(d))));
202 p_ignore[i] = FALSE;
203 if ((wgn_predictee_name != "") && (wgn_predictee_name == p_name[i]))
204 wgn_predictee = i;
205 if ((wgn_count_field_name != "") &&
206 (wgn_count_field_name == p_name[i]))
207 wgn_count_field = i;
208 if ((tname == "count") || (i == wgn_count_field))
209 {
210 // The count must be ignored, repeat it if you want it too
211 p_type[i] = wndt_ignore; // the count must be ignored
212 p_ignore[i] = TRUE;
213 wgn_count_field = i;
214 }
215 else if ((tname == "ignore") || (siod_member_str(p_name[i],ignores)))
216 {
217 p_type[i] = wndt_ignore; // user specified ignore
218 p_ignore[i] = TRUE;
219 if (i == wgn_predictee)
220 wagon_error(EST_String("predictee \"")+p_name[i]+
221 "\" can't be ignored \n");
222 }
223 else if (siod_llength(car(d)) > 2)
224 {
225 LISP rest = cdr(car(d));
226 EST_StrList sl;
227 siod_list_to_strlist(rest,sl);
228 p_type[i] = wgn_discretes.def(sl);
229 if (streq(get_c_string(car(rest)),"_other_"))
230 wgn_discretes[p_type[i]].def_val("_other_");
231 }
232 else if (tname == "binary")
233 p_type[i] = wndt_binary;
234 else if (tname == "cluster")
235 p_type[i] = wndt_cluster;
236 else if (tname == "vector")
237 p_type[i] = wndt_vector;
238 else if (tname == "trajectory")
239 p_type[i] = wndt_trajectory;
240 else if (tname == "ols")
241 p_type[i] = wndt_ols;
242 else if (tname == "matrix")
243 p_type[i] = wndt_matrix;
244 else if (tname == "float")
245 p_type[i] = wndt_float;
246 else
247 {
248 wagon_error(EST_String("Unknown type \"")+tname+
249 "\" for field number "+itoString(i)+
250 "/"+p_name[i]+" in description file \""+fname+"\"");
251 }
252 }
253
254 if (wgn_predictee == -1)
255 {
256 wagon_error(EST_String("predictee field \"")+wgn_predictee_name+
257 "\" not found in description ");
258 }
259 }
260
ask(const WVector & w) const261 const int WQuestion::ask(const WVector &w) const
262 {
263 // Ask this question of the given vector
264 switch (op)
265 {
266 case wnop_equal: // for numbers
267 if (w.get_flt_val(feature_pos) == operand1.Float())
268 return TRUE;
269 else
270 return FALSE;
271 case wnop_binary: // for numbers
272 if (w.get_int_val(feature_pos) == 1)
273 return TRUE;
274 else
275 return FALSE;
276 case wnop_greaterthan:
277 if (w.get_flt_val(feature_pos) > operand1.Float())
278 return TRUE;
279 else
280 return FALSE;
281 case wnop_lessthan:
282 if (w.get_flt_val(feature_pos) < operand1.Float())
283 return TRUE;
284 else
285 return FALSE;
286 case wnop_is: // for classes
287 if (w.get_int_val(feature_pos) == operand1.Int())
288 return TRUE;
289 else
290 return FALSE;
291 case wnop_in: // for subsets -- note operand is list of ints
292 if (ilist_member(operandl,w.get_int_val(feature_pos)))
293 return TRUE;
294 else
295 return FALSE;
296 default:
297 wagon_error("Unknown test operator");
298 }
299
300 return FALSE;
301 }
302
operator <<(ostream & s,const WQuestion & q)303 ostream& operator<<(ostream& s, const WQuestion &q)
304 {
305 EST_String name;
306 static EST_Regex needquotes(".*[()'\";., \t\n\r].*");
307
308 s << "(" << wgn_dataset.feat_name(q.get_fp());
309 switch (q.get_op())
310 {
311 case wnop_equal:
312 s << " = " << q.get_operand1().string();
313 break;
314 case wnop_binary:
315 break;
316 case wnop_greaterthan:
317 s << " > " << q.get_operand1().Float();
318 break;
319 case wnop_lessthan:
320 s << " < " << q.get_operand1().Float();
321 break;
322 case wnop_is:
323 name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
324 name(q.get_operand1().Int());
325 s << " is ";
326 if (name.matches(needquotes))
327 s << quote_string(name,"\"","\\",1);
328 else
329 s << name;
330 break;
331 case wnop_matches:
332 name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
333 name(q.get_operand1().Int());
334 s << " matches " << quote_string(name,"\"","\\",1);
335 break;
336 case wnop_in:
337 s << " in (";
338 for (int l=0; l < q.get_operandl().length(); l++)
339 {
340 name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
341 name(q.get_operandl().nth(l));
342 if (name.matches(needquotes))
343 s << quote_string(name,"\"","\\",1);
344 else
345 s << name;
346 s << " ";
347 }
348 s << ")";
349 break;
350 // SunCC wont let me add this
351 // default:
352 // s << " unknown operation ";
353 }
354 s << ")";
355
356 return s;
357 }
358
value(void)359 EST_Val WImpurity::value(void)
360 {
361 // Returns the recommended value for this
362 EST_String s;
363 double prob;
364
365 if (t==wnim_unset)
366 {
367 cerr << "WImpurity: no value currently set\n";
368 return EST_Val(0.0);
369 }
370 else if (t==wnim_class)
371 return EST_Val(p.most_probable(&prob));
372 else if (t==wnim_cluster)
373 return EST_Val(a.mean());
374 else if (t==wnim_ols) /* OLS TBA */
375 return EST_Val(a.mean());
376 else if (t==wnim_vector)
377 return EST_Val(a.mean()); /* wnim_vector */
378 else if (t==wnim_trajectory)
379 return EST_Val(a.mean()); /* NOT YET WRITTEN */
380 else
381 return EST_Val(a.mean());
382 }
383
samples(void)384 double WImpurity::samples(void)
385 {
386 if (t==wnim_float)
387 return a.samples();
388 else if (t==wnim_class)
389 return (int)p.samples();
390 else if (t==wnim_cluster)
391 return members.length();
392 else if (t==wnim_ols)
393 return members.length();
394 else if (t==wnim_vector)
395 return members.length();
396 else if (t==wnim_trajectory)
397 return members.length();
398 else
399 return 0;
400 }
401
WImpurity(const WVectorVector & ds)402 WImpurity::WImpurity(const WVectorVector &ds)
403 {
404 int i;
405
406 t=wnim_unset;
407 a.reset(); trajectory=0; l=0; width=0;
408 data = &ds; // for ols, model calculation
409 for (i=0; i < ds.n(); i++)
410 {
411 if (t == wnim_ols)
412 cumulate(i,1);
413 else if (wgn_count_field == -1)
414 cumulate((*(ds(i)))[wgn_predictee],1);
415 else
416 cumulate((*(ds(i)))[wgn_predictee],
417 (*(ds(i)))[wgn_count_field]);
418 }
419 }
420
measure(void)421 float WImpurity::measure(void)
422 {
423 if (t == wnim_float)
424 return a.variance()*a.samples();
425 else if (t == wnim_vector)
426 return vector_impurity();
427 else if (t == wnim_trajectory)
428 return trajectory_impurity();
429 else if (t == wnim_matrix)
430 return a.variance()*a.samples();
431 else if (t == wnim_class)
432 return p.entropy()*p.samples();
433 else if (t == wnim_cluster)
434 return cluster_impurity();
435 else if (t == wnim_ols)
436 return ols_impurity(); /* RMSE for OLS model */
437 else
438 {
439 cerr << "WImpurity: can't measure unset object" << endl;
440 return 0.0;
441 }
442 }
443
vector_impurity()444 float WImpurity::vector_impurity()
445 {
446 // Find the mean/stddev for all values in all vectors
447 // sum the variances and multiply them by the number of members
448 EST_Litem *pp;
449 EST_Litem *countpp;
450 int i,j;
451 EST_SuffStats b;
452 double count = 1;
453
454 a.reset();
455 #if 1
456 /* simple distance */
457 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
458 {
459 if (wgn_VertexFeats.a(0,j) > 0.0)
460 {
461 b.reset();
462 for (pp=members.head(), countpp=member_counts.head(); pp != 0; pp=pp->next(), countpp=countpp->next())
463 {
464 i = members.item(pp);
465
466 // Accumulate the value with count
467 b.cumulate(wgn_VertexTrack.a(i,j), member_counts.item(countpp)) ;
468 }
469 a += b.stddev();
470 count = b.samples();
471 }
472 }
473 #endif
474
475 #if 0
476 EST_SuffStats *c;
477 float x, lshift, rshift, ushift;
478 /* Find base mean, then measure do fshift to find best match */
479 c = new EST_SuffStats[wgn_VertexTrack.num_channels()+1];
480 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
481 {
482 if (wgn_VertexFeats.a(0,j) > 0.0)
483 {
484 c[j].reset();
485 for (pp=members.head(), countpp=member_counts.head(); pp != 0;
486 pp=pp->next(), countpp=countpp->next())
487 {
488 i = members.item(pp);
489 // Accumulate the value with count
490 c[j].cumulate(wgn_VertexTrack.a(i,j),member_counts.item(countpp));
491 }
492 count = c[j].samples();
493 }
494 }
495
496 /* Pass through again but vary the num_channels offset (hardcoded) */
497 for (pp=members.head(), countpp=member_counts.head(); pp != 0;
498 pp=pp->next(), countpp=countpp->next())
499 {
500 int q;
501 float bshift, qshift;
502 /* For each sample */
503 i = members.item(pp);
504 /* Find the value left shifted, unshifted, and right shifted */
505 lshift = 0; ushift = 0; rshift = 0;
506 bshift = 0;
507 for (q=-20; q<=20; q++)
508 {
509 qshift = 0;
510 for (j=67+q; j<147+q/*hardcoded*/; j++)
511 {
512 x = c[j].mean() - wgn_VertexTrack(i,j);
513 qshift += sqrt(x*x);
514 if ((bshift > 0) && (qshift > bshift))
515 break;
516 }
517 if ((bshift == 0) || (qshift < bshift))
518 bshift = qshift;
519 }
520 a += bshift;
521 }
522
523 #endif
524
525 #if 0
526 /* full covariance */
527 /* worse in listening experiments */
528 EST_SuffStats **cs;
529 int mmm;
530 cs = new EST_SuffStats *[wgn_VertexTrack.num_channels()+1];
531 for (j=0; j<=wgn_VertexTrack.num_channels(); j++)
532 cs[j] = new EST_SuffStats[wgn_VertexTrack.num_channels()+1];
533 /* Find means for diagonal */
534 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
535 {
536 if (wgn_VertexFeats.a(0,j) > 0.0)
537 {
538 for (pp=members.head(); pp != 0; pp=pp->next())
539 cs[j][j] += wgn_VertexTrack.a(members.item(pp),j);
540 }
541 }
542 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
543 {
544 for (i=j+1; i<wgn_VertexFeats.num_channels(); i++)
545 if (wgn_VertexFeats.a(0,j) > 0.0)
546 {
547 for (pp=members.head(); pp != 0; pp=pp->next())
548 {
549 mmm = members.item(pp);
550 cs[i][j] += (wgn_VertexTrack.a(mmm,i)-cs[j][j].mean())*
551 (wgn_VertexTrack.a(mmm,j)-cs[j][j].mean());
552 }
553 }
554 }
555 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
556 {
557 for (i=j+1; i<wgn_VertexFeats.num_channels(); i++)
558 if (wgn_VertexFeats.a(0,j) > 0.0)
559 a += cs[i][j].stddev();
560 }
561 count = cs[0][0].samples();
562 #endif
563
564 #if 0
565 // look at mean euclidean distance between vectors
566 EST_Litem *qq;
567 int x,y;
568 double d,q;
569 count = 0;
570 for (pp=members.head(); pp != 0; pp=pp->next())
571 {
572 x = members.item(pp);
573 count++;
574 for (qq=pp->next(); qq != 0; qq=qq->next())
575 {
576 y = members.item(qq);
577 for (q=0.0,j=0; j<wgn_VertexFeats.num_channels(); j++)
578 if (wgn_VertexFeats.a(0,j) > 0.0)
579 {
580 d = wgn_VertexTrack(x,j)-wgn_VertexTrack(y,j);
581 q += d*d;
582 }
583 a += sqrt(q);
584 }
585
586 }
587 #endif
588
589 // This is sum of stddev*samples
590 return a.mean() * count;
591 }
592
~WImpurity()593 WImpurity::~WImpurity()
594 {
595 int j;
596
597 if (trajectory != 0)
598 {
599 for (j=0; j<l; j++)
600 delete [] trajectory[j];
601 delete [] trajectory;
602 trajectory = 0;
603 l = 0;
604 }
605 }
606
607
trajectory_impurity()608 float WImpurity::trajectory_impurity()
609 {
610 // Find the mean length of all the units in the cluster
611 // Create that number of points
612 // Interpolate each unit to that number of points
613 // collect means and standard deviations for each point
614 // impurity is sum of the variance for each point and each coef
615 // multiplied by the number of units.
616 EST_Litem *pp;
617 int i, j;
618 int s, ti, ni, q;
619 int s1l, s2l;
620 double n, m, m1, m2, w;
621 EST_SuffStats lss, stdss;
622 EST_SuffStats l1ss, l2ss;
623 int l1, l2;
624 int ola=0;
625
626 if (trajectory != 0)
627 { /* already done this */
628 return score;
629 }
630
631 lss.reset();
632 l = 0;
633 for (pp=members.head(); pp != 0; pp=pp->next())
634 {
635 i = members.item(pp);
636 for (q=0; q<wgn_UnitTrack.a(i,1); q++)
637 {
638 ni = (int)wgn_UnitTrack.a(i,0)+q;
639 if (wgn_VertexTrack.a(ni,0) == -1.0)
640 {
641 l1ss += q;
642 ola = 1;
643 break;
644 }
645 }
646 if (q==wgn_UnitTrack.a(i,1))
647 { /* can't find -1 center point, so put all in l2 */
648 l1ss += 0;
649 l2ss += q;
650 }
651 else
652 l2ss += wgn_UnitTrack.a(i,1) - (q+1) - 1;
653 lss += wgn_UnitTrack.a(i,1); /* length of each unit in the cluster */
654 if (wgn_UnitTrack.a(i,1) > l)
655 l = (int)wgn_UnitTrack.a(i,1);
656 }
657
658 if (ola==0) /* no -1's so its not an ola type cluster */
659 {
660 l = ((int)lss.mean() < 7) ? 7 : (int)lss.mean();
661
662 /* a list of SuffStats on for each point in the trajectory */
663 trajectory = new EST_SuffStats *[l];
664 width = wgn_VertexTrack.num_channels()+1;
665 for (j=0; j<l; j++)
666 trajectory[j] = new EST_SuffStats[width];
667
668 for (pp=members.head(); pp != 0; pp=pp->next())
669 { /* for each unit */
670 i = members.item(pp);
671 m = (float)wgn_UnitTrack.a(i,1)/(float)l; /* find interpolation */
672 s = (int)wgn_UnitTrack.a(i,0); /* start point */
673 for (ti=0,n=0.0; ti<l; ti++,n+=m)
674 {
675 ni = (int)n; // hmm floor or nint ??
676 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
677 {
678 if (wgn_VertexFeats.a(0,j) > 0.0)
679 trajectory[ti][j] += wgn_VertexTrack.a(s+ni,j);
680 }
681 }
682 }
683
684 /* find sum of sum of stddev for all coefs of all traj points */
685 stdss.reset();
686 for (ti=0; ti<l; ti++)
687 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
688 {
689 if (wgn_VertexFeats.a(0,j) > 0.0)
690 stdss += trajectory[ti][j].stddev();
691 }
692
693 // This is sum of all stddev * samples
694 score = stdss.mean() * members.length();
695 }
696 else
697 { /* OLA model */
698 l1 = (l1ss.mean() < 10.0) ? 10 : (int)l1ss.mean();
699 l2 = (l2ss.mean() < 10.0) ? 10 : (int)l2ss.mean();
700 l = l1 + l2 + 1 + 1;
701
702 /* a list of SuffStats on for each point in the trajectory */
703 trajectory = new EST_SuffStats *[l];
704 for (j=0; j<l; j++)
705 trajectory[j] = new EST_SuffStats[wgn_VertexTrack.num_channels()+1];
706
707 for (pp=members.head(); pp != 0; pp=pp->next())
708 { /* for each unit */
709 i = members.item(pp);
710 s1l = 0;
711 s = (int)wgn_UnitTrack.a(i,0); /* start point */
712 for (q=0; q<wgn_UnitTrack.a(i,1); q++)
713 if (wgn_VertexTrack.a(s+q,0) == -1.0)
714 {
715 s1l = q; /* printf("awb q is -1 at %d\n",q); */
716 break;
717 }
718 s2l = (int)wgn_UnitTrack.a(i,1) - (s1l + 2);
719 m1 = (float)(s1l)/(float)l1; /* find interpolation step */
720 m2 = (float)(s2l)/(float)l2; /* find interpolation step */
721 /* First half */
722 for (ti=0,n=0.0; s1l > 0 && ti<l1; ti++,n+=m1)
723 {
724 ni = s + (((int)n < s1l) ? (int)n : s1l - 1);
725 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
726 if (wgn_VertexFeats.a(0,j) > 0.0)
727 trajectory[ti][j] += wgn_VertexTrack.a(ni,j);
728 }
729 ti = l1; /* do it explicitly in case s1l < 1 */
730 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
731 if (wgn_VertexFeats.a(0,j) > 0.0)
732 trajectory[ti][j] += -1;
733 /* Second half */
734 s += s1l+1;
735 for (ti++,n=0.0; s2l > 0 && ti<l-1; ti++,n+=m2)
736 {
737 ni = s + (((int)n < s2l) ? (int)n : s2l - 1);
738 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
739 if (wgn_VertexFeats.a(0,j) > 0.0)
740 trajectory[ti][j] += wgn_VertexTrack.a(ni,j);
741 }
742 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
743 if (wgn_VertexFeats.a(0,j) > 0.0)
744 trajectory[ti][j] += -2;
745 }
746
747 /* find sum of sum of stddev for all coefs of all traj points */
748 /* windowing the sums with a triangular weight window */
749 stdss.reset();
750 m = 1.0/(float)l1;
751 for (w=0.0,ti=0; ti<l1; ti++,w+=m)
752 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
753 if (wgn_VertexFeats.a(0,j) > 0.0)
754 stdss += trajectory[ti][j].stddev() * w;
755 m = 1.0/(float)l2;
756 for (w=1.0,ti++; ti<l-1; ti++,w-=m)
757 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
758 if (wgn_VertexFeats.a(0,j) > 0.0)
759 stdss += trajectory[ti][j].stddev() * w;
760
761 // This is sum of all stddev * samples
762 score = stdss.mean() * members.length();
763 }
764 return score;
765 }
766
part_to_ols_data(EST_FMatrix & X,EST_FMatrix & Y,EST_IVector & included,EST_StrList & feat_names,const EST_IList & members,const WVectorVector & d)767 static void part_to_ols_data(EST_FMatrix &X, EST_FMatrix &Y,
768 EST_IVector &included,
769 EST_StrList &feat_names,
770 const EST_IList &members,
771 const WVectorVector &d)
772 {
773 int m,n,p;
774 int w, xm=0;
775 EST_Litem *pp;
776 WVector *wv;
777
778 w = wgn_dataset.width();
779 included.resize(w);
780 X.resize(members.length(),w);
781 Y.resize(members.length(),1);
782 feat_names.append("Intercept");
783 included[0] = TRUE;
784
785 for (p=0,pp=members.head(); pp; p++,pp=pp->next())
786 {
787 n = members.item(pp);
788 if (n < 0)
789 {
790 p--;
791 continue;
792 }
793 wv = d(n);
794 Y.a_no_check(p,0) = (*wv)[0];
795 X.a_no_check(p,0) = 1;
796 for (m=1,xm=1; m < w; m++)
797 {
798 if (wgn_dataset.ftype(m) == wndt_float)
799 {
800 if (p == 0) // only do this once
801 {
802 feat_names.append(wgn_dataset.feat_name(m));
803 }
804 X.a_no_check(p,xm) = (*wv)[m];
805 included.a_no_check(xm) = FALSE;
806 included.a_no_check(xm) = TRUE;
807 xm++;
808 }
809 }
810 }
811
812 included.resize(xm);
813 X.resize(p,xm);
814 Y.resize(p,1);
815 }
816
ols_impurity()817 float WImpurity::ols_impurity()
818 {
819 // Build an OLS model for the current data and measure it against
820 // the data itself and give a RMSE
821 EST_FMatrix X,Y;
822 EST_IVector included;
823 EST_FMatrix coeffs;
824 EST_StrList feat_names;
825 float best_score;
826 EST_FMatrix coeffsl;
827 EST_FMatrix pred;
828 float cor,rmse;
829
830 // Load the sample members into matrices for ols
831 part_to_ols_data(X,Y,included,feat_names,members,*data);
832
833 // Find the best ols model.
834 // Far too computationally expensive
835 // if (!stepwise_ols(X,Y,feat_names,0.0,coeffs,
836 // X,Y,included,best_score))
837 // return WGN_HUGE_VAL; // couldn't find a model
838
839 // Non stepwise model
840 if (!robust_ols(X,Y,included,coeffsl))
841 {
842 // printf("no robust ols\n");
843 return WGN_HUGE_VAL;
844 }
845 ols_apply(X,coeffsl,pred);
846 ols_test(Y,pred,cor,rmse);
847 best_score = cor;
848
849 printf("Impurity OLS X(%d,%d) Y(%d,%d) %f, %f, %f\n",
850 X.num_rows(),X.num_columns(),Y.num_rows(),Y.num_columns(),
851 rmse,cor,
852 1-best_score);
853 if (fabs(coeffsl[0]) > 10000)
854 {
855 // printf("weird sized Intercept %f\n",coeffsl[0]);
856 return WGN_HUGE_VAL;
857 }
858
859 return (1-best_score) *members.length();
860 }
861
cluster_impurity()862 float WImpurity::cluster_impurity()
863 {
864 // Find the mean distance between all members of the dataset
865 // Uses the global DistMatrix for distances between members of
866 // the cluster set. Distances are assumed to be symmetric thus only
867 // the bottom half of the distance matrix is filled
868 EST_Litem *pp, *q;
869 int i,j;
870 double dist;
871
872 a.reset();
873 for (pp=members.head(); pp != 0; pp=pp->next())
874 {
875 i = members.item(pp);
876 for (q=pp->next(); q != 0; q=q->next())
877 {
878 j = members.item(q);
879 dist = (j < i ? wgn_DistMatrix.a_no_check(i,j) :
880 wgn_DistMatrix.a_no_check(j,i));
881 a+=dist; // cumulate for whole cluster
882 }
883 }
884
885 // This is sum distance between cross product of members
886 // return a.sum();
887 if (a.samples() > 1)
888 return a.stddev() * a.samples();
889 else
890 return 0.0;
891 }
892
cluster_distance(int i)893 float WImpurity::cluster_distance(int i)
894 {
895 // Distance this unit is from all others in this cluster
896 // in absolute standard deviations from the the mean.
897 float dist = cluster_member_mean(i);
898 float mdist = dist-a.mean();
899
900 if (mdist == 0.0)
901 return 0.0;
902 else
903 return fabs((dist-a.mean())/a.stddev());
904
905 }
906
in_cluster(int i)907 int WImpurity::in_cluster(int i)
908 {
909 // Would this be a member of this cluster?. Returns 1 if
910 // its distance is less than at least one other
911 float dist = cluster_member_mean(i);
912 EST_Litem *pp;
913
914 for (pp=members.head(); pp != 0; pp=pp->next())
915 {
916 if (dist < cluster_member_mean(members.item(pp)))
917 return 1;
918 }
919 return 0;
920 }
921
cluster_ranking(int i)922 float WImpurity::cluster_ranking(int i)
923 {
924 // Position in ranking closest to centre
925 float dist = cluster_distance(i);
926 EST_Litem *pp;
927 int ranking = 1;
928
929 for (pp=members.head(); pp != 0; pp=pp->next())
930 {
931 if (dist >= cluster_distance(members.item(pp)))
932 ranking++;
933 }
934
935 return ranking;
936 }
937
cluster_member_mean(int i)938 float WImpurity::cluster_member_mean(int i)
939 {
940 // Returns the mean difference between this member and all others
941 // in cluster
942 EST_Litem *q;
943 int j,n;
944 double dist,sum;
945
946 for (sum=0.0,n=0,q=members.head(); q != 0; q=q->next())
947 {
948 j = members.item(q);
949 if (i != j)
950 {
951 dist = (j < i ? wgn_DistMatrix(i,j) : wgn_DistMatrix(j,i));
952 sum += dist;
953 n++;
954 }
955 }
956
957 return ( n == 0 ? 0.0 : sum/n );
958 }
959
cumulate(const float pv,double count)960 void WImpurity::cumulate(const float pv,double count)
961 {
962 // Cumulate data for impurity calculation
963
964 if (wgn_dataset.ftype(wgn_predictee) == wndt_cluster)
965 {
966 t = wnim_cluster;
967 members.append((int)pv);
968 }
969 else if (wgn_dataset.ftype(wgn_predictee) == wndt_ols)
970 {
971 t = wnim_ols;
972 members.append((int)pv);
973 }
974 else if (wgn_dataset.ftype(wgn_predictee) == wndt_vector)
975 {
976 t = wnim_vector;
977
978 // AUP: Implement counts in vectors
979 members.append((int)pv);
980 member_counts.append((float)count);
981 }
982 else if (wgn_dataset.ftype(wgn_predictee) == wndt_trajectory)
983 {
984 t = wnim_trajectory;
985 members.append((int)pv);
986 }
987 else if (wgn_dataset.ftype(wgn_predictee) >= wndt_class)
988 {
989 if (t == wnim_unset)
990 p.init(&wgn_discretes[wgn_dataset.ftype(wgn_predictee)]);
991 t = wnim_class;
992 p.cumulate((int)pv,count);
993 }
994 else if (wgn_dataset.ftype(wgn_predictee) == wndt_binary)
995 {
996 t = wnim_float;
997 a.cumulate((int)pv,count);
998 }
999 else if (wgn_dataset.ftype(wgn_predictee) == wndt_float)
1000 {
1001 t = wnim_float;
1002 a.cumulate(pv,count);
1003 }
1004 else
1005 {
1006 wagon_error("WImpurity: cannot cumulate EST_Val type");
1007 }
1008 }
1009
operator <<(ostream & s,WImpurity & imp)1010 ostream & operator <<(ostream &s, WImpurity &imp)
1011 {
1012 int j,i;
1013 EST_SuffStats b;
1014
1015 if (imp.t == wnim_float)
1016 s << "(" << imp.a.stddev() << " " << imp.a.mean() << ")";
1017 else if (imp.t == wnim_vector)
1018 {
1019 EST_Litem *p, *countp;
1020 s << "((";
1021 imp.vector_impurity();
1022 if (wgn_vertex_output == "mean") //output means
1023 {
1024 for (j=0; j<wgn_VertexTrack.num_channels(); j++)
1025 {
1026 b.reset();
1027 for (p=imp.members.head(), countp=imp.member_counts.head(); p != 0; p=p->next(), countp=countp->next())
1028 {
1029 // Accumulate the members with their counts
1030 b.cumulate(wgn_VertexTrack.a(imp.members.item(p),j), imp.member_counts.item(countp));
1031 //b += wgn_VertexTrack.a(imp.members.item(p),j);
1032 }
1033 s << "(" << b.mean() << " ";
1034 if (isfinite(b.stddev()))
1035 s << b.stddev() << ")";
1036 else
1037 s << "0.001" << ")";
1038 if (j+1<wgn_VertexTrack.num_channels())
1039 s << " ";
1040 }
1041 }
1042 else /* output best in the cluster */
1043 {
1044 /* print out vector closest to center, rather than average */
1045 double best = WGN_HUGE_VAL;
1046 double x,d;
1047 int bestp = 0;
1048 EST_SuffStats *cs;
1049
1050 cs = new EST_SuffStats [wgn_VertexTrack.num_channels()+1];
1051
1052 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
1053 if (wgn_VertexFeats.a(0,j) > 0.0)
1054 {
1055 cs[j].reset();
1056 for (p=imp.members.head(); p != 0; p=p->next())
1057 {
1058 cs[j] += wgn_VertexTrack.a(imp.members.item(p),j);
1059 }
1060 }
1061
1062 for (p=imp.members.head(); p != 0; p=p->next())
1063 {
1064 for (x=0.0,j=0; j<wgn_VertexFeats.num_channels(); j++)
1065 if (wgn_VertexFeats.a(0,j) > 0.0)
1066 {
1067 d = (wgn_VertexTrack.a(imp.members.item(p),j)-cs[j].mean())
1068 /* / cs[j].stddev() */ ; /* seems worse 061218 */
1069 x += d*d;
1070 }
1071 if (x < best)
1072 {
1073 bestp = imp.members.item(p);
1074 best = x;
1075 }
1076 }
1077 for (j=0; j<wgn_VertexTrack.num_channels(); j++)
1078 {
1079 s << "( ";
1080 s << wgn_VertexTrack.a(bestp,j);
1081 // s << " 0 "; // fake stddev
1082 s << " ";
1083 if (isfinite(cs[j].stddev()))
1084 s << cs[j].stddev();
1085 else
1086 s << "0";
1087 s << " ) ";
1088 if (j+1<wgn_VertexTrack.num_channels())
1089 s << " ";
1090 }
1091
1092 delete [] cs;
1093 }
1094 s << ") ";
1095 s << imp.a.mean() << ")";
1096 }
1097 else if (imp.t == wnim_trajectory)
1098 {
1099 s << "((";
1100 imp.trajectory_impurity();
1101 for (i=0; i<imp.l; i++)
1102 {
1103 s << "(";
1104 for (j=0; j<wgn_VertexTrack.num_channels(); j++)
1105 {
1106 s << "(" << imp.trajectory[i][j].mean() << " "
1107 << imp.trajectory[i][j].stddev() << " " << ")";
1108 }
1109 s << ")\n";
1110 }
1111 s << ") ";
1112 // Mean of cross product of distances (cluster score)
1113 s << imp.a.mean() << ")";
1114 }
1115 else if (imp.t == wnim_cluster)
1116 {
1117 EST_Litem *p;
1118 s << "((";
1119 for (p=imp.members.head(); p != 0; p=p->next())
1120 {
1121 // Ouput cluster member and its mean distance to others
1122 s << "(" << imp.members.item(p) << " " <<
1123 imp.cluster_member_mean(imp.members.item(p)) << ")";
1124 if (p->next() != 0)
1125 s << " ";
1126 }
1127 s << ") ";
1128 // Mean of cross product of distances (cluster score)
1129 s << imp.a.mean() << ")";
1130 }
1131 else if (imp.t == wnim_ols)
1132 {
1133 /* Output intercept, feature names and coefficients for ols model */
1134 EST_FMatrix X,Y;
1135 EST_IVector included;
1136 EST_FMatrix coeffs;
1137 EST_StrList feat_names;
1138 EST_FMatrix coeffsl;
1139 EST_FMatrix pred;
1140 float cor=0.0,rmse;
1141
1142 s << "((";
1143 // Load the sample members into matrices for ols
1144 part_to_ols_data(X,Y,included,feat_names,imp.members,*(imp.data));
1145 if (!robust_ols(X,Y,included,coeffsl))
1146 {
1147 printf("no robust ols\n");
1148 // shouldn't happen
1149 }
1150 else
1151 {
1152 ols_apply(X,coeffsl,pred);
1153 ols_test(Y,pred,cor,rmse);
1154 for (i=0; i<coeffsl.num_rows(); i++)
1155 {
1156 s << "(";
1157 s << feat_names.nth(i);
1158 s << " ";
1159 s << coeffsl[i];
1160 s << ") ";
1161 }
1162 }
1163
1164 // Mean of cross product of distances (cluster score)
1165 s << ") " << cor << ")";
1166 }
1167 else if (imp.t == wnim_class)
1168 {
1169 EST_Litem *i;
1170 EST_String name;
1171 double prob;
1172
1173 s << "(";
1174 for (i=imp.p.item_start(); !imp.p.item_end(i); i=imp.p.item_next(i))
1175 {
1176 imp.p.item_prob(i,name,prob);
1177 s << "(" << name << " " << prob << ") ";
1178 }
1179 s << imp.p.most_probable(&prob) << ")";
1180 }
1181 else
1182 s << "([WImpurity unset])";
1183
1184 return s;
1185 }
1186
1187
1188
1189
1190