1 #ifndef __BINDINGS_H__ 2 #define __BINDINGS_H__ 3 4 using namespace std; 5 6 #include <cstring> 7 #include <vector> 8 #include <map> 9 #include <assert.h> 10 11 #include "libsvm.h" 12 13 class DataSet { 14 friend class SVM; 15 16 private: 17 double label; 18 struct svm_node *attributes; 19 int n; int max_n; int max_i; 20 bool realigned; 21 public: 22 DataSet(double l); setLabel(double l)23 void setLabel(double l) { label = l; } getLabel()24 double getLabel() { return label; } getMaxI()25 int getMaxI() { return max_i; } 26 void setAttribute(int k, double v); 27 double getAttribute(int k); getIndexAt(int i)28 int getIndexAt(int i) { if (i<=n) { return attributes[i].index; } else { return -1; }} getValueAt(int i)29 double getValueAt(int i) { if (i<=n) { return attributes[i].value; } else { return 0; }} 30 31 void realign(struct svm_node *address); 32 ~DataSet(); 33 }; 34 35 36 class SVM { 37 public: 38 SVM(int st, int kt, int d, double g, double c0, double C, double nu, 39 double e); 40 void addDataSet(DataSet *ds); 41 int saveModel(char *filename); 42 int loadModel(char *filename); 43 void clearDataSet(); 44 int train(int retrain); 45 double predict_value(DataSet *ds); 46 double predict(DataSet *ds); 47 void free_x_space(); setSVMType(int st)48 void setSVMType(int st) { param.svm_type = st; } getSVMType()49 int getSVMType() { return param.svm_type; } setKernelType(int kt)50 void setKernelType(int kt) { param.kernel_type = kt; } getKernelType()51 int getKernelType() { return param.kernel_type; } setGamma(double g)52 void setGamma(double g) { param.gamma = g; } getGamma()53 double getGamma() { return param.gamma; } setDegree(int d)54 void setDegree(int d) { param.degree = d; } getDegree()55 double getDegree() { return param.degree; } setCoef0(double c)56 void setCoef0(double c) { param.coef0 = c; } getCoef0()57 double getCoef0() { return param.coef0; } setC(double c)58 void setC(double c) { param.C = c; } getC()59 double getC() { return param.C; } setNu(double n)60 void setNu(double n) { param.nu = n; } getNu()61 double getNu() { return param.nu; } setEpsilon(double e)62 void setEpsilon(double e) { param.p = e; } getEpsilon()63 double getEpsilon() { return param.p; } 64 double crossValidate(int nfolds); 65 int getNRClass(); 66 int getLabels(int* label); 67 double getSVRProbability(); 68 int checkProbabilityModel(); 69 70 ~SVM(); 71 private: 72 long nelem; 73 struct svm_parameter param; 74 vector<DataSet *> dataset; 75 struct svm_problem *prob; 76 struct svm_model *model; 77 struct svm_node *x_space; 78 int randomized; 79 }; 80 81 #endif 82