1 /*
2 This code was extracted from libsvm 3.2.3 in Feb 2019 and
3 modified for the use with Octave and Matlab
4
5
6 Copyright (c) 2000-2019 Chih-Chung Chang and Chih-Jen Lin
7 All rights reserved.
8
9 Redistribution and use in source and binary forms, with or without
10 modification, are permitted provided that the following conditions
11 are met:
12
13 1. Redistributions of source code must retain the above copyright
14 notice, this list of conditions and the following disclaimer.
15
16 2. Redistributions in binary form must reproduce the above copyright
17 notice, this list of conditions and the following disclaimer in the
18 documentation and/or other materials provided with the distribution.
19
20 3. Neither name of copyright holders nor the names of its contributors
21 may be used to endorse or promote products derived from this software
22 without specific prior written permission.
23
24
25 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
26 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
27 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
28 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR
29 CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
30 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
31 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
32 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
33 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
34 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
35 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36
37 */
38
39 #include <stdio.h>
40 #include <stdlib.h>
41 #include <string.h>
42 #include "svm.h"
43
44 #include "mex.h"
45 #include "svm_model_matlab.h"
46
47 #ifdef MX_API_VER
48 #if MX_API_VER < 0x07030000
49 typedef int mwIndex;
50 #endif
51 #endif
52
53 #define CMD_LEN 2048
54
print_null(const char * s,...)55 int print_null(const char *s,...) {return 0;}
56 int (*info)(const char *fmt,...) = &mexPrintf;
57
read_sparse_instance(const mxArray * prhs,int index,struct svm_node * x)58 void read_sparse_instance(const mxArray *prhs, int index, struct svm_node *x)
59 {
60 int i, j, low, high;
61 mwIndex *ir, *jc;
62 double *samples;
63
64 ir = mxGetIr(prhs);
65 jc = mxGetJc(prhs);
66 samples = mxGetPr(prhs);
67
68 // each column is one instance
69 j = 0;
70 low = (int)jc[index], high = (int)jc[index+1];
71 for(i=low;i<high;i++)
72 {
73 x[j].index = (int)ir[i] + 1;
74 x[j].value = samples[i];
75 j++;
76 }
77 x[j].index = -1;
78 }
79
fake_answer(int nlhs,mxArray * plhs[])80 static void fake_answer(int nlhs, mxArray *plhs[])
81 {
82 int i;
83 for(i=0;i<nlhs;i++)
84 plhs[i] = mxCreateDoubleMatrix(0, 0, mxREAL);
85 }
86
predict(int nlhs,mxArray * plhs[],const mxArray * prhs[],struct svm_model * model,const int predict_probability)87 void predict(int nlhs, mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, const int predict_probability)
88 {
89 int label_vector_row_num, label_vector_col_num;
90 int feature_number, testing_instance_number;
91 int instance_index;
92 double *ptr_instance, *ptr_label, *ptr_predict_label;
93 double *ptr_prob_estimates, *ptr_dec_values, *ptr;
94 struct svm_node *x;
95 mxArray *pplhs[1]; // transposed instance sparse matrix
96 mxArray *tplhs[3]; // temporary storage for plhs[]
97
98 int correct = 0;
99 int total = 0;
100 double error = 0;
101 double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0;
102
103 int svm_type=svm_get_svm_type(model);
104 int nr_class=svm_get_nr_class(model);
105 double *prob_estimates=NULL;
106
107 // prhs[1] = testing instance matrix
108 feature_number = (int)mxGetN(prhs[1]);
109 testing_instance_number = (int)mxGetM(prhs[1]);
110 label_vector_row_num = (int)mxGetM(prhs[0]);
111 label_vector_col_num = (int)mxGetN(prhs[0]);
112
113 if(label_vector_row_num!=testing_instance_number)
114 {
115 mexPrintf("Length of label vector does not match # of instances.\n");
116 fake_answer(nlhs, plhs);
117 return;
118 }
119 if(label_vector_col_num!=1)
120 {
121 mexPrintf("label (1st argument) should be a vector (# of column is 1).\n");
122 fake_answer(nlhs, plhs);
123 return;
124 }
125
126 ptr_instance = mxGetPr(prhs[1]);
127 ptr_label = mxGetPr(prhs[0]);
128
129 // transpose instance matrix
130 if(mxIsSparse(prhs[1]))
131 {
132 if(model->param.kernel_type == PRECOMPUTED)
133 {
134 // precomputed kernel requires dense matrix, so we make one
135 mxArray *rhs[1], *lhs[1];
136 rhs[0] = mxDuplicateArray(prhs[1]);
137 if(mexCallMATLAB(1, lhs, 1, rhs, "full"))
138 {
139 mexPrintf("Error: cannot full testing instance matrix\n");
140 fake_answer(nlhs, plhs);
141 return;
142 }
143 ptr_instance = mxGetPr(lhs[0]);
144 mxDestroyArray(rhs[0]);
145 }
146 else
147 {
148 mxArray *pprhs[1];
149 pprhs[0] = mxDuplicateArray(prhs[1]);
150 if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
151 {
152 mexPrintf("Error: cannot transpose testing instance matrix\n");
153 fake_answer(nlhs, plhs);
154 return;
155 }
156 }
157 }
158
159 if(predict_probability)
160 {
161 if(svm_type==NU_SVR || svm_type==EPSILON_SVR)
162 info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n",svm_get_svr_probability(model));
163 else
164 prob_estimates = (double *) malloc(nr_class*sizeof(double));
165 }
166
167 tplhs[0] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
168 if(predict_probability)
169 {
170 // prob estimates are in plhs[2]
171 if(svm_type==C_SVC || svm_type==NU_SVC)
172 tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class, mxREAL);
173 else
174 tplhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
175 }
176 else
177 {
178 // decision values are in plhs[2]
179 if(svm_type == ONE_CLASS ||
180 svm_type == EPSILON_SVR ||
181 svm_type == NU_SVR ||
182 nr_class == 1) // if only one class in training data, decision values are still returned.
183 tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
184 else
185 tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class*(nr_class-1)/2, mxREAL);
186 }
187
188 ptr_predict_label = mxGetPr(tplhs[0]);
189 ptr_prob_estimates = mxGetPr(tplhs[2]);
190 ptr_dec_values = mxGetPr(tplhs[2]);
191 x = (struct svm_node*)malloc((feature_number+1)*sizeof(struct svm_node) );
192 for(instance_index=0;instance_index<testing_instance_number;instance_index++)
193 {
194 int i;
195 double target_label, predict_label;
196
197 target_label = ptr_label[instance_index];
198
199 if(mxIsSparse(prhs[1]) && model->param.kernel_type != PRECOMPUTED) // prhs[1]^T is still sparse
200 read_sparse_instance(pplhs[0], instance_index, x);
201 else
202 {
203 for(i=0;i<feature_number;i++)
204 {
205 x[i].index = i+1;
206 x[i].value = ptr_instance[testing_instance_number*i+instance_index];
207 }
208 x[feature_number].index = -1;
209 }
210
211 if(predict_probability)
212 {
213 if(svm_type==C_SVC || svm_type==NU_SVC)
214 {
215 predict_label = svm_predict_probability(model, x, prob_estimates);
216 ptr_predict_label[instance_index] = predict_label;
217 for(i=0;i<nr_class;i++)
218 ptr_prob_estimates[instance_index + i * testing_instance_number] = prob_estimates[i];
219 } else {
220 predict_label = svm_predict(model,x);
221 ptr_predict_label[instance_index] = predict_label;
222 }
223 }
224 else
225 {
226 if(svm_type == ONE_CLASS ||
227 svm_type == EPSILON_SVR ||
228 svm_type == NU_SVR)
229 {
230 double res;
231 predict_label = svm_predict_values(model, x, &res);
232 ptr_dec_values[instance_index] = res;
233 }
234 else
235 {
236 double *dec_values = (double *) malloc(sizeof(double) * nr_class*(nr_class-1)/2);
237 predict_label = svm_predict_values(model, x, dec_values);
238 if(nr_class == 1)
239 ptr_dec_values[instance_index] = 1;
240 else
241 for(i=0;i<(nr_class*(nr_class-1))/2;i++)
242 ptr_dec_values[instance_index + i * testing_instance_number] = dec_values[i];
243 free(dec_values);
244 }
245 ptr_predict_label[instance_index] = predict_label;
246 }
247
248 if(predict_label == target_label)
249 ++correct;
250 error += (predict_label-target_label)*(predict_label-target_label);
251 sump += predict_label;
252 sumt += target_label;
253 sumpp += predict_label*predict_label;
254 sumtt += target_label*target_label;
255 sumpt += predict_label*target_label;
256 ++total;
257 }
258 if(svm_type==NU_SVR || svm_type==EPSILON_SVR)
259 {
260 info("Mean squared error = %g (regression)\n",error/total);
261 info("Squared correlation coefficient = %g (regression)\n",
262 ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
263 ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt))
264 );
265 }
266 else
267 info("Accuracy = %g%% (%d/%d) (classification)\n",
268 (double)correct/total*100,correct,total);
269
270 // return accuracy, mean squared error, squared correlation coefficient
271 tplhs[1] = mxCreateDoubleMatrix(3, 1, mxREAL);
272 ptr = mxGetPr(tplhs[1]);
273 ptr[0] = (double)correct/total*100;
274 ptr[1] = error/total;
275 ptr[2] = ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
276 ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt));
277
278 free(x);
279 if(prob_estimates != NULL)
280 free(prob_estimates);
281
282 switch(nlhs)
283 {
284 case 3:
285 plhs[2] = tplhs[2];
286 plhs[1] = tplhs[1];
287 case 1:
288 case 0:
289 plhs[0] = tplhs[0];
290 }
291 }
292
exit_with_help()293 void exit_with_help()
294 {
295 mexPrintf(
296 "Usage: [predicted_label, accuracy, decision_values/prob_estimates] = svmpredict(testing_label_vector, testing_instance_matrix, model, 'libsvm_options')\n"
297 " [predicted_label] = svmpredict(testing_label_vector, testing_instance_matrix, model, 'libsvm_options')\n"
298 "Parameters:\n"
299 " model: SVM model structure from svmtrain.\n"
300 " libsvm_options:\n"
301 " -b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yet\n"
302 " -q : quiet mode (no outputs)\n"
303 "Returns:\n"
304 " predicted_label: SVM prediction output vector.\n"
305 " accuracy: a vector with accuracy, mean squared error, squared correlation coefficient.\n"
306 " prob_estimates: If selected, probability estimate vector.\n"
307 );
308 }
309
mexFunction(int nlhs,mxArray * plhs[],int nrhs,const mxArray * prhs[])310 void mexFunction( int nlhs, mxArray *plhs[],
311 int nrhs, const mxArray *prhs[] )
312 {
313 int prob_estimate_flag = 0;
314 struct svm_model *model;
315 info = &mexPrintf;
316
317 if(nlhs == 2 || nlhs > 3 || nrhs > 4 || nrhs < 3)
318 {
319 exit_with_help();
320 fake_answer(nlhs, plhs);
321 return;
322 }
323
324 if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
325 mexPrintf("Error: label vector and instance matrix must be double\n");
326 fake_answer(nlhs, plhs);
327 return;
328 }
329
330 if(mxIsStruct(prhs[2]))
331 {
332 const char *error_msg;
333
334 // parse options
335 if(nrhs==4)
336 {
337 int i, argc = 1;
338 char cmd[CMD_LEN], *argv[CMD_LEN/2];
339
340 // put options in argv[]
341 mxGetString(prhs[3], cmd, mxGetN(prhs[3]) + 1);
342 if((argv[argc] = strtok(cmd, " ")) != NULL)
343 while((argv[++argc] = strtok(NULL, " ")) != NULL)
344 ;
345
346 for(i=1;i<argc;i++)
347 {
348 if(argv[i][0] != '-') break;
349 if((++i>=argc) && argv[i-1][1] != 'q')
350 {
351 exit_with_help();
352 fake_answer(nlhs, plhs);
353 return;
354 }
355 switch(argv[i-1][1])
356 {
357 case 'b':
358 prob_estimate_flag = atoi(argv[i]);
359 break;
360 case 'q':
361 i--;
362 info = &print_null;
363 break;
364 default:
365 mexPrintf("Unknown option: -%c\n", argv[i-1][1]);
366 exit_with_help();
367 fake_answer(nlhs, plhs);
368 return;
369 }
370 }
371 }
372
373 model = matlab_matrix_to_model(prhs[2], &error_msg);
374 if (model == NULL)
375 {
376 mexPrintf("Error: can't read model: %s\n", error_msg);
377 fake_answer(nlhs, plhs);
378 return;
379 }
380
381 if(prob_estimate_flag)
382 {
383 if(svm_check_probability_model(model)==0)
384 {
385 mexPrintf("Model does not support probabiliy estimates\n");
386 fake_answer(nlhs, plhs);
387 svm_free_and_destroy_model(&model);
388 return;
389 }
390 }
391 else
392 {
393 if(svm_check_probability_model(model)!=0)
394 info("Model supports probability estimates, but disabled in predicton.\n");
395 }
396
397 predict(nlhs, plhs, prhs, model, prob_estimate_flag);
398 // destroy model
399 svm_free_and_destroy_model(&model);
400 }
401 else
402 {
403 mexPrintf("model file should be a struct array\n");
404 fake_answer(nlhs, plhs);
405 }
406
407 return;
408 }
409