1 /*
2 
3 Copyright (c) 2007-2019 The LIBLINEAR Project.
4 All rights reserved.
5 
6 Redistribution and use in source and binary forms, with or without
7 modification, are permitted provided that the following conditions
8 are met:
9 
10 1. Redistributions of source code must retain the above copyright
11 notice, this list of conditions and the following disclaimer.
12 
13 2. Redistributions in binary form must reproduce the above copyright
14 notice, this list of conditions and the following disclaimer in the
15 documentation and/or other materials provided with the distribution.
16 
17 3. Neither name of copyright holders nor the names of its contributors
18 may be used to endorse or promote products derived from this software
19 without specific prior written permission.
20 
21 
22 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25 A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE REGENTS OR
26 CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
27 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
28 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
29 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
30 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
31 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
32 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 
34 
35 */
36 
37 #include <math.h>
38 #include <stdlib.h>
39 #include <string.h>
40 #include <ctype.h>
41 #include "linear.h"
42 
43 #include "mex.h"
44 #include "linear_model_matlab.h"
45 
46 #ifdef MX_API_VER
47 #if MX_API_VER < 0x07030000
48 typedef int mwIndex;
49 #endif
50 #endif
51 
52 #define CMD_LEN 2048
53 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
54 #define INF HUGE_VAL
55 
print_null(const char * s)56 void print_null(const char *s) {}
print_string_matlab(const char * s)57 void print_string_matlab(const char *s) {mexPrintf(s);}
58 
exit_with_help()59 void exit_with_help()
60 {
61 	mexPrintf(
62 	"Usage: model = train(training_label_vector, training_instance_matrix, 'liblinear_options', 'col');\n"
63 	"liblinear_options:\n"
64 	"-s type : set type of solver (default 1)\n"
65 	"  for multi-class classification\n"
66 	"	 0 -- L2-regularized logistic regression (primal)\n"
67 	"	 1 -- L2-regularized L2-loss support vector classification (dual)\n"
68 	"	 2 -- L2-regularized L2-loss support vector classification (primal)\n"
69 	"	 3 -- L2-regularized L1-loss support vector classification (dual)\n"
70 	"	 4 -- support vector classification by Crammer and Singer\n"
71 	"	 5 -- L1-regularized L2-loss support vector classification\n"
72 	"	 6 -- L1-regularized logistic regression\n"
73 	"	 7 -- L2-regularized logistic regression (dual)\n"
74 	"  for regression\n"
75 	"	11 -- L2-regularized L2-loss support vector regression (primal)\n"
76 	"	12 -- L2-regularized L2-loss support vector regression (dual)\n"
77 	"	13 -- L2-regularized L1-loss support vector regression (dual)\n"
78 	"-c cost : set the parameter C (default 1)\n"
79 	"-p epsilon : set the epsilon in loss function of SVR (default 0.1)\n"
80 	"-e epsilon : set tolerance of termination criterion\n"
81 	"	-s 0 and 2\n"
82 	"		|f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2,\n"
83 	"		where f is the primal function and pos/neg are # of\n"
84 	"		positive/negative data (default 0.01)\n"
85 	"	-s 11\n"
86 	"		|f'(w)|_2 <= eps*|f'(w0)|_2 (default 0.001)\n"
87 	"	-s 1, 3, 4 and 7\n"
88 	"		Dual maximal violation <= eps; similar to libsvm (default 0.1)\n"
89 	"	-s 5 and 6\n"
90 	"		|f'(w)|_1 <= eps*min(pos,neg)/l*|f'(w0)|_1,\n"
91 	"		where f is the primal function (default 0.01)\n"
92 	"	-s 12 and 13\n"
93 	"		|f'(alpha)|_1 <= eps |f'(alpha0)|,\n"
94 	"		where f is the dual function (default 0.1)\n"
95 	"-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default -1)\n"
96 	"-wi weight: weights adjust the parameter C of different classes (see README for details)\n"
97 	"-v n: n-fold cross validation mode\n"
98 	"-C : find parameter C (only for -s 0 and 2)\n"
99 	"-q : quiet mode (no outputs)\n"
100 	"col:\n"
101 	"	if 'col' is set, training_instance_matrix is parsed in column format, otherwise is in row format\n"
102 	);
103 }
104 
105 // liblinear arguments
106 struct parameter param;		// set by parse_command_line
107 struct problem prob;		// set by read_problem
108 struct model *model_;
109 struct feature_node *x_space;
110 int flag_cross_validation;
111 int flag_find_C;
112 int flag_C_specified;
113 int flag_solver_specified;
114 int col_format_flag;
115 int nr_fold;
116 double bias;
117 
118 
do_find_parameter_C(double * best_C,double * best_rate)119 void do_find_parameter_C(double *best_C, double *best_rate)
120 {
121 	double start_C;
122 	double max_C = 1024;
123 	if (flag_C_specified)
124 		start_C = param.C;
125 	else
126 		start_C = -1.0;
127 	find_parameter_C(&prob, &param, nr_fold, start_C, max_C, best_C, best_rate);
128 	mexPrintf("Best C = %lf  CV accuracy = %g%%\n", *best_C, 100.0**best_rate);
129 }
130 
131 
do_cross_validation()132 double do_cross_validation()
133 {
134 	int i;
135 	int total_correct = 0;
136 	double total_error = 0;
137 	double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
138 	double *target = Malloc(double, prob.l);
139 	double retval = 0.0;
140 
141 	cross_validation(&prob,&param,nr_fold,target);
142 	if(param.solver_type == L2R_L2LOSS_SVR ||
143 	   param.solver_type == L2R_L1LOSS_SVR_DUAL ||
144 	   param.solver_type == L2R_L2LOSS_SVR_DUAL)
145 	{
146 		for(i=0;i<prob.l;i++)
147                 {
148                         double y = prob.y[i];
149                         double v = target[i];
150                         total_error += (v-y)*(v-y);
151                         sumv += v;
152                         sumy += y;
153                         sumvv += v*v;
154                         sumyy += y*y;
155                         sumvy += v*y;
156                 }
157                 mexPrintf("Cross Validation Mean squared error = %g\n",total_error/prob.l);
158                 mexPrintf("Cross Validation Squared correlation coefficient = %g\n",
159                         ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/
160                         ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))
161                         );
162 		retval = total_error/prob.l;
163 	}
164 	else
165 	{
166 		for(i=0;i<prob.l;i++)
167 			if(target[i] == prob.y[i])
168 				++total_correct;
169 		mexPrintf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l);
170 		retval = 100.0*total_correct/prob.l;
171 	}
172 
173 	free(target);
174 	return retval;
175 }
176 
177 // nrhs should be 3
parse_command_line(int nrhs,const mxArray * prhs[],char * model_file_name)178 int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name)
179 {
180 	int i, argc = 1;
181 	char cmd[CMD_LEN];
182 	char *argv[CMD_LEN/2];
183 	void (*print_func)(const char *) = print_string_matlab;	// default printing to matlab display
184 
185 	// default values
186 	param.solver_type = L2R_L2LOSS_SVC_DUAL;
187 	param.C = 1;
188 	param.eps = INF; // see setting below
189 	param.p = 0.1;
190 	param.nr_weight = 0;
191 	param.weight_label = NULL;
192 	param.weight = NULL;
193 	param.init_sol = NULL;
194 	flag_cross_validation = 0;
195 	col_format_flag = 0;
196 	flag_C_specified = 0;
197 	flag_solver_specified = 0;
198 	flag_find_C = 0;
199 	bias = -1;
200 
201 
202 	if(nrhs <= 1)
203 		return 1;
204 
205 	if(nrhs == 4)
206 	{
207 		mxGetString(prhs[3], cmd, mxGetN(prhs[3])+1);
208 		if(strcmp(cmd, "col") == 0)
209 			col_format_flag = 1;
210 	}
211 
212 	// put options in argv[]
213 	if(nrhs > 2)
214 	{
215 		mxGetString(prhs[2], cmd,  mxGetN(prhs[2]) + 1);
216 		if((argv[argc] = strtok(cmd, " ")) != NULL)
217 			while((argv[++argc] = strtok(NULL, " ")) != NULL)
218 				;
219 	}
220 
221 	// parse options
222 	for(i=1;i<argc;i++)
223 	{
224 		if(argv[i][0] != '-') break;
225 		++i;
226 		if(i>=argc && argv[i-1][1] != 'q' && argv[i-1][1] != 'C') // since options -q and -C have no parameter
227 			return 1;
228 		switch(argv[i-1][1])
229 		{
230 			case 's':
231 				param.solver_type = atoi(argv[i]);
232 				flag_solver_specified = 1;
233 				break;
234 			case 'c':
235 				param.C = atof(argv[i]);
236 				flag_C_specified = 1;
237 				break;
238 			case 'p':
239 				param.p = atof(argv[i]);
240 				break;
241 			case 'e':
242 				param.eps = atof(argv[i]);
243 				break;
244 			case 'B':
245 				bias = atof(argv[i]);
246 				break;
247 			case 'v':
248 				flag_cross_validation = 1;
249 				nr_fold = atoi(argv[i]);
250 				if(nr_fold < 2)
251 				{
252 					mexPrintf("n-fold cross validation: n must >= 2\n");
253 					return 1;
254 				}
255 				break;
256 			case 'w':
257 				++param.nr_weight;
258 				param.weight_label = (int *) realloc(param.weight_label,sizeof(int)*param.nr_weight);
259 				param.weight = (double *) realloc(param.weight,sizeof(double)*param.nr_weight);
260 				param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]);
261 				param.weight[param.nr_weight-1] = atof(argv[i]);
262 				break;
263 			case 'q':
264 				print_func = &print_null;
265 				i--;
266 				break;
267 			case 'C':
268 				flag_find_C = 1;
269 				i--;
270 				break;
271 			default:
272 				mexPrintf("unknown option\n");
273 				return 1;
274 		}
275 	}
276 
277 	set_print_string_function(print_func);
278 
279 	// default solver for parameter selection is L2R_L2LOSS_SVC
280 	if(flag_find_C)
281 	{
282 		if(!flag_cross_validation)
283 			nr_fold = 5;
284 		if(!flag_solver_specified)
285 		{
286 			mexPrintf("Solver not specified. Using -s 2\n");
287 			param.solver_type = L2R_L2LOSS_SVC;
288 		}
289 		else if(param.solver_type != L2R_LR && param.solver_type != L2R_L2LOSS_SVC)
290 		{
291 			mexPrintf("Warm-start parameter search only available for -s 0 and -s 2\n");
292 			return 1;
293 		}
294 	}
295 
296 	if(param.eps == INF)
297 	{
298 		switch(param.solver_type)
299 		{
300 			case L2R_LR:
301 			case L2R_L2LOSS_SVC:
302 				param.eps = 0.01;
303 				break;
304 			case L2R_L2LOSS_SVR:
305 				param.eps = 0.001;
306 				break;
307 			case L2R_L2LOSS_SVC_DUAL:
308 			case L2R_L1LOSS_SVC_DUAL:
309 			case MCSVM_CS:
310 			case L2R_LR_DUAL:
311 				param.eps = 0.1;
312 				break;
313 			case L1R_L2LOSS_SVC:
314 			case L1R_LR:
315 				param.eps = 0.01;
316 				break;
317 			case L2R_L1LOSS_SVR_DUAL:
318 			case L2R_L2LOSS_SVR_DUAL:
319 				param.eps = 0.1;
320 				break;
321 		}
322 	}
323 	return 0;
324 }
325 
fake_answer(int nlhs,mxArray * plhs[])326 static void fake_answer(int nlhs, mxArray *plhs[])
327 {
328 	int i;
329 	for(i=0;i<nlhs;i++)
330 		plhs[i] = mxCreateDoubleMatrix(0, 0, mxREAL);
331 }
332 
read_problem_sparse(const mxArray * label_vec,const mxArray * instance_mat)333 int read_problem_sparse(const mxArray *label_vec, const mxArray *instance_mat)
334 {
335 	mwIndex *ir, *jc, low, high, k;
336 	// using size_t due to the output type of matlab functions
337 	size_t i, j, l, elements, max_index, label_vector_row_num;
338 	mwSize num_samples;
339 	double *samples, *labels;
340 	mxArray *instance_mat_col; // instance sparse matrix in column format
341 
342 	prob.x = NULL;
343 	prob.y = NULL;
344 	x_space = NULL;
345 
346 	if(col_format_flag)
347 		instance_mat_col = (mxArray *)instance_mat;
348 	else
349 	{
350 		// transpose instance matrix
351 		mxArray *prhs[1], *plhs[1];
352 		prhs[0] = mxDuplicateArray(instance_mat);
353 		if(mexCallMATLAB(1, plhs, 1, prhs, "transpose"))
354 		{
355 			mexPrintf("Error: cannot transpose training instance matrix\n");
356 			return -1;
357 		}
358 		instance_mat_col = plhs[0];
359 		mxDestroyArray(prhs[0]);
360 	}
361 
362 	// the number of instance
363 	l = mxGetN(instance_mat_col);
364 	label_vector_row_num = mxGetM(label_vec);
365 	prob.l = (int) l;
366 
367 	if(label_vector_row_num!=l)
368 	{
369 		mexPrintf("Length of label vector does not match # of instances.\n");
370 		return -1;
371 	}
372 
373 	// each column is one instance
374 	labels = mxGetPr(label_vec);
375 	samples = mxGetPr(instance_mat_col);
376 	ir = mxGetIr(instance_mat_col);
377 	jc = mxGetJc(instance_mat_col);
378 
379 	num_samples = mxGetNzmax(instance_mat_col);
380 
381 	elements = num_samples + l*2;
382 	max_index = mxGetM(instance_mat_col);
383 
384 	prob.y = Malloc(double, l);
385 	prob.x = Malloc(struct feature_node*, l);
386 	x_space = Malloc(struct feature_node, elements);
387 
388 	prob.bias=bias;
389 
390 	j = 0;
391 	for(i=0;i<l;i++)
392 	{
393 		prob.x[i] = &x_space[j];
394 		prob.y[i] = labels[i];
395 		low = jc[i], high = jc[i+1];
396 		for(k=low;k<high;k++)
397 		{
398 			x_space[j].index = (int) ir[k]+1;
399 			x_space[j].value = samples[k];
400 			j++;
401 	 	}
402 		if(prob.bias>=0)
403 		{
404 			x_space[j].index = (int) max_index+1;
405 			x_space[j].value = prob.bias;
406 			j++;
407 		}
408 		x_space[j++].index = -1;
409 	}
410 
411 	if(prob.bias>=0)
412 		prob.n = (int) max_index+1;
413 	else
414 		prob.n = (int) max_index;
415 
416 	return 0;
417 }
418 
419 // Interface function of matlab
420 // now assume prhs[0]: label prhs[1]: features
mexFunction(int nlhs,mxArray * plhs[],int nrhs,const mxArray * prhs[])421 void mexFunction( int nlhs, mxArray *plhs[],
422 		int nrhs, const mxArray *prhs[] )
423 {
424 	const char *error_msg;
425 	// fix random seed to have same results for each run
426 	// (for cross validation)
427 	srand(1);
428 
429 	if(nlhs > 1)
430 	{
431 		exit_with_help();
432 		fake_answer(nlhs, plhs);
433 		return;
434 	}
435 
436 	// Transform the input Matrix to libsvm format
437 	if(nrhs > 1 && nrhs < 5)
438 	{
439 		int err=0;
440 
441 		if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1]))
442 		{
443 			mexPrintf("Error: label vector and instance matrix must be double\n");
444 			fake_answer(nlhs, plhs);
445 			return;
446 		}
447 
448 		if(mxIsSparse(prhs[0]))
449 		{
450 			mexPrintf("Error: label vector should not be in sparse format");
451 			fake_answer(nlhs, plhs);
452 			return;
453 		}
454 
455 		if(parse_command_line(nrhs, prhs, NULL))
456 		{
457 			exit_with_help();
458 			destroy_param(&param);
459 			fake_answer(nlhs, plhs);
460 			return;
461 		}
462 
463 		if(mxIsSparse(prhs[1]))
464 			err = read_problem_sparse(prhs[0], prhs[1]);
465 		else
466 		{
467 			mexPrintf("Training_instance_matrix must be sparse; "
468 				"use sparse(Training_instance_matrix) first\n");
469 			destroy_param(&param);
470 			fake_answer(nlhs, plhs);
471 			return;
472 		}
473 
474 		// train's original code
475 		error_msg = check_parameter(&prob, &param);
476 
477 		if(err || error_msg)
478 		{
479 			if (error_msg != NULL)
480 				mexPrintf("Error: %s\n", error_msg);
481 			destroy_param(&param);
482 			free(prob.y);
483 			free(prob.x);
484 			free(x_space);
485 			fake_answer(nlhs, plhs);
486 			return;
487 		}
488 
489 		if (flag_find_C)
490 		{
491 			double best_C, best_rate, *ptr;
492 
493 			do_find_parameter_C(&best_C, &best_rate);
494 
495 			plhs[0] = mxCreateDoubleMatrix(2, 1, mxREAL);
496 			ptr = mxGetPr(plhs[0]);
497 			ptr[0] = best_C;
498 			ptr[1] = best_rate;
499 		}
500 		else if(flag_cross_validation)
501 		{
502 			double *ptr;
503 			plhs[0] = mxCreateDoubleMatrix(1, 1, mxREAL);
504 			ptr = mxGetPr(plhs[0]);
505 			ptr[0] = do_cross_validation();
506 		}
507 		else
508 		{
509 			const char *error_msg;
510 
511 			model_ = train(&prob, &param);
512 			error_msg = model_to_matlab_structure(plhs, model_);
513 			if(error_msg)
514 				mexPrintf("Error: can't convert libsvm model to matrix structure: %s\n", error_msg);
515 			free_and_destroy_model(&model_);
516 		}
517 		destroy_param(&param);
518 		free(prob.y);
519 		free(prob.x);
520 		free(x_space);
521 	}
522 	else
523 	{
524 		exit_with_help();
525 		fake_answer(nlhs, plhs);
526 		return;
527 	}
528 }
529