1 /*************************************************************************/ 2 /* */ 3 /* Centre for Speech Technology Research */ 4 /* University of Edinburgh, UK */ 5 /* Copyright (c) 1996,1997 */ 6 /* All Rights Reserved. */ 7 /* */ 8 /* Permission is hereby granted, free of charge, to use and distribute */ 9 /* this software and its documentation without restriction, including */ 10 /* without limitation the rights to use, copy, modify, merge, publish, */ 11 /* distribute, sublicense, and/or sell copies of this work, and to */ 12 /* permit persons to whom this work is furnished to do so, subject to */ 13 /* the following conditions: */ 14 /* 1. The code must retain the above copyright notice, this list of */ 15 /* conditions and the following disclaimer. */ 16 /* 2. Any modifications must be clearly marked as such. */ 17 /* 3. Original authors' names are not deleted. */ 18 /* 4. The authors' names are not used to endorse or promote products */ 19 /* derived from this software without specific prior written */ 20 /* permission. */ 21 /* */ 22 /* THE UNIVERSITY OF EDINBURGH AND THE CONTRIBUTORS TO THIS WORK */ 23 /* DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING */ 24 /* ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT */ 25 /* SHALL THE UNIVERSITY OF EDINBURGH NOR THE CONTRIBUTORS BE LIABLE */ 26 /* FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES */ 27 /* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN */ 28 /* AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, */ 29 /* ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF */ 30 /* THIS SOFTWARE. */ 31 /* */ 32 /*************************************************************************/ 33 /* Author : Alan W Black */ 34 /* Date : May 1996 */ 35 /*-----------------------------------------------------------------------*/ 36 /* */ 37 /* Public declarations for Wagon (CART builder) */ 38 /* */ 39 /*=======================================================================*/ 40 #ifndef __WAGON_H__ 41 #define __WAGON_H__ 42 43 #include "EST_String.h" 44 #include "EST_Val.h" 45 #include "EST_TVector.h" 46 #include "EST_TList.h" 47 #include "EST_simplestats.h" /* For EST_SuffStats class */ 48 #include "EST_Track.h" 49 #include "siod.h" 50 #define wagon_error(WMESS) (cerr << WMESS << endl,exit(-1)) 51 52 // I get floating point exceptions of Alphas when I do any comparisons 53 // with HUGE_VAL or FLT_MAX so I'll make my own 54 #define WGN_HUGE_VAL 1.0e20 55 56 class WVector : public EST_FVector 57 { 58 public: WVector(int n)59 WVector(int n) : EST_FVector(n) {} get_int_val(int n)60 int get_int_val(int n) const { return (int)a_no_check(n); } get_flt_val(int n)61 float get_flt_val(int n) const { return a_no_check(n); } set_int_val(int n,int i)62 void set_int_val(int n,int i) { a_check(n) = (int)i; } set_flt_val(int n,float f)63 void set_flt_val(int n,float f) { a_check(n) = f; } 64 }; 65 66 typedef EST_TList<WVector *> WVectorList; 67 typedef EST_TVector<WVector *> WVectorVector; 68 69 /* Different types of feature */ 70 enum wn_dtype {/* for predictees and predictors */ 71 wndt_binary, wndt_float, wndt_class, 72 /* for predictees only */ 73 wndt_cluster, wndt_vector, wndt_matrix, wndt_trajectory, 74 wndt_ols, 75 /* for ignored features */ 76 wndt_ignore}; 77 78 class WDataSet : public WVectorList { 79 private: 80 int dlength; 81 EST_IVector p_type; 82 EST_IVector p_ignore; 83 EST_StrVector p_name; 84 public: 85 void load_description(const EST_String& descfname,LISP ignores); 86 void ignore_non_numbers(); 87 ftype(const int & i)88 int ftype(const int &i) const {return p_type(i);} ignore(int i)89 int ignore(int i) const {return p_ignore(i); } set_ignore(int i,int value)90 void set_ignore(int i,int value) { p_ignore[i] = value; } feat_name(const int & i)91 const EST_String &feat_name(const int &i) const {return p_name(i);} samples(void)92 int samples(void) const {return length();} width(void)93 int width(void) const {return dlength;} 94 }; 95 enum wn_oper {wnop_equal, wnop_binary, wnop_greaterthan, 96 wnop_lessthan, wnop_is, wnop_in, wnop_matches}; 97 98 class WQuestion { 99 private: 100 int feature_pos; 101 wn_oper op; 102 int yes; 103 int no; 104 EST_Val operand1; 105 EST_IList operandl; 106 float score; 107 public: WQuestion()108 WQuestion() {;} WQuestion(const WQuestion & s)109 WQuestion(const WQuestion &s) 110 { feature_pos=s.feature_pos; 111 op=s.op; yes=s.yes; no=s.no; operand1=s.operand1; 112 operandl = s.operandl; score=s.score;} ~WQuestion()113 ~WQuestion() {;} WQuestion(int fp,wn_oper o,EST_Val a)114 WQuestion(int fp, wn_oper o,EST_Val a) 115 { feature_pos=fp; op=o; operand1=a; } set_fp(const int & fp)116 void set_fp(const int &fp) {feature_pos=fp;} set_oper(const wn_oper & o)117 void set_oper(const wn_oper &o) {op=o;} set_operand1(const EST_Val & a)118 void set_operand1(const EST_Val &a) {operand1 = a;} set_yes(const int & y)119 void set_yes(const int &y) {yes=y;} set_no(const int & n)120 void set_no(const int &n) {no=n;} get_yes(void)121 int get_yes(void) const {return yes;} get_no(void)122 int get_no(void) const {return no;} get_fp(void)123 const int get_fp(void) const {return feature_pos;} get_op(void)124 const wn_oper get_op(void) const {return op;} get_operand1(void)125 const EST_Val get_operand1(void) const {return operand1;} get_operandl(void)126 const EST_IList &get_operandl(void) const {return operandl;} get_score(void)127 const float get_score(void) const {return score;} set_score(const float & f)128 void set_score(const float &f) {score=f;} 129 const int ask(const WVector &w) const; 130 friend ostream& operator<<(ostream& s, const WQuestion &q); 131 }; 132 133 enum wnim_type {wnim_unset, wnim_float, wnim_class, 134 wnim_cluster, wnim_vector, wnim_matrix, wnim_ols, 135 wnim_trajectory}; 136 137 // Impurity measure for cumulating impurities from set of data 138 class WImpurity { 139 private: 140 wnim_type t; 141 EST_SuffStats a; 142 EST_DiscreteProbDistribution p; 143 144 float cluster_impurity(); 145 float cluster_member_mean(int i); 146 float vector_impurity(); 147 float trajectory_impurity(); 148 float ols_impurity(); 149 public: 150 EST_IList members; // Maybe there should be a cluster class 151 EST_FList member_counts; // AUP: Implement counts for vectors 152 EST_SuffStats **trajectory; 153 const WVectorVector *data; // Needed for ols 154 float score; 155 int l,width; 156 WImpurity()157 WImpurity() { t=wnim_unset; a.reset(); trajectory=0; l=0; width=0; data=0;} 158 ~WImpurity(); 159 WImpurity(const WVectorVector &ds); copy(const WImpurity & s)160 void copy(const WImpurity &s) 161 { 162 int i,j; 163 t=s.t; a=s.a; p=s.p; members=s.members; member_counts = s.member_counts; l=s.l; width=s.width; 164 score = s.score; 165 data = s.data; 166 if (s.trajectory) 167 { 168 trajectory = new EST_SuffStats *[l]; 169 for (i=0; i<l; i++) 170 { 171 trajectory[i] = new EST_SuffStats[width]; 172 for (j=0; j<width; j++) 173 trajectory[i][j] = s.trajectory[i][j]; 174 } 175 } 176 } 177 WImpurity &operator = (const WImpurity &a) { copy(a); return *this; } 178 179 float measure(void); 180 double samples(void); type(void)181 wnim_type type(void) const { return t;} 182 void cumulate(const float pv,double count=1.0); 183 EST_Val value(void); pd()184 EST_DiscreteProbDistribution &pd() { return p; } 185 float cluster_distance(int i); // distance i from centre in sds 186 int in_cluster(int i); // distance i from centre < most remote member 187 float cluster_ranking(int i); // position in closeness to centre 188 friend ostream& operator<<(ostream &s, WImpurity &imp); 189 }; 190 191 class WDlist { 192 private: 193 float p_score; 194 WQuestion p_question; 195 EST_String p_token; 196 int p_freq; 197 int p_samples; 198 WDlist *next; 199 public: WDlist()200 WDlist() { next=0; } ~WDlist()201 ~WDlist() { if (next != 0) delete next; } set_score(float s)202 void set_score(float s) { p_score = s; } set_question(const WQuestion & q)203 void set_question(const WQuestion &q) { p_question = q; } set_best(const EST_String & t,int freq,int samples)204 void set_best(const EST_String &t,int freq, int samples) 205 { p_token = t; p_freq = freq; p_samples = samples;} score()206 float score() const {return p_score;} token(void)207 const EST_String &token(void) const {return p_token;} question()208 const WQuestion &question() const {return p_question;} 209 EST_Val predict(const WVector &w); 210 friend WDlist *add_to_dlist(WDlist *l,WDlist *a); 211 friend ostream &operator<<(ostream &s, WDlist &d); 212 }; 213 214 class WNode { 215 private: 216 WVectorVector data; 217 WQuestion question; 218 WImpurity impurity; 219 WNode *left; 220 WNode *right; 221 void print_out(ostream &s, int margin); leaf(void)222 int leaf(void) const { return ((left == 0) || (right == 0)); } 223 int pure(void); 224 public: WNode()225 WNode() { left = right = 0; } ~WNode()226 ~WNode() { if (left != 0) {delete left; left=0;} 227 if (right != 0) {delete right; right=0;} } get_data(void)228 WVectorVector &get_data(void) { return data; } set_subnodes(WNode * l,WNode * r)229 void set_subnodes(WNode *l,WNode *r) { left=l; right=r; } set_impurity(const WImpurity & imp)230 void set_impurity(const WImpurity &imp) {impurity=imp;} set_question(const WQuestion & q)231 void set_question(const WQuestion &q) {question=q;} 232 void prune(void); 233 void held_out_prune(void); get_impurity(void)234 WImpurity &get_impurity(void) {return impurity;} get_question(void)235 WQuestion &get_question(void) {return question;} 236 EST_Val predict(const WVector &w); 237 WNode *predict_node(const WVector &d); samples(void)238 int samples(void) const { return data.n(); } 239 friend ostream& operator<<(ostream &s, WNode &n); 240 }; 241 242 extern Discretes wgn_discretes; 243 extern WDataSet wgn_dataset; 244 extern WDataSet wgn_test_dataset; 245 extern EST_FMatrix wgn_DistMatrix; 246 extern EST_Track wgn_VertexTrack; 247 extern EST_Track wgn_UnitTrack; 248 extern EST_Track wgn_VertexFeats; 249 250 void wgn_load_datadescription(EST_String fname,LISP ignores); 251 void wgn_load_dataset(WDataSet &ds,EST_String fname); 252 WNode *wgn_build_tree(float &score); 253 WNode *wgn_build_dlist(float &score,ostream *output); 254 WNode *wagon_stepwise(float limit); 255 float wgn_score_question(WQuestion &q, WVectorVector &ds); 256 void wgn_find_split(WQuestion &q,WVectorVector &ds, 257 WVectorVector &y,WVectorVector &n); 258 float summary_results(WNode &tree,ostream *output); 259 260 extern int wgn_min_cluster_size; 261 extern int wgn_held_out; 262 extern int wgn_prune; 263 extern int wgn_quiet; 264 extern int wgn_verbose; 265 extern int wgn_predictee; 266 extern int wgn_count_field; 267 extern EST_String wgn_count_field_name; 268 extern EST_String wgn_predictee_name; 269 extern float wgn_float_range_split; 270 extern float wgn_balance; 271 extern EST_String wgn_opt_param; 272 extern EST_String wgn_vertex_output; 273 274 #define wgn_ques_feature(X) (get_c_string(car(X))) 275 #define wgn_ques_oper_str(X) (get_c_string(car(cdr(X)))) 276 #define wgn_ques_operand(X) (car(cdr(cdr(X)))) 277 278 int wagon_ask_question(LISP question, LISP value); 279 280 int stepwise_ols(const EST_FMatrix &X, 281 const EST_FMatrix &Y, 282 const EST_StrList &feat_names, 283 float limit, 284 EST_FMatrix &coeffs, 285 const EST_FMatrix &Xtest, 286 const EST_FMatrix &Ytest, 287 EST_IVector &included, 288 float &best_score); 289 int robust_ols(const EST_FMatrix &X, 290 const EST_FMatrix &Y, 291 EST_IVector &included, 292 EST_FMatrix &coeffs); 293 int ols_apply(const EST_FMatrix &samples, 294 const EST_FMatrix &coeffs, 295 EST_FMatrix &res); 296 int ols_test(const EST_FMatrix &real, 297 const EST_FMatrix &predicted, 298 float &correlation, 299 float &rmse); 300 301 #endif /* __WAGON_H__ */ 302