1 //
2 // svm.cpp
3 // support vector machine
4 //
5 // Created by Joshua Lynch on 6/19/2013.
6 // Copyright (c) 2013 Schloss Lab. All rights reserved.
7 //
8 #include <algorithm>
9 #include <functional>
10 #include <iomanip>
11 #include <iostream>
12 #include <limits>
13 #include <numeric>
14 #include <stack>
15 #include <utility>
16
17 #include "svm.hpp"
18
19 // OutputFilter constants
20 const int OutputFilter::QUIET = 0;
21 const int OutputFilter::INFO = 1;
22 const int OutputFilter::mDEBUG = 2;
23 const int OutputFilter::TRACE = 3;
24
25
26 #define RANGE(X) X, X + sizeof(X)/sizeof(double)
27
28 // parameters will be tested in the order they are specified
29
30 const string LinearKernelFunction::MapKey = "linear";//"LinearKernel";
31 const string LinearKernelFunction::MapKey_Constant = "constant";//"LinearKernel_Constant";
32 const double defaultLinearConstantRangeArray[] = {0.0, -1.0, 1.0, -10.0, 10.0};
33 const ParameterRange LinearKernelFunction::defaultConstantRange = ParameterRange(RANGE(defaultLinearConstantRangeArray));
34
35 const string RbfKernelFunction::MapKey = "rbf";//"RbfKernel";
36 const string RbfKernelFunction::MapKey_Gamma = "gamma";//"RbfKernel_Gamma";
37 const double defaultRbfGammaRangeArray[] = {0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0};
38 const ParameterRange RbfKernelFunction::defaultGammaRange = ParameterRange(RANGE(defaultRbfGammaRangeArray));
39
40 const string PolynomialKernelFunction::MapKey = "polynomial";//"PolynomialKernel";
41 const string PolynomialKernelFunction::MapKey_Constant = "constant";//"PolynomialKernel_Constant";
42 const string PolynomialKernelFunction::MapKey_Coefficient = "coefficient";//"PolynomialKernel_Coefficient";
43 const string PolynomialKernelFunction::MapKey_Degree = "degree";//"PolynomialKernel_Degree";
44
45 const double defaultPolynomialConstantRangeArray[] = {0.0, -1.0, 1.0, -2.0, 2.0, -3.0, 3.0};
46 const ParameterRange PolynomialKernelFunction::defaultConstantRange = ParameterRange(RANGE(defaultPolynomialConstantRangeArray));
47 const double defaultPolynomialCoefficientRangeArray[] = {0.01, 0.1, 1.0, 10.0, 100.0};
48 const ParameterRange PolynomialKernelFunction::defaultCoefficientRange = ParameterRange(RANGE(defaultPolynomialCoefficientRangeArray));
49 const double defaultPolynomialDegreeRangeArray[] = {2.0, 3.0, 4.0};
50 const ParameterRange PolynomialKernelFunction::defaultDegreeRange = ParameterRange(RANGE(defaultPolynomialDegreeRangeArray));
51
52 const string SigmoidKernelFunction::MapKey = "sigmoid";
53 const string SigmoidKernelFunction::MapKey_Alpha = "alpha";
54 const string SigmoidKernelFunction::MapKey_Constant = "constant";
55
56 const double defaultSigmoidAlphaRangeArray[] = {1.0, 2.0};
57 const ParameterRange SigmoidKernelFunction::defaultAlphaRange = ParameterRange(RANGE(defaultSigmoidAlphaRangeArray));
58 const double defaultSigmoidConstantRangeArray[] = {1.0, 2.0};
59 const ParameterRange SigmoidKernelFunction::defaultConstantRange = ParameterRange(RANGE(defaultSigmoidConstantRangeArray));
60
61 const string SmoTrainer::MapKey_C = "smoc";//"SmoTrainer_C";
62 const double defaultSmoTrainerCRangeArray[] = {0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0};
63 const ParameterRange SmoTrainer::defaultCRange = ParameterRange(RANGE(defaultSmoTrainerCRangeArray));
64
65 MothurOut* m = MothurOut::getInstance();
66
buildLabelPair(const Label & one,const Label & two)67 LabelPair buildLabelPair(const Label& one, const Label& two) {
68 LabelVector labelPair(2);
69 labelPair[0] = one;
70 labelPair[1] = two;
71 return labelPair;
72 }
73
74 // Dividing a dataset into training and testing sets while maintaining equal
75 // representation of all classes is done using a LabelToLabeledObservationVector.
76 // This container is used to divide datasets into groups of LabeledObservations
77 // having the same label. For example, given a LabeledObservationVector like
78 // ["blue", [1.0, 2.0, 3.0]]
79 // ["green", [3.0, 4.0, 5.0]]
80 // ["blue", [2,0, 3.0. 4.0]]
81 // ["green", [4.0, 5.0, 6.0]]
82 // the corresponding LabelToLabeledObservationVector looks like
83 // "blue" : [["blue", [1.0, 2.0, 3.0]], ["blue", [2,0, 3.0. 4.0]]]
84 // "green" : [["green", [3.0, 4.0, 5.0]], ["green", [4.0, 5.0, 6.0]]]
buildLabelToLabeledObservationVector(LabelToLabeledObservationVector & labelToLabeledObservationVector,const LabeledObservationVector & labeledObservationVector)85 void buildLabelToLabeledObservationVector(LabelToLabeledObservationVector& labelToLabeledObservationVector, const LabeledObservationVector& labeledObservationVector) {
86 for ( LabeledObservationVector::const_iterator j = labeledObservationVector.begin(); j != labeledObservationVector.end(); j++ ) {
87 labelToLabeledObservationVector[j->first].push_back(*j);
88 }
89 }
90
91
92 class MeanAndStd {
93 private:
94 double n;
95 double M2;
96 double mean;
97
98 public:
MeanAndStd()99 MeanAndStd() {}
~MeanAndStd()100 ~MeanAndStd() {}
101
initialize()102 void initialize() {
103 n = 0.0;
104 mean = 0.0;
105 M2 = 0.0;
106 }
107
processNextValue(double x)108 void processNextValue(double x) {
109 n += 1.0;
110 double delta = x - mean;
111 mean += delta / n;
112 M2 += delta * (x - mean);
113 }
114
getMean()115 double getMean() {
116 return mean;
117 }
118
getStd()119 double getStd() {
120 double variance = M2 / (n - 1.0);
121 return sqrt(variance);
122 }
123 };
124
125
126 // The LabelMatchesEither functor is used only in a call to remove_copy_if in the
127 // OneVsOneMultiClassSvmTrainer::train method. It returns true if the labeled
128 // observation argument has the same label as either of the two label arguments.
129 class FeatureLabelMatches {
130 public:
FeatureLabelMatches(const string & _featureLabel)131 FeatureLabelMatches(const string& _featureLabel) : featureLabel(_featureLabel){}
132
operator ()(const Feature & f)133 bool operator() (const Feature& f) {
134 return f.getFeatureLabel() == featureLabel;
135 }
136
137 private:
138 const string& featureLabel;
139
140 };
141
removeFeature(Feature featureToRemove,LabeledObservationVector & observations,FeatureVector & featureVector)142 Feature removeFeature(Feature featureToRemove, LabeledObservationVector& observations, FeatureVector& featureVector) {
143 FeatureLabelMatches matchFeatureLabel(featureToRemove.getFeatureLabel());
144 featureVector.erase(
145 remove_if(featureVector.begin(), featureVector.end(), matchFeatureLabel),
146 featureVector.end()
147 );
148 for ( ObservationVector::size_type observation = 0; observation < observations.size(); observation++ ) {
149 observations[observation].removeFeatureAtIndex(featureToRemove.getFeatureIndex());
150 }
151 // update the feature indices
152 for ( int i = 0; i < featureVector.size(); i++ ) {
153 featureVector.at(i).setFeatureIndex(i);
154 }
155 featureToRemove.setFeatureIndex(-1);
156 return featureToRemove;
157 }
158
applyStdThreshold(double stdThreshold,LabeledObservationVector & observations,FeatureVector & featureVector)159 FeatureVector applyStdThreshold(double stdThreshold, LabeledObservationVector& observations, FeatureVector& featureVector) {
160 // calculate standard deviation of each feature
161 // remove features with standard deviation less than or equal to stdThreshold
162 MeanAndStd ms;
163 // loop over features in reverse order so we can get the index of each
164 // for example,
165 // if there are 5 features a,b,c,d,e
166 // and features a, c, e fall below the stdThreshold
167 // loop iteration 0: remove feature e (index 4) -- features are now a,b,c,d
168 // loop iteration 1: leave feature d (index 3)
169 // loop iteration 2: remove feature c (index 2) -- features are now a,b,d
170 // loop iteration 3: leave feature b (index 1)
171 // loop iteration 4: remove feature a (index 0) -- features are now b,d
172 FeatureVector removedFeatureVector;
173 for ( int feature = observations[0].second->size()-1; feature >= 0 ; feature-- ) {
174 ms.initialize();
175 m->mothurOut("feature index " + toString(feature)); m->mothurOutEndLine();
176 for ( ObservationVector::size_type observation = 0; observation < observations.size(); observation++ ) {
177 ms.processNextValue(observations[observation].second->at(feature));
178 }
179 m->mothurOut( "feature " + toString(feature) + " has std " + toString(ms.getStd()) ); m->mothurOutEndLine();
180 if ( ms.getStd() <= stdThreshold ) {
181 m->mothurOut( "removing feature with index " + toString(feature) ); m->mothurOutEndLine();
182 // remove this feature
183
184 Feature featureToRemove = featureVector.at(feature);
185 removedFeatureVector.push_back(
186 removeFeature(featureToRemove, observations, featureVector)
187 );
188 }
189 }
190 reverse(removedFeatureVector.begin(), removedFeatureVector.end());
191 return removedFeatureVector;
192 }
193
194
195 // this function standardizes data to mean 0 and variance 1
196 // but this may not be a good standardization for OTU data
transformZeroMeanUnitVariance(LabeledObservationVector & observations)197 void transformZeroMeanUnitVariance(LabeledObservationVector& observations) {
198 bool vebose = false;
199 // online method for mean and variance
200 MeanAndStd ms;
201 for ( Observation::size_type feature = 0; feature < observations[0].second->size(); feature++ ) {
202 ms.initialize();
203 //double n = 0.0;
204 //double mean = 0.0;
205 //double M2 = 0.0;
206 for ( ObservationVector::size_type observation = 0; observation < observations.size(); observation++ ) {
207 ms.processNextValue(observations[observation].second->at(feature));
208 //n += 1.0;
209 //double x = observations[observation].second->at(feature);
210 //double delta = x - mean;
211 //mean += delta / n;
212 //M2 += delta * (x - mean);
213 }
214 //double variance = M2 / (n - 1.0);
215 //double standardDeviation = sqrt(variance);
216 if (vebose) {
217 m->mothurOut( "mean of feature " + toString(feature) + " is " + toString(ms.getMean()) ); m->mothurOutEndLine();
218 m->mothurOut( "std of feature " + toString(feature) + " is " + toString(ms.getStd()) ); m->mothurOutEndLine();
219 }
220 // normalize the feature
221 double mean = ms.getMean();
222 double std = ms.getStd();
223 for ( ObservationVector::size_type observation = 0; observation < observations.size(); observation++ ) {
224 observations[observation].second->at(feature) = (observations[observation].second->at(feature) - mean ) / std;
225 }
226 }
227 }
228
229
getMinimumFeatureValueForObservation(Observation::size_type featureIndex,LabeledObservationVector & observations)230 double getMinimumFeatureValueForObservation(Observation::size_type featureIndex, LabeledObservationVector& observations) {
231 double featureMinimum = numeric_limits<double>::max();
232 for ( ObservationVector::size_type observation = 0; observation < observations.size(); observation++ ) {
233 if ( observations[observation].second->at(featureIndex) < featureMinimum ) {
234 featureMinimum = observations[observation].second->at(featureIndex);
235 }
236 }
237 return featureMinimum;
238 }
239
240
getMaximumFeatureValueForObservation(Observation::size_type featureIndex,LabeledObservationVector & observations)241 double getMaximumFeatureValueForObservation(Observation::size_type featureIndex, LabeledObservationVector& observations) {
242 double featureMaximum = numeric_limits<double>::min();
243 for ( ObservationVector::size_type observation = 0; observation < observations.size(); observation++ ) {
244 if ( observations[observation].second->at(featureIndex) > featureMaximum ) {
245 featureMaximum = observations[observation].second->at(featureIndex);
246 }
247 }
248 return featureMaximum;
249 }
250
251
252 // this function standardizes data to minimum value 0.0 and maximum value 1.0
transformZeroOne(LabeledObservationVector & observations)253 void transformZeroOne(LabeledObservationVector& observations) {
254 for ( Observation::size_type feature = 0; feature < observations[0].second->size(); feature++ ) {
255 double featureMinimum = getMinimumFeatureValueForObservation(feature, observations);
256 double featureMaximum = getMaximumFeatureValueForObservation(feature, observations);
257 // standardize the feature
258 for ( ObservationVector::size_type observation = 0; observation < observations.size(); observation++ ) {
259 double x = observations[observation].second->at(feature);
260 double xstd = (x - featureMinimum) / (featureMaximum - featureMinimum);
261 observations[observation].second->at(feature) = xstd / (1.0 - 0.0) + 0.0;
262 }
263 }
264 }
265
266
267 //
268 // SVM member functions
269 //
270 // the discriminant member function returns +1 or -1
discriminant(const Observation & observation) const271 int SVM::discriminant(const Observation& observation) const {
272 // d is the discriminant function
273 double d = b;
274 for ( int i = 0; i < y.size(); i++ ) {
275 d += y[i]*a[i]*inner_product(observation.begin(), observation.end(), x[i].second->begin(), 0.0);
276 }
277 return d > 0.0 ? 1 : -1;
278 }
279
classify(const LabeledObservationVector & twoClassLabeledObservationVector) const280 LabelVector SVM::classify(const LabeledObservationVector& twoClassLabeledObservationVector) const {
281 LabelVector predictionVector;
282 for ( LabeledObservationVector::const_iterator i = twoClassLabeledObservationVector.begin(); i != twoClassLabeledObservationVector.end(); i++ ) {
283 Label prediction = classify(*(i->getObservation()));
284 Label actual = i->getLabel();
285
286 predictionVector.push_back(prediction);
287 }
288 return predictionVector;
289 }
290
291 // the score member function classifies each labeled observation from the
292 // argument and returns the fraction of correct classifications
293 // don't need this any more????
score(const LabeledObservationVector & twoClassLabeledObservationVector) const294 double SVM::score(const LabeledObservationVector& twoClassLabeledObservationVector) const {
295
296 double s = 0.0;
297 for ( LabeledObservationVector::const_iterator i = twoClassLabeledObservationVector.begin(); i != twoClassLabeledObservationVector.end(); i++ ) {
298 Label predicted_label = classify(*(i->second));
299
300 if ( predicted_label == i->first ) {
301 s = s + 1.0;
302 }
303 else {
304
305 }
306 }
307 return s / double(twoClassLabeledObservationVector.size());
308 }
309
init(const SVM & svm,const LabeledObservationVector & actualLabels,const LabelVector & predictedLabels)310 void SvmPerformanceSummary::init(const SVM& svm, const LabeledObservationVector& actualLabels, const LabelVector& predictedLabels) {
311 // accumulate four counts:
312 // tp (true positive) -- correct classifications (classified +1 as +1)
313 // fp (false positive) -- incorrect classifications (classified -1 as +1)
314 // fn (false negative) -- incorrect classifications (classified +1 as -1)
315 // tn (true negative) -- correct classification (classified -1 as -1)
316 // the label corresponding to discriminant +1 will be the 'positive' class
317 NumericClassToLabel discriminantToLabel = svm.getDiscriminantToLabel();
318 positiveClassLabel = discriminantToLabel[1];
319 negativeClassLabel = discriminantToLabel[-1];
320
321 double tp = 0;
322 double fp = 0;
323 double fn = 0;
324 double tn = 0;
325 for (int i = 0; i < actualLabels.size(); i++) {
326 Label predictedLabel = predictedLabels.at(i);
327 Label actualLabel = actualLabels.at(i).getLabel();
328
329 if ( actualLabel.compare(positiveClassLabel) == 0) {
330 if ( predictedLabel.compare(positiveClassLabel) == 0 ) {
331 tp++;
332 }
333 else if ( predictedLabel.compare(negativeClassLabel) == 0 ) {
334 fn++;
335 }
336 else {
337 m->mothurOut( "actual label is positive but something is wrong" ); m->mothurOutEndLine();
338 }
339 }
340 else if ( actualLabel.compare(negativeClassLabel) == 0 ) {
341 if ( predictedLabel.compare(positiveClassLabel) == 0 ) {
342 fp++;
343 }
344 else if ( predictedLabel.compare(negativeClassLabel) == 0 ) {
345 tn++;
346 }
347 else {
348 m->mothurOut( "actual label is negative but something is wrong" ); m->mothurOutEndLine();
349 }
350 }
351 else {
352 // in the event we have been given an observation that is labeled
353 // neither positive nor negative then we will get a false classification
354
355 if ( predictedLabel.compare(positiveClassLabel) ) {
356 fp++;
357 }
358 else {
359 fn++;
360 }
361 }
362 }
363 Utils util;
364 if (util.isEqual(tp, 0) && util.isEqual(fp, 0) ) {
365 precision = 0;
366 }
367 else {
368 precision = tp / (tp + fp);
369 }
370 recall = tp / (tp + fn);
371 if ( util.isEqual(precision, 0) && util.isEqual(recall, 0) ) {
372 f = 0;
373 }
374 else {
375 f = 2.0 * (precision * recall) / (precision + recall);
376 }
377 accuracy = (tp + tn) / (tp + tn + fp + fn);
378 }
379
380
MultiClassSVM(const vector<SVM * > s,const LabelSet & l,const SvmToSvmPerformanceSummary & p,OutputFilter of)381 MultiClassSVM::MultiClassSVM(const vector<SVM*> s, const LabelSet& l, const SvmToSvmPerformanceSummary& p, OutputFilter of) : twoClassSvmList(s.begin(), s.end()), labelSet(l), svmToSvmPerformanceSummary(p), outputFilter(of), accuracy(0) {}
382
383
~MultiClassSVM()384 MultiClassSVM::~MultiClassSVM() {
385 for ( int i = 0; i < twoClassSvmList.size(); i++ ) {
386 delete twoClassSvmList[i];
387 }
388 }
389
390 // The fewerVotes function is used to find the maximum vote
391 // tally in MultiClassSVM::classify. This function returns true
392 // if the first element (number of votes for the first label) is
393 // less than the second element (number of votes for the second label).
fewerVotes(const pair<Label,int> & p,const pair<Label,int> & q)394 bool fewerVotes(const pair<Label, int>& p, const pair<Label, int>& q) {
395 return p.second < q.second;
396 }
397
398
classify(const Observation & observation)399 Label MultiClassSVM::classify(const Observation& observation) {
400 map<Label, int> labelToVoteCount;
401 for ( int i = 0; i < twoClassSvmList.size(); i++ ) {
402 Label predictedLabel = twoClassSvmList[i]->classify(observation);
403 labelToVoteCount[predictedLabel]++;
404 }
405 pair<Label, int> winner = *max_element(labelToVoteCount.begin(), labelToVoteCount.end(), fewerVotes);
406 LabelVector winningLabels;
407 winningLabels.push_back(winner.first);
408 for ( map<Label, int>::const_iterator i = labelToVoteCount.begin(); i != labelToVoteCount.end(); i++ ) {
409 if ( i->second == winner.second && i->first != winner.first ) {
410 winningLabels.push_back(i->first);
411 }
412 }
413 if ( winningLabels.size() == 1) {
414 // we have a winner
415 }
416 else {
417 // we have a tie
418 throw MultiClassSvmClassificationTie(winningLabels, winner.second);
419 }
420
421 return winner.first;
422 }
423
score(const LabeledObservationVector & multiClassLabeledObservationVector)424 double MultiClassSVM::score(const LabeledObservationVector& multiClassLabeledObservationVector) {
425 double s = 0.0;
426 for (LabeledObservationVector::const_iterator i = multiClassLabeledObservationVector.begin(); i != multiClassLabeledObservationVector.end(); i++) {
427
428 try {
429 Label predicted_label = classify(*(i->second));
430 if ( predicted_label == i->first ) {
431 s = s + 1.0;
432 }
433 else {
434 // predicted label does not match actual label
435 }
436 }
437 catch ( MultiClassSvmClassificationTie& e ) {
438 if ( outputFilter.debug() ) {
439 m->mothurOut( "classification tie for observation " + toString(i->datasetIndex) + " with label " + toString(i->first) ); m->mothurOutEndLine();
440 }
441 }
442 }
443 return s / double(multiClassLabeledObservationVector.size());
444 }
445
446 class MaxIterationsExceeded : public exception {
what() const447 virtual const char* what() const throw() {
448 return "maximum iterations exceeded during SMO";
449 }
450 } maxIterationsExceeded;
451
452
453 //SvmTrainingInterruptedException smoTrainingInterruptedException("SMO training interrupted by user");
454
455 // The train method implements Sequential Minimal Optimization as described in
456 // "Support Vector Machine Solvers" by Bottou and Lin.
457 //
458 // SmoTrainer::train releases a pointer to an SVM into the wild so we must be
459 // careful about handling the LabeledObservationVector.... Must create a copy
460 // of those labeled vectors???
train(KernelFunctionCache & K,const LabeledObservationVector & twoClassLabeledObservationVector)461 SVM* SmoTrainer::train(KernelFunctionCache& K, const LabeledObservationVector& twoClassLabeledObservationVector) {
462 const int observationCount = twoClassLabeledObservationVector.size();
463 const int featureCount = twoClassLabeledObservationVector[0].second->size();
464
465 if (outputFilter.debug()) m->mothurOut( "observation count : " + toString(observationCount) ); m->mothurOutEndLine();
466 if (outputFilter.debug()) m->mothurOut( "feature count : " + toString(featureCount) ); m->mothurOutEndLine();
467 // dual coefficients
468 vector<double> a(observationCount, 0.0);
469 // gradient
470 vector<double> g(observationCount, 1.0);
471 // convert the labels to -1.0,+1.0
472 vector<double> y(observationCount);
473 if (outputFilter.trace()) m->mothurOut( "assign numeric labels" ); m->mothurOutEndLine();
474 NumericClassToLabel discriminantToLabel;
475 assignNumericLabels(y, twoClassLabeledObservationVector, discriminantToLabel);
476 if (outputFilter.trace()) m->mothurOut( "assign A and B" ); m->mothurOutEndLine();
477 vector<double> A(observationCount);
478 vector<double> B(observationCount);
479 Utils util;
480 for ( int n = 0; n < observationCount; n++ ) {
481 if ( util.isEqual(y[n], +1.0)) {
482 A[n] = 0.0;
483 B[n] = C;
484 }
485 else {
486 A[n] = -C;
487 B[n] = 0;
488 }
489 if (outputFilter.trace()) m->mothurOut( toString(n) + " " + toString(A[n]) + " " + toString(B[n]) ); m->mothurOutEndLine();
490 }
491 if (outputFilter.trace()) m->mothurOut( "assign K" ); m->mothurOutEndLine();
492 int m_count = 0;
493 vector<double> u(3);
494 vector<double> ya(observationCount);
495 vector<double> yg(observationCount);
496 double lambda = numeric_limits<double>::max();
497 while ( true ) {
498
499 if (m->getControl_pressed()) { return 0; }
500
501 m_count++;
502 int i = 0; // 0
503 int j = 0; // 0
504 double yg_max = numeric_limits<double>::min();
505 double yg_min = numeric_limits<double>::max();
506 if (outputFilter.trace()) m->mothurOut( "m = " + toString(m_count) ); m->mothurOutEndLine();
507 for ( int k = 0; k < observationCount; k++ ) {
508 ya[k] = y[k] * a[k];
509 yg[k] = y[k] * g[k];
510 }
511 if (outputFilter.trace()) {
512 m->mothurOut( "yg =");
513 for ( int k = 0; k < observationCount; k++ ) {
514
515 m->mothurOut( " " + toString(yg[k]));
516 }
517 m->mothurOutEndLine();
518 }
519
520 for ( int k = 0; k < observationCount; k++ ) {
521 if ( ya[k] < B[k] && yg[k] > yg_max ) {
522 yg_max = yg[k];
523 i = k;
524 }
525 if ( A[k] < ya[k] && yg[k] < yg_min ) {
526 yg_min = yg[k];
527 j = k;
528 }
529
530 }
531 // maximum violating pair is i,j
532 if (outputFilter.trace()) {
533 m->mothurOut( "maximal violating pair: " + toString(i) + " " + toString(j) ); m->mothurOutEndLine();
534 m->mothurOut( " i = " + toString(i) + " features: ");
535 for ( int feature = 0; feature < featureCount; feature++ ) {
536 m->mothurOut( toString(twoClassLabeledObservationVector[i].second->at(feature)) + " ");
537 };
538 m->mothurOutEndLine();
539 m->mothurOut( " j = " + toString(j) + " features: ");
540 for ( int feature = 0; feature < featureCount; feature++ ) {
541 m->mothurOut( toString(twoClassLabeledObservationVector[j].second->at(feature)) + " ");
542 };
543 m->mothurOutEndLine();
544 }
545
546 // parameterize this
547 if ( m_count > 1000 ) { //1000
548 // what happens if we just go with what we've got instead of throwing an exception?
549 // things work pretty well for the most part
550 // might be better to look at lambda???
551 if (outputFilter.debug()) m->mothurOut( "iteration limit reached with lambda = " + toString(lambda) ); m->mothurOutEndLine();
552 break;
553 }
554
555 // using lambda to break is a good performance enhancement
556 if ( yg[i] <= yg[j] or lambda < 0.0001) {
557 break;
558 }
559 u[0] = B[i] - ya[i];
560 u[1] = ya[j] - A[j];
561
562 double K_ii = K.similarity(twoClassLabeledObservationVector[i], twoClassLabeledObservationVector[i]);
563 double K_jj = K.similarity(twoClassLabeledObservationVector[j], twoClassLabeledObservationVector[j]);
564 double K_ij = K.similarity(twoClassLabeledObservationVector[i], twoClassLabeledObservationVector[j]);
565 u[2] = (yg[i] - yg[j]) / (K_ii+K_jj-2.0*K_ij);
566 if (outputFilter.trace()) m->mothurOut( "directions: (" + toString(u[0]) + "," + toString(u[1]) + "," + toString(u[2]) + ")" ); m->mothurOutEndLine();
567 lambda = *min_element(u.begin(), u.end());
568 if (outputFilter.trace()) m->mothurOut( "lambda: " + toString(lambda) ); m->mothurOutEndLine();
569 for ( int k = 0; k < observationCount; k++ ) {
570 double K_ik = K.similarity(twoClassLabeledObservationVector[i], twoClassLabeledObservationVector[k]);
571 double K_jk = K.similarity(twoClassLabeledObservationVector[j], twoClassLabeledObservationVector[k]);
572 g[k] += (-lambda * y[k] * K_ik + lambda * y[k] * K_jk);
573 }
574 if (outputFilter.trace()) {
575 m->mothurOut( "g =");
576 for ( int k = 0; k < observationCount; k++ ) {
577 m->mothurOut( " " + toString(g[k]));
578 }
579 m->mothurOutEndLine();
580 }
581 a[i] += y[i] * lambda;
582 a[j] -= y[j] * lambda;
583 }
584
585
586 // at this point the optimal a's have been found
587 // now use them to find w and b
588 if (outputFilter.trace()) m->mothurOut( "find w" ); m->mothurOutEndLine();
589 vector<double> w(twoClassLabeledObservationVector[0].second->size(), 0.0);
590 double b = 0.0;
591 for ( int i = 0; i < y.size(); i++ ) {
592 if (outputFilter.trace()) m->mothurOut( "alpha[" + toString(i) + "] = " + toString(a[i]) ); m->mothurOutEndLine();
593 for ( int j = 0; j < w.size(); j++ ) {
594 w[j] += a[i] * y[i] * twoClassLabeledObservationVector[i].second->at(j);
595 }
596 if ( A[i] < a[i] && a[i] < B[i] ) {
597 b = yg[i];
598 if (outputFilter.trace()) m->mothurOut( "b = " + toString(b) ); m->mothurOutEndLine();
599 }
600 }
601
602 if (outputFilter.trace()) {
603 for ( int i = 0; i < w.size(); i++ ) {
604 m->mothurOut( "w[" + toString(i) + "] = " + toString(w[i]) ); m->mothurOutEndLine();
605 }
606 }
607
608 // be careful about passing twoClassLabeledObservationVector - what if this vector
609 // is deleted???
610 //
611 // we can eliminate elements of y, a and observation vectors corresponding to a = 0
612 vector<double> support_y;
613 vector<double> nonzero_a;
614 LabeledObservationVector supportVectors;
615 for (int i = 0; i < a.size(); i++) {
616 if ( util.isEqual(a.at(i), 0.0) ) {
617 // this dual coefficient does not correspond to a support vector
618 }
619 else {
620 support_y.push_back(y.at(i));
621 nonzero_a.push_back(a.at(i));
622 supportVectors.push_back(twoClassLabeledObservationVector.at(i));
623 }
624 }
625 //return new SVM(y, a, twoClassLabeledObservationVector, b, discriminantToLabel);
626 if (outputFilter.info()) m->mothurOut( "found " + toString(supportVectors.size()) + " support vectors\n" );
627 return new SVM(support_y, nonzero_a, supportVectors, b, discriminantToLabel);
628 }
629
630 typedef map<Label, double> LabelToNumericClassLabel;
631
632 // For SVM training we need to assign numeric class labels of -1.0 and +1.0.
633 // This method populates the y vector argument with -1.0 and +1.0
634 // corresponding to the two classes in the labelVector argument.
635 // For example, if labeledObservationVector looks like this:
636 // [ (0, "blue", [...some observations...]),
637 // (1, "green", [...some observations...]),
638 // (2, "blue", [...some observations...]) ]
639 // Then after the function executes the y vector will look like this:
640 // [-1.0, blue
641 // +1.0, green
642 // -1.0] blue
643 // and discriminantToLabel will look like this:
644 // { -1.0 : "blue",
645 // +1.0 : "green" }
646 // The label "blue" is mapped to -1.0 because it is (lexicographically) less than "green".
647 // When given labels "blue" and "green" this function will always assign "blue" to -1.0 and
648 // "green" to +1.0. This is not fundamentally important but it makes testing easier and is
649 // not a hassle to implement.
assignNumericLabels(vector<double> & y,const LabeledObservationVector & labeledObservationVector,NumericClassToLabel & discriminantToLabel)650 void SmoTrainer::assignNumericLabels(vector<double>& y, const LabeledObservationVector& labeledObservationVector, NumericClassToLabel& discriminantToLabel) {
651 // it would be nice if we assign -1.0 and +1.0 consistently for each pair of labels
652 // I think the label set will always be traversed in sorted order so we should get this for free
653
654 // we are going to overwrite arguments y and discriminantToLabel
655 y.clear();
656 discriminantToLabel.clear();
657
658 LabelSet labelSet;
659 buildLabelSet(labelSet, labeledObservationVector);
660 LabelVector uniqueLabels(labelSet.begin(), labelSet.end());
661 if (labelSet.size() != 2) {
662 // throw an exception
663 cerr << "unexpected label set size " << labelSet.size() << endl;
664 for (LabelSet::const_iterator i = labelSet.begin(); i != labelSet.end(); i++) {
665 cerr << " label " << *i << endl;
666 }
667 throw SmoTrainerException("SmoTrainer::assignNumericLabels was passed more than 2 labels");
668 }
669 else {
670 LabelToNumericClassLabel labelToNumericClassLabel;
671 labelToNumericClassLabel[uniqueLabels[0]] = -1.0;
672 labelToNumericClassLabel[uniqueLabels[1]] = +1.0;
673 for ( LabeledObservationVector::const_iterator i = labeledObservationVector.begin(); i != labeledObservationVector.end(); i++ ) {
674 y.push_back( labelToNumericClassLabel[i->first] );
675 }
676 discriminantToLabel[-1.0] = uniqueLabels[0];
677 discriminantToLabel[+1.0] = uniqueLabels[1];
678 }
679 }
680
681 // the is a convenience function for getting parameter ranges for all kernels
getDefaultKernelParameterRangeMap(KernelParameterRangeMap & kernelParameterRangeMap)682 void getDefaultKernelParameterRangeMap(KernelParameterRangeMap& kernelParameterRangeMap) {
683 ParameterRangeMap linearParameterRangeMap;
684 linearParameterRangeMap[SmoTrainer::MapKey_C] = SmoTrainer::defaultCRange;
685 linearParameterRangeMap[LinearKernelFunction::MapKey_Constant] = LinearKernelFunction::defaultConstantRange;
686
687 ParameterRangeMap rbfParameterRangeMap;
688 rbfParameterRangeMap[SmoTrainer::MapKey_C] = SmoTrainer::defaultCRange;
689 rbfParameterRangeMap[RbfKernelFunction::MapKey_Gamma] = RbfKernelFunction::defaultGammaRange;
690
691 ParameterRangeMap polynomialParameterRangeMap;
692 polynomialParameterRangeMap[SmoTrainer::MapKey_C] = SmoTrainer::defaultCRange;
693 polynomialParameterRangeMap[PolynomialKernelFunction::MapKey_Constant] = PolynomialKernelFunction::defaultConstantRange;
694 polynomialParameterRangeMap[PolynomialKernelFunction::MapKey_Coefficient] = PolynomialKernelFunction::defaultCoefficientRange;
695 polynomialParameterRangeMap[PolynomialKernelFunction::MapKey_Degree] = PolynomialKernelFunction::defaultDegreeRange;
696
697 ParameterRangeMap sigmoidParameterRangeMap;
698 sigmoidParameterRangeMap[SmoTrainer::MapKey_C] = SmoTrainer::defaultCRange;
699 sigmoidParameterRangeMap[SigmoidKernelFunction::MapKey_Alpha] = SigmoidKernelFunction::defaultAlphaRange;
700 sigmoidParameterRangeMap[SigmoidKernelFunction::MapKey_Constant] = SigmoidKernelFunction::defaultConstantRange;
701
702 kernelParameterRangeMap[LinearKernelFunction::MapKey] = linearParameterRangeMap;
703 kernelParameterRangeMap[RbfKernelFunction::MapKey] = rbfParameterRangeMap;
704 kernelParameterRangeMap[PolynomialKernelFunction::MapKey] = polynomialParameterRangeMap;
705 kernelParameterRangeMap[SigmoidKernelFunction::MapKey] = sigmoidParameterRangeMap;
706 }
707
708
709 //
710 // OneVsOneMultiClassSvmTrainer
711 //
712 // An instance of OneVsOneMultiClassSvmTrainer is intended to work with a single set of data
713 // to produce a single instance of MultiClassSVM. That's why observations and labels go in to
714 // the constructor.
OneVsOneMultiClassSvmTrainer(SvmDataset & d,int e,int t,OutputFilter & of)715 OneVsOneMultiClassSvmTrainer::OneVsOneMultiClassSvmTrainer(SvmDataset& d, int e, int t, OutputFilter& of) :
716 svmDataset(d),
717 evaluationFoldCount(e),
718 trainFoldCount(t),
719 outputFilter(of) {
720 buildLabelSet(labelSet, svmDataset.getLabeledObservationVector());
721 buildLabelToLabeledObservationVector(labelToLabeledObservationVector, svmDataset.getLabeledObservationVector());
722 buildLabelPairSet(labelPairSet, svmDataset.getLabeledObservationVector());
723 }
724
buildLabelSet(LabelSet & labelSet,const LabeledObservationVector & labeledObservationVector)725 void buildLabelSet(LabelSet& labelSet, const LabeledObservationVector& labeledObservationVector) {
726 for (LabeledObservationVector::const_iterator i = labeledObservationVector.begin(); i != labeledObservationVector.end(); i++) {
727 labelSet.insert(i->first);
728 }
729 }
730
731
732 // This function uses the LabeledObservationVector argument to populate the LabelPairSet
733 // argument with pairs of labels. For example, if labeledObservationVector looks like this:
734 // [ ("blue", x), ("green", y), ("red", z) ]
735 // then the labelPairSet will be populated with the following label pairs:
736 // ("blue", "green"), ("blue", "red"), ("green", "red")
737 // The order of labels in the pairs is determined by the ordering of labels in the temporary
738 // LabelSet. By default this order will be ascending. However, labels are taken off the
739 // temporary labelStack in reverse order, so the labelStack is initialized with reverse iterators.
740 // In the end our label pairs will be in sorted order.
buildLabelPairSet(LabelPairSet & labelPairSet,const LabeledObservationVector & labeledObservationVector)741 void OneVsOneMultiClassSvmTrainer::buildLabelPairSet(LabelPairSet& labelPairSet, const LabeledObservationVector& labeledObservationVector) {
742
743 LabelSet labelSet;
744 buildLabelSet(labelSet, labeledObservationVector);
745 LabelVector labelStack(labelSet.rbegin(), labelSet.rend());
746 while (labelStack.size() > 1) {
747 Label label = labelStack.back();
748 labelStack.pop_back();
749 LabelPair labelPair(2);
750 labelPair[0] = label;
751 for (LabelVector::const_iterator i = labelStack.begin(); i != labelStack.end(); i++) {
752 labelPair[1] = *i;
753 labelPairSet.insert(
754 //make_pair(label, *i)
755 labelPair
756 );
757 }
758 }
759 }
760
761
762 // The LabelMatchesEither functor is used only in a call to remove_copy_if in the
763 // OneVsOneMultiClassSvmTrainer::train method. It returns true if the labeled
764 // observation argument has the same label as either of the two label arguments.
765 class LabelMatchesEither {
766 public:
LabelMatchesEither(const Label & _label0,const Label & _label1)767 LabelMatchesEither(const Label& _label0, const Label& _label1) : label0(_label0), label1(_label1) {}
768
operator ()(const LabeledObservation & o)769 bool operator() (const LabeledObservation& o) {
770 return !((o.first == label0) || (o.first == label1));
771 }
772
773 private:
774 const Label& label0;
775 const Label& label1;
776 };
777
train(const KernelParameterRangeMap & kernelParameterRangeMap)778 MultiClassSVM* OneVsOneMultiClassSvmTrainer::train(const KernelParameterRangeMap& kernelParameterRangeMap) {
779 double bestMultiClassSvmScore = 0.0;
780 MultiClassSVM* bestMc;
781
782 KernelFunctionFactory kernelFunctionFactory(svmDataset.getLabeledObservationVector());
783
784 // first divide the data into a 'development' set for tuning hyperparameters
785 // and an 'evaluation' set for measuring performance
786 int evaluationFoldNumber = 0;
787 KFoldLabeledObservationsDivider kFoldDevEvalDivider(evaluationFoldCount, svmDataset.getLabeledObservationVector());
788 for ( kFoldDevEvalDivider.start(); !kFoldDevEvalDivider.end(); kFoldDevEvalDivider.next() ) {
789 const LabeledObservationVector& developmentObservations = kFoldDevEvalDivider.getTrainingData();
790 const LabeledObservationVector& evaluationObservations = kFoldDevEvalDivider.getTestingData();
791
792 evaluationFoldNumber++;
793 if ( outputFilter.debug() ) {
794 m->mothurOut( "evaluation fold " + toString(evaluationFoldNumber) + " of " + toString(evaluationFoldCount) ); m->mothurOutEndLine();
795 }
796
797 vector<SVM*> twoClassSvmList;
798 SvmToSvmPerformanceSummary svmToSvmPerformanceSummary;
799 SmoTrainer smoTrainer(outputFilter);
800 LabelPairSet::iterator labelPair;
801 for (labelPair = labelPairSet.begin(); labelPair != labelPairSet.end(); labelPair++) {
802 // generate training and testing data for this label pair
803 Label label0 = (*labelPair)[0];
804 Label label1 = (*labelPair)[1];
805 if ( outputFilter.debug() ) {
806 m->mothurOut("training SVM on labels " + toString(label0) + " and " + toString(label1) ); m->mothurOutEndLine();
807 }
808
809 double bestMeanScoreOnKFolds = 0.0;
810 ParameterMap bestParameterMap;
811 string bestKernelFunctionKey;
812 LabeledObservationVector twoClassDevelopmentObservations;
813 LabelMatchesEither labelMatchesEither(label0, label1);
814 remove_copy_if(
815 developmentObservations.begin(),
816 developmentObservations.end(),
817 back_inserter(twoClassDevelopmentObservations),
818 labelMatchesEither
819 //[&](const LabeledObservation& o){
820 // return !((o.first == label0) || (o.first == label1));
821 //}
822 );
823 KFoldLabeledObservationsDivider kFoldLabeledObservationsDivider(trainFoldCount, twoClassDevelopmentObservations);
824 // loop on kernel functions and kernel function parameters
825 for ( KernelParameterRangeMap::const_iterator kmap = kernelParameterRangeMap.begin(); kmap != kernelParameterRangeMap.end(); kmap++ ) {
826 string kernelFunctionKey = kmap->first;
827 KernelFunction& kernelFunction = kernelFunctionFactory.getKernelFunctionForKey(kmap->first);
828 ParameterSetBuilder p(kmap->second);
829 for (ParameterMapVector::const_iterator hp = p.getParameterSetList().begin(); hp != p.getParameterSetList().end(); hp++) {
830 kernelFunction.setParameters(*hp);
831 KernelFunctionCache kernelFunctionCache(kernelFunction, svmDataset.getLabeledObservationVector());
832 smoTrainer.setParameters(*hp);
833 if (outputFilter.debug()) {
834 m->mothurOut( "parameters for " + toString(kernelFunctionKey) + " kernel" ); m->mothurOutEndLine();
835 for ( ParameterMap::const_iterator i = hp->begin(); i != hp->end(); i++ ) {
836 m->mothurOut( " " + toString(i->first) + ":" + toString(i->second) ); m->mothurOutEndLine();
837 }
838 }
839 double meanScoreOnKFolds = trainOnKFolds(smoTrainer, kernelFunctionCache, kFoldLabeledObservationsDivider);
840 if ( meanScoreOnKFolds > bestMeanScoreOnKFolds ) {
841 bestMeanScoreOnKFolds = meanScoreOnKFolds;
842 bestParameterMap = *hp;
843 bestKernelFunctionKey = kernelFunctionKey;
844 }
845 }
846 }
847 Utils util;
848 if ( util.isEqual(bestMeanScoreOnKFolds, 0.0) ) {
849 m->mothurOut( "failed to train SVM on labels " + toString(label0) + " and " + toString(label1) ); m->mothurOutEndLine();
850 throw exception();
851 }
852 else {
853 if ( outputFilter.debug() ) {
854 m->mothurOut( "trained SVM on labels " + label0 + " and " + label1 ); m->mothurOutEndLine();
855 m->mothurOut( " best mean score over " + toString(trainFoldCount) + " folds is " + toString(bestMeanScoreOnKFolds) ); m->mothurOutEndLine();
856 m->mothurOut( " best parameters for " + bestKernelFunctionKey + " kernel" ); m->mothurOutEndLine();
857 for ( ParameterMap::const_iterator p = bestParameterMap.begin(); p != bestParameterMap.end(); p++ ) {
858 m->mothurOut( " " + toString(p->first) + " : " + toString(p->second) ); m->mothurOutEndLine();
859 }
860 }
861
862 LabelMatchesEither labelMatchesEither(label0, label1);
863 LabeledObservationVector twoClassDevelopmentObservations;
864 remove_copy_if(
865 developmentObservations.begin(),
866 developmentObservations.end(),
867 back_inserter(twoClassDevelopmentObservations),
868 labelMatchesEither
869 //[&](const LabeledObservation& o){
870 // return !((o.first == label0) || (o.first == label1));
871 //}
872 );
873 if (outputFilter.info()) {
874 m->mothurOut( "training final SVM with " + toString(twoClassDevelopmentObservations.size()) + " labeled observations" ); m->mothurOutEndLine();
875 for ( ParameterMap::const_iterator i = bestParameterMap.begin(); i != bestParameterMap.end(); i++ ) {
876 m->mothurOut( " " + toString(i->first) + ":" + toString(i->second) ); m->mothurOutEndLine();
877 }
878 }
879
880 KernelFunction& kernelFunction = kernelFunctionFactory.getKernelFunctionForKey(bestKernelFunctionKey);
881 kernelFunction.setParameters(bestParameterMap);
882 smoTrainer.setParameters(bestParameterMap);
883 KernelFunctionCache kernelFunctionCache(kernelFunction, svmDataset.getLabeledObservationVector());
884 SVM* svm = smoTrainer.train(kernelFunctionCache, twoClassDevelopmentObservations);
885
886 twoClassSvmList.push_back(svm);
887 // return a performance summary using the evaluation dataset
888 LabeledObservationVector twoClassEvaluationObservations;
889 remove_copy_if(
890 evaluationObservations.begin(),
891 evaluationObservations.end(),
892 back_inserter(twoClassEvaluationObservations),
893 labelMatchesEither
894 );
895 SvmPerformanceSummary p(*svm, twoClassEvaluationObservations);
896 svmToSvmPerformanceSummary[svm->getLabelPair()] = p;
897 }
898 }
899
900 MultiClassSVM* mc = new MultiClassSVM(twoClassSvmList, labelSet, svmToSvmPerformanceSummary, outputFilter);
901 //double score = mc->score(evaluationObservations);
902 mc->setAccuracy(evaluationObservations);
903 if ( outputFilter.debug() ) {
904 m->mothurOut( "fold " + toString(evaluationFoldNumber) + " multiclass SVM score: " + toString(mc->getAccuracy()) ); m->mothurOutEndLine();
905 }
906 if ( mc->getAccuracy() > bestMultiClassSvmScore ) {
907 bestMc = mc;
908 bestMultiClassSvmScore = mc->getAccuracy();
909 }
910 else {
911 delete mc;
912 }
913 }
914
915 if ( outputFilter.info() ) {
916 m->mothurOut( "best multiclass SVM has score " + toString(bestMc->getAccuracy()) ); m->mothurOutEndLine();
917 }
918
919 return bestMc;
920 }
921
922 //SvmTrainingInterruptedException multiClassSvmTrainingInterruptedException("one-vs-one multiclass SVM training interrupted by user");
923
trainOnKFolds(SmoTrainer & smoTrainer,KernelFunctionCache & kernelFunction,KFoldLabeledObservationsDivider & kFoldLabeledObservationsDivider)924 double OneVsOneMultiClassSvmTrainer::trainOnKFolds(SmoTrainer& smoTrainer, KernelFunctionCache& kernelFunction, KFoldLabeledObservationsDivider& kFoldLabeledObservationsDivider) {
925 double meanScoreOverKFolds = 0.0;
926 double online_mean_n = 0.0;
927 double online_mean_score = 0.0;
928 meanScoreOverKFolds = -1.0; // means we failed to train a SVM
929
930 for ( kFoldLabeledObservationsDivider.start(); !kFoldLabeledObservationsDivider.end(); kFoldLabeledObservationsDivider.next() ) {
931 const LabeledObservationVector& kthTwoClassTrainingFold = kFoldLabeledObservationsDivider.getTrainingData();
932 const LabeledObservationVector& kthTwoClassTestingFold = kFoldLabeledObservationsDivider.getTestingData();
933 if (outputFilter.info()) {
934 m->mothurOut( "fold " + toString(kFoldLabeledObservationsDivider.getFoldNumber()) + " training data has " + toString(kthTwoClassTrainingFold.size()) + " labeled observations" ); m->mothurOutEndLine();
935 m->mothurOut( "fold " + toString(kFoldLabeledObservationsDivider.getFoldNumber()) + " testing data has " + toString(kthTwoClassTestingFold.size()) + " labeled observations" ); m->mothurOutEndLine();
936 }
937 if (m->getControl_pressed()) { return 0; }
938
939 else {
940 try {
941 if (outputFilter.debug()) m->mothurOut( "begin training" ); m->mothurOutEndLine();
942
943 SVM* evaluationSvm = smoTrainer.train(kernelFunction, kthTwoClassTrainingFold);
944 SvmPerformanceSummary svmPerformanceSummary(*evaluationSvm, kthTwoClassTestingFold);
945 double score = evaluationSvm->score(kthTwoClassTestingFold);
946 //double score = svmPerformanceSummary.getAccuracy();
947 if (outputFilter.debug()) {
948 m->mothurOut( "score on fold " + toString(kFoldLabeledObservationsDivider.getFoldNumber()) + " of test data is " + toString(score) ); m->mothurOutEndLine();
949 m->mothurOut( "positive label: " + toString(svmPerformanceSummary.getPositiveClassLabel()) ); m->mothurOutEndLine();
950 m->mothurOut( "negative label: " + toString(svmPerformanceSummary.getNegativeClassLabel()) ); m->mothurOutEndLine();
951 m->mothurOut( " precision: " + toString(svmPerformanceSummary.getPrecision())
952 + " recall: " + toString(svmPerformanceSummary.getRecall())
953 + " f: " + toString(svmPerformanceSummary.getF())
954 + " accuracy: " + toString(svmPerformanceSummary.getAccuracy())
955 ); m->mothurOutEndLine();
956 }
957 online_mean_n += 1.0;
958 double online_mean_delta = score - online_mean_score;
959 online_mean_score += online_mean_delta / online_mean_n;
960 meanScoreOverKFolds = online_mean_score;
961
962 delete evaluationSvm;
963 }
964 catch ( exception& e ) {
965 m->mothurOut( "exception: " + toString(e.what()) ); m->mothurOutEndLine();
966 m->mothurOut( " on fold " + toString(kFoldLabeledObservationsDivider.getFoldNumber()) + " failed to train SVM with C = " + toString(smoTrainer.getC()) ); m->mothurOutEndLine();
967 }
968 }
969 }
970 if (outputFilter.debug()) {
971 m->mothurOut( "done with cross validation on C = " + toString(smoTrainer.getC()) ); m->mothurOutEndLine();
972 m->mothurOut( " mean score over " + toString(kFoldLabeledObservationsDivider.getFoldNumber()) + " folds is " + toString(meanScoreOverKFolds) ); m->mothurOutEndLine();
973 }
974 Utils util;
975 if ( util.isEqual(meanScoreOverKFolds, 0.0) ) { m->mothurOut( "failed to train SVM with C = " + toString(smoTrainer.getC()) + "\n"); }
976 return meanScoreOverKFolds;
977 }
978
979
980 class UnrankedFeature {
981 public:
UnrankedFeature(const Feature & f)982 UnrankedFeature(const Feature& f) : feature(f), rankingCriterion(0.0) {}
~UnrankedFeature()983 ~UnrankedFeature() {}
984
getFeature() const985 Feature getFeature() const { return feature; }
986
getRankingCriterion() const987 double getRankingCriterion() const { return rankingCriterion; }
setRankingCriterion(double rc)988 void setRankingCriterion(double rc) { rankingCriterion = rc; }
989
990 private:
991 Feature feature;
992 double rankingCriterion;
993 };
994
lessThanRankingCriterion(const UnrankedFeature & a,const UnrankedFeature & b)995 bool lessThanRankingCriterion(const UnrankedFeature& a, const UnrankedFeature& b) {
996 return a.getRankingCriterion() < b.getRankingCriterion();
997 }
998
lessThanFeatureIndex(const UnrankedFeature & a,const UnrankedFeature & b)999 bool lessThanFeatureIndex(const UnrankedFeature& a, const UnrankedFeature& b) {
1000 return a.getFeature().getFeatureIndex() < b.getFeature().getFeatureIndex();
1001 }
1002
1003 typedef list<UnrankedFeature> UnrankedFeatureList;
1004
1005
1006 // Only the linear svm can be used here.
1007 // Consider allowing only parameter ranges as arguments.
1008 // Right now any kernel can be sent in.
1009 // It would be useful to remove more than one feature at a time
1010 // Might make sense to turn last two arguments into one
getOrderedFeatureList(SvmDataset & svmDataset,OneVsOneMultiClassSvmTrainer & t,const ParameterRange & linearKernelConstantRange,const ParameterRange & smoTrainerParameterRange)1011 RankedFeatureList SvmRfe::getOrderedFeatureList(SvmDataset& svmDataset, OneVsOneMultiClassSvmTrainer& t, const ParameterRange& linearKernelConstantRange, const ParameterRange& smoTrainerParameterRange) {
1012
1013 KernelParameterRangeMap rfeKernelParameterRangeMap;
1014 ParameterRangeMap linearParameterRangeMap;
1015 linearParameterRangeMap[SmoTrainer::MapKey_C] = smoTrainerParameterRange;
1016 linearParameterRangeMap[LinearKernelFunction::MapKey_Constant] = linearKernelConstantRange;
1017
1018 rfeKernelParameterRangeMap[LinearKernelFunction::MapKey] = linearParameterRangeMap;
1019
1020 // the rankedFeatureList is empty at first
1021 RankedFeatureList rankedFeatureList;
1022 // loop until all but one feature have been eliminated
1023 // no need to eliminate the last feature, after all
1024 int svmRfeRound = 0;
1025 //while ( rankedFeatureList.size() < (svmDataset.getFeatureVector().size()-1) ) {
1026 while ( svmDataset.getFeatureVector().size() > 1 ) {
1027 svmRfeRound++;
1028 m->mothurOut( "SVM-RFE round " + toString(svmRfeRound) + ":" ); m->mothurOutEndLine();
1029 UnrankedFeatureList unrankedFeatureList;
1030 for (int featureIndex = 0; featureIndex < svmDataset.getFeatureVector().size(); featureIndex++) {
1031 Feature f = svmDataset.getFeatureVector().at(featureIndex);
1032 unrankedFeatureList.push_back(UnrankedFeature(f));
1033 }
1034 m->mothurOut( toString(unrankedFeatureList.size()) + " unranked features" ); m->mothurOutEndLine();
1035
1036 MultiClassSVM* s = t.train(rfeKernelParameterRangeMap);
1037 m->mothurOut( "multiclass SVM accuracy: " + toString(s->getAccuracy()) ); m->mothurOutEndLine();
1038
1039 m->mothurOut( "two-class SVM performance" ); m->mothurOutEndLine();
1040
1041 m->mothurOut("class 1\tclass 2\tprecision\trecall\f\accuracy\n");
1042 for ( SvmVector::const_iterator svm = s->getSvmList().begin(); svm != s->getSvmList().end(); svm++ ) {
1043 SvmPerformanceSummary sps = s->getSvmPerformanceSummary(**svm);
1044 m->mothurOut(toString(sps.getPositiveClassLabel())
1045 + toString(sps.getNegativeClassLabel())
1046 + toString(sps.getPrecision())
1047 + toString(sps.getRecall())
1048 + toString(sps.getF())
1049 + toString(sps.getAccuracy()) ); m->mothurOutEndLine();
1050 }
1051 // calculate the 'ranking criterion' for each (remaining) feature using each binary svm
1052 for (UnrankedFeatureList::iterator f = unrankedFeatureList.begin(); f != unrankedFeatureList.end(); f++) {
1053 const int i = f->getFeature().getFeatureIndex();
1054 // rankingCriterion combines feature weights for feature i in all svms
1055 double rankingCriterion = 0.0;
1056 for ( SvmVector::const_iterator svm = s->getSvmList().begin(); svm != s->getSvmList().end(); svm++ ) {
1057 // output SVM performance summary
1058 // calculate the weight w of feature i for this svm
1059 double wi = 0.0;
1060 for (int j = 0; j < (*svm)->x.size(); j++) {
1061 // all support vectors contribute to wi
1062 wi += (*svm)->a.at(j) * (*svm)->y.at(j) * (*svm)->x.at(j).second->at(i);
1063 }
1064 // accumulate weights for feature i from all svms
1065 rankingCriterion += pow(wi, 2);
1066 }
1067 // update the (unranked) feature ranking criterion
1068 f->setRankingCriterion(rankingCriterion);
1069 }
1070 delete s;
1071
1072 // sort the unranked features by ranking criterion
1073 unrankedFeatureList.sort(lessThanRankingCriterion);
1074
1075 // eliminate the bottom 1/(n+1) features - this is very slow but gives good results
1076 ////int eliminateFeatureCount = ceil(unrankedFeatureList.size() / (iterationCount+1.0));
1077 // eliminate the bottom 1/3 features - fast but results slightly different from above
1078 // how about 1/4?
1079 int eliminateFeatureCount = ceil(unrankedFeatureList.size() / 4.0);
1080 m->mothurOut( "eliminating " + toString(eliminateFeatureCount) + " feature(s) of " + toString(unrankedFeatureList.size()) + " total features\n");
1081 m->mothurOutEndLine();
1082 UnrankedFeatureList featuresToEliminate;
1083 for ( int i = 0; i < eliminateFeatureCount; i++ ) {
1084 // remove the lowest ranked feature(s) from the list of unranked features
1085 UnrankedFeature unrankedFeature = unrankedFeatureList.front();
1086 unrankedFeatureList.pop_front();
1087
1088 featuresToEliminate.push_back(unrankedFeature);
1089 // put the lowest ranked feature at the front of the list of ranked features
1090 // the first feature to be eliminated will be at the back of this list
1091 // the last feature to be eliminated will be at the front of this list
1092 rankedFeatureList.push_front(RankedFeature(unrankedFeature.getFeature(), svmRfeRound));
1093 }
1094
1095 featuresToEliminate.sort(lessThanFeatureIndex);
1096 reverse(featuresToEliminate.begin(), featuresToEliminate.end());
1097 for (UnrankedFeatureList::iterator g = featuresToEliminate.begin(); g != featuresToEliminate.end(); g++) {
1098 Feature unrankedFeature = g->getFeature();
1099 removeFeature(unrankedFeature, svmDataset.getLabeledObservationVector(), svmDataset.getFeatureVector());
1100 }
1101
1102 }
1103
1104 // there may be one feature left
1105 svmRfeRound++;
1106
1107 for ( FeatureVector::iterator f = svmDataset.getFeatureVector().begin(); f != svmDataset.getFeatureVector().end(); f++ ) {
1108 rankedFeatureList.push_front(RankedFeature(*f, svmRfeRound));
1109 }
1110
1111 return rankedFeatureList;
1112 }
1113
1114