1 /*======================================================================================
2 * svmocas_mex.c: Matlab MEX interface to OCAS solver for training two-class
3 * linear SVM classifier
4 *
5 * Synopsis:
6 * [W,W0,stat] = svmocas(X,X0,y,C,Method,TolRel,TolAbs,QPBound,BufSize,
7 * nData,MaxTime,verb)
8 *
9 * See svmocas.m for more help.
10 *
11 * Copyright (C) 2008, 2009 Vojtech Franc, xfrancv@cmp.felk.cvut.cz
12 * Soeren Sonnenburg, soeren.sonnenburg@first.fraunhofer.de
13 *
14 * This program is free software; you can redistribute it and/or
15 * modify it under the terms of the GNU General Public
16 * License as published by the Free Software Foundation;
17 *=====================================================================================*/
18
19 #include <stdio.h>
20 #include <string.h>
21 #include <stdint.h>
22 #include <mex.h>
23
24 #include "libocas.h"
25 #include "ocas_helper.h"
26 #include "features_int8.h"
27 #include "features_double.h"
28 #include "features_single.h"
29
30 #define DEFAULT_METHOD 1
31 #define DEFAULT_TOLREL 0.01
32 #define DEFAULT_TOLABS 0.0
33 #define DEFAULT_QPVALUE 0.0
34 #define DEFAULT_BUFSIZE 2000
35 #define DEFAULT_MAXTIME mxGetInf()
36 #define DEFAULT_VERB 1
37
38 /*======================================================================
39 Main code plus interface to Matlab.
40 ========================================================================*/
41
mexFunction(int nlhs,mxArray * plhs[],int nrhs,const mxArray * prhs[])42 void mexFunction( int nlhs, mxArray *plhs[],int nrhs, const mxArray *prhs[] )
43 {
44 double C, TolRel, TolAbs, QPBound, trn_err, MaxTime;
45 double *vec_C;
46 uint32_t num_of_Cs;
47 uint32_t i, j, BufSize;
48 uint16_t Method;
49 int verb;
50 double *ptr_double;
51 float *ptr_single;
52 int8_t *ptr_int8;
53 ocas_return_value_T ocas;
54
55 /* timing variables */
56 double init_time;
57 double total_time;
58
59 total_time = get_time();
60 init_time = total_time;
61
62 if(nrhs < 4 || nrhs > 12)
63 mexErrMsgTxt("Improper number of input arguments.\n\n"
64 "SVMOCAS solver for training two-class linear SVM classifiers\n\n"
65 "Synopsis:\n"
66 " [W,W0,stat] = svmocas(X,X0,y,C,Method,TolRel,TolAbs,QPBound,"
67 "BufSize,nExamples,MaxTime) \n\n"
68 "Input: \n"
69 " X [nDim x nExamples] training inputs (dense double or sparse double or dense single or dense int8)\n"
70 " X0 [1 x 1 (double)] constant feature added to all examples\n"
71 " y [nExamples x 1 (double)] labels of the examples (+1/-1)\n"
72 " C [1x1] or [nExamples x 1] regularization constant(s) \n"
73 " Method [1x1 (double)] 0 for BMRM; 1 for OCAS \n"
74 " TolRel [1x1 (double)]\n"
75 " TolAbs [1x1 (double)]\n"
76 " QPBound [1x1 (double)]\n"
77 " BufSize [1x1 (double)]\n"
78 " nExamples [1x1 (double) number of examples to use; "
79 "(inf means use all examples)\n"
80 " MaxTime [1x1 (double)]\n"
81 " verb [1x1 (bouble)]\n\n"
82 "Output:\n"
83 " W [nDim x 1] Parameter vector\n"
84 " W0 [1x1] Bias term\n"
85 " stat [struct] \n");
86
87
88 if(nrhs >= 12)
89 verb = (int)mxGetScalar(prhs[11]);
90 else
91 verb = DEFAULT_VERB;
92
93 /* 1st input argument: training feature vectors */
94 data_X = (mxArray*)prhs[0];
95 if( (mxGetNumberOfDimensions(data_X) != 2) ||
96 !( ( mxIsDouble(data_X) && mxIsSparse(data_X) ) ||
97 ( mxIsDouble(data_X) && !mxIsSparse(data_X) ) ||
98 ( mxIsSingle(data_X) && !mxIsSparse(data_X) ) ||
99 ( mxIsInt8(data_X) && !mxIsSparse(data_X) ) ))
100 {
101 mexErrMsgTxt("The first input argument must be two dimensional matrix of the following type:\n"
102 "dense double or sparse double or dense single or dense int8 matrix.\n");
103 }
104
105 /* 2nd input argument: constant coordinate added to feature vectors */
106 X0 = (double)mxGetScalar(prhs[1]);
107
108 /*3rd input argument: vector of labels */
109 if( !mxIsDouble(prhs[2]) || mxIsSparse(prhs[2]) )
110 mexErrMsgTxt("The third input argument must be dense vector of doubles.");
111 data_y = (double*)mxGetPr(prhs[2]);
112
113 if(LIBOCAS_MAX(mxGetM(prhs[2]),mxGetN(prhs[2])) != mxGetN(prhs[0]))
114 mexErrMsgTxt("Length of vector y (3rd input argument) must equl to the number of columns of matrix X (1st input argument).");
115
116 nDim = mxGetM(prhs[0]);
117
118 if(verb)
119 {
120 mexPrintf("Input data statistics:\n"
121 " # of examples : %d\n"
122 " dimensionality : %d\n",
123 mxGetN(data_X), nDim);
124
125 if( mxIsSparse(data_X)== true )
126 mexPrintf(" sparse features (density=%.2f%%) ",
127 100.0*(double)mxGetNzmax(data_X)/((double)nDim*(double)(mxGetN(data_X))));
128 else
129 mexPrintf(" dense features ");
130 if( mxIsDouble(data_X) )
131 mexPrintf("in double precision\n");
132 if( mxIsSingle(data_X) )
133 mexPrintf("in single precision\n");
134 if( mxIsInt8(data_X) )
135 mexPrintf("represented as int8\n");
136 }
137
138 num_of_Cs = LIBOCAS_MAX(mxGetN(prhs[3]),mxGetM(prhs[3]));
139 if(num_of_Cs == 1)
140 {
141 C = (double)mxGetScalar(prhs[3]);
142 }
143 else
144 {
145 vec_C = (double*)mxGetPr(prhs[3]);
146 }
147
148 if(nrhs >= 5)
149 Method = (uint32_t)mxGetScalar(prhs[4]);
150 else
151 Method = DEFAULT_METHOD;
152
153 if(nrhs >= 6)
154 TolRel = (double)mxGetScalar(prhs[5]);
155 else
156 TolRel = DEFAULT_TOLREL;
157
158 if(nrhs >= 7)
159 TolAbs = (double)mxGetScalar(prhs[6]);
160 else
161 TolAbs = DEFAULT_TOLABS;
162
163 if(nrhs >= 8)
164 QPBound = (double)mxGetScalar(prhs[7]);
165 else
166 QPBound = DEFAULT_QPVALUE;
167
168 if(nrhs >= 9)
169 BufSize = (uint32_t)mxGetScalar(prhs[8]);
170 else
171 BufSize = DEFAULT_BUFSIZE;
172
173 if(nrhs >= 10 && mxIsInf(mxGetScalar(prhs[9])) == false)
174 nData = (uint32_t)mxGetScalar(prhs[9]);
175 else
176 nData = mxGetN(data_X);
177
178 if(nData < 1 || nData > mxGetN(prhs[0]))
179 mexErrMsgTxt("Improper value of argument nData.");
180
181 if(num_of_Cs > 1 && num_of_Cs < nData)
182 mexErrMsgTxt("Length of the vector C less than the number of examples.");
183
184 if(nrhs >= 11)
185 MaxTime = (double)mxGetScalar(prhs[10]);
186 else
187 MaxTime = DEFAULT_MAXTIME;
188
189
190 /*----------------------------------------------------------------
191 Print setting
192 -------------------------------------------------------------------*/
193 if(verb)
194 {
195 mexPrintf("Setting:\n");
196
197 if( num_of_Cs == 1)
198 mexPrintf(" C : %f\n", C);
199 else
200 mexPrintf(" C : different for each example\n");
201
202 mexPrintf(" bias : %.0f\n"
203 " # of examples : %d\n"
204 " solver : %d\n"
205 " cache size : %d\n"
206 " TolAbs : %f\n"
207 " TolRel : %f\n"
208 " QPValue : %f\n"
209 " MaxTime : %f [s]\n"
210 " verb : %d\n",
211 X0, nData, Method,BufSize,TolAbs,TolRel, QPBound, MaxTime, verb);
212 }
213
214 /* learned weight vector */
215 plhs[0] = (mxArray*)mxCreateDoubleMatrix(nDim,1,mxREAL);
216 W = (double*)mxGetPr(plhs[0]);
217 if(W == NULL) mexErrMsgTxt("Not enough memory for vector W.");
218
219 oldW = (double*)mxCalloc(nDim,sizeof(double));
220 if(oldW == NULL) mexErrMsgTxt("Not enough memory for vector oldW.");
221
222 W0 = 0;
223 oldW0 = 0;
224
225 A0 = mxCalloc(BufSize,sizeof(A0[0]));
226 if(A0 == NULL) mexErrMsgTxt("Not enough memory for vector A0.");
227
228 /* allocate buffer for computing cutting plane */
229 new_a = (double*)mxCalloc(nDim,sizeof(double));
230 if(new_a == NULL)
231 mexErrMsgTxt("Not enough memory for auxciliary cutting plane buffer new_a.");
232
233 if(num_of_Cs > 1)
234 {
235 for(i=0; i < nData; i++)
236 data_y[i] = data_y[i]*vec_C[i];
237 }
238
239 /* select function to print progress info */
240 void (*print_function)(ocas_return_value_T);
241 if(verb)
242 {
243 mexPrintf("Starting optimization:\n");
244 print_function = &ocas_print;
245 }
246 else
247 {
248 print_function = &ocas_print_null;
249 }
250
251 if( mxIsSparse(data_X)== true )
252 {
253
254 /* for i=1:nData, X(:,i) = X(:,i)*y(i); end*/
255 for(i=0; i < nData; i++)
256 mul_sparse_col(data_y[i], data_X, i);
257
258
259 /* init cutting plane buffer */
260 sparse_A.nz_dims = mxCalloc(BufSize,sizeof(uint32_t));
261 sparse_A.index = mxCalloc(BufSize,sizeof(sparse_A.index[0]));
262 sparse_A.value = mxCalloc(BufSize,sizeof(sparse_A.value[0]));
263 if(sparse_A.nz_dims == NULL || sparse_A.index == NULL || sparse_A.value == NULL)
264 mexErrMsgTxt("Not enough memory for cutting plane buffer sparse_A.");
265
266 init_time=get_time()-init_time;
267
268
269 if(num_of_Cs == 1)
270 {
271 ocas = svm_ocas_solver( C, nData, TolRel, TolAbs, QPBound, MaxTime,BufSize, Method,
272 &sparse_compute_W, &update_W, &sparse_add_new_cut,
273 &sparse_compute_output, &qsort_data,
274 print_function, 0);
275 }
276 else
277 {
278 ocas = svm_ocas_solver_difC( vec_C, nData, TolRel, TolAbs, QPBound,
279 MaxTime,BufSize, Method,
280 &sparse_compute_W, &update_W,
281 &sparse_add_new_cut, &sparse_compute_output,
282 &qsort_data, print_function, 0);
283 }
284
285 }
286 else
287 {
288
289 int (*add_new_cut)(double*, uint32_t*, uint32_t, uint32_t, void*);
290 int (*compute_output)( double*, void* );
291
292 /* features in double precision */
293 if( mxIsDouble(data_X) )
294 {
295 ptr_double = mxGetPr(data_X);
296 for(i=0; i < nData; i++) {
297 for(j=0; j < nDim; j++ ) {
298 ptr_double[LIBOCAS_INDEX(j,i,nDim)] = ptr_double[LIBOCAS_INDEX(j,i,nDim)]*data_y[i];
299 }
300 }
301
302 add_new_cut = &full_add_new_cut;
303 compute_output = &full_compute_output;
304 }
305
306 /* features in single precision */
307 if( mxIsSingle(data_X) )
308 {
309 ptr_single = (float*)mxGetPr(data_X);
310 for(i=0; i < nData; i++) {
311 for(j=0; j < nDim; j++ ) {
312 ptr_single[LIBOCAS_INDEX(j,i,nDim)] = ptr_single[LIBOCAS_INDEX(j,i,nDim)]*data_y[i];
313 }
314 }
315
316 add_new_cut = &full_single_add_new_cut;
317 compute_output = &full_single_compute_output;
318 }
319
320 /* features in int8 */
321 if( mxIsInt8(data_X) )
322 {
323 ptr_int8 = (int8_t*)mxGetPr(data_X);
324 for(i=0; i < nData; i++) {
325 for(j=0; j < nDim; j++ ) {
326 ptr_int8[LIBOCAS_INDEX(j,i,nDim)] = ptr_int8[LIBOCAS_INDEX(j,i,nDim)]/(int8_t)data_y[i];
327 }
328 }
329
330 add_new_cut = &full_int8_add_new_cut;
331 compute_output = &full_int8_compute_output;
332 }
333
334 /* init cutting plane buffer */
335 full_A = mxCalloc(BufSize*nDim,sizeof(double));
336 if( full_A == NULL )
337 mexErrMsgTxt("Not enough memory for cutting plane buffer full_A.");
338
339 init_time=get_time()-init_time;
340
341 if(num_of_Cs == 1)
342 {
343 ocas = svm_ocas_solver( C, nData, TolRel, TolAbs, QPBound, MaxTime,BufSize, Method,
344 &full_compute_W, &update_W, add_new_cut,
345 compute_output, &qsort_data, print_function, 0);
346 }
347 else
348 {
349 ocas = svm_ocas_solver_difC( vec_C, nData, TolRel, TolAbs, QPBound, MaxTime,
350 BufSize, Method,
351 &full_compute_W, &update_W, add_new_cut,
352 compute_output, &qsort_data, print_function, 0);
353 }
354 }
355
356 total_time=get_time()-total_time;
357
358 if(verb)
359 {
360 mexPrintf("Stopping condition: ");
361 switch( ocas.exitflag )
362 {
363 case 1: mexPrintf("1-Q_D/Q_P <= TolRel(=%f) satisfied.\n", TolRel); break;
364 case 2: mexPrintf("Q_P-Q_D <= TolAbs(=%f) satisfied.\n", TolAbs); break;
365 case 3: mexPrintf("Q_P <= QPBound(=%f) satisfied.\n", QPBound); break;
366 case 4: mexPrintf("Optimization time (=%f) >= MaxTime(=%f).\n",
367 ocas.ocas_time, MaxTime); break;
368 case -1: mexPrintf("Has not converged!\n" ); break;
369 case -2: mexPrintf("Not enough memory for the solver.\n" ); break;
370 }
371
372 mexPrintf("Timing statistics:\n"
373 " init_time : %f[s]\n"
374 " qp_solver_time : %f[s]\n"
375 " sort_time : %f[s]\n"
376 " output_time : %f[s]\n"
377 " add_time : %f[s]\n"
378 " w_time : %f[s]\n"
379 " print_time : %f[s]\n"
380 " ocas_time : %f[s]\n"
381 " total_time : %f[s]\n",
382 init_time, ocas.qp_solver_time, ocas.sort_time, ocas.output_time,
383 ocas.add_time, ocas.w_time, ocas.print_time, ocas.ocas_time, total_time);
384
385 mexPrintf("Training error: %.4f%%\n", 100*(double)ocas.trn_err/(double)nData);
386 }
387
388 /* multiply data by labels as it was at the begining */
389 if( mxIsSparse(data_X)== true )
390 {
391 /* for i=1:nData, X(:,i) = X(:,i)*y(i); end*/
392 for(i=0; i < nData; i++)
393 {
394 mul_sparse_col(1/data_y[i], data_X, i);
395 }
396 }
397 else
398 {
399
400 /* features in double precision */
401 if( mxIsDouble(data_X) )
402 {
403 ptr_double = mxGetPr(data_X);
404 for(i=0; i < nData; i++) {
405 for(j=0; j < nDim; j++ ) {
406 ptr_double[LIBOCAS_INDEX(j,i,nDim)] = ptr_double[LIBOCAS_INDEX(j,i,nDim)]/data_y[i];
407 }
408 }
409 }
410
411 /* features in single precision */
412 if( mxIsSingle(data_X) )
413 {
414 ptr_single = (float*)mxGetPr(data_X);
415 for(i=0; i < nData; i++) {
416 for(j=0; j < nDim; j++ ) {
417 ptr_single[LIBOCAS_INDEX(j,i,nDim)] = ptr_single[LIBOCAS_INDEX(j,i,nDim)]/data_y[i];
418 }
419 }
420 }
421
422 /* features in int8 */
423 if( mxIsInt8(data_X) )
424 {
425 ptr_int8 = (int8_t*)mxGetPr(data_X);
426 for(i=0; i < nData; i++) {
427 for(j=0; j < nDim; j++ ) {
428 ptr_int8[LIBOCAS_INDEX(j,i,nDim)] = ptr_int8[LIBOCAS_INDEX(j,i,nDim)]/data_y[i];
429 }
430 }
431 }
432 }
433
434 /* divide labels by Cs as it was at the begining */
435 if(num_of_Cs > 1)
436 {
437 for(i=0; i < nData; i++)
438 data_y[i] = data_y[i]/vec_C[i];
439 }
440
441 /* create output variables */
442 plhs[1] = mxCreateDoubleScalar( W0 );
443
444 const char *field_names[] = {"nTrnErrors","Q_P","Q_D","nIter","nCutPlanes","exitflag",
445 "init_time","output_time","sort_time",
446 "qp_solver_time","add_time","w_time","print_time",
447 "ocas_time","total_time"};
448 mwSize dims[2] = {1,1};
449
450 plhs[2] = mxCreateStructArray(2, dims, (sizeof(field_names)/sizeof(*field_names)),
451 field_names);
452
453 mxSetField(plhs[2],0,"nIter",mxCreateDoubleScalar((double)ocas.nIter));
454 mxSetField(plhs[2],0,"nCutPlanes",mxCreateDoubleScalar((double)ocas.nCutPlanes));
455 mxSetField(plhs[2],0,"nTrnErrors",mxCreateDoubleScalar(ocas.trn_err));
456 mxSetField(plhs[2],0,"Q_P",mxCreateDoubleScalar(ocas.Q_P));
457 mxSetField(plhs[2],0,"Q_D",mxCreateDoubleScalar(ocas.Q_D));
458 mxSetField(plhs[2],0,"init_time",mxCreateDoubleScalar(init_time));
459 mxSetField(plhs[2],0,"output_time",mxCreateDoubleScalar(ocas.output_time));
460 mxSetField(plhs[2],0,"sort_time",mxCreateDoubleScalar(ocas.sort_time));
461 mxSetField(plhs[2],0,"qp_solver_time",mxCreateDoubleScalar(ocas.qp_solver_time));
462 mxSetField(plhs[2],0,"add_time",mxCreateDoubleScalar(ocas.add_time));
463 mxSetField(plhs[2],0,"w_time",mxCreateDoubleScalar(ocas.w_time));
464 mxSetField(plhs[2],0,"print_time",mxCreateDoubleScalar(ocas.print_time));
465 mxSetField(plhs[2],0,"ocas_time",mxCreateDoubleScalar(ocas.ocas_time));
466 mxSetField(plhs[2],0,"total_time",mxCreateDoubleScalar(total_time));
467 mxSetField(plhs[2],0,"exitflag",mxCreateDoubleScalar((double)ocas.exitflag));
468
469 return;
470 }
471
472