1 // 2 // svm.hpp 3 // support vector machine 4 // 5 // Created by Joshua Lynch on 6/19/2013. 6 // Copyright (c) 2013 Schloss Lab. All rights reserved. 7 // 8 9 #ifndef svm_hpp_ 10 #define svm_hpp_ 11 12 13 #include <algorithm> 14 #include <cmath> 15 #include <deque> 16 #include <exception> 17 #include <list> 18 #include <map> 19 #include <set> 20 #include <stack> 21 #include <string> 22 #include <sstream> 23 #include "mothurout.h" 24 #include "utils.hpp" 25 26 // For the purpose of training a support vector machine 27 // we need to calculate a dot product between two feature 28 // vectors. In general these feature vectors are not 29 // restricted to lists of doubles, but in this implementation 30 // feature vectors (or 'observations' as they will be called from here on) 31 // will be vectors of doubles. 32 typedef vector<double> Observation; 33 34 /* 35 class Observation { 36 public: 37 Observation() {} 38 ~Observation() {} 39 40 private: 41 vector<double> obs; 42 }; 43 */ 44 45 // A dataset is a collection of labeled observations. 46 // The ObservationVector typedef is a vector 47 // of pointers to ObservationVectors. Pointers are used here since 48 // datasets will be rearranged many times during cross validation. 49 // Using pointers to Observations makes copying the elements of 50 // an ObservationVector cheap. 51 typedef vector<Observation*> ObservationVector; 52 53 // Training a support vector machine requires labeled data. The 54 // Label typedef defines what will constitute a class 'label' in 55 // this implementation. 56 typedef string Label; 57 typedef vector<Label> LabelVector; 58 typedef set<Label> LabelSet; 59 60 // Pairs of class labels are important because a support vector machine 61 // can only learn two classes of data. The LabelPair typedef is a vector 62 // even though a pair might seem more natural, but it is useful to 63 // iterate over the pair. 64 typedef vector<Label> LabelPair; 65 LabelPair buildLabelPair(const Label& one, const Label& two); 66 67 // Learning to classify a dataset with more than two classes requires 68 // training a separate support vector machine for each pair of classes. 69 // The LabelPairSet typedef defines a container for the collection of 70 // all unique label pairs for a set of data. 71 typedef set<LabelPair> LabelPairSet; 72 73 // A dataset is a set of observations with associated labels. The 74 // LabeledObservation typedef is a label-observation pair intended to 75 // hold one observation and its corresponding label. Using a pointer 76 // to Observation makes these objects cheap to copy. 77 //typedef pair<Label, Observation*> LabeledObservation; 78 79 // This is a refactoring of the original LabeledObservation typedef. 80 // The original typedef has been promoted to a class in order to add 81 // at least one additional member variable, int datasetIndex, which 82 // will be used to implement kernel function optimizations. 83 class LabeledObservation { 84 public: LabeledObservation(int _datasetIndex,Label _label,Observation * _o)85 LabeledObservation(int _datasetIndex, Label _label, Observation* _o) : datasetIndex(_datasetIndex), first(_label), second(_o) {} ~LabeledObservation()86 ~LabeledObservation() {} 87 removeFeatureAtIndex(int n)88 void removeFeatureAtIndex(int n) { 89 int m = 0; 90 Observation::iterator i = second->begin(); 91 while ( m < n ) { 92 i++; 93 m++; 94 } 95 second->erase(i); 96 } 97 getDatasetIndex() const98 int getDatasetIndex() const { return datasetIndex; } getLabel() const99 Label getLabel() const { return first; } getObservation() const100 Observation* getObservation() const { return second; } 101 102 //private: 103 int datasetIndex; 104 Label first; 105 Observation* second; 106 }; 107 108 109 110 // A LabeledObservationVector is a container for an entire dataset (or a 111 // subset of an entire dataset). 112 typedef vector<LabeledObservation> LabeledObservationVector; 113 void buildLabelSet(LabelSet&, const LabeledObservationVector&); 114 115 116 double getMinimumFeatureValueForObservation(Observation::size_type featureIndex, LabeledObservationVector& observations); 117 double getMaximumFeatureValueForObservation(Observation::size_type featureIndex, LabeledObservationVector& observations); 118 119 120 void transformZeroOne(LabeledObservationVector&); 121 void transformZeroMeanUnitVariance(LabeledObservationVector&); 122 123 124 class Feature { 125 public: Feature(int i,const string & l)126 Feature(int i, const string& l) : index(i), label(l) {} Feature(const Feature & f)127 Feature(const Feature& f) : index(f.index), label(f.label) {} ~Feature()128 ~Feature() {} 129 getFeatureIndex() const130 int getFeatureIndex() const { return index; } setFeatureIndex(int i)131 void setFeatureIndex(int i) { index = i; } getFeatureLabel() const132 string getFeatureLabel() const { return label; } 133 134 private: 135 int index; 136 string label; 137 }; 138 139 typedef list<Feature> FeatureList; 140 typedef vector<Feature> FeatureVector; 141 142 // might make sense for this to be a member function of SvmDataset 143 FeatureVector applyStdThreshold(double, LabeledObservationVector&, FeatureVector&); 144 145 146 // A RankedFeature is just a Feature and a its associated 'rank', where 147 // rank is the SVM-RFE iteration during which the feature was eliminated. 148 // If the SVM-RFE method eliminates multiple features in an iteration 149 // then some features will have the same rank. 150 class RankedFeature { 151 public: RankedFeature(const Feature & f,int r)152 RankedFeature(const Feature& f, int r) : feature(f), rank(r) {} ~RankedFeature()153 ~RankedFeature() {} 154 getFeature() const155 Feature getFeature() const { return feature; } getRank() const156 int getRank() const { return rank; } 157 158 private: 159 Feature feature; 160 int rank; 161 }; 162 163 typedef list<RankedFeature> RankedFeatureList; 164 165 166 // The SvmDataset class encapsulates labeled observations and feature information. 167 // All data required to train SVMs is found in SvmDataset. 168 class SvmDataset { 169 public: SvmDataset(const LabeledObservationVector & v,const FeatureVector & f)170 SvmDataset(const LabeledObservationVector& v, const FeatureVector& f) : labeledObservationVector(v), featureVector(f) {} ~SvmDataset()171 ~SvmDataset() {} 172 getLabeledObservationVector()173 LabeledObservationVector& getLabeledObservationVector() { return labeledObservationVector; } getFeatureVector()174 FeatureVector& getFeatureVector() { return featureVector; } 175 removeFeature(const Feature feature)176 void removeFeature(const Feature feature) { 177 178 } 179 180 private: 181 LabeledObservationVector labeledObservationVector; 182 FeatureVector featureVector; 183 }; 184 185 186 // 187 // 0 - print no optional information (quiet) 188 // 1 - print minimum optional information (info) 189 // 2 - print a little more optional information (debug) 190 // 3 - print the maximum amount of optional information (trace) 191 // 192 class OutputFilter { 193 public: OutputFilter(int v)194 OutputFilter(int v) : verbosity(v) {} OutputFilter(const OutputFilter & of)195 OutputFilter(const OutputFilter& of) : verbosity(of.verbosity) {} ~OutputFilter()196 ~OutputFilter() {} 197 info() const198 bool info() const { return verbosity >= INFO; } debug() const199 bool debug() const { return verbosity >= mDEBUG; } trace() const200 bool trace() const { return verbosity >= TRACE; } 201 202 static const int QUIET; 203 static const int INFO; 204 static const int mDEBUG; 205 static const int TRACE; 206 207 private: 208 const int verbosity; 209 }; 210 211 212 // Dividing a dataset into training and testing sets while maintaing equal 213 // representation of all classes is done using a LabelToLabeledObservationVector. 214 // This container is used to divide datasets into groups of LabeledObservations 215 // having the same label. For example, given a LabeledObservationVector like 216 // ["blue", [1.0, 2.0, 3.0]] 217 // ["green", [3.0, 4.0, 5.0]] 218 // ["blue", [2,0, 3.0. 4.0]] 219 // ["green", [4.0, 5.0, 6.0]] 220 // the corresponding LabelToLabeledObservationVector looks like 221 // "blue", [["blue", [1.0, 2.0, 3.0]], ["blue", [2,0, 3.0. 4.0]]] 222 // "green", [["green", [3.0, 4.0, 5.0]], ["green", [4.0, 5.0, 6.0]]] 223 typedef map<Label, LabeledObservationVector> LabelToLabeledObservationVector; 224 void buildLabelToLabeledObservationVector(LabelToLabeledObservationVector&, const LabeledObservationVector&); 225 226 // A support vector machine uses +1 and -1 in calculations to represent 227 // the two classes of data it is trained to distinguish. The NumericClassToLabel 228 // container is used to record the labels associated with these integers. 229 // For a dataset with labels "blue" and "green" a NumericClassToLabel map looks like 230 // 1, "blue" 231 // -1, "green" 232 typedef map<int, Label> NumericClassToLabel; 233 void buildNumericClassToLabelMap(LabelPair); 234 235 typedef double Parameter; 236 typedef string ParameterName; 237 typedef vector<double> ParameterRange; 238 typedef map<ParameterName, ParameterRange> ParameterRangeMap; 239 240 typedef map<string, ParameterRangeMap> KernelParameterRangeMap; 241 void getDefaultKernelParameterRangeMap(KernelParameterRangeMap& kernelParameterRangeMap); 242 243 typedef map<ParameterName, Parameter> ParameterMap; 244 typedef vector<ParameterMap> ParameterMapVector; 245 typedef stack<Parameter> ParameterStack; 246 247 class ParameterSetBuilder { 248 public: 249 // If the argument ParameterRangeMap looks like this: 250 // { "a" : [1.0, 2.0], "b" : [-1.0, 1.0], "c" : [0.5, 0.6] } 251 // then the list of parameter sets looks like this: 252 // [ {"a":1.0, "b":-1.0, "c":0.5}, 253 // {"a":1.0, "b":-1.0, "c":0.6}, 254 // {"a":1.0, "b": 1.0, "c":0.5}, 255 // {"a":1.0, "b": 1.0, "c":0.6}, 256 // {"a":2.0, "b":-1.0, "c":0.5}, 257 // {"a":2.0, "b":-1.0, "c":0.6}, 258 // {"a":2.0, "b": 1.0, "c":0.5}, 259 // {"a":2.0, "b": 1.0, "c":0.6}, 260 // ] ParameterSetBuilder(const ParameterRangeMap & parameterRangeMap)261 ParameterSetBuilder(const ParameterRangeMap& parameterRangeMap) { 262 // a small step toward quieting down this code 263 bool verbose = false; 264 265 stack<pair<ParameterName, ParameterStack> > stackOfParameterRanges; 266 stack<pair<ParameterName, ParameterStack> > stackOfEmptyParameterRanges; 267 ParameterMap nextParameterSet; 268 int parameterSetCount = 1; 269 for ( ParameterRangeMap::const_iterator i = parameterRangeMap.begin(); i != parameterRangeMap.end(); i++ ) { 270 parameterSetCount *= i->second.size(); 271 ParameterName parameterName = i->first; 272 ParameterStack emptyParameterStack; 273 stackOfEmptyParameterRanges.push(make_pair(parameterName, emptyParameterStack)); 274 } 275 // get started 276 for ( int n = 0; n < parameterSetCount; n++ ) { 277 278 if (verbose) m->mothurOut("n = " + toString(n) ); m->mothurOutEndLine(); 279 280 // pull empty stacks off until there are no empty stacks 281 while ( stackOfParameterRanges.size() > 0 and stackOfParameterRanges.top().second.size() == 0 ) { 282 283 if (verbose) m->mothurOut(" empty parameter range: " + stackOfParameterRanges.top().first); m->mothurOutEndLine(); 284 285 stackOfEmptyParameterRanges.push(stackOfParameterRanges.top()); 286 stackOfParameterRanges.pop(); 287 } 288 289 // move to the next value for the parameter at the top of the stackOfParameterRanges 290 if ( stackOfParameterRanges.size() > 0 ) { 291 if (verbose) { 292 m->mothurOut( " moving to next value for parameter " + toString(stackOfParameterRanges.top().first) ); m->mothurOutEndLine(); 293 m->mothurOut( " next value is " + toString(stackOfParameterRanges.top().second.top()) ); m->mothurOutEndLine(); 294 } 295 ParameterName parameterName = stackOfParameterRanges.top().first; 296 nextParameterSet[parameterName] = stackOfParameterRanges.top().second.top(); 297 stackOfParameterRanges.top().second.pop(); 298 } 299 if (verbose) m->mothurOut( "stack of empty parameter ranges has size " + toString(stackOfEmptyParameterRanges.size() ) ); m->mothurOutEndLine(); 300 // reset each parameter range that has been exhausted 301 while ( stackOfEmptyParameterRanges.size() > 0 ) { 302 ParameterName parameterName = stackOfEmptyParameterRanges.top().first; 303 if (verbose) m->mothurOut( " reseting range for parameter " + toString(stackOfEmptyParameterRanges.top().first) ); m->mothurOutEndLine(); 304 stackOfParameterRanges.push(stackOfEmptyParameterRanges.top()); 305 stackOfEmptyParameterRanges.pop(); 306 const ParameterRange& parameterRange = parameterRangeMap.find(parameterName)->second; 307 // it is nice to have the parameters used in order smallest to largest 308 // so that we choose the smallest in ties 309 // but we will not enforce this so users can specify parameters in the order they like 310 // this loop will use parameters in the order they are found in the parameter range 311 for (ParameterRange::const_reverse_iterator i = parameterRange.rbegin(); i != parameterRange.rend(); i++ ) { 312 stackOfParameterRanges.top().second.push(*i); 313 } 314 nextParameterSet[parameterName] = stackOfParameterRanges.top().second.top(); 315 stackOfParameterRanges.top().second.pop(); 316 } 317 parameterSetVector.push_back(nextParameterSet); 318 // print out the next parameter set 319 if (verbose) { 320 for (ParameterMap::iterator p = nextParameterSet.begin(); p != nextParameterSet.end(); p++) { 321 m->mothurOut(toString(p->first) + " : " + toString(p->second) ); m->mothurOutEndLine(); 322 } 323 } 324 } 325 } ~ParameterSetBuilder()326 ~ParameterSetBuilder() {} 327 getParameterSetList()328 const ParameterMapVector& getParameterSetList() { return parameterSetVector; } 329 330 private: 331 ParameterMapVector parameterSetVector; 332 MothurOut* m; 333 }; 334 335 336 class RowCache { 337 public: RowCache(int d)338 RowCache(int d) { //: cache(d, NULL) 339 for (int i = 0; i < d; i++) { cache.push_back(NULL); } 340 } 341 ~RowCache()342 virtual ~RowCache() { 343 for (int i = 0; i < cache.size(); i++) { 344 if ( !rowNotCached(i) ) { 345 delete cache[i]; 346 } 347 } 348 } 349 getCachedValue(int i,int j)350 double getCachedValue(int i, int j) { 351 if ( rowNotCached(i) ) { 352 createRow(i); 353 } 354 return cache.at(i)->at(j); 355 } 356 createRow(int i)357 void createRow(int i) { 358 cache[i] = new vector<double>(cache.size(), numeric_limits<double>::signaling_NaN()); 359 for ( int v = 0; v < cache.size(); v++ ) { 360 cache.at(i)->at(v) = calculateValueForCache(i, v); 361 } 362 } 363 rowNotCached(int i)364 bool rowNotCached(int i) { 365 return cache[i] == NULL; 366 } 367 368 virtual double calculateValueForCache(int, int) = 0; 369 370 private: 371 vector<vector<double>* > cache; 372 }; 373 374 375 class InnerProductRowCache : public RowCache { 376 public: InnerProductRowCache(const LabeledObservationVector & _obs)377 InnerProductRowCache(const LabeledObservationVector& _obs) : obs(_obs), RowCache(_obs.size()) {} ~InnerProductRowCache()378 virtual ~InnerProductRowCache() {} 379 getInnerProduct(const LabeledObservation & obs_i,const LabeledObservation & obs_j)380 double getInnerProduct(const LabeledObservation& obs_i, const LabeledObservation& obs_j) { 381 return getCachedValue( 382 obs_i.datasetIndex, 383 obs_j.datasetIndex 384 ); 385 } 386 calculateValueForCache(int i,int j)387 double calculateValueForCache(int i, int j) { 388 return inner_product(obs[i].second->begin(), obs[i].second->end(), obs[j].second->begin(), 0.0); 389 } 390 391 private: 392 const LabeledObservationVector& obs; 393 }; 394 395 396 // The KernelFunction class caches a partial kernel value that does not depend on kernel parameters. 397 class KernelFunction { 398 public: 399 //KernelFunction(const LabeledObservationVector& _obs, InnerProductCache& _ipc) : obs(_obs), innerProductRowCache(_ipc) {} 400 KernelFunction(const LabeledObservationVector & _obs)401 KernelFunction(const LabeledObservationVector& _obs) : 402 obs(_obs), 403 cache(_obs.size(), NULL) {} 404 ~KernelFunction()405 virtual ~KernelFunction() { 406 for (int i = 0; i < cache.size(); i++) { 407 if ( !rowNotCached(i) ) { 408 delete cache[i]; 409 } 410 } 411 } 412 413 virtual double similarity(const LabeledObservation&, const LabeledObservation&) = 0; 414 virtual void setParameters(const ParameterMap&) = 0; 415 virtual void getDefaultParameterRanges(ParameterRangeMap&) = 0; 416 417 virtual double calculateParameterFreeSimilarity(const LabeledObservation&, const LabeledObservation&) = 0; 418 getCachedParameterFreeSimilarity(const LabeledObservation & obs_i,const LabeledObservation & obs_j)419 double getCachedParameterFreeSimilarity(const LabeledObservation& obs_i, const LabeledObservation& obs_j) { 420 const int i = obs_i.datasetIndex; 421 const int j = obs_j.datasetIndex; 422 423 if ( rowNotCached(i) ) { 424 cache[i] = new vector<double>(obs.size(), numeric_limits<double>::signaling_NaN()); 425 for ( int v = 0; v < obs.size(); v++ ) { 426 cache.at(i)->at(v) = calculateParameterFreeSimilarity(obs[i], obs[v]); 427 } 428 } 429 return cache.at(i)->at(j); 430 } 431 rowNotCached(int i)432 bool rowNotCached(int i) { 433 return cache[i] == NULL; 434 } 435 436 private: 437 const LabeledObservationVector& obs; 438 //vector<vector<double> > cache; 439 vector<vector<double>* > cache; 440 //InnerProductRowCache& innerProductCache; 441 }; 442 443 444 class LinearKernelFunction : public KernelFunction { 445 public: 446 // parameters must be set before using a KernelFunction is used LinearKernelFunction(const LabeledObservationVector & _obs)447 LinearKernelFunction(const LabeledObservationVector& _obs) : KernelFunction(_obs), constant(0.0) {} ~LinearKernelFunction()448 ~LinearKernelFunction() {} 449 similarity(const LabeledObservation & i,const LabeledObservation & j)450 double similarity(const LabeledObservation& i, const LabeledObservation& j) { 451 return getCachedParameterFreeSimilarity(i, j) + constant; 452 } 453 calculateParameterFreeSimilarity(const LabeledObservation & i,const LabeledObservation & j)454 double calculateParameterFreeSimilarity(const LabeledObservation& i, const LabeledObservation& j) { 455 return inner_product(i.second->begin(), i.second->end(), j.second->begin(), 0.0); 456 } 457 getConstant()458 double getConstant() { return constant; } setConstant(double c)459 void setConstant(double c) { constant = c; } 460 setParameters(const ParameterMap & p)461 void setParameters(const ParameterMap& p) { 462 setConstant(p.find(MapKey_Constant)->second); 463 }; 464 getDefaultParameterRanges(ParameterRangeMap & p)465 void getDefaultParameterRanges(ParameterRangeMap& p) { 466 p[MapKey_Constant] = defaultConstantRange; 467 } 468 469 static const string MapKey; 470 static const string MapKey_Constant; 471 static const ParameterRange defaultConstantRange; 472 473 private: 474 double constant; 475 }; 476 477 478 class RbfKernelFunction : public KernelFunction { 479 public: 480 // parameters must be set before a KernelFunction is used RbfKernelFunction(const LabeledObservationVector & _obs)481 RbfKernelFunction(const LabeledObservationVector& _obs) : KernelFunction(_obs), gamma(0.0) {} ~RbfKernelFunction()482 ~RbfKernelFunction() {} 483 similarity(const LabeledObservation & i,const LabeledObservation & j)484 double similarity(const LabeledObservation& i, const LabeledObservation& j) { 485 //double sumOfSquaredDifs = 0.0; 486 //for (int n = 0; n < i.second->size(); n++) { 487 // sumOfSquaredDifs += pow((i.second->at(n) - j.second->at(n)), 2.0); 488 //} 489 return gamma * getCachedParameterFreeSimilarity(i, j); 490 } 491 calculateParameterFreeSimilarity(const LabeledObservation & i,const LabeledObservation & j)492 double calculateParameterFreeSimilarity(const LabeledObservation& i, const LabeledObservation& j) { 493 //double sumOfSquaredDifs = 0.0; 494 //for (int n = 0; n < i.second->size(); n++) { 495 // sumOfSquaredDifs += pow((i.second->at(n) - j.second->at(n)), 2.0); 496 //} 497 double sumOfSquaredDifs = 498 inner_product(i.second->begin(), i.second->end(), i.second->begin(), 0.0) 499 - 2.0 * inner_product(i.second->begin(), i.second->end(), j.second->begin(), 0.0) 500 + inner_product(j.second->begin(), j.second->end(), j.second->begin(), 0.0); 501 return exp(sqrt(sumOfSquaredDifs)); 502 } 503 getGamma()504 double getGamma() { return gamma; } setGamma(double g)505 void setGamma(double g) { gamma = g; } 506 setParameters(const ParameterMap & p)507 void setParameters(const ParameterMap& p) { 508 setGamma(p.find(MapKey_Gamma)->second); 509 } 510 getDefaultParameterRanges(ParameterRangeMap & p)511 void getDefaultParameterRanges(ParameterRangeMap& p) { 512 p[MapKey_Gamma] = defaultGammaRange; 513 } 514 515 static const string MapKey; 516 static const string MapKey_Gamma; 517 518 static const ParameterRange defaultGammaRange; 519 520 private: 521 double gamma; 522 }; 523 524 525 class PolynomialKernelFunction : public KernelFunction { 526 public: 527 // parameters must be set before using a KernelFunction is used PolynomialKernelFunction(const LabeledObservationVector & _obs)528 PolynomialKernelFunction(const LabeledObservationVector& _obs) : KernelFunction(_obs), c(0.0), gamma(0.0), d(0) {} ~PolynomialKernelFunction()529 ~PolynomialKernelFunction() {} 530 similarity(const LabeledObservation & i,const LabeledObservation & j)531 double similarity(const LabeledObservation& i, const LabeledObservation& j) { 532 return pow((gamma * getCachedParameterFreeSimilarity(i, j) + c), d); 533 //return pow(inner_product(i.second->begin(), i.second->end(), j.second->begin(), c), d); 534 } 535 calculateParameterFreeSimilarity(const LabeledObservation & i,const LabeledObservation & j)536 double calculateParameterFreeSimilarity(const LabeledObservation& i, const LabeledObservation& j) { 537 return inner_product(i.second->begin(), i.second->end(), j.second->begin(), 0.0); 538 } 539 setParameters(const ParameterMap & p)540 void setParameters(const ParameterMap& p) { 541 c = p.find(MapKey_Constant)->second; 542 gamma = p.find(MapKey_Coefficient)->second; 543 d = int(p.find(MapKey_Degree)->second); 544 } 545 getDefaultParameterRanges(ParameterRangeMap & p)546 void getDefaultParameterRanges(ParameterRangeMap& p) { 547 p[MapKey_Constant] = defaultConstantRange; 548 p[MapKey_Coefficient] = defaultCoefficientRange; 549 p[MapKey_Degree] = defaultDegreeRange; 550 } 551 552 static const string MapKey; 553 static const string MapKey_Constant; 554 static const string MapKey_Coefficient; 555 static const string MapKey_Degree; 556 557 static const ParameterRange defaultConstantRange; 558 static const ParameterRange defaultCoefficientRange; 559 static const ParameterRange defaultDegreeRange; 560 561 private: 562 double c; 563 double gamma; 564 int d; 565 }; 566 567 568 class SigmoidKernelFunction : public KernelFunction { 569 public: 570 // parameters must be set before using a KernelFunction is used SigmoidKernelFunction(const LabeledObservationVector & _obs)571 SigmoidKernelFunction(const LabeledObservationVector& _obs) : KernelFunction(_obs), alpha(0.0), c(0.0) {} ~SigmoidKernelFunction()572 ~SigmoidKernelFunction() {} 573 similarity(const LabeledObservation & i,const LabeledObservation & j)574 double similarity(const LabeledObservation& i, const LabeledObservation& j) { 575 return tanh(alpha * getCachedParameterFreeSimilarity(i, j) + c); 576 //return tanh(alpha * inner_product(i.second->begin(), i.second->end(), j.second->begin(), c)); 577 } 578 calculateParameterFreeSimilarity(const LabeledObservation & i,const LabeledObservation & j)579 double calculateParameterFreeSimilarity(const LabeledObservation& i, const LabeledObservation& j) { 580 return inner_product(i.second->begin(), i.second->end(), j.second->begin(), 0.0); 581 } 582 setParameters(const ParameterMap & p)583 void setParameters(const ParameterMap& p) { 584 alpha = p.find(MapKey_Alpha)->second; 585 c = p.find(MapKey_Constant)->second; 586 } 587 getDefaultParameterRanges(ParameterRangeMap & p)588 void getDefaultParameterRanges(ParameterRangeMap& p) { 589 p[MapKey_Alpha] = defaultAlphaRange; 590 p[MapKey_Constant] = defaultConstantRange; 591 } 592 593 static const string MapKey; 594 static const string MapKey_Alpha; 595 static const string MapKey_Constant; 596 597 static const ParameterRange defaultAlphaRange; 598 static const ParameterRange defaultConstantRange; 599 private: 600 double alpha; 601 double c; 602 }; 603 604 605 class KernelFactory { 606 public: getKernelFunctionForKey(string kernelFunctionKey,const LabeledObservationVector & obs)607 static KernelFunction* getKernelFunctionForKey(string kernelFunctionKey, const LabeledObservationVector& obs) { 608 if ( kernelFunctionKey == LinearKernelFunction::MapKey ) { 609 return new LinearKernelFunction(obs); 610 } 611 else if ( kernelFunctionKey == RbfKernelFunction::MapKey ) { 612 return new RbfKernelFunction(obs); 613 } 614 else if ( kernelFunctionKey == PolynomialKernelFunction::MapKey ) { 615 return new PolynomialKernelFunction(obs); 616 } 617 else if ( kernelFunctionKey == SigmoidKernelFunction::MapKey ) { 618 return new SigmoidKernelFunction(obs); 619 } 620 else { 621 throw new exception(); 622 } 623 } 624 }; 625 626 627 typedef map<string, KernelFunction*> KernelFunctionMap; 628 629 // An instance of KernelFunctionFactory dynamically allocates kernel function 630 // instances and maintains a table of pointers to them. This allows kernel 631 // function instances to be reused which improves performance since the 632 // kernel values do not have to be recalculated as often. A KernelFunctionFactory 633 // maintains an inner product cache used by the KernelFunctions it builds. 634 class KernelFunctionFactory { 635 public: KernelFunctionFactory(const LabeledObservationVector & _obs)636 KernelFunctionFactory(const LabeledObservationVector& _obs) : obs(_obs) {} ~KernelFunctionFactory()637 ~KernelFunctionFactory() { 638 for ( KernelFunctionMap::iterator i = kernelFunctionTable.begin(); i != kernelFunctionTable.end(); i++ ) { 639 delete i->second; 640 } 641 } 642 getKernelFunctionForKey(string kernelFunctionKey)643 KernelFunction& getKernelFunctionForKey(string kernelFunctionKey) { 644 if ( kernelFunctionTable.count(kernelFunctionKey) == 0 ) { 645 kernelFunctionTable.insert( 646 make_pair( 647 kernelFunctionKey, 648 KernelFactory::getKernelFunctionForKey(kernelFunctionKey, obs) 649 ) 650 ); 651 } 652 return *kernelFunctionTable[kernelFunctionKey]; 653 } 654 655 private: 656 const LabeledObservationVector& obs; 657 KernelFunctionMap kernelFunctionTable; 658 //InnerProductCache innerProductCache; 659 }; 660 661 662 class KernelFunctionCache { 663 public: KernelFunctionCache(KernelFunction & _k,const LabeledObservationVector & _obs)664 KernelFunctionCache(KernelFunction& _k, const LabeledObservationVector& _obs) : 665 k(_k), obs(_obs), 666 cache(_obs.size(), NULL) {} ~KernelFunctionCache()667 ~KernelFunctionCache() { 668 669 for (int i = 0; i < cache.size(); i++) { 670 if ( !rowNotCached(i) ) { 671 delete cache[i]; 672 } 673 } 674 } 675 similarity(const LabeledObservation & obs_i,const LabeledObservation & obs_j)676 double similarity(const LabeledObservation& obs_i, const LabeledObservation& obs_j) { 677 const int i = obs_i.datasetIndex; 678 const int j = obs_j.datasetIndex; 679 // if the first element of row i is NaN then calculate all elements for row i 680 if ( rowNotCached(i) ) { 681 cache[i] = new vector<double>(obs.size(), numeric_limits<double>::signaling_NaN()); 682 for ( int v = 0; v < obs.size(); v++ ) { 683 cache.at(i)->at(v) = k.similarity( 684 obs[i], 685 obs[v] 686 ); 687 } 688 } 689 return cache.at(i)->at(j); 690 } 691 rowNotCached(int i)692 bool rowNotCached(int i) { 693 return cache[i] == NULL; 694 } 695 696 private: 697 KernelFunction& k; 698 const LabeledObservationVector& obs; 699 //vector<vector<double> > cache; 700 vector<vector<double>* > cache; 701 }; 702 703 704 // The SVM class implements the Support Vector Machine 705 // discriminant function. Instances are constructed with 706 // a vector of class labels (+1.0 or -1.0), a vector of dual 707 // coefficients, a vector of observations, and a bias value. 708 // 709 // The class SmoTrainer is responsible for determining the dual 710 // coefficients and bias value. 711 // 712 class SVM { 713 public: SVM(const vector<double> & yy,const vector<double> & aa,const LabeledObservationVector & oo,double bb,const NumericClassToLabel & mm)714 SVM(const vector<double>& yy, const vector<double>& aa, const LabeledObservationVector& oo, double bb, const NumericClassToLabel& mm) : 715 y(yy), a(aa), x(oo), b(bb), discriminantToLabel(mm) {} ~SVM()716 ~SVM() {} 717 718 // the classify method should accept a list of observations? 719 int discriminant(const Observation&) const; classify(const Observation & observation) const720 Label classify(const Observation& observation) const { 721 //return discriminantToLabel[discriminant(observation)]; 722 return discriminantToLabel.find(discriminant(observation))->second; 723 } 724 LabelVector classify(const LabeledObservationVector&) const; 725 double score(const LabeledObservationVector&) const; 726 getDiscriminantToLabel() const727 NumericClassToLabel getDiscriminantToLabel() const { return discriminantToLabel; } getLabelPair() const728 LabelPair getLabelPair() const { return buildLabelPair(discriminantToLabel.find(1)->second, discriminantToLabel.find(-1)->second); } 729 730 public: 731 // y holds the numeric class: +1.0 or -1.0 732 const vector<double> y; 733 // a holds the optimal dual coefficients 734 const vector<double> a; 735 // x holds the support vectors 736 const LabeledObservationVector x; 737 const double b; 738 const NumericClassToLabel discriminantToLabel; 739 }; 740 741 742 class SvmPerformanceSummary { 743 public: SvmPerformanceSummary()744 SvmPerformanceSummary() {} 745 // this constructor should be used by clients other than tests SvmPerformanceSummary(const SVM & svm,const LabeledObservationVector & actual)746 SvmPerformanceSummary(const SVM& svm, const LabeledObservationVector& actual) { 747 init(svm, actual, svm.classify(actual)); 748 } 749 // this constructor is intended for unit testing SvmPerformanceSummary(const SVM & svm,const LabeledObservationVector & actual,const LabelVector & predictions)750 SvmPerformanceSummary(const SVM& svm, const LabeledObservationVector& actual, const LabelVector& predictions) { 751 init(svm, actual, predictions); 752 } 753 getPositiveClassLabel() const754 Label getPositiveClassLabel() const { return positiveClassLabel; } getNegativeClassLabel() const755 Label getNegativeClassLabel() const { return negativeClassLabel; } 756 getPrecision() const757 double getPrecision() const { return precision; } getRecall() const758 double getRecall() const { return recall; } getF() const759 double getF() const { return f; } getAccuracy() const760 double getAccuracy() const { return accuracy; } 761 762 private: 763 void init(const SVM&, const LabeledObservationVector&, const LabelVector&); 764 765 //const SVM& svm; 766 767 Label positiveClassLabel; 768 Label negativeClassLabel; 769 770 double precision; 771 double recall; 772 double f; 773 double accuracy; 774 }; 775 776 777 class MultiClassSvmClassificationTie : public exception { 778 public: MultiClassSvmClassificationTie(LabelVector & t,int c)779 MultiClassSvmClassificationTie(LabelVector& t, int c) : tiedLabels(t), tiedVoteCount(c) {} ~MultiClassSvmClassificationTie()780 ~MultiClassSvmClassificationTie() throw() {} 781 what() const782 virtual const char* what() const throw() { 783 return "classification tie"; 784 } 785 786 private: 787 const LabelVector tiedLabels; 788 const int tiedVoteCount; 789 }; 790 791 typedef vector<SVM*> SvmVector; 792 typedef map<LabelPair, SvmPerformanceSummary> SvmToSvmPerformanceSummary; 793 794 // Using SVM with more than two classes requires training multiple SVMs. 795 // The MultiClassSVM uses a vector of trained SVMs to do classification 796 // on data having more than two classes. 797 class MultiClassSVM { 798 public: 799 MultiClassSVM(const vector<SVM*>, const LabelSet&, const SvmToSvmPerformanceSummary&, OutputFilter); 800 ~MultiClassSVM(); 801 802 // the classify method should accept a list of observations 803 Label classify(const Observation& observation); 804 double score(const LabeledObservationVector&); 805 806 // no need to delete these pointers getSvmList()807 const SvmVector& getSvmList() { return twoClassSvmList; } 808 getLabels()809 const LabelSet& getLabels() { return labelSet; } 810 getSvmPerformanceSummary(const SVM & svm)811 const SvmPerformanceSummary& getSvmPerformanceSummary(const SVM& svm) { return svmToSvmPerformanceSummary.at(svm.getLabelPair()); } 812 getAccuracy()813 double getAccuracy() { return accuracy; } setAccuracy(const LabeledObservationVector & obs)814 void setAccuracy(const LabeledObservationVector& obs) { accuracy = score(obs); } 815 816 private: 817 const SvmVector twoClassSvmList; 818 const LabelSet labelSet; 819 const OutputFilter outputFilter; 820 821 double accuracy; 822 MothurOut* m; 823 824 // this is a map from label pairs to performance summaries 825 SvmToSvmPerformanceSummary svmToSvmPerformanceSummary; 826 }; 827 828 829 //class SvmTrainingInterruptedException : public exception { 830 //public: 831 // SvmTrainingInterruptedException(const string& m) : message(m) {} 832 // ~SvmTrainingInterruptedException() throw() {} 833 // virtual const char* what() const throw() { 834 // return message.c_str(); 835 // } 836 837 //private: 838 // string message; 839 //}; 840 841 class SmoTrainerException : public exception { 842 public: SmoTrainerException(const string & m)843 SmoTrainerException(const string& m) : message(m) {} ~SmoTrainerException()844 ~SmoTrainerException() throw() {} what() const845 virtual const char* what() const throw() { 846 return message.c_str(); 847 } 848 849 private: 850 string message; 851 }; 852 853 //class ExternalSvmTrainingInterruption { 854 //public: 855 // ExternalSvmTrainingInterruption() {} 856 // virtual ~ExternalSvmTrainingInterruption() throw() {} 857 // virtual bool interruptTraining() { return false; } 858 //}; 859 860 861 // SmoTrainer trains a support vector machine using Sequential 862 // Minimal Optimization as described in the article 863 // "Support Vector Machine Solvers" by Bottou and Lin. 864 class SmoTrainer { 865 public: SmoTrainer(OutputFilter of)866 SmoTrainer(OutputFilter of) : outputFilter(of), C(1.0) {} 867 ~SmoTrainer()868 ~SmoTrainer() {} 869 getC()870 double getC() { return C; } setC(double C)871 void setC(double C) { this->C = C; } 872 setParameters(const ParameterMap & p)873 void setParameters(const ParameterMap& p) { 874 C = p.find(MapKey_C)->second; 875 } 876 877 SVM* train(KernelFunctionCache&, const LabeledObservationVector&); 878 void assignNumericLabels(vector<double>&, const LabeledObservationVector&, NumericClassToLabel&); elementwise_multiply(vector<double> & a,vector<double> & b,vector<double> & c)879 void elementwise_multiply(vector<double>& a, vector<double>& b, vector<double>& c) { 880 transform(a.begin(), a.end(), b.begin(), c.begin(), multiplies<double>()); 881 } 882 883 static const string MapKey_C; 884 static const ParameterRange defaultCRange; 885 886 private: 887 //ExternalSvmTrainingInterruption& externalSvmTrainingInterruption; 888 889 const OutputFilter outputFilter; 890 891 double C; 892 }; 893 894 895 // KFoldLabeledObservationDivider is used in cross validation to generate 896 // training and testing data sets of labeled observations. The labels will 897 // be distributed in proportion to their frequency in the data, as much as possible. 898 // 899 // Consider a data set with 100 observations from five classes. Also, let 900 // each class have 20 observations. If we want to do 10-fold cross validation 901 // then training sets should have 90 observations and test sets should have 902 // 10 observations. A training set should have approximately equal representation 903 // from each class, as should the test sets. 904 // 905 // An instance of KFoldLabeledObservationDivider will generate training and test 906 // sets within a for loop like this: 907 // 908 // KFoldLabeledObservationDivider X(10, allLabeledObservations); 909 // for (X.start(); !X.end(); X.next()) { 910 // const LabeledObservationVector& trainingData = X.getTrainingData(); 911 // const LabeledObservationVector& testingData = X.getTestingData(); 912 // // do cross validation on one fold 913 // } 914 class KFoldLabeledObservationsDivider { 915 public: 916 // initialize the k member variable to K so end() will return true if it is called before start() 917 // this is not perfect protection against misuse but it's better than nothing KFoldLabeledObservationsDivider(int _K,const LabeledObservationVector & l)918 KFoldLabeledObservationsDivider(int _K, const LabeledObservationVector& l) : K(_K), k(_K) { 919 buildLabelToLabeledObservationVector(labelToLabeledObservationVector, l); 920 } ~KFoldLabeledObservationsDivider()921 ~KFoldLabeledObservationsDivider() {} 922 start()923 void start() { 924 k = 0; 925 trainingData.clear(); 926 testingData.clear(); 927 for (LabelToLabeledObservationVector::const_iterator p = labelToLabeledObservationVector.begin(); p != labelToLabeledObservationVector.end(); p++) { 928 appendKthFold(k, K, p->second, trainingData, testingData); 929 } 930 } 931 end()932 bool end() { 933 return k >= K; 934 } 935 next()936 void next() { 937 k++; 938 trainingData.clear(); 939 testingData.clear(); 940 for (LabelToLabeledObservationVector::const_iterator p = labelToLabeledObservationVector.begin(); p != labelToLabeledObservationVector.end(); p++) { 941 appendKthFold(k, K, p->second, trainingData, testingData); 942 } 943 } 944 getFoldNumber()945 int getFoldNumber() { return k; } getTrainingData()946 const LabeledObservationVector& getTrainingData() { return trainingData; } getTestingData()947 const LabeledObservationVector& getTestingData() { return testingData; } 948 949 // Function appendKthFold takes care of partitioning the observations in x into two sets, 950 // one for training and one for testing. The argument K specifies how many folds 951 // will be requested in all. The argument k specifies which fold to return. 952 // An example: let K=3, k=0, and let there be 10 observations (all having the same label) 953 // i i%3 (i%3)==0 k=0 partition (i%3)==1 k=1 partition (i%3)==2 k=2 partition 954 // 0 0 true testing false training false training 955 // 1 1 false training true testing false training 956 // 2 2 false training false training true testing 957 // 3 0 true testing false training false training 958 // 4 1 false training true testing false training 959 // 5 2 false training false training true testing 960 // 6 0 true testing false training false training 961 // 7 1 false training true testing false training 962 // 8 2 false training false training true testing 963 // 9 0 true testing false training false training 964 // appendKthFold(int k,int K,const LabeledObservationVector & x,LabeledObservationVector & trainingData,LabeledObservationVector & testingData)965 static void appendKthFold(int k, int K, const LabeledObservationVector& x, LabeledObservationVector& trainingData, LabeledObservationVector& testingData) { 966 //for ( int i = 0; i < x.size(); i++) { 967 int i = 0; 968 for (LabeledObservationVector::const_iterator xi = x.begin(); xi != x.end(); xi++) { 969 if ( (i % K) == k) { 970 testingData.push_back(*xi); 971 } 972 else { 973 trainingData.push_back(*xi); 974 } 975 i++; 976 } 977 } 978 979 private: 980 const int K; 981 int k; 982 LabelVector labelVector; 983 LabelToLabeledObservationVector labelToLabeledObservationVector; 984 LabeledObservationVector trainingData; 985 LabeledObservationVector testingData; 986 }; 987 988 989 // OneVsOneMultiClassSvmTrainer trains a support vector machine for each 990 // pair of labels in a set of data. 991 class OneVsOneMultiClassSvmTrainer { 992 public: 993 OneVsOneMultiClassSvmTrainer(SvmDataset&, int, int, OutputFilter&); ~OneVsOneMultiClassSvmTrainer()994 ~OneVsOneMultiClassSvmTrainer() {} 995 996 MultiClassSVM* train(const KernelParameterRangeMap&); 997 double trainOnKFolds(SmoTrainer&, KernelFunctionCache&, KFoldLabeledObservationsDivider&); getLabelSet()998 const LabelSet& getLabelSet() { return labelSet; } getLabeledObservations()999 const LabeledObservationVector& getLabeledObservations() { return svmDataset.getLabeledObservationVector(); } getLabelPairSet()1000 const LabelPairSet& getLabelPairSet() { return labelPairSet; } getLabeledObservationVectorForLabel(const Label & label)1001 const LabeledObservationVector& getLabeledObservationVectorForLabel(const Label& label) { return labelToLabeledObservationVector[label]; } 1002 getOutputFilter()1003 const OutputFilter& getOutputFilter() { return outputFilter; } 1004 1005 static void buildLabelPairSet(LabelPairSet&, const LabeledObservationVector&); 1006 static void appendTrainingAndTestingData(Label, const LabeledObservationVector&, LabeledObservationVector&, LabeledObservationVector&); 1007 1008 private: 1009 1010 const OutputFilter outputFilter; 1011 //bool verbose; 1012 1013 SvmDataset& svmDataset; 1014 1015 const int evaluationFoldCount; 1016 const int trainFoldCount; 1017 1018 LabelSet labelSet; 1019 LabelToLabeledObservationVector labelToLabeledObservationVector; 1020 LabelPairSet labelPairSet; 1021 1022 }; 1023 1024 // A better name for this class is MsvmRfe after MSVM-RFE described in 1025 // "MSVM-RFE: extensions of SVM-RFE for multiclass gene selection on 1026 // DNA microarray data", Zhou and Tuck, 2007, Bioinformatics 1027 class SvmRfe { 1028 public: SvmRfe()1029 SvmRfe() {} ~SvmRfe()1030 ~SvmRfe() {} 1031 1032 RankedFeatureList getOrderedFeatureList(SvmDataset&, OneVsOneMultiClassSvmTrainer&, const ParameterRange&, const ParameterRange&); 1033 }; 1034 1035 1036 #endif /* svm_hpp_ */ 1037