1 /*
2   This code was extracted from libsvm 3.2.3 in Feb 2019 and
3   modified for the use with Octave and Matlab
4 
5 
6 Copyright (c) 2000-2019 Chih-Chung Chang and Chih-Jen Lin
7 All rights reserved.
8 
9 Redistribution and use in source and binary forms, with or without
10 modification, are permitted provided that the following conditions
11 are met:
12 
13 1. Redistributions of source code must retain the above copyright
14 notice, this list of conditions and the following disclaimer.
15 
16 2. Redistributions in binary form must reproduce the above copyright
17 notice, this list of conditions and the following disclaimer in the
18 documentation and/or other materials provided with the distribution.
19 
20 3. Neither name of copyright holders nor the names of its contributors
21 may be used to endorse or promote products derived from this software
22 without specific prior written permission.
23 
24 
25 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
26 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
27 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
28 A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE REGENTS OR
29 CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
30 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
31 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
32 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
33 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
34 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
35 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36 
37 */
38 
39 #include <stdio.h>
40 #include <stdlib.h>
41 #include <string.h>
42 #include "svm.h"
43 
44 #include "mex.h"
45 #include "svm_model_matlab.h"
46 
47 #ifdef MX_API_VER
48 #if MX_API_VER < 0x07030000
49 typedef int mwIndex;
50 #endif
51 #endif
52 
53 #define CMD_LEN 2048
54 
print_null(const char * s,...)55 int print_null(const char *s,...) {return 0;}
56 int (*info)(const char *fmt,...) = &mexPrintf;
57 
read_sparse_instance(const mxArray * prhs,int index,struct svm_node * x)58 void read_sparse_instance(const mxArray *prhs, int index, struct svm_node *x)
59 {
60 	int i, j, low, high;
61 	mwIndex *ir, *jc;
62 	double *samples;
63 
64 	ir = mxGetIr(prhs);
65 	jc = mxGetJc(prhs);
66 	samples = mxGetPr(prhs);
67 
68 	// each column is one instance
69 	j = 0;
70 	low = (int)jc[index], high = (int)jc[index+1];
71 	for(i=low;i<high;i++)
72 	{
73 		x[j].index = (int)ir[i] + 1;
74 		x[j].value = samples[i];
75 		j++;
76 	}
77 	x[j].index = -1;
78 }
79 
fake_answer(int nlhs,mxArray * plhs[])80 static void fake_answer(int nlhs, mxArray *plhs[])
81 {
82 	int i;
83 	for(i=0;i<nlhs;i++)
84 		plhs[i] = mxCreateDoubleMatrix(0, 0, mxREAL);
85 }
86 
predict(int nlhs,mxArray * plhs[],const mxArray * prhs[],struct svm_model * model,const int predict_probability)87 void predict(int nlhs, mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, const int predict_probability)
88 {
89 	int label_vector_row_num, label_vector_col_num;
90 	int feature_number, testing_instance_number;
91 	int instance_index;
92 	double *ptr_instance, *ptr_label, *ptr_predict_label;
93 	double *ptr_prob_estimates, *ptr_dec_values, *ptr;
94 	struct svm_node *x;
95 	mxArray *pplhs[1]; // transposed instance sparse matrix
96 	mxArray *tplhs[3]; // temporary storage for plhs[]
97 
98 	int correct = 0;
99 	int total = 0;
100 	double error = 0;
101 	double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0;
102 
103 	int svm_type=svm_get_svm_type(model);
104 	int nr_class=svm_get_nr_class(model);
105 	double *prob_estimates=NULL;
106 
107 	// prhs[1] = testing instance matrix
108 	feature_number = (int)mxGetN(prhs[1]);
109 	testing_instance_number = (int)mxGetM(prhs[1]);
110 	label_vector_row_num = (int)mxGetM(prhs[0]);
111 	label_vector_col_num = (int)mxGetN(prhs[0]);
112 
113 	if(label_vector_row_num!=testing_instance_number)
114 	{
115 		mexPrintf("Length of label vector does not match # of instances.\n");
116 		fake_answer(nlhs, plhs);
117 		return;
118 	}
119 	if(label_vector_col_num!=1)
120 	{
121 		mexPrintf("label (1st argument) should be a vector (# of column is 1).\n");
122 		fake_answer(nlhs, plhs);
123 		return;
124 	}
125 
126 	ptr_instance = mxGetPr(prhs[1]);
127 	ptr_label    = mxGetPr(prhs[0]);
128 
129 	// transpose instance matrix
130 	if(mxIsSparse(prhs[1]))
131 	{
132 		if(model->param.kernel_type == PRECOMPUTED)
133 		{
134 			// precomputed kernel requires dense matrix, so we make one
135 			mxArray *rhs[1], *lhs[1];
136 			rhs[0] = mxDuplicateArray(prhs[1]);
137 			if(mexCallMATLAB(1, lhs, 1, rhs, "full"))
138 			{
139 				mexPrintf("Error: cannot full testing instance matrix\n");
140 				fake_answer(nlhs, plhs);
141 				return;
142 			}
143 			ptr_instance = mxGetPr(lhs[0]);
144 			mxDestroyArray(rhs[0]);
145 		}
146 		else
147 		{
148 			mxArray *pprhs[1];
149 			pprhs[0] = mxDuplicateArray(prhs[1]);
150 			if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
151 			{
152 				mexPrintf("Error: cannot transpose testing instance matrix\n");
153 				fake_answer(nlhs, plhs);
154 				return;
155 			}
156 		}
157 	}
158 
159 	if(predict_probability)
160 	{
161 		if(svm_type==NU_SVR || svm_type==EPSILON_SVR)
162 			info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n",svm_get_svr_probability(model));
163 		else
164 			prob_estimates = (double *) malloc(nr_class*sizeof(double));
165 	}
166 
167 	tplhs[0] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
168 	if(predict_probability)
169 	{
170 		// prob estimates are in plhs[2]
171 		if(svm_type==C_SVC || svm_type==NU_SVC)
172 			tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class, mxREAL);
173 		else
174 			tplhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
175 	}
176 	else
177 	{
178 		// decision values are in plhs[2]
179 		if(svm_type == ONE_CLASS ||
180 		   svm_type == EPSILON_SVR ||
181 		   svm_type == NU_SVR ||
182 		   nr_class == 1) // if only one class in training data, decision values are still returned.
183 			tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
184 		else
185 			tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class*(nr_class-1)/2, mxREAL);
186 	}
187 
188 	ptr_predict_label = mxGetPr(tplhs[0]);
189 	ptr_prob_estimates = mxGetPr(tplhs[2]);
190 	ptr_dec_values = mxGetPr(tplhs[2]);
191 	x = (struct svm_node*)malloc((feature_number+1)*sizeof(struct svm_node) );
192 	for(instance_index=0;instance_index<testing_instance_number;instance_index++)
193 	{
194 		int i;
195 		double target_label, predict_label;
196 
197 		target_label = ptr_label[instance_index];
198 
199 		if(mxIsSparse(prhs[1]) && model->param.kernel_type != PRECOMPUTED) // prhs[1]^T is still sparse
200 			read_sparse_instance(pplhs[0], instance_index, x);
201 		else
202 		{
203 			for(i=0;i<feature_number;i++)
204 			{
205 				x[i].index = i+1;
206 				x[i].value = ptr_instance[testing_instance_number*i+instance_index];
207 			}
208 			x[feature_number].index = -1;
209 		}
210 
211 		if(predict_probability)
212 		{
213 			if(svm_type==C_SVC || svm_type==NU_SVC)
214 			{
215 				predict_label = svm_predict_probability(model, x, prob_estimates);
216 				ptr_predict_label[instance_index] = predict_label;
217 				for(i=0;i<nr_class;i++)
218 					ptr_prob_estimates[instance_index + i * testing_instance_number] = prob_estimates[i];
219 			} else {
220 				predict_label = svm_predict(model,x);
221 				ptr_predict_label[instance_index] = predict_label;
222 			}
223 		}
224 		else
225 		{
226 			if(svm_type == ONE_CLASS ||
227 			   svm_type == EPSILON_SVR ||
228 			   svm_type == NU_SVR)
229 			{
230 				double res;
231 				predict_label = svm_predict_values(model, x, &res);
232 				ptr_dec_values[instance_index] = res;
233 			}
234 			else
235 			{
236 				double *dec_values = (double *) malloc(sizeof(double) * nr_class*(nr_class-1)/2);
237 				predict_label = svm_predict_values(model, x, dec_values);
238 				if(nr_class == 1)
239 					ptr_dec_values[instance_index] = 1;
240 				else
241 					for(i=0;i<(nr_class*(nr_class-1))/2;i++)
242 						ptr_dec_values[instance_index + i * testing_instance_number] = dec_values[i];
243 				free(dec_values);
244 			}
245 			ptr_predict_label[instance_index] = predict_label;
246 		}
247 
248 		if(predict_label == target_label)
249 			++correct;
250 		error += (predict_label-target_label)*(predict_label-target_label);
251 		sump += predict_label;
252 		sumt += target_label;
253 		sumpp += predict_label*predict_label;
254 		sumtt += target_label*target_label;
255 		sumpt += predict_label*target_label;
256 		++total;
257 	}
258 	if(svm_type==NU_SVR || svm_type==EPSILON_SVR)
259 	{
260 		info("Mean squared error = %g (regression)\n",error/total);
261 		info("Squared correlation coefficient = %g (regression)\n",
262 			((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
263 			((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt))
264 			);
265 	}
266 	else
267 		info("Accuracy = %g%% (%d/%d) (classification)\n",
268 			(double)correct/total*100,correct,total);
269 
270 	// return accuracy, mean squared error, squared correlation coefficient
271 	tplhs[1] = mxCreateDoubleMatrix(3, 1, mxREAL);
272 	ptr = mxGetPr(tplhs[1]);
273 	ptr[0] = (double)correct/total*100;
274 	ptr[1] = error/total;
275 	ptr[2] = ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
276 				((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt));
277 
278 	free(x);
279 	if(prob_estimates != NULL)
280 		free(prob_estimates);
281 
282 	switch(nlhs)
283 	{
284 		case 3:
285 			plhs[2] = tplhs[2];
286 			plhs[1] = tplhs[1];
287 		case 1:
288 		case 0:
289 			plhs[0] = tplhs[0];
290 	}
291 }
292 
exit_with_help()293 void exit_with_help()
294 {
295 	mexPrintf(
296 		"Usage: [predicted_label, accuracy, decision_values/prob_estimates] = svmpredict(testing_label_vector, testing_instance_matrix, model, 'libsvm_options')\n"
297 		"       [predicted_label] = svmpredict(testing_label_vector, testing_instance_matrix, model, 'libsvm_options')\n"
298 		"Parameters:\n"
299 		"  model: SVM model structure from svmtrain.\n"
300 		"  libsvm_options:\n"
301 		"    -b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yet\n"
302 		"    -q : quiet mode (no outputs)\n"
303 		"Returns:\n"
304 		"  predicted_label: SVM prediction output vector.\n"
305 		"  accuracy: a vector with accuracy, mean squared error, squared correlation coefficient.\n"
306 		"  prob_estimates: If selected, probability estimate vector.\n"
307 	);
308 }
309 
mexFunction(int nlhs,mxArray * plhs[],int nrhs,const mxArray * prhs[])310 void mexFunction( int nlhs, mxArray *plhs[],
311 		 int nrhs, const mxArray *prhs[] )
312 {
313 	int prob_estimate_flag = 0;
314 	struct svm_model *model;
315 	info = &mexPrintf;
316 
317 	if(nlhs == 2 || nlhs > 3 || nrhs > 4 || nrhs < 3)
318 	{
319 		exit_with_help();
320 		fake_answer(nlhs, plhs);
321 		return;
322 	}
323 
324 	if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
325 		mexPrintf("Error: label vector and instance matrix must be double\n");
326 		fake_answer(nlhs, plhs);
327 		return;
328 	}
329 
330 	if(mxIsStruct(prhs[2]))
331 	{
332 		const char *error_msg;
333 
334 		// parse options
335 		if(nrhs==4)
336 		{
337 			int i, argc = 1;
338 			char cmd[CMD_LEN], *argv[CMD_LEN/2];
339 
340 			// put options in argv[]
341 			mxGetString(prhs[3], cmd,  mxGetN(prhs[3]) + 1);
342 			if((argv[argc] = strtok(cmd, " ")) != NULL)
343 				while((argv[++argc] = strtok(NULL, " ")) != NULL)
344 					;
345 
346 			for(i=1;i<argc;i++)
347 			{
348 				if(argv[i][0] != '-') break;
349 				if((++i>=argc) && argv[i-1][1] != 'q')
350 				{
351 					exit_with_help();
352 					fake_answer(nlhs, plhs);
353 					return;
354 				}
355 				switch(argv[i-1][1])
356 				{
357 					case 'b':
358 						prob_estimate_flag = atoi(argv[i]);
359 						break;
360 					case 'q':
361 						i--;
362 						info = &print_null;
363 						break;
364 					default:
365 						mexPrintf("Unknown option: -%c\n", argv[i-1][1]);
366 						exit_with_help();
367 						fake_answer(nlhs, plhs);
368 						return;
369 				}
370 			}
371 		}
372 
373 		model = matlab_matrix_to_model(prhs[2], &error_msg);
374 		if (model == NULL)
375 		{
376 			mexPrintf("Error: can't read model: %s\n", error_msg);
377 			fake_answer(nlhs, plhs);
378 			return;
379 		}
380 
381 		if(prob_estimate_flag)
382 		{
383 			if(svm_check_probability_model(model)==0)
384 			{
385 				mexPrintf("Model does not support probabiliy estimates\n");
386 				fake_answer(nlhs, plhs);
387 				svm_free_and_destroy_model(&model);
388 				return;
389 			}
390 		}
391 		else
392 		{
393 			if(svm_check_probability_model(model)!=0)
394 				info("Model supports probability estimates, but disabled in predicton.\n");
395 		}
396 
397 		predict(nlhs, plhs, prhs, model, prob_estimate_flag);
398 		// destroy model
399 		svm_free_and_destroy_model(&model);
400 	}
401 	else
402 	{
403 		mexPrintf("model file should be a struct array\n");
404 		fake_answer(nlhs, plhs);
405 	}
406 
407 	return;
408 }
409