1 //
2 //  svm.cpp
3 //  support vector machine
4 //
5 //  Created by Joshua Lynch on 6/19/2013.
6 //  Copyright (c) 2013 Schloss Lab. All rights reserved.
7 //
8 #include <algorithm>
9 #include <functional>
10 #include <iomanip>
11 #include <iostream>
12 #include <limits>
13 #include <numeric>
14 #include <stack>
15 #include <utility>
16 
17 #include "svm.hpp"
18 
19 // OutputFilter constants
20 const int OutputFilter::QUIET  = 0;
21 const int OutputFilter::INFO   = 1;
22 const int OutputFilter::mDEBUG  = 2;
23 const int OutputFilter::TRACE  = 3;
24 
25 
26 #define RANGE(X) X, X + sizeof(X)/sizeof(double)
27 
28 // parameters will be tested in the order they are specified
29 
30 const string LinearKernelFunction::MapKey                      = "linear";//"LinearKernel";
31 const string LinearKernelFunction::MapKey_Constant             = "constant";//"LinearKernel_Constant";
32 const double defaultLinearConstantRangeArray[]                      = {0.0, -1.0, 1.0, -10.0, 10.0};
33 const ParameterRange LinearKernelFunction::defaultConstantRange     = ParameterRange(RANGE(defaultLinearConstantRangeArray));
34 
35 const string RbfKernelFunction::MapKey                         = "rbf";//"RbfKernel";
36 const string RbfKernelFunction::MapKey_Gamma                   = "gamma";//"RbfKernel_Gamma";
37 const double defaultRbfGammaRangeArray[]                            = {0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0};
38 const ParameterRange RbfKernelFunction::defaultGammaRange           = ParameterRange(RANGE(defaultRbfGammaRangeArray));
39 
40 const string PolynomialKernelFunction::MapKey                  = "polynomial";//"PolynomialKernel";
41 const string PolynomialKernelFunction::MapKey_Constant         = "constant";//"PolynomialKernel_Constant";
42 const string PolynomialKernelFunction::MapKey_Coefficient      = "coefficient";//"PolynomialKernel_Coefficient";
43 const string PolynomialKernelFunction::MapKey_Degree           = "degree";//"PolynomialKernel_Degree";
44 
45 const double defaultPolynomialConstantRangeArray[]                     = {0.0, -1.0, 1.0, -2.0, 2.0, -3.0, 3.0};
46 const ParameterRange PolynomialKernelFunction::defaultConstantRange    = ParameterRange(RANGE(defaultPolynomialConstantRangeArray));
47 const double defaultPolynomialCoefficientRangeArray[]                  = {0.01, 0.1, 1.0, 10.0, 100.0};
48 const ParameterRange PolynomialKernelFunction::defaultCoefficientRange = ParameterRange(RANGE(defaultPolynomialCoefficientRangeArray));
49 const double defaultPolynomialDegreeRangeArray[]                       = {2.0, 3.0, 4.0};
50 const ParameterRange PolynomialKernelFunction::defaultDegreeRange      = ParameterRange(RANGE(defaultPolynomialDegreeRangeArray));
51 
52 const string SigmoidKernelFunction::MapKey                     = "sigmoid";
53 const string SigmoidKernelFunction::MapKey_Alpha               = "alpha";
54 const string SigmoidKernelFunction::MapKey_Constant            = "constant";
55 
56 const double defaultSigmoidAlphaRangeArray[]                        = {1.0, 2.0};
57 const ParameterRange SigmoidKernelFunction::defaultAlphaRange       = ParameterRange(RANGE(defaultSigmoidAlphaRangeArray));
58 const double defaultSigmoidConstantRangeArray[]                     = {1.0, 2.0};
59 const ParameterRange SigmoidKernelFunction::defaultConstantRange    = ParameterRange(RANGE(defaultSigmoidConstantRangeArray));
60 
61 const string SmoTrainer::MapKey_C                              = "smoc";//"SmoTrainer_C";
62 const double defaultSmoTrainerCRangeArray[]                         = {0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0};
63 const ParameterRange SmoTrainer::defaultCRange                      = ParameterRange(RANGE(defaultSmoTrainerCRangeArray));
64 
65 MothurOut* m = MothurOut::getInstance();
66 
buildLabelPair(const Label & one,const Label & two)67 LabelPair buildLabelPair(const Label& one, const Label& two) {
68     LabelVector labelPair(2);
69     labelPair[0] = one;
70     labelPair[1] = two;
71     return labelPair;
72 }
73 
74 // Dividing a dataset into training and testing sets while maintaining equal
75 // representation of all classes is done using a LabelToLabeledObservationVector.
76 // This container is used to divide datasets into groups of LabeledObservations
77 // having the same label.  For example, given a LabeledObservationVector like
78 //     ["blue",  [1.0, 2.0, 3.0]]
79 //     ["green", [3.0, 4.0, 5.0]]
80 //     ["blue",  [2,0, 3.0. 4.0]]
81 //     ["green", [4.0, 5.0, 6.0]]
82 // the corresponding LabelToLabeledObservationVector looks like
83 //     "blue"  : [["blue",  [1.0, 2.0, 3.0]], ["blue",  [2,0, 3.0. 4.0]]]
84 //     "green" : [["green", [3.0, 4.0, 5.0]], ["green", [4.0, 5.0, 6.0]]]
buildLabelToLabeledObservationVector(LabelToLabeledObservationVector & labelToLabeledObservationVector,const LabeledObservationVector & labeledObservationVector)85 void buildLabelToLabeledObservationVector(LabelToLabeledObservationVector& labelToLabeledObservationVector, const LabeledObservationVector& labeledObservationVector) {
86     for ( LabeledObservationVector::const_iterator j = labeledObservationVector.begin(); j != labeledObservationVector.end(); j++ ) {
87         labelToLabeledObservationVector[j->first].push_back(*j);
88     }
89 }
90 
91 
92 class MeanAndStd {
93 private:
94     double n;
95     double M2;
96     double mean;
97 
98 public:
MeanAndStd()99     MeanAndStd() {}
~MeanAndStd()100     ~MeanAndStd() {}
101 
initialize()102     void initialize() {
103         n = 0.0;
104         mean = 0.0;
105         M2 = 0.0;
106     }
107 
processNextValue(double x)108     void processNextValue(double x) {
109         n += 1.0;
110         double delta = x - mean;
111         mean += delta / n;
112         M2 += delta * (x - mean);
113     }
114 
getMean()115     double getMean() {
116         return mean;
117     }
118 
getStd()119     double getStd() {
120         double variance = M2 / (n - 1.0);
121         return sqrt(variance);
122     }
123 };
124 
125 
126 // The LabelMatchesEither functor is used only in a call to remove_copy_if in the
127 // OneVsOneMultiClassSvmTrainer::train method.  It returns true if the labeled
128 // observation argument has the same label as either of the two label arguments.
129 class FeatureLabelMatches {
130 public:
FeatureLabelMatches(const string & _featureLabel)131     FeatureLabelMatches(const string& _featureLabel) : featureLabel(_featureLabel){}
132 
operator ()(const Feature & f)133     bool operator() (const Feature& f) {
134         return f.getFeatureLabel() == featureLabel;
135     }
136 
137 private:
138     const string& featureLabel;
139 
140 };
141 
removeFeature(Feature featureToRemove,LabeledObservationVector & observations,FeatureVector & featureVector)142 Feature removeFeature(Feature featureToRemove, LabeledObservationVector& observations, FeatureVector& featureVector) {
143     FeatureLabelMatches matchFeatureLabel(featureToRemove.getFeatureLabel());
144     featureVector.erase(
145             remove_if(featureVector.begin(), featureVector.end(), matchFeatureLabel),
146             featureVector.end()
147     );
148     for ( ObservationVector::size_type observation = 0; observation < observations.size(); observation++ ) {
149         observations[observation].removeFeatureAtIndex(featureToRemove.getFeatureIndex());
150     }
151     // update the feature indices
152     for ( int i = 0; i < featureVector.size(); i++ ) {
153         featureVector.at(i).setFeatureIndex(i);
154     }
155     featureToRemove.setFeatureIndex(-1);
156     return featureToRemove;
157 }
158 
applyStdThreshold(double stdThreshold,LabeledObservationVector & observations,FeatureVector & featureVector)159 FeatureVector applyStdThreshold(double stdThreshold, LabeledObservationVector& observations, FeatureVector& featureVector) {
160     // calculate standard deviation of each feature
161     // remove features with standard deviation less than or equal to stdThreshold
162     MeanAndStd ms;
163     // loop over features in reverse order so we can get the index of each
164     // for example,
165     //     if there are 5 features a,b,c,d,e
166     //     and features a, c, e fall below the stdThreshold
167     //     loop iteration 0: remove feature e (index 4) -- features are now a,b,c,d
168     //     loop iteration 1: leave feature d (index 3)
169     //     loop iteration 2: remove feature c (index 2) -- features are now a,b,d
170     //     loop iteration 3: leave feature b (index 1)
171     //     loop iteration 4: remove feature a (index 0) -- features are now b,d
172     FeatureVector removedFeatureVector;
173     for ( int feature = observations[0].second->size()-1; feature >= 0 ; feature-- ) {
174         ms.initialize();
175         m->mothurOut("feature index " + toString(feature)); m->mothurOutEndLine();
176         for ( ObservationVector::size_type observation = 0; observation < observations.size(); observation++ ) {
177             ms.processNextValue(observations[observation].second->at(feature));
178         }
179         m->mothurOut(  "feature " + toString(feature) + " has std " + toString(ms.getStd()) ); m->mothurOutEndLine();
180         if ( ms.getStd() <= stdThreshold ) {
181             m->mothurOut( "removing feature with index " + toString(feature) ); m->mothurOutEndLine();
182             // remove this feature
183 
184             Feature featureToRemove = featureVector.at(feature);
185             removedFeatureVector.push_back(
186                 removeFeature(featureToRemove, observations, featureVector)
187             );
188         }
189     }
190     reverse(removedFeatureVector.begin(), removedFeatureVector.end());
191     return removedFeatureVector;
192 }
193 
194 
195 // this function standardizes data to mean 0 and variance 1
196 // but this may not be a good standardization for OTU data
transformZeroMeanUnitVariance(LabeledObservationVector & observations)197 void transformZeroMeanUnitVariance(LabeledObservationVector& observations) {
198     bool vebose = false;
199     // online method for mean and variance
200     MeanAndStd ms;
201     for ( Observation::size_type feature = 0; feature < observations[0].second->size(); feature++ ) {
202         ms.initialize();
203         //double n = 0.0;
204         //double mean = 0.0;
205         //double M2 = 0.0;
206         for ( ObservationVector::size_type observation = 0; observation < observations.size(); observation++ ) {
207             ms.processNextValue(observations[observation].second->at(feature));
208             //n += 1.0;
209             //double x = observations[observation].second->at(feature);
210             //double delta = x - mean;
211             //mean += delta / n;
212             //M2 += delta * (x - mean);
213         }
214         //double variance = M2 / (n - 1.0);
215         //double standardDeviation = sqrt(variance);
216         if (vebose) {
217             m->mothurOut( "mean of feature " + toString(feature) + " is " + toString(ms.getMean()) ); m->mothurOutEndLine();
218             m->mothurOut( "std of feature " + toString(feature) + " is " + toString(ms.getStd()) ); m->mothurOutEndLine();
219         }
220         // normalize the feature
221         double mean = ms.getMean();
222         double std = ms.getStd();
223         for ( ObservationVector::size_type observation = 0; observation < observations.size(); observation++ ) {
224             observations[observation].second->at(feature) = (observations[observation].second->at(feature) - mean ) / std;
225         }
226     }
227 }
228 
229 
getMinimumFeatureValueForObservation(Observation::size_type featureIndex,LabeledObservationVector & observations)230 double getMinimumFeatureValueForObservation(Observation::size_type featureIndex, LabeledObservationVector& observations) {
231     double featureMinimum = numeric_limits<double>::max();
232     for ( ObservationVector::size_type observation = 0; observation < observations.size(); observation++ ) {
233         if ( observations[observation].second->at(featureIndex) < featureMinimum ) {
234             featureMinimum = observations[observation].second->at(featureIndex);
235         }
236     }
237     return featureMinimum;
238 }
239 
240 
getMaximumFeatureValueForObservation(Observation::size_type featureIndex,LabeledObservationVector & observations)241 double getMaximumFeatureValueForObservation(Observation::size_type featureIndex, LabeledObservationVector& observations) {
242     double featureMaximum = numeric_limits<double>::min();
243     for ( ObservationVector::size_type observation = 0; observation < observations.size(); observation++ ) {
244         if ( observations[observation].second->at(featureIndex) > featureMaximum ) {
245             featureMaximum = observations[observation].second->at(featureIndex);
246         }
247     }
248     return featureMaximum;
249 }
250 
251 
252 // this function standardizes data to minimum value 0.0 and maximum value 1.0
transformZeroOne(LabeledObservationVector & observations)253 void transformZeroOne(LabeledObservationVector& observations) {
254     for ( Observation::size_type feature = 0; feature < observations[0].second->size(); feature++ ) {
255         double featureMinimum = getMinimumFeatureValueForObservation(feature, observations);
256         double featureMaximum = getMaximumFeatureValueForObservation(feature, observations);
257         // standardize the feature
258         for ( ObservationVector::size_type observation = 0; observation < observations.size(); observation++ ) {
259             double x = observations[observation].second->at(feature);
260             double xstd = (x - featureMinimum) / (featureMaximum - featureMinimum);
261             observations[observation].second->at(feature) = xstd / (1.0 - 0.0) + 0.0;
262         }
263     }
264 }
265 
266 
267 //
268 // SVM member functions
269 //
270 // the discriminant member function returns +1 or -1
discriminant(const Observation & observation) const271 int SVM::discriminant(const Observation& observation) const {
272     // d is the discriminant function
273     double d = b;
274     for ( int i = 0; i < y.size(); i++ ) {
275         d += y[i]*a[i]*inner_product(observation.begin(), observation.end(), x[i].second->begin(), 0.0);
276     }
277     return d > 0.0 ? 1 : -1;
278 }
279 
classify(const LabeledObservationVector & twoClassLabeledObservationVector) const280 LabelVector SVM::classify(const LabeledObservationVector& twoClassLabeledObservationVector) const {
281     LabelVector predictionVector;
282     for ( LabeledObservationVector::const_iterator i =  twoClassLabeledObservationVector.begin(); i != twoClassLabeledObservationVector.end(); i++ ) {
283         Label prediction = classify(*(i->getObservation()));
284         Label actual = i->getLabel();
285 
286         predictionVector.push_back(prediction);
287     }
288     return predictionVector;
289 }
290 
291 // the score member function classifies each labeled observation from the
292 // argument and returns the fraction of correct classifications
293 // don't need this any more????
score(const LabeledObservationVector & twoClassLabeledObservationVector) const294 double SVM::score(const LabeledObservationVector& twoClassLabeledObservationVector) const {
295 
296     double s = 0.0;
297     for ( LabeledObservationVector::const_iterator i = twoClassLabeledObservationVector.begin(); i != twoClassLabeledObservationVector.end(); i++ ) {
298         Label predicted_label = classify(*(i->second));
299 
300         if ( predicted_label == i->first ) {
301             s = s + 1.0;
302         }
303         else {
304 
305         }
306     }
307     return s / double(twoClassLabeledObservationVector.size());
308 }
309 
init(const SVM & svm,const LabeledObservationVector & actualLabels,const LabelVector & predictedLabels)310 void SvmPerformanceSummary::init(const SVM& svm, const LabeledObservationVector& actualLabels, const LabelVector& predictedLabels) {
311     // accumulate four counts:
312     //     tp (true positive)  -- correct classifications (classified +1 as +1)
313     //     fp (false positive) -- incorrect classifications (classified -1 as +1)
314     //     fn (false negative) -- incorrect classifications (classified +1 as -1)
315     //     tn (true negative)  -- correct classification (classified -1 as -1)
316     // the label corresponding to discriminant +1 will be the 'positive' class
317     NumericClassToLabel discriminantToLabel = svm.getDiscriminantToLabel();
318     positiveClassLabel = discriminantToLabel[1];
319     negativeClassLabel = discriminantToLabel[-1];
320 
321     double tp = 0;
322     double fp = 0;
323     double fn = 0;
324     double tn = 0;
325     for (int i = 0; i < actualLabels.size(); i++) {
326         Label predictedLabel = predictedLabels.at(i);
327         Label actualLabel = actualLabels.at(i).getLabel();
328 
329         if ( actualLabel.compare(positiveClassLabel) == 0) {
330             if ( predictedLabel.compare(positiveClassLabel) == 0 ) {
331                 tp++;
332             }
333             else if ( predictedLabel.compare(negativeClassLabel) == 0 ) {
334                 fn++;
335             }
336             else {
337                 m->mothurOut( "actual label is positive but something is wrong" ); m->mothurOutEndLine();
338             }
339         }
340         else if ( actualLabel.compare(negativeClassLabel) == 0 ) {
341             if ( predictedLabel.compare(positiveClassLabel) == 0 ) {
342                 fp++;
343             }
344             else if ( predictedLabel.compare(negativeClassLabel) == 0 ) {
345                 tn++;
346             }
347             else {
348                 m->mothurOut( "actual label is negative but something is wrong" ); m->mothurOutEndLine();
349             }
350         }
351         else {
352             // in the event we have been given an observation that is labeled
353             // neither positive nor negative then we will get a false classification
354 
355             if ( predictedLabel.compare(positiveClassLabel) ) {
356                 fp++;
357             }
358             else {
359                 fn++;
360             }
361         }
362     }
363     Utils util;
364     if (util.isEqual(tp, 0) && util.isEqual(fp, 0) ) {
365         precision = 0;
366     }
367     else {
368         precision = tp / (tp + fp);
369     }
370     recall = tp / (tp + fn);
371     if ( util.isEqual(precision, 0) && util.isEqual(recall, 0) ) {
372         f = 0;
373     }
374     else {
375         f = 2.0 * (precision * recall) / (precision + recall);
376     }
377     accuracy = (tp + tn) / (tp + tn + fp + fn);
378 }
379 
380 
MultiClassSVM(const vector<SVM * > s,const LabelSet & l,const SvmToSvmPerformanceSummary & p,OutputFilter of)381 MultiClassSVM::MultiClassSVM(const vector<SVM*> s, const LabelSet& l, const SvmToSvmPerformanceSummary& p, OutputFilter of) : twoClassSvmList(s.begin(), s.end()), labelSet(l), svmToSvmPerformanceSummary(p), outputFilter(of), accuracy(0) {}
382 
383 
~MultiClassSVM()384 MultiClassSVM::~MultiClassSVM() {
385     for ( int i = 0; i < twoClassSvmList.size(); i++ ) {
386         delete twoClassSvmList[i];
387     }
388 }
389 
390 // The fewerVotes function is used to find the maximum vote
391 // tally in MultiClassSVM::classify.  This function returns true
392 // if the first element (number of votes for the first label) is
393 // less than the second element (number of votes for the second label).
fewerVotes(const pair<Label,int> & p,const pair<Label,int> & q)394 bool fewerVotes(const pair<Label, int>& p, const pair<Label, int>& q) {
395     return p.second < q.second;
396 }
397 
398 
classify(const Observation & observation)399 Label MultiClassSVM::classify(const Observation& observation) {
400     map<Label, int> labelToVoteCount;
401     for ( int i = 0; i < twoClassSvmList.size(); i++ ) {
402         Label predictedLabel = twoClassSvmList[i]->classify(observation);
403         labelToVoteCount[predictedLabel]++;
404     }
405     pair<Label, int> winner = *max_element(labelToVoteCount.begin(), labelToVoteCount.end(), fewerVotes);
406     LabelVector winningLabels;
407     winningLabels.push_back(winner.first);
408     for ( map<Label, int>::const_iterator i = labelToVoteCount.begin(); i != labelToVoteCount.end(); i++ ) {
409         if ( i->second == winner.second && i->first != winner.first ) {
410             winningLabels.push_back(i->first);
411         }
412     }
413     if ( winningLabels.size() == 1) {
414         // we have a winner
415     }
416     else {
417         // we have a tie
418         throw MultiClassSvmClassificationTie(winningLabels, winner.second);
419     }
420 
421     return winner.first;
422 }
423 
score(const LabeledObservationVector & multiClassLabeledObservationVector)424 double MultiClassSVM::score(const LabeledObservationVector& multiClassLabeledObservationVector) {
425     double s = 0.0;
426     for (LabeledObservationVector::const_iterator i = multiClassLabeledObservationVector.begin(); i != multiClassLabeledObservationVector.end(); i++) {
427 
428         try {
429             Label predicted_label = classify(*(i->second));
430             if ( predicted_label == i->first ) {
431                 s = s + 1.0;
432             }
433             else {
434                 // predicted label does not match actual label
435             }
436         }
437         catch ( MultiClassSvmClassificationTie& e ) {
438             if ( outputFilter.debug() ) {
439                 m->mothurOut( "classification tie for observation " + toString(i->datasetIndex) + " with label " + toString(i->first) ); m->mothurOutEndLine();
440             }
441         }
442     }
443     return s / double(multiClassLabeledObservationVector.size());
444 }
445 
446 class MaxIterationsExceeded : public exception {
what() const447     virtual const char* what() const throw() {
448         return "maximum iterations exceeded during SMO";
449     }
450 } maxIterationsExceeded;
451 
452 
453 //SvmTrainingInterruptedException smoTrainingInterruptedException("SMO training interrupted by user");
454 
455 //  The train method implements Sequential Minimal Optimization as described in
456 //  "Support Vector Machine Solvers" by Bottou and Lin.
457 //
458 //  SmoTrainer::train releases a pointer to an SVM into the wild so we must be
459 //  careful about handling the LabeledObservationVector....  Must create a copy
460 //  of those labeled vectors???
train(KernelFunctionCache & K,const LabeledObservationVector & twoClassLabeledObservationVector)461 SVM* SmoTrainer::train(KernelFunctionCache& K, const LabeledObservationVector& twoClassLabeledObservationVector) {
462     const int observationCount = twoClassLabeledObservationVector.size();
463     const int featureCount = twoClassLabeledObservationVector[0].second->size();
464 
465     if (outputFilter.debug()) m->mothurOut( "observation count : " + toString(observationCount) ); m->mothurOutEndLine();
466     if (outputFilter.debug()) m->mothurOut( "feature count     : " + toString(featureCount) ); m->mothurOutEndLine();
467     // dual coefficients
468     vector<double> a(observationCount, 0.0);
469     // gradient
470     vector<double> g(observationCount, 1.0);
471     // convert the labels to -1.0,+1.0
472     vector<double> y(observationCount);
473     if (outputFilter.trace()) m->mothurOut( "assign numeric labels" ); m->mothurOutEndLine();
474     NumericClassToLabel discriminantToLabel;
475     assignNumericLabels(y, twoClassLabeledObservationVector, discriminantToLabel);
476     if (outputFilter.trace()) m->mothurOut( "assign A and B" ); m->mothurOutEndLine();
477     vector<double> A(observationCount);
478     vector<double> B(observationCount);
479     Utils util;
480     for ( int n = 0; n < observationCount; n++ ) {
481         if ( util.isEqual(y[n], +1.0)) {
482             A[n] = 0.0;
483             B[n] = C;
484         }
485         else {
486             A[n] = -C;
487             B[n] = 0;
488         }
489         if (outputFilter.trace()) m->mothurOut( toString(n) + " " + toString(A[n]) + " " + toString(B[n]) ); m->mothurOutEndLine();
490     }
491     if (outputFilter.trace()) m->mothurOut( "assign K" ); m->mothurOutEndLine();
492     int m_count = 0;
493     vector<double> u(3);
494     vector<double> ya(observationCount);
495     vector<double> yg(observationCount);
496     double lambda = numeric_limits<double>::max();
497     while ( true ) {
498 
499         if (m->getControl_pressed()) { return 0; }
500 
501         m_count++;
502         int i = 0; // 0
503         int j = 0; // 0
504         double yg_max = numeric_limits<double>::min();
505         double yg_min = numeric_limits<double>::max();
506         if (outputFilter.trace()) m->mothurOut( "m = " + toString(m_count) ); m->mothurOutEndLine();
507         for ( int k = 0; k < observationCount; k++ ) {
508             ya[k] = y[k] * a[k];
509             yg[k] = y[k] * g[k];
510         }
511         if (outputFilter.trace()) {
512             m->mothurOut( "yg =");
513             for ( int k = 0; k < observationCount; k++ ) {
514 
515                 m->mothurOut( " " + toString(yg[k]));
516             }
517             m->mothurOutEndLine();
518         }
519 
520         for ( int k = 0; k < observationCount; k++ ) {
521             if ( ya[k] < B[k] && yg[k] > yg_max ) {
522                 yg_max = yg[k];
523                 i = k;
524             }
525             if ( A[k] < ya[k] && yg[k] < yg_min ) {
526                 yg_min = yg[k];
527                 j = k;
528             }
529 
530         }
531         // maximum violating pair is i,j
532         if (outputFilter.trace()) {
533             m->mothurOut( "maximal violating pair: " + toString(i) + " " + toString(j) ); m->mothurOutEndLine();
534             m->mothurOut( "  i = " + toString(i) + " features: ");
535             for ( int feature = 0; feature < featureCount; feature++ ) {
536                 m->mothurOut( toString(twoClassLabeledObservationVector[i].second->at(feature)) + " ");
537             };
538             m->mothurOutEndLine();
539             m->mothurOut( "  j = " + toString(j) + " features: ");
540             for ( int feature = 0; feature < featureCount; feature++ ) {
541                 m->mothurOut( toString(twoClassLabeledObservationVector[j].second->at(feature)) + " ");
542             };
543             m->mothurOutEndLine();
544         }
545 
546         // parameterize this
547         if ( m_count > 1000 ) { //1000
548             // what happens if we just go with what we've got instead of throwing an exception?
549             // things work pretty well for the most part
550             // might be better to look at lambda???
551             if (outputFilter.debug()) m->mothurOut( "iteration limit reached with lambda = " + toString(lambda) ); m->mothurOutEndLine();
552             break;
553         }
554 
555         // using lambda to break is a good performance enhancement
556         if ( yg[i] <= yg[j] or lambda < 0.0001) {
557             break;
558         }
559         u[0] = B[i] - ya[i];
560         u[1] = ya[j] - A[j];
561 
562         double K_ii = K.similarity(twoClassLabeledObservationVector[i], twoClassLabeledObservationVector[i]);
563         double K_jj = K.similarity(twoClassLabeledObservationVector[j], twoClassLabeledObservationVector[j]);
564         double K_ij = K.similarity(twoClassLabeledObservationVector[i], twoClassLabeledObservationVector[j]);
565         u[2] = (yg[i] - yg[j]) / (K_ii+K_jj-2.0*K_ij);
566         if (outputFilter.trace()) m->mothurOut( "directions: (" + toString(u[0]) + "," + toString(u[1]) + "," + toString(u[2]) + ")" ); m->mothurOutEndLine();
567         lambda = *min_element(u.begin(), u.end());
568         if (outputFilter.trace()) m->mothurOut( "lambda: " + toString(lambda) ); m->mothurOutEndLine();
569         for ( int k = 0; k < observationCount; k++ ) {
570             double K_ik = K.similarity(twoClassLabeledObservationVector[i], twoClassLabeledObservationVector[k]);
571             double K_jk = K.similarity(twoClassLabeledObservationVector[j], twoClassLabeledObservationVector[k]);
572             g[k] += (-lambda * y[k] * K_ik + lambda * y[k] * K_jk);
573         }
574         if (outputFilter.trace()) {
575             m->mothurOut( "g =");
576             for ( int k = 0; k < observationCount; k++ ) {
577                 m->mothurOut( " " + toString(g[k]));
578             }
579             m->mothurOutEndLine();
580         }
581         a[i] += y[i] * lambda;
582         a[j] -= y[j] * lambda;
583     }
584 
585 
586     // at this point the optimal a's have been found
587     // now use them to find w and b
588     if (outputFilter.trace()) m->mothurOut( "find w" ); m->mothurOutEndLine();
589     vector<double> w(twoClassLabeledObservationVector[0].second->size(), 0.0);
590     double b = 0.0;
591     for ( int i = 0; i < y.size(); i++ ) {
592         if (outputFilter.trace()) m->mothurOut( "alpha[" + toString(i) + "] = " + toString(a[i]) ); m->mothurOutEndLine();
593         for ( int j = 0; j < w.size(); j++ ) {
594             w[j] += a[i] * y[i] * twoClassLabeledObservationVector[i].second->at(j);
595         }
596         if ( A[i] < a[i] && a[i] < B[i] ) {
597             b = yg[i];
598             if (outputFilter.trace()) m->mothurOut( "b = " + toString(b) ); m->mothurOutEndLine();
599         }
600     }
601 
602     if (outputFilter.trace()) {
603         for ( int i = 0; i < w.size(); i++ ) {
604             m->mothurOut( "w[" + toString(i) + "] = " + toString(w[i]) ); m->mothurOutEndLine();
605         }
606     }
607 
608     // be careful about passing twoClassLabeledObservationVector - what if this vector
609     // is deleted???
610     //
611     // we can eliminate elements of y, a and observation vectors corresponding to a = 0
612     vector<double> support_y;
613     vector<double> nonzero_a;
614     LabeledObservationVector supportVectors;
615     for (int i = 0; i < a.size(); i++) {
616         if ( util.isEqual(a.at(i), 0.0) ) {
617             // this dual coefficient does not correspond to a support vector
618         }
619         else {
620             support_y.push_back(y.at(i));
621             nonzero_a.push_back(a.at(i));
622             supportVectors.push_back(twoClassLabeledObservationVector.at(i));
623         }
624     }
625     //return new SVM(y, a, twoClassLabeledObservationVector, b, discriminantToLabel);
626     if (outputFilter.info()) m->mothurOut( "found " + toString(supportVectors.size()) + " support vectors\n" );
627     return new SVM(support_y, nonzero_a, supportVectors, b, discriminantToLabel);
628 }
629 
630 typedef map<Label, double> LabelToNumericClassLabel;
631 
632 // For SVM training we need to assign numeric class labels of -1.0 and +1.0.
633 // This method populates the y vector argument with -1.0 and +1.0
634 // corresponding to the two classes in the labelVector argument.
635 // For example, if labeledObservationVector looks like this:
636 //     [ (0, "blue",  [...some observations...]),
637 //       (1, "green", [...some observations...]),
638 //       (2, "blue",  [...some observations...]) ]
639 // Then after the function executes the y vector will look like this:
640 //     [-1.0,   blue
641 //      +1.0,   green
642 //      -1.0]   blue
643 // and discriminantToLabel will look like this:
644 //     { -1.0 : "blue",
645 //       +1.0 : "green" }
646 // The label "blue" is mapped to -1.0 because it is (lexicographically) less than "green".
647 // When given labels "blue" and "green" this function will always assign "blue" to -1.0 and
648 // "green" to +1.0.  This is not fundamentally important but it makes testing easier and is
649 // not a hassle to implement.
assignNumericLabels(vector<double> & y,const LabeledObservationVector & labeledObservationVector,NumericClassToLabel & discriminantToLabel)650 void SmoTrainer::assignNumericLabels(vector<double>& y, const LabeledObservationVector& labeledObservationVector, NumericClassToLabel& discriminantToLabel) {
651     // it would be nice if we assign -1.0 and +1.0 consistently for each pair of labels
652 	// I think the label set will always be traversed in sorted order so we should get this for free
653 
654     // we are going to overwrite arguments y and discriminantToLabel
655     y.clear();
656     discriminantToLabel.clear();
657 
658     LabelSet labelSet;
659     buildLabelSet(labelSet, labeledObservationVector);
660     LabelVector uniqueLabels(labelSet.begin(), labelSet.end());
661     if (labelSet.size() != 2) {
662         // throw an exception
663         cerr << "unexpected label set size " << labelSet.size() << endl;
664         for (LabelSet::const_iterator i = labelSet.begin(); i != labelSet.end(); i++) {
665             cerr << "    label " << *i << endl;
666         }
667         throw SmoTrainerException("SmoTrainer::assignNumericLabels was passed more than 2 labels");
668     }
669     else {
670         LabelToNumericClassLabel labelToNumericClassLabel;
671         labelToNumericClassLabel[uniqueLabels[0]] = -1.0;
672         labelToNumericClassLabel[uniqueLabels[1]] = +1.0;
673         for ( LabeledObservationVector::const_iterator i = labeledObservationVector.begin(); i != labeledObservationVector.end(); i++ ) {
674             y.push_back( labelToNumericClassLabel[i->first] );
675         }
676         discriminantToLabel[-1.0] = uniqueLabels[0];
677         discriminantToLabel[+1.0] = uniqueLabels[1];
678     }
679 }
680 
681 // the is a convenience function for getting parameter ranges for all kernels
getDefaultKernelParameterRangeMap(KernelParameterRangeMap & kernelParameterRangeMap)682 void getDefaultKernelParameterRangeMap(KernelParameterRangeMap& kernelParameterRangeMap) {
683     ParameterRangeMap linearParameterRangeMap;
684     linearParameterRangeMap[SmoTrainer::MapKey_C] = SmoTrainer::defaultCRange;
685     linearParameterRangeMap[LinearKernelFunction::MapKey_Constant] = LinearKernelFunction::defaultConstantRange;
686 
687     ParameterRangeMap rbfParameterRangeMap;
688     rbfParameterRangeMap[SmoTrainer::MapKey_C] = SmoTrainer::defaultCRange;
689     rbfParameterRangeMap[RbfKernelFunction::MapKey_Gamma] = RbfKernelFunction::defaultGammaRange;
690 
691     ParameterRangeMap polynomialParameterRangeMap;
692     polynomialParameterRangeMap[SmoTrainer::MapKey_C] = SmoTrainer::defaultCRange;
693     polynomialParameterRangeMap[PolynomialKernelFunction::MapKey_Constant] = PolynomialKernelFunction::defaultConstantRange;
694     polynomialParameterRangeMap[PolynomialKernelFunction::MapKey_Coefficient] = PolynomialKernelFunction::defaultCoefficientRange;
695     polynomialParameterRangeMap[PolynomialKernelFunction::MapKey_Degree] = PolynomialKernelFunction::defaultDegreeRange;
696 
697     ParameterRangeMap sigmoidParameterRangeMap;
698     sigmoidParameterRangeMap[SmoTrainer::MapKey_C] = SmoTrainer::defaultCRange;
699     sigmoidParameterRangeMap[SigmoidKernelFunction::MapKey_Alpha] = SigmoidKernelFunction::defaultAlphaRange;
700     sigmoidParameterRangeMap[SigmoidKernelFunction::MapKey_Constant] = SigmoidKernelFunction::defaultConstantRange;
701 
702     kernelParameterRangeMap[LinearKernelFunction::MapKey] = linearParameterRangeMap;
703     kernelParameterRangeMap[RbfKernelFunction::MapKey] = rbfParameterRangeMap;
704     kernelParameterRangeMap[PolynomialKernelFunction::MapKey] = polynomialParameterRangeMap;
705     kernelParameterRangeMap[SigmoidKernelFunction::MapKey] = sigmoidParameterRangeMap;
706 }
707 
708 
709 //
710 // OneVsOneMultiClassSvmTrainer
711 //
712 // An instance of OneVsOneMultiClassSvmTrainer is intended to work with a single set of data
713 // to produce a single instance of MultiClassSVM.  That's why observations and labels go in to
714 // the constructor.
OneVsOneMultiClassSvmTrainer(SvmDataset & d,int e,int t,OutputFilter & of)715 OneVsOneMultiClassSvmTrainer::OneVsOneMultiClassSvmTrainer(SvmDataset& d, int e, int t, OutputFilter& of) :
716         svmDataset(d),
717         evaluationFoldCount(e),
718         trainFoldCount(t),
719         outputFilter(of) {
720     buildLabelSet(labelSet, svmDataset.getLabeledObservationVector());
721     buildLabelToLabeledObservationVector(labelToLabeledObservationVector, svmDataset.getLabeledObservationVector());
722     buildLabelPairSet(labelPairSet, svmDataset.getLabeledObservationVector());
723 }
724 
buildLabelSet(LabelSet & labelSet,const LabeledObservationVector & labeledObservationVector)725 void buildLabelSet(LabelSet& labelSet, const LabeledObservationVector& labeledObservationVector) {
726     for (LabeledObservationVector::const_iterator i = labeledObservationVector.begin(); i != labeledObservationVector.end(); i++) {
727         labelSet.insert(i->first);
728     }
729 }
730 
731 
732 //  This function uses the LabeledObservationVector argument to populate the LabelPairSet
733 //  argument with pairs of labels.  For example, if labeledObservationVector looks like this:
734 //    [ ("blue", x), ("green", y), ("red", z) ]
735 //  then the labelPairSet will be populated with the following label pairs:
736 //    ("blue", "green"), ("blue", "red"), ("green", "red")
737 //  The order of labels in the pairs is determined by the ordering of labels in the temporary
738 //  LabelSet.  By default this order will be ascending.  However, labels are taken off the
739 //  temporary labelStack in reverse order, so the labelStack is initialized with reverse iterators.
740 //  In the end our label pairs will be in sorted order.
buildLabelPairSet(LabelPairSet & labelPairSet,const LabeledObservationVector & labeledObservationVector)741 void OneVsOneMultiClassSvmTrainer::buildLabelPairSet(LabelPairSet& labelPairSet, const LabeledObservationVector& labeledObservationVector) {
742 
743     LabelSet labelSet;
744     buildLabelSet(labelSet, labeledObservationVector);
745     LabelVector labelStack(labelSet.rbegin(), labelSet.rend());
746     while (labelStack.size() > 1) {
747         Label label = labelStack.back();
748         labelStack.pop_back();
749         LabelPair labelPair(2);
750         labelPair[0] = label;
751         for (LabelVector::const_iterator i = labelStack.begin(); i != labelStack.end(); i++) {
752             labelPair[1] = *i;
753             labelPairSet.insert(
754                 //make_pair(label, *i)
755                 labelPair
756             );
757         }
758     }
759 }
760 
761 
762 // The LabelMatchesEither functor is used only in a call to remove_copy_if in the
763 // OneVsOneMultiClassSvmTrainer::train method.  It returns true if the labeled
764 // observation argument has the same label as either of the two label arguments.
765 class LabelMatchesEither {
766 public:
LabelMatchesEither(const Label & _label0,const Label & _label1)767     LabelMatchesEither(const Label& _label0, const Label& _label1) : label0(_label0), label1(_label1) {}
768 
operator ()(const LabeledObservation & o)769     bool operator() (const LabeledObservation& o) {
770         return !((o.first == label0) || (o.first == label1));
771     }
772 
773 private:
774     const Label& label0;
775     const Label& label1;
776 };
777 
train(const KernelParameterRangeMap & kernelParameterRangeMap)778 MultiClassSVM* OneVsOneMultiClassSvmTrainer::train(const KernelParameterRangeMap& kernelParameterRangeMap) {
779     double bestMultiClassSvmScore = 0.0;
780     MultiClassSVM* bestMc;
781 
782     KernelFunctionFactory kernelFunctionFactory(svmDataset.getLabeledObservationVector());
783 
784     // first divide the data into a 'development' set for tuning hyperparameters
785     // and an 'evaluation' set for measuring performance
786     int evaluationFoldNumber = 0;
787     KFoldLabeledObservationsDivider kFoldDevEvalDivider(evaluationFoldCount, svmDataset.getLabeledObservationVector());
788     for ( kFoldDevEvalDivider.start(); !kFoldDevEvalDivider.end(); kFoldDevEvalDivider.next() ) {
789         const LabeledObservationVector& developmentObservations = kFoldDevEvalDivider.getTrainingData();
790         const LabeledObservationVector& evaluationObservations  = kFoldDevEvalDivider.getTestingData();
791 
792         evaluationFoldNumber++;
793         if ( outputFilter.debug() ) {
794             m->mothurOut( "evaluation fold " + toString(evaluationFoldNumber) + " of " + toString(evaluationFoldCount) ); m->mothurOutEndLine();
795         }
796 
797         vector<SVM*> twoClassSvmList;
798         SvmToSvmPerformanceSummary svmToSvmPerformanceSummary;
799         SmoTrainer smoTrainer(outputFilter);
800         LabelPairSet::iterator labelPair;
801         for (labelPair = labelPairSet.begin(); labelPair != labelPairSet.end(); labelPair++) {
802             // generate training and testing data for this label pair
803             Label label0 = (*labelPair)[0];
804             Label label1 = (*labelPair)[1];
805             if ( outputFilter.debug() ) {
806                 m->mothurOut("training SVM on labels " + toString(label0) + " and " + toString(label1) ); m->mothurOutEndLine();
807             }
808 
809             double bestMeanScoreOnKFolds = 0.0;
810             ParameterMap bestParameterMap;
811             string bestKernelFunctionKey;
812             LabeledObservationVector twoClassDevelopmentObservations;
813             LabelMatchesEither labelMatchesEither(label0, label1);
814             remove_copy_if(
815                 developmentObservations.begin(),
816                 developmentObservations.end(),
817                 back_inserter(twoClassDevelopmentObservations),
818                 labelMatchesEither
819                 //[&](const LabeledObservation& o){
820                 //    return !((o.first == label0) || (o.first == label1));
821                 //}
822             );
823             KFoldLabeledObservationsDivider kFoldLabeledObservationsDivider(trainFoldCount, twoClassDevelopmentObservations);
824             // loop on kernel functions and kernel function parameters
825             for ( KernelParameterRangeMap::const_iterator kmap = kernelParameterRangeMap.begin(); kmap != kernelParameterRangeMap.end(); kmap++ ) {
826                 string kernelFunctionKey = kmap->first;
827                 KernelFunction& kernelFunction = kernelFunctionFactory.getKernelFunctionForKey(kmap->first);
828                 ParameterSetBuilder p(kmap->second);
829                 for (ParameterMapVector::const_iterator hp = p.getParameterSetList().begin(); hp != p.getParameterSetList().end(); hp++) {
830                     kernelFunction.setParameters(*hp);
831                     KernelFunctionCache kernelFunctionCache(kernelFunction, svmDataset.getLabeledObservationVector());
832                     smoTrainer.setParameters(*hp);
833                     if (outputFilter.debug()) {
834                         m->mothurOut( "parameters for " + toString(kernelFunctionKey) + " kernel" ); m->mothurOutEndLine();
835                         for ( ParameterMap::const_iterator i = hp->begin(); i != hp->end(); i++ ) {
836                             m->mothurOut( "    " + toString(i->first) + ":" + toString(i->second) ); m->mothurOutEndLine();
837                         }
838                     }
839                     double meanScoreOnKFolds = trainOnKFolds(smoTrainer, kernelFunctionCache, kFoldLabeledObservationsDivider);
840                     if ( meanScoreOnKFolds > bestMeanScoreOnKFolds ) {
841                         bestMeanScoreOnKFolds = meanScoreOnKFolds;
842                         bestParameterMap = *hp;
843                         bestKernelFunctionKey = kernelFunctionKey;
844                     }
845                 }
846             }
847             Utils util;
848             if ( util.isEqual(bestMeanScoreOnKFolds, 0.0) ) {
849                 m->mothurOut( "failed to train SVM on labels " + toString(label0) + " and " + toString(label1) ); m->mothurOutEndLine();
850                 throw exception();
851             }
852             else {
853                 if ( outputFilter.debug() ) {
854                     m->mothurOut( "trained SVM on labels " + label0 + " and " + label1 ); m->mothurOutEndLine();
855                     m->mothurOut( "    best mean score over " + toString(trainFoldCount) + " folds is " + toString(bestMeanScoreOnKFolds) ); m->mothurOutEndLine();
856                     m->mothurOut( "    best parameters for " + bestKernelFunctionKey + " kernel" ); m->mothurOutEndLine();
857                     for ( ParameterMap::const_iterator p = bestParameterMap.begin(); p != bestParameterMap.end(); p++ ) {
858                         m->mothurOut( "        "  + toString(p->first) + " : " + toString(p->second) ); m->mothurOutEndLine();
859                     }
860                 }
861 
862                 LabelMatchesEither labelMatchesEither(label0, label1);
863                 LabeledObservationVector twoClassDevelopmentObservations;
864                 remove_copy_if(
865                     developmentObservations.begin(),
866                     developmentObservations.end(),
867                     back_inserter(twoClassDevelopmentObservations),
868                     labelMatchesEither
869                     //[&](const LabeledObservation& o){
870                     //    return !((o.first == label0) || (o.first == label1));
871                     //}
872                 );
873                 if (outputFilter.info()) {
874                     m->mothurOut( "training final SVM with " + toString(twoClassDevelopmentObservations.size()) + " labeled observations" ); m->mothurOutEndLine();
875                     for ( ParameterMap::const_iterator i = bestParameterMap.begin(); i != bestParameterMap.end(); i++ ) {
876                         m->mothurOut( "    " + toString(i->first) + ":" + toString(i->second) ); m->mothurOutEndLine();
877                     }
878                 }
879 
880                 KernelFunction& kernelFunction = kernelFunctionFactory.getKernelFunctionForKey(bestKernelFunctionKey);
881                 kernelFunction.setParameters(bestParameterMap);
882                 smoTrainer.setParameters(bestParameterMap);
883                 KernelFunctionCache kernelFunctionCache(kernelFunction, svmDataset.getLabeledObservationVector());
884                 SVM* svm = smoTrainer.train(kernelFunctionCache, twoClassDevelopmentObservations);
885 
886                 twoClassSvmList.push_back(svm);
887                 // return a performance summary using the evaluation dataset
888                 LabeledObservationVector twoClassEvaluationObservations;
889                 remove_copy_if(
890                     evaluationObservations.begin(),
891                     evaluationObservations.end(),
892                     back_inserter(twoClassEvaluationObservations),
893                     labelMatchesEither
894                 );
895                 SvmPerformanceSummary p(*svm, twoClassEvaluationObservations);
896                 svmToSvmPerformanceSummary[svm->getLabelPair()] = p;
897             }
898         }
899 
900         MultiClassSVM* mc = new MultiClassSVM(twoClassSvmList, labelSet, svmToSvmPerformanceSummary, outputFilter);
901         //double score = mc->score(evaluationObservations);
902         mc->setAccuracy(evaluationObservations);
903         if ( outputFilter.debug() ) {
904             m->mothurOut( "fold " + toString(evaluationFoldNumber) + " multiclass SVM score: " + toString(mc->getAccuracy()) ); m->mothurOutEndLine();
905         }
906         if ( mc->getAccuracy() > bestMultiClassSvmScore ) {
907             bestMc = mc;
908             bestMultiClassSvmScore = mc->getAccuracy();
909         }
910         else {
911             delete mc;
912         }
913     }
914 
915     if ( outputFilter.info() ) {
916         m->mothurOut( "best multiclass SVM has score " + toString(bestMc->getAccuracy()) ); m->mothurOutEndLine();
917     }
918 
919     return bestMc;
920 }
921 
922 //SvmTrainingInterruptedException multiClassSvmTrainingInterruptedException("one-vs-one multiclass SVM training interrupted by user");
923 
trainOnKFolds(SmoTrainer & smoTrainer,KernelFunctionCache & kernelFunction,KFoldLabeledObservationsDivider & kFoldLabeledObservationsDivider)924 double OneVsOneMultiClassSvmTrainer::trainOnKFolds(SmoTrainer& smoTrainer, KernelFunctionCache& kernelFunction, KFoldLabeledObservationsDivider& kFoldLabeledObservationsDivider) {
925     double meanScoreOverKFolds = 0.0;
926     double online_mean_n = 0.0;
927     double online_mean_score = 0.0;
928     meanScoreOverKFolds = -1.0;  // means we failed to train a SVM
929 
930     for ( kFoldLabeledObservationsDivider.start(); !kFoldLabeledObservationsDivider.end(); kFoldLabeledObservationsDivider.next() ) {
931         const LabeledObservationVector& kthTwoClassTrainingFold = kFoldLabeledObservationsDivider.getTrainingData();
932         const LabeledObservationVector& kthTwoClassTestingFold = kFoldLabeledObservationsDivider.getTestingData();
933         if (outputFilter.info()) {
934             m->mothurOut( "fold " + toString(kFoldLabeledObservationsDivider.getFoldNumber()) + " training data has " + toString(kthTwoClassTrainingFold.size()) + " labeled observations" ); m->mothurOutEndLine();
935             m->mothurOut( "fold " + toString(kFoldLabeledObservationsDivider.getFoldNumber()) + " testing data has " + toString(kthTwoClassTestingFold.size()) + " labeled observations" ); m->mothurOutEndLine();
936         }
937         if (m->getControl_pressed()) { return 0; }
938 
939         else {
940             try {
941                 if (outputFilter.debug()) m->mothurOut( "begin training" ); m->mothurOutEndLine();
942 
943                 SVM* evaluationSvm = smoTrainer.train(kernelFunction, kthTwoClassTrainingFold);
944                 SvmPerformanceSummary svmPerformanceSummary(*evaluationSvm, kthTwoClassTestingFold);
945                 double score = evaluationSvm->score(kthTwoClassTestingFold);
946                 //double score = svmPerformanceSummary.getAccuracy();
947                 if (outputFilter.debug()) {
948                     m->mothurOut( "score on fold " + toString(kFoldLabeledObservationsDivider.getFoldNumber()) + " of test data is " + toString(score) ); m->mothurOutEndLine();
949                     m->mothurOut( "positive label: " + toString(svmPerformanceSummary.getPositiveClassLabel()) ); m->mothurOutEndLine();
950                     m->mothurOut( "negative label: " + toString(svmPerformanceSummary.getNegativeClassLabel()) ); m->mothurOutEndLine();
951                     m->mothurOut( "  precision: " + toString(svmPerformanceSummary.getPrecision())
952                               + "     recall: " + toString(svmPerformanceSummary.getRecall())
953                               + "          f: " + toString(svmPerformanceSummary.getF())
954                               + "   accuracy: " + toString(svmPerformanceSummary.getAccuracy())
955                               ); m->mothurOutEndLine();
956                 }
957                 online_mean_n += 1.0;
958                 double online_mean_delta = score - online_mean_score;
959                 online_mean_score += online_mean_delta / online_mean_n;
960                 meanScoreOverKFolds = online_mean_score;
961 
962                 delete evaluationSvm;
963             }
964             catch ( exception& e ) {
965                 m->mothurOut( "exception: " + toString(e.what()) ); m->mothurOutEndLine();
966                 m->mothurOut( "    on fold " + toString(kFoldLabeledObservationsDivider.getFoldNumber()) + " failed to train SVM with C = " + toString(smoTrainer.getC()) ); m->mothurOutEndLine();
967             }
968         }
969     }
970     if (outputFilter.debug()) {
971         m->mothurOut( "done with cross validation on C = " + toString(smoTrainer.getC()) ); m->mothurOutEndLine();
972         m->mothurOut( "    mean score over " + toString(kFoldLabeledObservationsDivider.getFoldNumber()) + " folds is " + toString(meanScoreOverKFolds) ); m->mothurOutEndLine();
973     }
974     Utils util;
975     if ( util.isEqual(meanScoreOverKFolds, 0.0) ) { m->mothurOut( "failed to train SVM with C = " + toString(smoTrainer.getC()) + "\n");  }
976     return meanScoreOverKFolds;
977 }
978 
979 
980 class UnrankedFeature {
981 public:
UnrankedFeature(const Feature & f)982     UnrankedFeature(const Feature& f) : feature(f), rankingCriterion(0.0) {}
~UnrankedFeature()983     ~UnrankedFeature() {}
984 
getFeature() const985     Feature getFeature() const { return feature; }
986 
getRankingCriterion() const987     double getRankingCriterion() const { return rankingCriterion; }
setRankingCriterion(double rc)988     void setRankingCriterion(double rc) { rankingCriterion = rc; }
989 
990 private:
991     Feature feature;
992     double rankingCriterion;
993 };
994 
lessThanRankingCriterion(const UnrankedFeature & a,const UnrankedFeature & b)995 bool lessThanRankingCriterion(const UnrankedFeature& a, const UnrankedFeature& b) {
996     return a.getRankingCriterion() < b.getRankingCriterion();
997 }
998 
lessThanFeatureIndex(const UnrankedFeature & a,const UnrankedFeature & b)999 bool lessThanFeatureIndex(const UnrankedFeature& a, const UnrankedFeature& b) {
1000     return a.getFeature().getFeatureIndex() < b.getFeature().getFeatureIndex();
1001 }
1002 
1003 typedef list<UnrankedFeature> UnrankedFeatureList;
1004 
1005 
1006 // Only the linear svm can be used here.
1007 // Consider allowing only parameter ranges as arguments.
1008 // Right now any kernel can be sent in.
1009 // It would be useful to remove more than one feature at a time
1010 // Might make sense to turn last two arguments into one
getOrderedFeatureList(SvmDataset & svmDataset,OneVsOneMultiClassSvmTrainer & t,const ParameterRange & linearKernelConstantRange,const ParameterRange & smoTrainerParameterRange)1011 RankedFeatureList SvmRfe::getOrderedFeatureList(SvmDataset& svmDataset, OneVsOneMultiClassSvmTrainer& t, const ParameterRange& linearKernelConstantRange, const ParameterRange& smoTrainerParameterRange) {
1012 
1013     KernelParameterRangeMap rfeKernelParameterRangeMap;
1014     ParameterRangeMap linearParameterRangeMap;
1015     linearParameterRangeMap[SmoTrainer::MapKey_C] = smoTrainerParameterRange;
1016     linearParameterRangeMap[LinearKernelFunction::MapKey_Constant] = linearKernelConstantRange;
1017 
1018     rfeKernelParameterRangeMap[LinearKernelFunction::MapKey] = linearParameterRangeMap;
1019 
1020     // the rankedFeatureList is empty at first
1021     RankedFeatureList rankedFeatureList;
1022     // loop until all but one feature have been eliminated
1023     // no need to eliminate the last feature, after all
1024     int svmRfeRound = 0;
1025     //while ( rankedFeatureList.size() < (svmDataset.getFeatureVector().size()-1) ) {
1026     while ( svmDataset.getFeatureVector().size() > 1 ) {
1027         svmRfeRound++;
1028         m->mothurOut( "SVM-RFE round " + toString(svmRfeRound) + ":" ); m->mothurOutEndLine();
1029         UnrankedFeatureList unrankedFeatureList;
1030         for (int featureIndex = 0; featureIndex < svmDataset.getFeatureVector().size(); featureIndex++) {
1031             Feature f = svmDataset.getFeatureVector().at(featureIndex);
1032             unrankedFeatureList.push_back(UnrankedFeature(f));
1033         }
1034         m->mothurOut( toString(unrankedFeatureList.size()) + " unranked features" ); m->mothurOutEndLine();
1035 
1036         MultiClassSVM* s = t.train(rfeKernelParameterRangeMap);
1037         m->mothurOut( "multiclass SVM accuracy: " + toString(s->getAccuracy()) ); m->mothurOutEndLine();
1038 
1039         m->mothurOut( "two-class SVM performance" ); m->mothurOutEndLine();
1040 
1041         m->mothurOut("class 1\tclass 2\tprecision\trecall\f\accuracy\n");
1042         for ( SvmVector::const_iterator svm = s->getSvmList().begin(); svm != s->getSvmList().end(); svm++ ) {
1043             SvmPerformanceSummary sps = s->getSvmPerformanceSummary(**svm);
1044             m->mothurOut(toString(sps.getPositiveClassLabel())
1045                       + toString(sps.getNegativeClassLabel())
1046                       + toString(sps.getPrecision())
1047                       + toString(sps.getRecall())
1048                       + toString(sps.getF())
1049                       + toString(sps.getAccuracy()) ); m->mothurOutEndLine();
1050         }
1051         // calculate the 'ranking criterion' for each (remaining) feature using each binary svm
1052         for (UnrankedFeatureList::iterator f = unrankedFeatureList.begin(); f != unrankedFeatureList.end(); f++) {
1053             const int i = f->getFeature().getFeatureIndex();
1054             // rankingCriterion combines feature weights for feature i in all svms
1055             double rankingCriterion = 0.0;
1056             for ( SvmVector::const_iterator svm = s->getSvmList().begin(); svm != s->getSvmList().end(); svm++ ) {
1057                 // output SVM performance summary
1058                 // calculate the weight w of feature i for this svm
1059                 double wi = 0.0;
1060                 for (int j = 0; j < (*svm)->x.size(); j++) {
1061                     // all support vectors contribute to wi
1062                     wi += (*svm)->a.at(j) * (*svm)->y.at(j) * (*svm)->x.at(j).second->at(i);
1063                 }
1064                 // accumulate weights for feature i from all svms
1065                 rankingCriterion += pow(wi, 2);
1066             }
1067             // update the (unranked) feature ranking criterion
1068             f->setRankingCriterion(rankingCriterion);
1069         }
1070         delete s;
1071 
1072         // sort the unranked features by ranking criterion
1073         unrankedFeatureList.sort(lessThanRankingCriterion);
1074 
1075         // eliminate the bottom 1/(n+1) features - this is very slow but gives good results
1076         ////int eliminateFeatureCount = ceil(unrankedFeatureList.size() / (iterationCount+1.0));
1077         // eliminate the bottom 1/3 features - fast but results slightly different from above
1078         // how about 1/4?
1079         int eliminateFeatureCount = ceil(unrankedFeatureList.size() / 4.0);
1080         m->mothurOut( "eliminating " + toString(eliminateFeatureCount) + " feature(s) of " + toString(unrankedFeatureList.size()) + " total features\n");
1081         m->mothurOutEndLine();
1082         UnrankedFeatureList featuresToEliminate;
1083         for ( int i = 0; i < eliminateFeatureCount; i++ ) {
1084             // remove the lowest ranked feature(s) from the list of unranked features
1085             UnrankedFeature unrankedFeature = unrankedFeatureList.front();
1086             unrankedFeatureList.pop_front();
1087 
1088             featuresToEliminate.push_back(unrankedFeature);
1089             // put the lowest ranked feature at the front of the list of ranked features
1090             // the first feature to be eliminated will be at the back of this list
1091             // the last feature to be eliminated will be at the front of this list
1092             rankedFeatureList.push_front(RankedFeature(unrankedFeature.getFeature(), svmRfeRound));
1093         }
1094 
1095         featuresToEliminate.sort(lessThanFeatureIndex);
1096         reverse(featuresToEliminate.begin(), featuresToEliminate.end());
1097         for (UnrankedFeatureList::iterator g = featuresToEliminate.begin(); g != featuresToEliminate.end(); g++) {
1098             Feature unrankedFeature = g->getFeature();
1099             removeFeature(unrankedFeature, svmDataset.getLabeledObservationVector(), svmDataset.getFeatureVector());
1100         }
1101 
1102     }
1103 
1104     // there may be one feature left
1105     svmRfeRound++;
1106 
1107     for ( FeatureVector::iterator f = svmDataset.getFeatureVector().begin(); f != svmDataset.getFeatureVector().end(); f++ ) {
1108         rankedFeatureList.push_front(RankedFeature(*f, svmRfeRound));
1109     }
1110 
1111     return rankedFeatureList;
1112 }
1113 
1114