1 /*
2
3 Copyright (c) 2007-2019 The LIBLINEAR Project.
4 All rights reserved.
5
6 Redistribution and use in source and binary forms, with or without
7 modification, are permitted provided that the following conditions
8 are met:
9
10 1. Redistributions of source code must retain the above copyright
11 notice, this list of conditions and the following disclaimer.
12
13 2. Redistributions in binary form must reproduce the above copyright
14 notice, this list of conditions and the following disclaimer in the
15 documentation and/or other materials provided with the distribution.
16
17 3. Neither name of copyright holders nor the names of its contributors
18 may be used to endorse or promote products derived from this software
19 without specific prior written permission.
20
21
22 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR
26 CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
27 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
28 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
29 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
30 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
31 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
32 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
34
35 */
36
37 #include <math.h>
38 #include <stdlib.h>
39 #include <string.h>
40 #include <ctype.h>
41 #include "linear.h"
42
43 #include "mex.h"
44 #include "linear_model_matlab.h"
45
46 #ifdef MX_API_VER
47 #if MX_API_VER < 0x07030000
48 typedef int mwIndex;
49 #endif
50 #endif
51
52 #define CMD_LEN 2048
53 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
54 #define INF HUGE_VAL
55
print_null(const char * s)56 void print_null(const char *s) {}
print_string_matlab(const char * s)57 void print_string_matlab(const char *s) {mexPrintf(s);}
58
exit_with_help()59 void exit_with_help()
60 {
61 mexPrintf(
62 "Usage: model = train(training_label_vector, training_instance_matrix, 'liblinear_options', 'col');\n"
63 "liblinear_options:\n"
64 "-s type : set type of solver (default 1)\n"
65 " for multi-class classification\n"
66 " 0 -- L2-regularized logistic regression (primal)\n"
67 " 1 -- L2-regularized L2-loss support vector classification (dual)\n"
68 " 2 -- L2-regularized L2-loss support vector classification (primal)\n"
69 " 3 -- L2-regularized L1-loss support vector classification (dual)\n"
70 " 4 -- support vector classification by Crammer and Singer\n"
71 " 5 -- L1-regularized L2-loss support vector classification\n"
72 " 6 -- L1-regularized logistic regression\n"
73 " 7 -- L2-regularized logistic regression (dual)\n"
74 " for regression\n"
75 " 11 -- L2-regularized L2-loss support vector regression (primal)\n"
76 " 12 -- L2-regularized L2-loss support vector regression (dual)\n"
77 " 13 -- L2-regularized L1-loss support vector regression (dual)\n"
78 "-c cost : set the parameter C (default 1)\n"
79 "-p epsilon : set the epsilon in loss function of SVR (default 0.1)\n"
80 "-e epsilon : set tolerance of termination criterion\n"
81 " -s 0 and 2\n"
82 " |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2,\n"
83 " where f is the primal function and pos/neg are # of\n"
84 " positive/negative data (default 0.01)\n"
85 " -s 11\n"
86 " |f'(w)|_2 <= eps*|f'(w0)|_2 (default 0.001)\n"
87 " -s 1, 3, 4 and 7\n"
88 " Dual maximal violation <= eps; similar to libsvm (default 0.1)\n"
89 " -s 5 and 6\n"
90 " |f'(w)|_1 <= eps*min(pos,neg)/l*|f'(w0)|_1,\n"
91 " where f is the primal function (default 0.01)\n"
92 " -s 12 and 13\n"
93 " |f'(alpha)|_1 <= eps |f'(alpha0)|,\n"
94 " where f is the dual function (default 0.1)\n"
95 "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default -1)\n"
96 "-wi weight: weights adjust the parameter C of different classes (see README for details)\n"
97 "-v n: n-fold cross validation mode\n"
98 "-C : find parameter C (only for -s 0 and 2)\n"
99 "-q : quiet mode (no outputs)\n"
100 "col:\n"
101 " if 'col' is set, training_instance_matrix is parsed in column format, otherwise is in row format\n"
102 );
103 }
104
105 // liblinear arguments
106 struct parameter param; // set by parse_command_line
107 struct problem prob; // set by read_problem
108 struct model *model_;
109 struct feature_node *x_space;
110 int flag_cross_validation;
111 int flag_find_C;
112 int flag_C_specified;
113 int flag_solver_specified;
114 int col_format_flag;
115 int nr_fold;
116 double bias;
117
118
do_find_parameter_C(double * best_C,double * best_rate)119 void do_find_parameter_C(double *best_C, double *best_rate)
120 {
121 double start_C;
122 double max_C = 1024;
123 if (flag_C_specified)
124 start_C = param.C;
125 else
126 start_C = -1.0;
127 find_parameter_C(&prob, ¶m, nr_fold, start_C, max_C, best_C, best_rate);
128 mexPrintf("Best C = %lf CV accuracy = %g%%\n", *best_C, 100.0**best_rate);
129 }
130
131
do_cross_validation()132 double do_cross_validation()
133 {
134 int i;
135 int total_correct = 0;
136 double total_error = 0;
137 double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
138 double *target = Malloc(double, prob.l);
139 double retval = 0.0;
140
141 cross_validation(&prob,¶m,nr_fold,target);
142 if(param.solver_type == L2R_L2LOSS_SVR ||
143 param.solver_type == L2R_L1LOSS_SVR_DUAL ||
144 param.solver_type == L2R_L2LOSS_SVR_DUAL)
145 {
146 for(i=0;i<prob.l;i++)
147 {
148 double y = prob.y[i];
149 double v = target[i];
150 total_error += (v-y)*(v-y);
151 sumv += v;
152 sumy += y;
153 sumvv += v*v;
154 sumyy += y*y;
155 sumvy += v*y;
156 }
157 mexPrintf("Cross Validation Mean squared error = %g\n",total_error/prob.l);
158 mexPrintf("Cross Validation Squared correlation coefficient = %g\n",
159 ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/
160 ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))
161 );
162 retval = total_error/prob.l;
163 }
164 else
165 {
166 for(i=0;i<prob.l;i++)
167 if(target[i] == prob.y[i])
168 ++total_correct;
169 mexPrintf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l);
170 retval = 100.0*total_correct/prob.l;
171 }
172
173 free(target);
174 return retval;
175 }
176
177 // nrhs should be 3
parse_command_line(int nrhs,const mxArray * prhs[],char * model_file_name)178 int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name)
179 {
180 int i, argc = 1;
181 char cmd[CMD_LEN];
182 char *argv[CMD_LEN/2];
183 void (*print_func)(const char *) = print_string_matlab; // default printing to matlab display
184
185 // default values
186 param.solver_type = L2R_L2LOSS_SVC_DUAL;
187 param.C = 1;
188 param.eps = INF; // see setting below
189 param.p = 0.1;
190 param.nr_weight = 0;
191 param.weight_label = NULL;
192 param.weight = NULL;
193 param.init_sol = NULL;
194 flag_cross_validation = 0;
195 col_format_flag = 0;
196 flag_C_specified = 0;
197 flag_solver_specified = 0;
198 flag_find_C = 0;
199 bias = -1;
200
201
202 if(nrhs <= 1)
203 return 1;
204
205 if(nrhs == 4)
206 {
207 mxGetString(prhs[3], cmd, mxGetN(prhs[3])+1);
208 if(strcmp(cmd, "col") == 0)
209 col_format_flag = 1;
210 }
211
212 // put options in argv[]
213 if(nrhs > 2)
214 {
215 mxGetString(prhs[2], cmd, mxGetN(prhs[2]) + 1);
216 if((argv[argc] = strtok(cmd, " ")) != NULL)
217 while((argv[++argc] = strtok(NULL, " ")) != NULL)
218 ;
219 }
220
221 // parse options
222 for(i=1;i<argc;i++)
223 {
224 if(argv[i][0] != '-') break;
225 ++i;
226 if(i>=argc && argv[i-1][1] != 'q' && argv[i-1][1] != 'C') // since options -q and -C have no parameter
227 return 1;
228 switch(argv[i-1][1])
229 {
230 case 's':
231 param.solver_type = atoi(argv[i]);
232 flag_solver_specified = 1;
233 break;
234 case 'c':
235 param.C = atof(argv[i]);
236 flag_C_specified = 1;
237 break;
238 case 'p':
239 param.p = atof(argv[i]);
240 break;
241 case 'e':
242 param.eps = atof(argv[i]);
243 break;
244 case 'B':
245 bias = atof(argv[i]);
246 break;
247 case 'v':
248 flag_cross_validation = 1;
249 nr_fold = atoi(argv[i]);
250 if(nr_fold < 2)
251 {
252 mexPrintf("n-fold cross validation: n must >= 2\n");
253 return 1;
254 }
255 break;
256 case 'w':
257 ++param.nr_weight;
258 param.weight_label = (int *) realloc(param.weight_label,sizeof(int)*param.nr_weight);
259 param.weight = (double *) realloc(param.weight,sizeof(double)*param.nr_weight);
260 param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]);
261 param.weight[param.nr_weight-1] = atof(argv[i]);
262 break;
263 case 'q':
264 print_func = &print_null;
265 i--;
266 break;
267 case 'C':
268 flag_find_C = 1;
269 i--;
270 break;
271 default:
272 mexPrintf("unknown option\n");
273 return 1;
274 }
275 }
276
277 set_print_string_function(print_func);
278
279 // default solver for parameter selection is L2R_L2LOSS_SVC
280 if(flag_find_C)
281 {
282 if(!flag_cross_validation)
283 nr_fold = 5;
284 if(!flag_solver_specified)
285 {
286 mexPrintf("Solver not specified. Using -s 2\n");
287 param.solver_type = L2R_L2LOSS_SVC;
288 }
289 else if(param.solver_type != L2R_LR && param.solver_type != L2R_L2LOSS_SVC)
290 {
291 mexPrintf("Warm-start parameter search only available for -s 0 and -s 2\n");
292 return 1;
293 }
294 }
295
296 if(param.eps == INF)
297 {
298 switch(param.solver_type)
299 {
300 case L2R_LR:
301 case L2R_L2LOSS_SVC:
302 param.eps = 0.01;
303 break;
304 case L2R_L2LOSS_SVR:
305 param.eps = 0.001;
306 break;
307 case L2R_L2LOSS_SVC_DUAL:
308 case L2R_L1LOSS_SVC_DUAL:
309 case MCSVM_CS:
310 case L2R_LR_DUAL:
311 param.eps = 0.1;
312 break;
313 case L1R_L2LOSS_SVC:
314 case L1R_LR:
315 param.eps = 0.01;
316 break;
317 case L2R_L1LOSS_SVR_DUAL:
318 case L2R_L2LOSS_SVR_DUAL:
319 param.eps = 0.1;
320 break;
321 }
322 }
323 return 0;
324 }
325
fake_answer(int nlhs,mxArray * plhs[])326 static void fake_answer(int nlhs, mxArray *plhs[])
327 {
328 int i;
329 for(i=0;i<nlhs;i++)
330 plhs[i] = mxCreateDoubleMatrix(0, 0, mxREAL);
331 }
332
read_problem_sparse(const mxArray * label_vec,const mxArray * instance_mat)333 int read_problem_sparse(const mxArray *label_vec, const mxArray *instance_mat)
334 {
335 mwIndex *ir, *jc, low, high, k;
336 // using size_t due to the output type of matlab functions
337 size_t i, j, l, elements, max_index, label_vector_row_num;
338 mwSize num_samples;
339 double *samples, *labels;
340 mxArray *instance_mat_col; // instance sparse matrix in column format
341
342 prob.x = NULL;
343 prob.y = NULL;
344 x_space = NULL;
345
346 if(col_format_flag)
347 instance_mat_col = (mxArray *)instance_mat;
348 else
349 {
350 // transpose instance matrix
351 mxArray *prhs[1], *plhs[1];
352 prhs[0] = mxDuplicateArray(instance_mat);
353 if(mexCallMATLAB(1, plhs, 1, prhs, "transpose"))
354 {
355 mexPrintf("Error: cannot transpose training instance matrix\n");
356 return -1;
357 }
358 instance_mat_col = plhs[0];
359 mxDestroyArray(prhs[0]);
360 }
361
362 // the number of instance
363 l = mxGetN(instance_mat_col);
364 label_vector_row_num = mxGetM(label_vec);
365 prob.l = (int) l;
366
367 if(label_vector_row_num!=l)
368 {
369 mexPrintf("Length of label vector does not match # of instances.\n");
370 return -1;
371 }
372
373 // each column is one instance
374 labels = mxGetPr(label_vec);
375 samples = mxGetPr(instance_mat_col);
376 ir = mxGetIr(instance_mat_col);
377 jc = mxGetJc(instance_mat_col);
378
379 num_samples = mxGetNzmax(instance_mat_col);
380
381 elements = num_samples + l*2;
382 max_index = mxGetM(instance_mat_col);
383
384 prob.y = Malloc(double, l);
385 prob.x = Malloc(struct feature_node*, l);
386 x_space = Malloc(struct feature_node, elements);
387
388 prob.bias=bias;
389
390 j = 0;
391 for(i=0;i<l;i++)
392 {
393 prob.x[i] = &x_space[j];
394 prob.y[i] = labels[i];
395 low = jc[i], high = jc[i+1];
396 for(k=low;k<high;k++)
397 {
398 x_space[j].index = (int) ir[k]+1;
399 x_space[j].value = samples[k];
400 j++;
401 }
402 if(prob.bias>=0)
403 {
404 x_space[j].index = (int) max_index+1;
405 x_space[j].value = prob.bias;
406 j++;
407 }
408 x_space[j++].index = -1;
409 }
410
411 if(prob.bias>=0)
412 prob.n = (int) max_index+1;
413 else
414 prob.n = (int) max_index;
415
416 return 0;
417 }
418
419 // Interface function of matlab
420 // now assume prhs[0]: label prhs[1]: features
mexFunction(int nlhs,mxArray * plhs[],int nrhs,const mxArray * prhs[])421 void mexFunction( int nlhs, mxArray *plhs[],
422 int nrhs, const mxArray *prhs[] )
423 {
424 const char *error_msg;
425 // fix random seed to have same results for each run
426 // (for cross validation)
427 srand(1);
428
429 if(nlhs > 1)
430 {
431 exit_with_help();
432 fake_answer(nlhs, plhs);
433 return;
434 }
435
436 // Transform the input Matrix to libsvm format
437 if(nrhs > 1 && nrhs < 5)
438 {
439 int err=0;
440
441 if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1]))
442 {
443 mexPrintf("Error: label vector and instance matrix must be double\n");
444 fake_answer(nlhs, plhs);
445 return;
446 }
447
448 if(mxIsSparse(prhs[0]))
449 {
450 mexPrintf("Error: label vector should not be in sparse format");
451 fake_answer(nlhs, plhs);
452 return;
453 }
454
455 if(parse_command_line(nrhs, prhs, NULL))
456 {
457 exit_with_help();
458 destroy_param(¶m);
459 fake_answer(nlhs, plhs);
460 return;
461 }
462
463 if(mxIsSparse(prhs[1]))
464 err = read_problem_sparse(prhs[0], prhs[1]);
465 else
466 {
467 mexPrintf("Training_instance_matrix must be sparse; "
468 "use sparse(Training_instance_matrix) first\n");
469 destroy_param(¶m);
470 fake_answer(nlhs, plhs);
471 return;
472 }
473
474 // train's original code
475 error_msg = check_parameter(&prob, ¶m);
476
477 if(err || error_msg)
478 {
479 if (error_msg != NULL)
480 mexPrintf("Error: %s\n", error_msg);
481 destroy_param(¶m);
482 free(prob.y);
483 free(prob.x);
484 free(x_space);
485 fake_answer(nlhs, plhs);
486 return;
487 }
488
489 if (flag_find_C)
490 {
491 double best_C, best_rate, *ptr;
492
493 do_find_parameter_C(&best_C, &best_rate);
494
495 plhs[0] = mxCreateDoubleMatrix(2, 1, mxREAL);
496 ptr = mxGetPr(plhs[0]);
497 ptr[0] = best_C;
498 ptr[1] = best_rate;
499 }
500 else if(flag_cross_validation)
501 {
502 double *ptr;
503 plhs[0] = mxCreateDoubleMatrix(1, 1, mxREAL);
504 ptr = mxGetPr(plhs[0]);
505 ptr[0] = do_cross_validation();
506 }
507 else
508 {
509 const char *error_msg;
510
511 model_ = train(&prob, ¶m);
512 error_msg = model_to_matlab_structure(plhs, model_);
513 if(error_msg)
514 mexPrintf("Error: can't convert libsvm model to matrix structure: %s\n", error_msg);
515 free_and_destroy_model(&model_);
516 }
517 destroy_param(¶m);
518 free(prob.y);
519 free(prob.x);
520 free(x_space);
521 }
522 else
523 {
524 exit_with_help();
525 fake_answer(nlhs, plhs);
526 return;
527 }
528 }
529