1#!/usr/bin/env python 2 3from ctypes import * 4from ctypes.util import find_library 5from os import path 6import sys 7 8try: 9 import scipy 10 from scipy import sparse 11except: 12 scipy = None 13 sparse = None 14 15if sys.version_info[0] < 3: 16 range = xrange 17 from itertools import izip as zip 18 19__all__ = ['libsvm', 'svm_problem', 'svm_parameter', 20 'toPyModel', 'gen_svm_nodearray', 'print_null', 'svm_node', 'C_SVC', 21 'EPSILON_SVR', 'LINEAR', 'NU_SVC', 'NU_SVR', 'ONE_CLASS', 22 'POLY', 'PRECOMPUTED', 'PRINT_STRING_FUN', 'RBF', 23 'SIGMOID', 'c_double', 'svm_model'] 24 25try: 26 dirname = path.dirname(path.abspath(__file__)) 27 if sys.platform == 'win32': 28 libsvm = CDLL(path.join(dirname, r'..\windows\libsvm.dll')) 29 else: 30 libsvm = CDLL(path.join(dirname, '../libsvm.so.2')) 31except: 32# For unix the prefix 'lib' is not considered. 33 if find_library('svm'): 34 libsvm = CDLL(find_library('svm')) 35 elif find_library('libsvm'): 36 libsvm = CDLL(find_library('libsvm')) 37 else: 38 raise Exception('LIBSVM library not found.') 39 40C_SVC = 0 41NU_SVC = 1 42ONE_CLASS = 2 43EPSILON_SVR = 3 44NU_SVR = 4 45 46LINEAR = 0 47POLY = 1 48RBF = 2 49SIGMOID = 3 50PRECOMPUTED = 4 51 52PRINT_STRING_FUN = CFUNCTYPE(None, c_char_p) 53def print_null(s): 54 return 55 56def genFields(names, types): 57 return list(zip(names, types)) 58 59def fillprototype(f, restype, argtypes): 60 f.restype = restype 61 f.argtypes = argtypes 62 63class svm_node(Structure): 64 _names = ["index", "value"] 65 _types = [c_int, c_double] 66 _fields_ = genFields(_names, _types) 67 68 def __init__(self, index=-1, value=0): 69 self.index, self.value = index, value 70 71 def __str__(self): 72 return '%d:%g' % (self.index, self.value) 73 74def gen_svm_nodearray(xi, feature_max=None, isKernel=False): 75 if feature_max: 76 assert(isinstance(feature_max, int)) 77 78 xi_shift = 0 # ensure correct indices of xi 79 if scipy and isinstance(xi, tuple) and len(xi) == 2\ 80 and isinstance(xi[0], scipy.ndarray) and isinstance(xi[1], scipy.ndarray): # for a sparse vector 81 if not isKernel: 82 index_range = xi[0] + 1 # index starts from 1 83 else: 84 index_range = xi[0] # index starts from 0 for precomputed kernel 85 if feature_max: 86 index_range = index_range[scipy.where(index_range <= feature_max)] 87 elif scipy and isinstance(xi, scipy.ndarray): 88 if not isKernel: 89 xi_shift = 1 90 index_range = xi.nonzero()[0] + 1 # index starts from 1 91 else: 92 index_range = scipy.arange(0, len(xi)) # index starts from 0 for precomputed kernel 93 if feature_max: 94 index_range = index_range[scipy.where(index_range <= feature_max)] 95 elif isinstance(xi, (dict, list, tuple)): 96 if isinstance(xi, dict): 97 index_range = xi.keys() 98 elif isinstance(xi, (list, tuple)): 99 if not isKernel: 100 xi_shift = 1 101 index_range = range(1, len(xi) + 1) # index starts from 1 102 else: 103 index_range = range(0, len(xi)) # index starts from 0 for precomputed kernel 104 105 if feature_max: 106 index_range = filter(lambda j: j <= feature_max, index_range) 107 if not isKernel: 108 index_range = filter(lambda j:xi[j-xi_shift] != 0, index_range) 109 110 index_range = sorted(index_range) 111 else: 112 raise TypeError('xi should be a dictionary, list, tuple, 1-d numpy array, or tuple of (index, data)') 113 114 ret = (svm_node*(len(index_range)+1))() 115 ret[-1].index = -1 116 117 if scipy and isinstance(xi, tuple) and len(xi) == 2\ 118 and isinstance(xi[0], scipy.ndarray) and isinstance(xi[1], scipy.ndarray): # for a sparse vector 119 for idx, j in enumerate(index_range): 120 ret[idx].index = j 121 ret[idx].value = (xi[1])[idx] 122 else: 123 for idx, j in enumerate(index_range): 124 ret[idx].index = j 125 ret[idx].value = xi[j - xi_shift] 126 127 max_idx = 0 128 if len(index_range) > 0: 129 max_idx = index_range[-1] 130 return ret, max_idx 131 132try: 133 from numba import jit 134 jit_enabled = True 135except: 136 jit = lambda x: x 137 jit_enabled = False 138 139@jit 140def csr_to_problem_jit(l, x_val, x_ind, x_rowptr, prob_val, prob_ind, prob_rowptr, indx_start): 141 for i in range(l): 142 b1,e1 = x_rowptr[i], x_rowptr[i+1] 143 b2,e2 = prob_rowptr[i], prob_rowptr[i+1]-1 144 for j in range(b1,e1): 145 prob_ind[j-b1+b2] = x_ind[j]+indx_start 146 prob_val[j-b1+b2] = x_val[j] 147def csr_to_problem_nojit(l, x_val, x_ind, x_rowptr, prob_val, prob_ind, prob_rowptr, indx_start): 148 for i in range(l): 149 x_slice = slice(x_rowptr[i], x_rowptr[i+1]) 150 prob_slice = slice(prob_rowptr[i], prob_rowptr[i+1]-1) 151 prob_ind[prob_slice] = x_ind[x_slice]+indx_start 152 prob_val[prob_slice] = x_val[x_slice] 153 154def csr_to_problem(x, prob, isKernel): 155 if not x.has_sorted_indices: 156 x.sort_indices() 157 158 # Extra space for termination node and (possibly) bias term 159 x_space = prob.x_space = scipy.empty((x.nnz+x.shape[0]), dtype=svm_node) 160 prob.rowptr = x.indptr.copy() 161 prob.rowptr[1:] += scipy.arange(1,x.shape[0]+1) 162 prob_ind = x_space["index"] 163 prob_val = x_space["value"] 164 prob_ind[:] = -1 165 if not isKernel: 166 indx_start = 1 # index starts from 1 167 else: 168 indx_start = 0 # index starts from 0 for precomputed kernel 169 if jit_enabled: 170 csr_to_problem_jit(x.shape[0], x.data, x.indices, x.indptr, prob_val, prob_ind, prob.rowptr, indx_start) 171 else: 172 csr_to_problem_nojit(x.shape[0], x.data, x.indices, x.indptr, prob_val, prob_ind, prob.rowptr, indx_start) 173 174class svm_problem(Structure): 175 _names = ["l", "y", "x"] 176 _types = [c_int, POINTER(c_double), POINTER(POINTER(svm_node))] 177 _fields_ = genFields(_names, _types) 178 179 def __init__(self, y, x, isKernel=False): 180 if (not isinstance(y, (list, tuple))) and (not (scipy and isinstance(y, scipy.ndarray))): 181 raise TypeError("type of y: {0} is not supported!".format(type(y))) 182 183 if isinstance(x, (list, tuple)): 184 if len(y) != len(x): 185 raise ValueError("len(y) != len(x)") 186 elif scipy != None and isinstance(x, (scipy.ndarray, sparse.spmatrix)): 187 if len(y) != x.shape[0]: 188 raise ValueError("len(y) != len(x)") 189 if isinstance(x, scipy.ndarray): 190 x = scipy.ascontiguousarray(x) # enforce row-major 191 if isinstance(x, sparse.spmatrix): 192 x = x.tocsr() 193 pass 194 else: 195 raise TypeError("type of x: {0} is not supported!".format(type(x))) 196 self.l = l = len(y) 197 198 max_idx = 0 199 x_space = self.x_space = [] 200 if scipy != None and isinstance(x, sparse.csr_matrix): 201 csr_to_problem(x, self, isKernel) 202 max_idx = x.shape[1] 203 else: 204 for i, xi in enumerate(x): 205 tmp_xi, tmp_idx = gen_svm_nodearray(xi,isKernel=isKernel) 206 x_space += [tmp_xi] 207 max_idx = max(max_idx, tmp_idx) 208 self.n = max_idx 209 210 self.y = (c_double * l)() 211 if scipy != None and isinstance(y, scipy.ndarray): 212 scipy.ctypeslib.as_array(self.y, (self.l,))[:] = y 213 else: 214 for i, yi in enumerate(y): self.y[i] = yi 215 216 self.x = (POINTER(svm_node) * l)() 217 if scipy != None and isinstance(x, sparse.csr_matrix): 218 base = addressof(self.x_space.ctypes.data_as(POINTER(svm_node))[0]) 219 x_ptr = cast(self.x, POINTER(c_uint64)) 220 x_ptr = scipy.ctypeslib.as_array(x_ptr,(self.l,)) 221 x_ptr[:] = self.rowptr[:-1]*sizeof(svm_node)+base 222 else: 223 for i, xi in enumerate(self.x_space): self.x[i] = xi 224 225class svm_parameter(Structure): 226 _names = ["svm_type", "kernel_type", "degree", "gamma", "coef0", 227 "cache_size", "eps", "C", "nr_weight", "weight_label", "weight", 228 "nu", "p", "shrinking", "probability"] 229 _types = [c_int, c_int, c_int, c_double, c_double, 230 c_double, c_double, c_double, c_int, POINTER(c_int), POINTER(c_double), 231 c_double, c_double, c_int, c_int] 232 _fields_ = genFields(_names, _types) 233 234 def __init__(self, options = None): 235 if options == None: 236 options = '' 237 self.parse_options(options) 238 239 def __str__(self): 240 s = '' 241 attrs = svm_parameter._names + list(self.__dict__.keys()) 242 values = map(lambda attr: getattr(self, attr), attrs) 243 for attr, val in zip(attrs, values): 244 s += (' %s: %s\n' % (attr, val)) 245 s = s.strip() 246 247 return s 248 249 def set_to_default_values(self): 250 self.svm_type = C_SVC; 251 self.kernel_type = RBF 252 self.degree = 3 253 self.gamma = 0 254 self.coef0 = 0 255 self.nu = 0.5 256 self.cache_size = 100 257 self.C = 1 258 self.eps = 0.001 259 self.p = 0.1 260 self.shrinking = 1 261 self.probability = 0 262 self.nr_weight = 0 263 self.weight_label = None 264 self.weight = None 265 self.cross_validation = False 266 self.nr_fold = 0 267 self.print_func = cast(None, PRINT_STRING_FUN) 268 269 def parse_options(self, options): 270 if isinstance(options, list): 271 argv = options 272 elif isinstance(options, str): 273 argv = options.split() 274 else: 275 raise TypeError("arg 1 should be a list or a str.") 276 self.set_to_default_values() 277 self.print_func = cast(None, PRINT_STRING_FUN) 278 weight_label = [] 279 weight = [] 280 281 i = 0 282 while i < len(argv): 283 if argv[i] == "-s": 284 i = i + 1 285 self.svm_type = int(argv[i]) 286 elif argv[i] == "-t": 287 i = i + 1 288 self.kernel_type = int(argv[i]) 289 elif argv[i] == "-d": 290 i = i + 1 291 self.degree = int(argv[i]) 292 elif argv[i] == "-g": 293 i = i + 1 294 self.gamma = float(argv[i]) 295 elif argv[i] == "-r": 296 i = i + 1 297 self.coef0 = float(argv[i]) 298 elif argv[i] == "-n": 299 i = i + 1 300 self.nu = float(argv[i]) 301 elif argv[i] == "-m": 302 i = i + 1 303 self.cache_size = float(argv[i]) 304 elif argv[i] == "-c": 305 i = i + 1 306 self.C = float(argv[i]) 307 elif argv[i] == "-e": 308 i = i + 1 309 self.eps = float(argv[i]) 310 elif argv[i] == "-p": 311 i = i + 1 312 self.p = float(argv[i]) 313 elif argv[i] == "-h": 314 i = i + 1 315 self.shrinking = int(argv[i]) 316 elif argv[i] == "-b": 317 i = i + 1 318 self.probability = int(argv[i]) 319 elif argv[i] == "-q": 320 self.print_func = PRINT_STRING_FUN(print_null) 321 elif argv[i] == "-v": 322 i = i + 1 323 self.cross_validation = 1 324 self.nr_fold = int(argv[i]) 325 if self.nr_fold < 2: 326 raise ValueError("n-fold cross validation: n must >= 2") 327 elif argv[i].startswith("-w"): 328 i = i + 1 329 self.nr_weight += 1 330 weight_label += [int(argv[i-1][2:])] 331 weight += [float(argv[i])] 332 else: 333 raise ValueError("Wrong options") 334 i += 1 335 336 libsvm.svm_set_print_string_function(self.print_func) 337 self.weight_label = (c_int*self.nr_weight)() 338 self.weight = (c_double*self.nr_weight)() 339 for i in range(self.nr_weight): 340 self.weight[i] = weight[i] 341 self.weight_label[i] = weight_label[i] 342 343class svm_model(Structure): 344 _names = ['param', 'nr_class', 'l', 'SV', 'sv_coef', 'rho', 345 'probA', 'probB', 'sv_indices', 'label', 'nSV', 'free_sv'] 346 _types = [svm_parameter, c_int, c_int, POINTER(POINTER(svm_node)), 347 POINTER(POINTER(c_double)), POINTER(c_double), 348 POINTER(c_double), POINTER(c_double), POINTER(c_int), 349 POINTER(c_int), POINTER(c_int), c_int] 350 _fields_ = genFields(_names, _types) 351 352 def __init__(self): 353 self.__createfrom__ = 'python' 354 355 def __del__(self): 356 # free memory created by C to avoid memory leak 357 if hasattr(self, '__createfrom__') and self.__createfrom__ == 'C': 358 libsvm.svm_free_and_destroy_model(pointer(pointer(self))) 359 360 def get_svm_type(self): 361 return libsvm.svm_get_svm_type(self) 362 363 def get_nr_class(self): 364 return libsvm.svm_get_nr_class(self) 365 366 def get_svr_probability(self): 367 return libsvm.svm_get_svr_probability(self) 368 369 def get_labels(self): 370 nr_class = self.get_nr_class() 371 labels = (c_int * nr_class)() 372 libsvm.svm_get_labels(self, labels) 373 return labels[:nr_class] 374 375 def get_sv_indices(self): 376 total_sv = self.get_nr_sv() 377 sv_indices = (c_int * total_sv)() 378 libsvm.svm_get_sv_indices(self, sv_indices) 379 return sv_indices[:total_sv] 380 381 def get_nr_sv(self): 382 return libsvm.svm_get_nr_sv(self) 383 384 def is_probability_model(self): 385 return (libsvm.svm_check_probability_model(self) == 1) 386 387 def get_sv_coef(self): 388 return [tuple(self.sv_coef[j][i] for j in range(self.nr_class - 1)) 389 for i in range(self.l)] 390 391 def get_SV(self): 392 result = [] 393 for sparse_sv in self.SV[:self.l]: 394 row = dict() 395 396 i = 0 397 while True: 398 if sparse_sv[i].index == -1: 399 break 400 row[sparse_sv[i].index] = sparse_sv[i].value 401 i += 1 402 403 result.append(row) 404 return result 405 406def toPyModel(model_ptr): 407 """ 408 toPyModel(model_ptr) -> svm_model 409 410 Convert a ctypes POINTER(svm_model) to a Python svm_model 411 """ 412 if bool(model_ptr) == False: 413 raise ValueError("Null pointer") 414 m = model_ptr.contents 415 m.__createfrom__ = 'C' 416 return m 417 418fillprototype(libsvm.svm_train, POINTER(svm_model), [POINTER(svm_problem), POINTER(svm_parameter)]) 419fillprototype(libsvm.svm_cross_validation, None, [POINTER(svm_problem), POINTER(svm_parameter), c_int, POINTER(c_double)]) 420 421fillprototype(libsvm.svm_save_model, c_int, [c_char_p, POINTER(svm_model)]) 422fillprototype(libsvm.svm_load_model, POINTER(svm_model), [c_char_p]) 423 424fillprototype(libsvm.svm_get_svm_type, c_int, [POINTER(svm_model)]) 425fillprototype(libsvm.svm_get_nr_class, c_int, [POINTER(svm_model)]) 426fillprototype(libsvm.svm_get_labels, None, [POINTER(svm_model), POINTER(c_int)]) 427fillprototype(libsvm.svm_get_sv_indices, None, [POINTER(svm_model), POINTER(c_int)]) 428fillprototype(libsvm.svm_get_nr_sv, c_int, [POINTER(svm_model)]) 429fillprototype(libsvm.svm_get_svr_probability, c_double, [POINTER(svm_model)]) 430 431fillprototype(libsvm.svm_predict_values, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)]) 432fillprototype(libsvm.svm_predict, c_double, [POINTER(svm_model), POINTER(svm_node)]) 433fillprototype(libsvm.svm_predict_probability, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)]) 434 435fillprototype(libsvm.svm_free_model_content, None, [POINTER(svm_model)]) 436fillprototype(libsvm.svm_free_and_destroy_model, None, [POINTER(POINTER(svm_model))]) 437fillprototype(libsvm.svm_destroy_param, None, [POINTER(svm_parameter)]) 438 439fillprototype(libsvm.svm_check_parameter, c_char_p, [POINTER(svm_problem), POINTER(svm_parameter)]) 440fillprototype(libsvm.svm_check_probability_model, c_int, [POINTER(svm_model)]) 441fillprototype(libsvm.svm_set_print_string_function, None, [PRINT_STRING_FUN]) 442