1 //
2 //  svm.hpp
3 //  support vector machine
4 //
5 //  Created by Joshua Lynch on 6/19/2013.
6 //  Copyright (c) 2013 Schloss Lab. All rights reserved.
7 //
8 
9 #ifndef svm_hpp_
10 #define svm_hpp_
11 
12 
13 #include <algorithm>
14 #include <cmath>
15 #include <deque>
16 #include <exception>
17 #include <list>
18 #include <map>
19 #include <set>
20 #include <stack>
21 #include <string>
22 #include <sstream>
23 #include "mothurout.h"
24 #include "utils.hpp"
25 
26 // For the purpose of training a support vector machine
27 // we need to calculate a dot product between two feature
28 // vectors.  In general these feature vectors are not
29 // restricted to lists of doubles, but in this implementation
30 // feature vectors (or 'observations' as they will be called from here on)
31 // will be vectors of doubles.
32 typedef vector<double> Observation;
33 
34 /*
35 class Observation {
36 public:
37     Observation() {}
38     ~Observation() {}
39 
40 private:
41     vector<double> obs;
42 };
43 */
44 
45 // A dataset is a collection of labeled observations.
46 // The ObservationVector typedef is a vector
47 // of pointers to ObservationVectors.  Pointers are used here since
48 // datasets will be rearranged many times during cross validation.
49 // Using pointers to Observations makes copying the elements of
50 // an ObservationVector cheap.
51 typedef vector<Observation*> ObservationVector;
52 
53 // Training a support vector machine requires labeled data.  The
54 // Label typedef defines what will constitute a class 'label' in
55 // this implementation.
56 typedef string Label;
57 typedef vector<Label> LabelVector;
58 typedef set<Label> LabelSet;
59 
60 // Pairs of class labels are important because a support vector machine
61 // can only learn two classes of data.  The LabelPair typedef is a vector
62 // even though a pair might seem more natural, but it is useful to
63 // iterate over the pair.
64 typedef vector<Label> LabelPair;
65 LabelPair buildLabelPair(const Label& one, const Label& two);
66 
67 // Learning to classify a dataset with more than two classes requires
68 // training a separate support vector machine for each pair of classes.
69 // The LabelPairSet typedef defines a container for the collection of
70 // all unique label pairs for a set of data.
71 typedef set<LabelPair> LabelPairSet;
72 
73 // A dataset is a set of observations with associated labels.  The
74 // LabeledObservation typedef is a label-observation pair intended to
75 // hold one observation and its corresponding label.  Using a pointer
76 // to Observation makes these objects cheap to copy.
77 //typedef pair<Label, Observation*> LabeledObservation;
78 
79 // This is a refactoring of the original LabeledObservation typedef.
80 // The original typedef has been promoted to a class in order to add
81 // at least one additional member variable, int datasetIndex, which
82 // will be used to implement kernel function optimizations.
83 class LabeledObservation {
84 public:
LabeledObservation(int _datasetIndex,Label _label,Observation * _o)85     LabeledObservation(int _datasetIndex, Label _label, Observation* _o) : datasetIndex(_datasetIndex), first(_label), second(_o) {}
~LabeledObservation()86     ~LabeledObservation() {}
87 
removeFeatureAtIndex(int n)88     void removeFeatureAtIndex(int n) {
89         int m = 0;
90         Observation::iterator i = second->begin();
91         while ( m < n ) {
92             i++;
93             m++;
94         }
95         second->erase(i);
96     }
97 
getDatasetIndex() const98     int getDatasetIndex()         const   { return datasetIndex; }
getLabel() const99     Label getLabel()              const   { return first; }
getObservation() const100     Observation* getObservation() const   { return second; }
101 
102 //private:
103     int datasetIndex;
104     Label first;
105     Observation* second;
106 };
107 
108 
109 
110 // A LabeledObservationVector is a container for an entire dataset (or a
111 // subset of an entire dataset).
112 typedef vector<LabeledObservation> LabeledObservationVector;
113 void buildLabelSet(LabelSet&, const LabeledObservationVector&);
114 
115 
116 double getMinimumFeatureValueForObservation(Observation::size_type featureIndex, LabeledObservationVector& observations);
117 double getMaximumFeatureValueForObservation(Observation::size_type featureIndex, LabeledObservationVector& observations);
118 
119 
120 void transformZeroOne(LabeledObservationVector&);
121 void transformZeroMeanUnitVariance(LabeledObservationVector&);
122 
123 
124 class Feature {
125 public:
Feature(int i,const string & l)126     Feature(int i, const string& l) : index(i), label(l) {}
Feature(const Feature & f)127     Feature(const Feature& f) : index(f.index), label(f.label) {}
~Feature()128     ~Feature() {}
129 
getFeatureIndex() const130     int getFeatureIndex() const { return index; }
setFeatureIndex(int i)131     void setFeatureIndex(int i) { index = i; }
getFeatureLabel() const132     string getFeatureLabel() const { return label; }
133 
134 private:
135     int index;
136     string label;
137 };
138 
139 typedef list<Feature> FeatureList;
140 typedef vector<Feature> FeatureVector;
141 
142 // might make sense for this to be a member function of SvmDataset
143 FeatureVector applyStdThreshold(double, LabeledObservationVector&, FeatureVector&);
144 
145 
146 //  A RankedFeature is just a Feature and a its associated 'rank', where
147 //  rank is the SVM-RFE iteration during which the feature was eliminated.
148 //  If the SVM-RFE method eliminates multiple features in an iteration
149 //  then some features will have the same rank.
150 class RankedFeature {
151 public:
RankedFeature(const Feature & f,int r)152     RankedFeature(const Feature& f, int r) : feature(f), rank(r) {}
~RankedFeature()153     ~RankedFeature() {}
154 
getFeature() const155     Feature getFeature() const { return feature; }
getRank() const156     int getRank() const { return rank; }
157 
158 private:
159     Feature feature;
160     int rank;
161 };
162 
163 typedef list<RankedFeature> RankedFeatureList;
164 
165 
166 // The SvmDataset class encapsulates labeled observations and feature information.
167 // All data required to train SVMs is found in SvmDataset.
168 class SvmDataset {
169 public:
SvmDataset(const LabeledObservationVector & v,const FeatureVector & f)170     SvmDataset(const LabeledObservationVector& v, const FeatureVector& f) : labeledObservationVector(v), featureVector(f) {}
~SvmDataset()171     ~SvmDataset() {}
172 
getLabeledObservationVector()173     LabeledObservationVector& getLabeledObservationVector() { return labeledObservationVector; }
getFeatureVector()174     FeatureVector& getFeatureVector() { return featureVector; }
175 
removeFeature(const Feature feature)176     void removeFeature(const Feature feature) {
177 
178     }
179 
180 private:
181     LabeledObservationVector labeledObservationVector;
182     FeatureVector featureVector;
183 };
184 
185 
186 //
187 //  0 - print no optional information (quiet)
188 //  1 - print minimum optional information (info)
189 //  2 - print a little more optional information (debug)
190 //  3 - print the maximum amount of optional information (trace)
191 //
192 class OutputFilter {
193 public:
OutputFilter(int v)194     OutputFilter(int v) : verbosity(v) {}
OutputFilter(const OutputFilter & of)195     OutputFilter(const OutputFilter& of) : verbosity(of.verbosity) {}
~OutputFilter()196     ~OutputFilter() {}
197 
info() const198     bool info()  const { return verbosity >= INFO; }
debug() const199     bool debug() const { return verbosity >= mDEBUG; }
trace() const200     bool trace() const { return verbosity >= TRACE; }
201 
202     static const int QUIET;
203     static const int INFO;
204     static const int mDEBUG;
205     static const int TRACE;
206 
207 private:
208     const int verbosity;
209 };
210 
211 
212 // Dividing a dataset into training and testing sets while maintaing equal
213 // representation of all classes is done using a LabelToLabeledObservationVector.
214 // This container is used to divide datasets into groups of LabeledObservations
215 // having the same label.  For example, given a LabeledObservationVector like
216 //     ["blue",  [1.0, 2.0, 3.0]]
217 //     ["green", [3.0, 4.0, 5.0]]
218 //     ["blue",  [2,0, 3.0. 4.0]]
219 //     ["green", [4.0, 5.0, 6.0]]
220 // the corresponding LabelToLabeledObservationVector looks like
221 //     "blue",  [["blue",  [1.0, 2.0, 3.0]], ["blue",  [2,0, 3.0. 4.0]]]
222 //     "green", [["green", [3.0, 4.0, 5.0]], ["green", [4.0, 5.0, 6.0]]]
223 typedef map<Label, LabeledObservationVector> LabelToLabeledObservationVector;
224 void buildLabelToLabeledObservationVector(LabelToLabeledObservationVector&, const LabeledObservationVector&);
225 
226 // A support vector machine uses +1 and -1 in calculations to represent
227 // the two classes of data it is trained to distinguish.  The NumericClassToLabel
228 // container is used to record the labels associated with these integers.
229 // For a dataset with labels "blue" and "green" a NumericClassToLabel map looks like
230 //     1, "blue"
231 //    -1, "green"
232 typedef map<int, Label> NumericClassToLabel;
233 void buildNumericClassToLabelMap(LabelPair);
234 
235 typedef double Parameter;
236 typedef string ParameterName;
237 typedef vector<double> ParameterRange;
238 typedef map<ParameterName, ParameterRange> ParameterRangeMap;
239 
240 typedef map<string, ParameterRangeMap> KernelParameterRangeMap;
241 void getDefaultKernelParameterRangeMap(KernelParameterRangeMap& kernelParameterRangeMap);
242 
243 typedef map<ParameterName, Parameter> ParameterMap;
244 typedef vector<ParameterMap> ParameterMapVector;
245 typedef stack<Parameter> ParameterStack;
246 
247 class ParameterSetBuilder {
248 public:
249     // If the argument ParameterRangeMap looks like this:
250     //     { "a" : [1.0, 2.0], "b" : [-1.0, 1.0], "c" : [0.5, 0.6] }
251     // then the list of parameter sets looks like this:
252     //     [ {"a":1.0, "b":-1.0, "c":0.5},
253     //       {"a":1.0, "b":-1.0, "c":0.6},
254     //       {"a":1.0, "b": 1.0, "c":0.5},
255     //       {"a":1.0, "b": 1.0, "c":0.6},
256     //       {"a":2.0, "b":-1.0, "c":0.5},
257     //       {"a":2.0, "b":-1.0, "c":0.6},
258     //       {"a":2.0, "b": 1.0, "c":0.5},
259     //       {"a":2.0, "b": 1.0, "c":0.6},
260     //     ]
ParameterSetBuilder(const ParameterRangeMap & parameterRangeMap)261     ParameterSetBuilder(const ParameterRangeMap& parameterRangeMap) {
262         // a small step toward quieting down this code
263         bool verbose = false;
264 
265         stack<pair<ParameterName, ParameterStack> > stackOfParameterRanges;
266         stack<pair<ParameterName, ParameterStack> > stackOfEmptyParameterRanges;
267         ParameterMap nextParameterSet;
268         int parameterSetCount = 1;
269         for ( ParameterRangeMap::const_iterator i = parameterRangeMap.begin(); i != parameterRangeMap.end(); i++ ) {
270             parameterSetCount *= i->second.size();
271             ParameterName parameterName = i->first;
272             ParameterStack emptyParameterStack;
273             stackOfEmptyParameterRanges.push(make_pair(parameterName, emptyParameterStack));
274         }
275         // get started
276         for ( int n = 0; n < parameterSetCount; n++ ) {
277 
278             if (verbose) m->mothurOut("n = " + toString(n) ); m->mothurOutEndLine();
279 
280             // pull empty stacks off until there are no empty stacks
281             while ( stackOfParameterRanges.size() > 0 and stackOfParameterRanges.top().second.size() == 0 ) {
282 
283                 if (verbose) m->mothurOut("  empty parameter range: " + stackOfParameterRanges.top().first); m->mothurOutEndLine();
284 
285                 stackOfEmptyParameterRanges.push(stackOfParameterRanges.top());
286                 stackOfParameterRanges.pop();
287             }
288 
289             // move to the next value for the parameter at the top of the stackOfParameterRanges
290             if ( stackOfParameterRanges.size() > 0 ) {
291                 if (verbose) {
292                     m->mothurOut( "  moving to next value for parameter " + toString(stackOfParameterRanges.top().first) ); m->mothurOutEndLine();
293                     m->mothurOut( "    next value is "  + toString(stackOfParameterRanges.top().second.top()) ); m->mothurOutEndLine();
294                 }
295                 ParameterName parameterName = stackOfParameterRanges.top().first;
296                 nextParameterSet[parameterName] = stackOfParameterRanges.top().second.top();
297                 stackOfParameterRanges.top().second.pop();
298             }
299             if (verbose) m->mothurOut( "stack of empty parameter ranges has size " + toString(stackOfEmptyParameterRanges.size() ) ); m->mothurOutEndLine();
300             // reset each parameter range that has been exhausted
301             while ( stackOfEmptyParameterRanges.size() > 0 ) {
302                 ParameterName parameterName = stackOfEmptyParameterRanges.top().first;
303                 if (verbose) m->mothurOut( "  reseting range for parameter " + toString(stackOfEmptyParameterRanges.top().first) ); m->mothurOutEndLine();
304                 stackOfParameterRanges.push(stackOfEmptyParameterRanges.top());
305                 stackOfEmptyParameterRanges.pop();
306                 const ParameterRange& parameterRange = parameterRangeMap.find(parameterName)->second;
307                 // it is nice to have the parameters used in order smallest to largest
308                 // so that we choose the smallest in ties
309                 // but we will not enforce this so users can specify parameters in the order they like
310                 // this loop will use parameters in the order they are found in the parameter range
311                 for (ParameterRange::const_reverse_iterator i = parameterRange.rbegin(); i != parameterRange.rend(); i++ ) {
312                     stackOfParameterRanges.top().second.push(*i);
313                 }
314                 nextParameterSet[parameterName] = stackOfParameterRanges.top().second.top();
315                 stackOfParameterRanges.top().second.pop();
316             }
317             parameterSetVector.push_back(nextParameterSet);
318             // print out the next parameter set
319             if (verbose) {
320                 for (ParameterMap::iterator p = nextParameterSet.begin(); p != nextParameterSet.end(); p++) {
321                     m->mothurOut(toString(p->first) + " : " + toString(p->second) ); m->mothurOutEndLine();
322                 }
323             }
324         }
325     }
~ParameterSetBuilder()326     ~ParameterSetBuilder() {}
327 
getParameterSetList()328     const ParameterMapVector& getParameterSetList() { return parameterSetVector; }
329 
330 private:
331     ParameterMapVector parameterSetVector;
332     MothurOut* m;
333 };
334 
335 
336 class RowCache {
337 public:
RowCache(int d)338     RowCache(int d)  { //: cache(d, NULL)
339         for (int i = 0; i < d; i++) {  cache.push_back(NULL);  }
340     }
341 
~RowCache()342     virtual ~RowCache() {
343         for (int i = 0; i < cache.size(); i++) {
344             if ( !rowNotCached(i) ) {
345                 delete cache[i];
346             }
347         }
348     }
349 
getCachedValue(int i,int j)350     double getCachedValue(int i, int j) {
351         if ( rowNotCached(i) ) {
352             createRow(i);
353         }
354         return cache.at(i)->at(j);
355     }
356 
createRow(int i)357     void createRow(int i) {
358         cache[i] = new vector<double>(cache.size(), numeric_limits<double>::signaling_NaN());
359         for ( int v = 0; v < cache.size(); v++ ) {
360             cache.at(i)->at(v) = calculateValueForCache(i, v);
361         }
362     }
363 
rowNotCached(int i)364     bool rowNotCached(int i) {
365         return cache[i] == NULL;
366     }
367 
368     virtual double calculateValueForCache(int, int) = 0;
369 
370 private:
371     vector<vector<double>* > cache;
372 };
373 
374 
375 class InnerProductRowCache : public RowCache {
376 public:
InnerProductRowCache(const LabeledObservationVector & _obs)377     InnerProductRowCache(const LabeledObservationVector& _obs) : obs(_obs), RowCache(_obs.size()) {}
~InnerProductRowCache()378     virtual ~InnerProductRowCache() {}
379 
getInnerProduct(const LabeledObservation & obs_i,const LabeledObservation & obs_j)380     double getInnerProduct(const LabeledObservation& obs_i, const LabeledObservation& obs_j) {
381         return getCachedValue(
382             obs_i.datasetIndex,
383             obs_j.datasetIndex
384         );
385     }
386 
calculateValueForCache(int i,int j)387     double calculateValueForCache(int i, int j) {
388         return inner_product(obs[i].second->begin(), obs[i].second->end(), obs[j].second->begin(), 0.0);
389     }
390 
391 private:
392     const LabeledObservationVector& obs;
393 };
394 
395 
396 // The KernelFunction class caches a partial kernel value that does not depend on kernel parameters.
397 class KernelFunction {
398 public:
399     //KernelFunction(const LabeledObservationVector& _obs, InnerProductCache& _ipc) : obs(_obs), innerProductRowCache(_ipc) {}
400 
KernelFunction(const LabeledObservationVector & _obs)401     KernelFunction(const LabeledObservationVector& _obs) :
402         obs(_obs),
403         cache(_obs.size(), NULL) {}
404 
~KernelFunction()405     virtual ~KernelFunction() {
406         for (int i = 0; i < cache.size(); i++) {
407             if ( !rowNotCached(i) ) {
408                 delete cache[i];
409             }
410         }
411     }
412 
413     virtual double similarity(const LabeledObservation&, const LabeledObservation&) = 0;
414     virtual void setParameters(const ParameterMap&) = 0;
415     virtual void getDefaultParameterRanges(ParameterRangeMap&) = 0;
416 
417     virtual double calculateParameterFreeSimilarity(const LabeledObservation&, const LabeledObservation&) = 0;
418 
getCachedParameterFreeSimilarity(const LabeledObservation & obs_i,const LabeledObservation & obs_j)419     double getCachedParameterFreeSimilarity(const LabeledObservation& obs_i, const LabeledObservation& obs_j) {
420         const int i = obs_i.datasetIndex;
421         const int j = obs_j.datasetIndex;
422 
423         if ( rowNotCached(i) ) {
424             cache[i] = new vector<double>(obs.size(), numeric_limits<double>::signaling_NaN());
425             for ( int v = 0; v < obs.size(); v++ ) {
426                 cache.at(i)->at(v) = calculateParameterFreeSimilarity(obs[i], obs[v]);
427             }
428         }
429         return cache.at(i)->at(j);
430     }
431 
rowNotCached(int i)432     bool rowNotCached(int i) {
433         return cache[i] == NULL;
434     }
435 
436 private:
437     const LabeledObservationVector& obs;
438     //vector<vector<double> > cache;
439     vector<vector<double>* > cache;
440     //InnerProductRowCache& innerProductCache;
441 };
442 
443 
444 class LinearKernelFunction : public KernelFunction {
445 public:
446     // parameters must be set before using a KernelFunction is used
LinearKernelFunction(const LabeledObservationVector & _obs)447     LinearKernelFunction(const LabeledObservationVector& _obs) : KernelFunction(_obs), constant(0.0) {}
~LinearKernelFunction()448     ~LinearKernelFunction() {}
449 
similarity(const LabeledObservation & i,const LabeledObservation & j)450     double similarity(const LabeledObservation& i, const LabeledObservation& j) {
451         return getCachedParameterFreeSimilarity(i, j) + constant;
452     }
453 
calculateParameterFreeSimilarity(const LabeledObservation & i,const LabeledObservation & j)454     double calculateParameterFreeSimilarity(const LabeledObservation& i, const LabeledObservation& j) {
455         return inner_product(i.second->begin(), i.second->end(), j.second->begin(), 0.0);
456     }
457 
getConstant()458     double getConstant() { return constant; }
setConstant(double c)459     void setConstant(double c) { constant = c; }
460 
setParameters(const ParameterMap & p)461     void setParameters(const ParameterMap& p) {
462         setConstant(p.find(MapKey_Constant)->second);
463     };
464 
getDefaultParameterRanges(ParameterRangeMap & p)465     void getDefaultParameterRanges(ParameterRangeMap& p) {
466         p[MapKey_Constant] = defaultConstantRange;
467     }
468 
469     static const string MapKey;
470     static const string MapKey_Constant;
471     static const ParameterRange defaultConstantRange;
472 
473 private:
474     double constant;
475 };
476 
477 
478 class RbfKernelFunction : public KernelFunction {
479 public:
480     // parameters must be set before a KernelFunction is used
RbfKernelFunction(const LabeledObservationVector & _obs)481     RbfKernelFunction(const LabeledObservationVector& _obs) : KernelFunction(_obs), gamma(0.0) {}
~RbfKernelFunction()482     ~RbfKernelFunction() {}
483 
similarity(const LabeledObservation & i,const LabeledObservation & j)484     double similarity(const LabeledObservation& i, const LabeledObservation& j) {
485         //double sumOfSquaredDifs = 0.0;
486         //for (int n = 0; n < i.second->size(); n++) {
487         //    sumOfSquaredDifs += pow((i.second->at(n) - j.second->at(n)), 2.0);
488         //}
489         return gamma * getCachedParameterFreeSimilarity(i, j);
490     }
491 
calculateParameterFreeSimilarity(const LabeledObservation & i,const LabeledObservation & j)492     double calculateParameterFreeSimilarity(const LabeledObservation& i, const LabeledObservation& j) {
493         //double sumOfSquaredDifs = 0.0;
494         //for (int n = 0; n < i.second->size(); n++) {
495         //    sumOfSquaredDifs += pow((i.second->at(n) - j.second->at(n)), 2.0);
496         //}
497         double sumOfSquaredDifs =
498                       inner_product(i.second->begin(), i.second->end(), i.second->begin(), 0.0)
499               - 2.0 * inner_product(i.second->begin(), i.second->end(), j.second->begin(), 0.0)
500               +       inner_product(j.second->begin(), j.second->end(), j.second->begin(), 0.0);
501         return exp(sqrt(sumOfSquaredDifs));
502     }
503 
getGamma()504     double getGamma()       { return gamma; }
setGamma(double g)505     void setGamma(double g) { gamma = g; }
506 
setParameters(const ParameterMap & p)507     void setParameters(const ParameterMap& p) {
508         setGamma(p.find(MapKey_Gamma)->second);
509     }
510 
getDefaultParameterRanges(ParameterRangeMap & p)511     void getDefaultParameterRanges(ParameterRangeMap& p) {
512         p[MapKey_Gamma] = defaultGammaRange;
513     }
514 
515     static const string MapKey;
516     static const string MapKey_Gamma;
517 
518     static const ParameterRange defaultGammaRange;
519 
520 private:
521     double gamma;
522 };
523 
524 
525 class PolynomialKernelFunction : public KernelFunction {
526 public:
527     // parameters must be set before using a KernelFunction is used
PolynomialKernelFunction(const LabeledObservationVector & _obs)528     PolynomialKernelFunction(const LabeledObservationVector& _obs) : KernelFunction(_obs), c(0.0), gamma(0.0), d(0) {}
~PolynomialKernelFunction()529     ~PolynomialKernelFunction() {}
530 
similarity(const LabeledObservation & i,const LabeledObservation & j)531     double similarity(const LabeledObservation& i, const LabeledObservation& j) {
532         return pow((gamma * getCachedParameterFreeSimilarity(i, j) + c), d);
533         //return pow(inner_product(i.second->begin(), i.second->end(), j.second->begin(), c), d);
534     }
535 
calculateParameterFreeSimilarity(const LabeledObservation & i,const LabeledObservation & j)536     double calculateParameterFreeSimilarity(const LabeledObservation& i, const LabeledObservation& j) {
537         return inner_product(i.second->begin(), i.second->end(), j.second->begin(), 0.0);
538     }
539 
setParameters(const ParameterMap & p)540     void setParameters(const ParameterMap& p) {
541         c = p.find(MapKey_Constant)->second;
542         gamma = p.find(MapKey_Coefficient)->second;
543         d = int(p.find(MapKey_Degree)->second);
544     }
545 
getDefaultParameterRanges(ParameterRangeMap & p)546     void getDefaultParameterRanges(ParameterRangeMap& p) {
547         p[MapKey_Constant] = defaultConstantRange;
548         p[MapKey_Coefficient] = defaultCoefficientRange;
549         p[MapKey_Degree] = defaultDegreeRange;
550     }
551 
552     static const string MapKey;
553     static const string MapKey_Constant;
554     static const string MapKey_Coefficient;
555     static const string MapKey_Degree;
556 
557     static const ParameterRange defaultConstantRange;
558     static const ParameterRange defaultCoefficientRange;
559     static const ParameterRange defaultDegreeRange;
560 
561 private:
562     double c;
563     double gamma;
564     int d;
565 };
566 
567 
568 class SigmoidKernelFunction : public KernelFunction {
569 public:
570     // parameters must be set before using a KernelFunction is used
SigmoidKernelFunction(const LabeledObservationVector & _obs)571     SigmoidKernelFunction(const LabeledObservationVector& _obs) : KernelFunction(_obs), alpha(0.0), c(0.0) {}
~SigmoidKernelFunction()572     ~SigmoidKernelFunction() {}
573 
similarity(const LabeledObservation & i,const LabeledObservation & j)574     double similarity(const LabeledObservation& i, const LabeledObservation& j) {
575         return tanh(alpha * getCachedParameterFreeSimilarity(i, j) + c);
576         //return tanh(alpha * inner_product(i.second->begin(), i.second->end(), j.second->begin(), c));
577     }
578 
calculateParameterFreeSimilarity(const LabeledObservation & i,const LabeledObservation & j)579     double calculateParameterFreeSimilarity(const LabeledObservation& i, const LabeledObservation& j) {
580         return inner_product(i.second->begin(), i.second->end(), j.second->begin(), 0.0);
581     }
582 
setParameters(const ParameterMap & p)583     void setParameters(const ParameterMap& p) {
584         alpha = p.find(MapKey_Alpha)->second;
585         c = p.find(MapKey_Constant)->second;
586     }
587 
getDefaultParameterRanges(ParameterRangeMap & p)588     void getDefaultParameterRanges(ParameterRangeMap& p) {
589         p[MapKey_Alpha] = defaultAlphaRange;
590         p[MapKey_Constant] = defaultConstantRange;
591     }
592 
593     static const string MapKey;
594     static const string MapKey_Alpha;
595     static const string MapKey_Constant;
596 
597     static const ParameterRange defaultAlphaRange;
598     static const ParameterRange defaultConstantRange;
599 private:
600     double alpha;
601     double c;
602 };
603 
604 
605 class KernelFactory {
606 public:
getKernelFunctionForKey(string kernelFunctionKey,const LabeledObservationVector & obs)607     static KernelFunction* getKernelFunctionForKey(string kernelFunctionKey, const LabeledObservationVector& obs) {
608         if ( kernelFunctionKey == LinearKernelFunction::MapKey ) {
609             return new LinearKernelFunction(obs);
610         }
611         else if ( kernelFunctionKey == RbfKernelFunction::MapKey ) {
612             return new RbfKernelFunction(obs);
613         }
614         else if ( kernelFunctionKey == PolynomialKernelFunction::MapKey ) {
615             return new PolynomialKernelFunction(obs);
616         }
617         else if ( kernelFunctionKey == SigmoidKernelFunction::MapKey ) {
618             return new SigmoidKernelFunction(obs);
619         }
620         else {
621             throw new exception();
622         }
623     }
624 };
625 
626 
627 typedef map<string, KernelFunction*> KernelFunctionMap;
628 
629 // An instance of KernelFunctionFactory dynamically allocates kernel function
630 // instances and maintains a table of pointers to them.  This allows kernel
631 // function instances to be reused which improves performance since the
632 // kernel values do not have to be recalculated as often.  A KernelFunctionFactory
633 // maintains an inner product cache used by the KernelFunctions it builds.
634 class KernelFunctionFactory {
635 public:
KernelFunctionFactory(const LabeledObservationVector & _obs)636     KernelFunctionFactory(const LabeledObservationVector& _obs) : obs(_obs) {}
~KernelFunctionFactory()637     ~KernelFunctionFactory() {
638         for ( KernelFunctionMap::iterator i = kernelFunctionTable.begin(); i != kernelFunctionTable.end(); i++ ) {
639             delete i->second;
640         }
641     }
642 
getKernelFunctionForKey(string kernelFunctionKey)643     KernelFunction& getKernelFunctionForKey(string kernelFunctionKey) {
644         if ( kernelFunctionTable.count(kernelFunctionKey) == 0 ) {
645             kernelFunctionTable.insert(
646                 make_pair(
647                     kernelFunctionKey,
648                     KernelFactory::getKernelFunctionForKey(kernelFunctionKey, obs)
649                 )
650             );
651         }
652         return *kernelFunctionTable[kernelFunctionKey];
653     }
654 
655 private:
656     const LabeledObservationVector& obs;
657     KernelFunctionMap kernelFunctionTable;
658     //InnerProductCache innerProductCache;
659 };
660 
661 
662 class KernelFunctionCache {
663 public:
KernelFunctionCache(KernelFunction & _k,const LabeledObservationVector & _obs)664     KernelFunctionCache(KernelFunction& _k, const LabeledObservationVector& _obs) :
665         k(_k), obs(_obs),
666         cache(_obs.size(), NULL) {}
~KernelFunctionCache()667     ~KernelFunctionCache() {
668 
669         for (int i = 0; i < cache.size(); i++) {
670             if ( !rowNotCached(i) ) {
671                 delete cache[i];
672             }
673         }
674     }
675 
similarity(const LabeledObservation & obs_i,const LabeledObservation & obs_j)676     double similarity(const LabeledObservation& obs_i, const LabeledObservation& obs_j) {
677         const int i = obs_i.datasetIndex;
678         const int j = obs_j.datasetIndex;
679         // if the first element of row i is NaN then calculate all elements for row i
680         if ( rowNotCached(i) ) {
681             cache[i] = new vector<double>(obs.size(), numeric_limits<double>::signaling_NaN());
682             for ( int v = 0; v < obs.size(); v++ ) {
683                 cache.at(i)->at(v) = k.similarity(
684                     obs[i],
685                     obs[v]
686                 );
687             }
688         }
689         return cache.at(i)->at(j);
690     }
691 
rowNotCached(int i)692     bool rowNotCached(int i) {
693         return cache[i] == NULL;
694     }
695 
696 private:
697     KernelFunction& k;
698     const LabeledObservationVector& obs;
699     //vector<vector<double> > cache;
700     vector<vector<double>* > cache;
701 };
702 
703 
704 // The SVM class implements the Support Vector Machine
705 // discriminant function.  Instances are constructed with
706 // a vector of class labels (+1.0 or -1.0), a vector of dual
707 // coefficients, a vector of observations, and a bias value.
708 //
709 // The class SmoTrainer is responsible for determining the dual
710 // coefficients and bias value.
711 //
712 class SVM {
713 public:
SVM(const vector<double> & yy,const vector<double> & aa,const LabeledObservationVector & oo,double bb,const NumericClassToLabel & mm)714     SVM(const vector<double>& yy, const vector<double>& aa, const LabeledObservationVector& oo, double bb, const NumericClassToLabel& mm) :
715         y(yy), a(aa), x(oo), b(bb), discriminantToLabel(mm) {}
~SVM()716     ~SVM() {}
717 
718     // the classify method should accept a list of observations?
719     int discriminant(const Observation&) const;
classify(const Observation & observation) const720     Label classify(const Observation& observation) const {
721         //return discriminantToLabel[discriminant(observation)];
722         return discriminantToLabel.find(discriminant(observation))->second;
723     }
724     LabelVector classify(const LabeledObservationVector&) const;
725     double score(const LabeledObservationVector&) const;
726 
getDiscriminantToLabel() const727     NumericClassToLabel getDiscriminantToLabel() const { return discriminantToLabel; }
getLabelPair() const728     LabelPair getLabelPair() const { return buildLabelPair(discriminantToLabel.find(1)->second, discriminantToLabel.find(-1)->second); }
729 
730 public:
731     // y holds the numeric class: +1.0 or -1.0
732     const vector<double> y;
733     // a holds the optimal dual coefficients
734     const vector<double> a;
735     // x holds the support vectors
736     const LabeledObservationVector x;
737     const double b;
738     const NumericClassToLabel discriminantToLabel;
739 };
740 
741 
742 class SvmPerformanceSummary {
743 public:
SvmPerformanceSummary()744     SvmPerformanceSummary() {}
745     // this constructor should be used by clients other than tests
SvmPerformanceSummary(const SVM & svm,const LabeledObservationVector & actual)746     SvmPerformanceSummary(const SVM& svm, const LabeledObservationVector& actual) {
747         init(svm, actual, svm.classify(actual));
748     }
749     // this constructor is intended for unit testing
SvmPerformanceSummary(const SVM & svm,const LabeledObservationVector & actual,const LabelVector & predictions)750     SvmPerformanceSummary(const SVM& svm, const LabeledObservationVector& actual, const LabelVector& predictions) {
751         init(svm, actual, predictions);
752     }
753 
getPositiveClassLabel() const754     Label getPositiveClassLabel() const { return positiveClassLabel; }
getNegativeClassLabel() const755     Label getNegativeClassLabel() const { return negativeClassLabel; }
756 
getPrecision() const757     double getPrecision() const { return precision; }
getRecall() const758     double getRecall()    const { return recall; }
getF() const759     double getF()         const { return f; }
getAccuracy() const760     double getAccuracy()  const { return accuracy; }
761 
762 private:
763     void init(const SVM&, const LabeledObservationVector&, const LabelVector&);
764 
765     //const SVM& svm;
766 
767     Label positiveClassLabel;
768     Label negativeClassLabel;
769 
770     double precision;
771     double recall;
772     double f;
773     double accuracy;
774 };
775 
776 
777 class MultiClassSvmClassificationTie : public exception {
778 public:
MultiClassSvmClassificationTie(LabelVector & t,int c)779     MultiClassSvmClassificationTie(LabelVector& t, int c) : tiedLabels(t), tiedVoteCount(c) {}
~MultiClassSvmClassificationTie()780     ~MultiClassSvmClassificationTie() throw() {}
781 
what() const782     virtual const char* what() const throw() {
783         return "classification tie";
784     }
785 
786 private:
787     const LabelVector tiedLabels;
788     const int tiedVoteCount;
789 };
790 
791 typedef vector<SVM*> SvmVector;
792 typedef map<LabelPair, SvmPerformanceSummary> SvmToSvmPerformanceSummary;
793 
794 // Using SVM with more than two classes requires training multiple SVMs.
795 // The MultiClassSVM uses a vector of trained SVMs to do classification
796 // on data having more than two classes.
797 class MultiClassSVM {
798 public:
799     MultiClassSVM(const vector<SVM*>, const LabelSet&, const SvmToSvmPerformanceSummary&, OutputFilter);
800     ~MultiClassSVM();
801 
802     // the classify method should accept a list of observations
803     Label classify(const Observation& observation);
804     double score(const LabeledObservationVector&);
805 
806     // no need to delete these pointers
getSvmList()807     const SvmVector& getSvmList() { return twoClassSvmList; }
808 
getLabels()809     const LabelSet& getLabels() { return labelSet; }
810 
getSvmPerformanceSummary(const SVM & svm)811     const SvmPerformanceSummary& getSvmPerformanceSummary(const SVM& svm) { return svmToSvmPerformanceSummary.at(svm.getLabelPair()); }
812 
getAccuracy()813     double getAccuracy() { return accuracy; }
setAccuracy(const LabeledObservationVector & obs)814     void setAccuracy(const LabeledObservationVector& obs) { accuracy = score(obs); }
815 
816 private:
817     const SvmVector twoClassSvmList;
818     const LabelSet labelSet;
819     const OutputFilter outputFilter;
820 
821     double accuracy;
822     MothurOut* m;
823 
824     // this is a map from label pairs to performance summaries
825     SvmToSvmPerformanceSummary svmToSvmPerformanceSummary;
826 };
827 
828 
829 //class SvmTrainingInterruptedException : public exception {
830 //public:
831 //    SvmTrainingInterruptedException(const string& m) : message(m) {}
832 //    ~SvmTrainingInterruptedException() throw() {}
833 //    virtual const char* what() const throw() {
834 //        return message.c_str();
835 //    }
836 
837 //private:
838 //    string message;
839 //};
840 
841 class SmoTrainerException : public exception {
842 public:
SmoTrainerException(const string & m)843     SmoTrainerException(const string& m) : message(m) {}
~SmoTrainerException()844     ~SmoTrainerException() throw() {}
what() const845     virtual const char* what() const throw() {
846         return message.c_str();
847     }
848 
849 private:
850     string message;
851 };
852 
853 //class ExternalSvmTrainingInterruption {
854 //public:
855 //    ExternalSvmTrainingInterruption() {}
856 //    virtual ~ExternalSvmTrainingInterruption() throw() {}
857 //    virtual bool interruptTraining() { return false; }
858 //};
859 
860 
861 // SmoTrainer trains a support vector machine using Sequential
862 // Minimal Optimization as described in the article
863 // "Support Vector Machine Solvers" by Bottou and Lin.
864 class SmoTrainer {
865 public:
SmoTrainer(OutputFilter of)866     SmoTrainer(OutputFilter of) : outputFilter(of), C(1.0) {}
867 
~SmoTrainer()868     ~SmoTrainer() {}
869 
getC()870     double getC()       { return C; }
setC(double C)871     void setC(double C) { this->C = C; }
872 
setParameters(const ParameterMap & p)873     void setParameters(const ParameterMap& p) {
874         C = p.find(MapKey_C)->second;
875     }
876 
877     SVM* train(KernelFunctionCache&, const LabeledObservationVector&);
878     void assignNumericLabels(vector<double>&, const LabeledObservationVector&, NumericClassToLabel&);
elementwise_multiply(vector<double> & a,vector<double> & b,vector<double> & c)879     void elementwise_multiply(vector<double>& a, vector<double>& b, vector<double>& c) {
880         transform(a.begin(), a.end(), b.begin(), c.begin(), multiplies<double>());
881     }
882 
883     static const string MapKey_C;
884     static const ParameterRange defaultCRange;
885 
886 private:
887     //ExternalSvmTrainingInterruption& externalSvmTrainingInterruption;
888 
889     const OutputFilter outputFilter;
890 
891     double C;
892 };
893 
894 
895 // KFoldLabeledObservationDivider is used in cross validation to generate
896 // training and testing data sets of labeled observations.  The labels will
897 // be distributed in proportion to their frequency in the data, as much as possible.
898 //
899 // Consider a data set with 100 observations from five classes.  Also, let
900 // each class have 20 observations.  If we want to do 10-fold cross validation
901 // then training sets should have 90 observations and test sets should have
902 // 10 observations.  A training set should have approximately equal representation
903 // from each class, as should the test sets.
904 //
905 // An instance of KFoldLabeledObservationDivider will generate training and test
906 // sets within a for loop like this:
907 //
908 //   KFoldLabeledObservationDivider X(10, allLabeledObservations);
909 //   for (X.start(); !X.end(); X.next()) {
910 //        const LabeledObservationVector& trainingData = X.getTrainingData();
911 //        const LabeledObservationVector& testingData  = X.getTestingData();
912 //        // do cross validation on one fold
913 //   }
914 class KFoldLabeledObservationsDivider {
915 public:
916     // initialize the k member variable to K so end() will return true if it is called before start()
917     // this is not perfect protection against misuse but it's better than nothing
KFoldLabeledObservationsDivider(int _K,const LabeledObservationVector & l)918     KFoldLabeledObservationsDivider(int _K, const LabeledObservationVector& l) : K(_K), k(_K) {
919         buildLabelToLabeledObservationVector(labelToLabeledObservationVector, l);
920     }
~KFoldLabeledObservationsDivider()921     ~KFoldLabeledObservationsDivider() {}
922 
start()923     void start() {
924         k = 0;
925         trainingData.clear();
926         testingData.clear();
927         for (LabelToLabeledObservationVector::const_iterator p = labelToLabeledObservationVector.begin(); p != labelToLabeledObservationVector.end(); p++) {
928             appendKthFold(k, K, p->second, trainingData, testingData);
929         }
930     }
931 
end()932     bool end() {
933         return k >= K;
934     }
935 
next()936     void next() {
937         k++;
938         trainingData.clear();
939         testingData.clear();
940         for (LabelToLabeledObservationVector::const_iterator p = labelToLabeledObservationVector.begin(); p != labelToLabeledObservationVector.end(); p++) {
941             appendKthFold(k, K, p->second, trainingData, testingData);
942         }
943     }
944 
getFoldNumber()945     int getFoldNumber() { return k; }
getTrainingData()946     const LabeledObservationVector& getTrainingData() { return trainingData; }
getTestingData()947     const LabeledObservationVector& getTestingData() { return testingData; }
948 
949     // Function appendKthFold takes care of partitioning the observations in x into two sets,
950     // one for training and one for testing.  The argument K specifies how many folds
951     // will be requested in all.  The argument k specifies which fold to return.
952     // An example: let K=3, k=0, and let there be 10 observations (all having the same label)
953     //     i  i%3  (i%3)==0  k=0 partition  (i%3)==1  k=1 partition  (i%3)==2  k=2 partition
954     //     0   0     true      testing        false     training       false     training
955     //     1   1     false     training       true      testing        false     training
956     //     2   2     false     training       false     training       true      testing
957     //     3   0     true      testing        false     training       false     training
958     //     4   1     false     training       true      testing        false     training
959     //     5   2     false     training       false     training       true      testing
960     //     6   0     true      testing        false     training       false     training
961     //     7   1     false     training       true      testing        false     training
962     //     8   2     false     training       false     training       true      testing
963     //     9   0     true      testing        false     training       false     training
964     //
appendKthFold(int k,int K,const LabeledObservationVector & x,LabeledObservationVector & trainingData,LabeledObservationVector & testingData)965     static void appendKthFold(int k, int K, const LabeledObservationVector& x, LabeledObservationVector& trainingData, LabeledObservationVector& testingData) {
966         //for ( int i = 0; i < x.size(); i++) {
967         int i = 0;
968         for (LabeledObservationVector::const_iterator xi = x.begin(); xi != x.end(); xi++) {
969             if ( (i % K) == k) {
970                 testingData.push_back(*xi);
971             }
972             else {
973                 trainingData.push_back(*xi);
974             }
975             i++;
976         }
977     }
978 
979 private:
980     const int K;
981     int k;
982     LabelVector labelVector;
983     LabelToLabeledObservationVector labelToLabeledObservationVector;
984     LabeledObservationVector trainingData;
985     LabeledObservationVector testingData;
986 };
987 
988 
989 // OneVsOneMultiClassSvmTrainer trains a support vector machine for each
990 // pair of labels in a set of data.
991 class OneVsOneMultiClassSvmTrainer {
992 public:
993     OneVsOneMultiClassSvmTrainer(SvmDataset&, int, int, OutputFilter&);
~OneVsOneMultiClassSvmTrainer()994     ~OneVsOneMultiClassSvmTrainer() {}
995 
996     MultiClassSVM* train(const KernelParameterRangeMap&);
997     double trainOnKFolds(SmoTrainer&, KernelFunctionCache&, KFoldLabeledObservationsDivider&);
getLabelSet()998     const LabelSet& getLabelSet() { return labelSet; }
getLabeledObservations()999     const LabeledObservationVector& getLabeledObservations() { return svmDataset.getLabeledObservationVector(); }
getLabelPairSet()1000     const LabelPairSet& getLabelPairSet() { return labelPairSet; }
getLabeledObservationVectorForLabel(const Label & label)1001     const LabeledObservationVector& getLabeledObservationVectorForLabel(const Label& label) { return labelToLabeledObservationVector[label]; }
1002 
getOutputFilter()1003     const OutputFilter& getOutputFilter() { return outputFilter; }
1004 
1005     static void buildLabelPairSet(LabelPairSet&, const LabeledObservationVector&);
1006     static void appendTrainingAndTestingData(Label, const LabeledObservationVector&, LabeledObservationVector&, LabeledObservationVector&);
1007 
1008 private:
1009 
1010     const OutputFilter outputFilter;
1011     //bool verbose;
1012 
1013     SvmDataset& svmDataset;
1014 
1015     const int evaluationFoldCount;
1016     const int trainFoldCount;
1017 
1018     LabelSet labelSet;
1019     LabelToLabeledObservationVector labelToLabeledObservationVector;
1020     LabelPairSet labelPairSet;
1021 
1022 };
1023 
1024 // A better name for this class is MsvmRfe after MSVM-RFE described in
1025 // "MSVM-RFE: extensions of SVM-RFE for multiclass gene selection on
1026 // DNA microarray data", Zhou and Tuck, 2007, Bioinformatics
1027 class SvmRfe {
1028 public:
SvmRfe()1029     SvmRfe() {}
~SvmRfe()1030     ~SvmRfe() {}
1031 
1032     RankedFeatureList getOrderedFeatureList(SvmDataset&, OneVsOneMultiClassSvmTrainer&, const ParameterRange&, const ParameterRange&);
1033 };
1034 
1035 
1036 #endif /* svm_hpp_ */
1037