1#!/usr/bin/env python
2
3from ctypes import *
4from ctypes.util import find_library
5from os import path
6import sys
7
8try:
9	import scipy
10	from scipy import sparse
11except:
12	scipy = None
13	sparse = None
14
15if sys.version_info[0] < 3:
16	range = xrange
17	from itertools import izip as zip
18
19__all__ = ['libsvm', 'svm_problem', 'svm_parameter',
20           'toPyModel', 'gen_svm_nodearray', 'print_null', 'svm_node', 'C_SVC',
21           'EPSILON_SVR', 'LINEAR', 'NU_SVC', 'NU_SVR', 'ONE_CLASS',
22           'POLY', 'PRECOMPUTED', 'PRINT_STRING_FUN', 'RBF',
23           'SIGMOID', 'c_double', 'svm_model']
24
25try:
26	dirname = path.dirname(path.abspath(__file__))
27	if sys.platform == 'win32':
28		libsvm = CDLL(path.join(dirname, r'..\windows\libsvm.dll'))
29	else:
30		libsvm = CDLL(path.join(dirname, '../libsvm.so.2'))
31except:
32# For unix the prefix 'lib' is not considered.
33	if find_library('svm'):
34		libsvm = CDLL(find_library('svm'))
35	elif find_library('libsvm'):
36		libsvm = CDLL(find_library('libsvm'))
37	else:
38		raise Exception('LIBSVM library not found.')
39
40C_SVC = 0
41NU_SVC = 1
42ONE_CLASS = 2
43EPSILON_SVR = 3
44NU_SVR = 4
45
46LINEAR = 0
47POLY = 1
48RBF = 2
49SIGMOID = 3
50PRECOMPUTED = 4
51
52PRINT_STRING_FUN = CFUNCTYPE(None, c_char_p)
53def print_null(s):
54	return
55
56def genFields(names, types):
57	return list(zip(names, types))
58
59def fillprototype(f, restype, argtypes):
60	f.restype = restype
61	f.argtypes = argtypes
62
63class svm_node(Structure):
64	_names = ["index", "value"]
65	_types = [c_int, c_double]
66	_fields_ = genFields(_names, _types)
67
68	def __init__(self, index=-1, value=0):
69		self.index, self.value = index, value
70
71	def __str__(self):
72		return '%d:%g' % (self.index, self.value)
73
74def gen_svm_nodearray(xi, feature_max=None, isKernel=False):
75	if feature_max:
76		assert(isinstance(feature_max, int))
77
78	xi_shift = 0 # ensure correct indices of xi
79	if scipy and isinstance(xi, tuple) and len(xi) == 2\
80			and isinstance(xi[0], scipy.ndarray) and isinstance(xi[1], scipy.ndarray): # for a sparse vector
81		if not isKernel:
82			index_range = xi[0] + 1 # index starts from 1
83		else:
84			index_range = xi[0] # index starts from 0 for precomputed kernel
85		if feature_max:
86			index_range = index_range[scipy.where(index_range <= feature_max)]
87	elif scipy and isinstance(xi, scipy.ndarray):
88		if not isKernel:
89			xi_shift = 1
90			index_range = xi.nonzero()[0] + 1 # index starts from 1
91		else:
92			index_range = scipy.arange(0, len(xi)) # index starts from 0 for precomputed kernel
93		if feature_max:
94			index_range = index_range[scipy.where(index_range <= feature_max)]
95	elif isinstance(xi, (dict, list, tuple)):
96		if isinstance(xi, dict):
97			index_range = xi.keys()
98		elif isinstance(xi, (list, tuple)):
99			if not isKernel:
100				xi_shift = 1
101				index_range = range(1, len(xi) + 1) # index starts from 1
102			else:
103				index_range = range(0, len(xi)) # index starts from 0 for precomputed kernel
104
105		if feature_max:
106			index_range = filter(lambda j: j <= feature_max, index_range)
107		if not isKernel:
108			index_range = filter(lambda j:xi[j-xi_shift] != 0, index_range)
109
110		index_range = sorted(index_range)
111	else:
112		raise TypeError('xi should be a dictionary, list, tuple, 1-d numpy array, or tuple of (index, data)')
113
114	ret = (svm_node*(len(index_range)+1))()
115	ret[-1].index = -1
116
117	if scipy and isinstance(xi, tuple) and len(xi) == 2\
118			and isinstance(xi[0], scipy.ndarray) and isinstance(xi[1], scipy.ndarray): # for a sparse vector
119		for idx, j in enumerate(index_range):
120			ret[idx].index = j
121			ret[idx].value = (xi[1])[idx]
122	else:
123		for idx, j in enumerate(index_range):
124			ret[idx].index = j
125			ret[idx].value = xi[j - xi_shift]
126
127	max_idx = 0
128	if len(index_range) > 0:
129		max_idx = index_range[-1]
130	return ret, max_idx
131
132try:
133	from numba import jit
134	jit_enabled = True
135except:
136	jit = lambda x: x
137	jit_enabled = False
138
139@jit
140def csr_to_problem_jit(l, x_val, x_ind, x_rowptr, prob_val, prob_ind, prob_rowptr, indx_start):
141	for i in range(l):
142		b1,e1 = x_rowptr[i], x_rowptr[i+1]
143		b2,e2 = prob_rowptr[i], prob_rowptr[i+1]-1
144		for j in range(b1,e1):
145			prob_ind[j-b1+b2] = x_ind[j]+indx_start
146			prob_val[j-b1+b2] = x_val[j]
147def csr_to_problem_nojit(l, x_val, x_ind, x_rowptr, prob_val, prob_ind, prob_rowptr, indx_start):
148	for i in range(l):
149		x_slice = slice(x_rowptr[i], x_rowptr[i+1])
150		prob_slice = slice(prob_rowptr[i], prob_rowptr[i+1]-1)
151		prob_ind[prob_slice] = x_ind[x_slice]+indx_start
152		prob_val[prob_slice] = x_val[x_slice]
153
154def csr_to_problem(x, prob, isKernel):
155	if not x.has_sorted_indices:
156		x.sort_indices()
157
158	# Extra space for termination node and (possibly) bias term
159	x_space = prob.x_space = scipy.empty((x.nnz+x.shape[0]), dtype=svm_node)
160	prob.rowptr = x.indptr.copy()
161	prob.rowptr[1:] += scipy.arange(1,x.shape[0]+1)
162	prob_ind = x_space["index"]
163	prob_val = x_space["value"]
164	prob_ind[:] = -1
165	if not isKernel:
166		indx_start = 1 # index starts from 1
167	else:
168		indx_start = 0 # index starts from 0 for precomputed kernel
169	if jit_enabled:
170		csr_to_problem_jit(x.shape[0], x.data, x.indices, x.indptr, prob_val, prob_ind, prob.rowptr, indx_start)
171	else:
172		csr_to_problem_nojit(x.shape[0], x.data, x.indices, x.indptr, prob_val, prob_ind, prob.rowptr, indx_start)
173
174class svm_problem(Structure):
175	_names = ["l", "y", "x"]
176	_types = [c_int, POINTER(c_double), POINTER(POINTER(svm_node))]
177	_fields_ = genFields(_names, _types)
178
179	def __init__(self, y, x, isKernel=False):
180		if (not isinstance(y, (list, tuple))) and (not (scipy and isinstance(y, scipy.ndarray))):
181			raise TypeError("type of y: {0} is not supported!".format(type(y)))
182
183		if isinstance(x, (list, tuple)):
184			if len(y) != len(x):
185				raise ValueError("len(y) != len(x)")
186		elif scipy != None and isinstance(x, (scipy.ndarray, sparse.spmatrix)):
187			if len(y) != x.shape[0]:
188				raise ValueError("len(y) != len(x)")
189			if isinstance(x, scipy.ndarray):
190				x = scipy.ascontiguousarray(x) # enforce row-major
191			if isinstance(x, sparse.spmatrix):
192				x = x.tocsr()
193				pass
194		else:
195			raise TypeError("type of x: {0} is not supported!".format(type(x)))
196		self.l = l = len(y)
197
198		max_idx = 0
199		x_space = self.x_space = []
200		if scipy != None and isinstance(x, sparse.csr_matrix):
201			csr_to_problem(x, self, isKernel)
202			max_idx = x.shape[1]
203		else:
204			for i, xi in enumerate(x):
205				tmp_xi, tmp_idx = gen_svm_nodearray(xi,isKernel=isKernel)
206				x_space += [tmp_xi]
207				max_idx = max(max_idx, tmp_idx)
208		self.n = max_idx
209
210		self.y = (c_double * l)()
211		if scipy != None and isinstance(y, scipy.ndarray):
212			scipy.ctypeslib.as_array(self.y, (self.l,))[:] = y
213		else:
214			for i, yi in enumerate(y): self.y[i] = yi
215
216		self.x = (POINTER(svm_node) * l)()
217		if scipy != None and isinstance(x, sparse.csr_matrix):
218			base = addressof(self.x_space.ctypes.data_as(POINTER(svm_node))[0])
219			x_ptr = cast(self.x, POINTER(c_uint64))
220			x_ptr = scipy.ctypeslib.as_array(x_ptr,(self.l,))
221			x_ptr[:] = self.rowptr[:-1]*sizeof(svm_node)+base
222		else:
223			for i, xi in enumerate(self.x_space): self.x[i] = xi
224
225class svm_parameter(Structure):
226	_names = ["svm_type", "kernel_type", "degree", "gamma", "coef0",
227			"cache_size", "eps", "C", "nr_weight", "weight_label", "weight",
228			"nu", "p", "shrinking", "probability"]
229	_types = [c_int, c_int, c_int, c_double, c_double,
230			c_double, c_double, c_double, c_int, POINTER(c_int), POINTER(c_double),
231			c_double, c_double, c_int, c_int]
232	_fields_ = genFields(_names, _types)
233
234	def __init__(self, options = None):
235		if options == None:
236			options = ''
237		self.parse_options(options)
238
239	def __str__(self):
240		s = ''
241		attrs = svm_parameter._names + list(self.__dict__.keys())
242		values = map(lambda attr: getattr(self, attr), attrs)
243		for attr, val in zip(attrs, values):
244			s += (' %s: %s\n' % (attr, val))
245		s = s.strip()
246
247		return s
248
249	def set_to_default_values(self):
250		self.svm_type = C_SVC;
251		self.kernel_type = RBF
252		self.degree = 3
253		self.gamma = 0
254		self.coef0 = 0
255		self.nu = 0.5
256		self.cache_size = 100
257		self.C = 1
258		self.eps = 0.001
259		self.p = 0.1
260		self.shrinking = 1
261		self.probability = 0
262		self.nr_weight = 0
263		self.weight_label = None
264		self.weight = None
265		self.cross_validation = False
266		self.nr_fold = 0
267		self.print_func = cast(None, PRINT_STRING_FUN)
268
269	def parse_options(self, options):
270		if isinstance(options, list):
271			argv = options
272		elif isinstance(options, str):
273			argv = options.split()
274		else:
275			raise TypeError("arg 1 should be a list or a str.")
276		self.set_to_default_values()
277		self.print_func = cast(None, PRINT_STRING_FUN)
278		weight_label = []
279		weight = []
280
281		i = 0
282		while i < len(argv):
283			if argv[i] == "-s":
284				i = i + 1
285				self.svm_type = int(argv[i])
286			elif argv[i] == "-t":
287				i = i + 1
288				self.kernel_type = int(argv[i])
289			elif argv[i] == "-d":
290				i = i + 1
291				self.degree = int(argv[i])
292			elif argv[i] == "-g":
293				i = i + 1
294				self.gamma = float(argv[i])
295			elif argv[i] == "-r":
296				i = i + 1
297				self.coef0 = float(argv[i])
298			elif argv[i] == "-n":
299				i = i + 1
300				self.nu = float(argv[i])
301			elif argv[i] == "-m":
302				i = i + 1
303				self.cache_size = float(argv[i])
304			elif argv[i] == "-c":
305				i = i + 1
306				self.C = float(argv[i])
307			elif argv[i] == "-e":
308				i = i + 1
309				self.eps = float(argv[i])
310			elif argv[i] == "-p":
311				i = i + 1
312				self.p = float(argv[i])
313			elif argv[i] == "-h":
314				i = i + 1
315				self.shrinking = int(argv[i])
316			elif argv[i] == "-b":
317				i = i + 1
318				self.probability = int(argv[i])
319			elif argv[i] == "-q":
320				self.print_func = PRINT_STRING_FUN(print_null)
321			elif argv[i] == "-v":
322				i = i + 1
323				self.cross_validation = 1
324				self.nr_fold = int(argv[i])
325				if self.nr_fold < 2:
326					raise ValueError("n-fold cross validation: n must >= 2")
327			elif argv[i].startswith("-w"):
328				i = i + 1
329				self.nr_weight += 1
330				weight_label += [int(argv[i-1][2:])]
331				weight += [float(argv[i])]
332			else:
333				raise ValueError("Wrong options")
334			i += 1
335
336		libsvm.svm_set_print_string_function(self.print_func)
337		self.weight_label = (c_int*self.nr_weight)()
338		self.weight = (c_double*self.nr_weight)()
339		for i in range(self.nr_weight):
340			self.weight[i] = weight[i]
341			self.weight_label[i] = weight_label[i]
342
343class svm_model(Structure):
344	_names = ['param', 'nr_class', 'l', 'SV', 'sv_coef', 'rho',
345			'probA', 'probB', 'sv_indices', 'label', 'nSV', 'free_sv']
346	_types = [svm_parameter, c_int, c_int, POINTER(POINTER(svm_node)),
347			POINTER(POINTER(c_double)), POINTER(c_double),
348			POINTER(c_double), POINTER(c_double), POINTER(c_int),
349			POINTER(c_int), POINTER(c_int), c_int]
350	_fields_ = genFields(_names, _types)
351
352	def __init__(self):
353		self.__createfrom__ = 'python'
354
355	def __del__(self):
356		# free memory created by C to avoid memory leak
357		if hasattr(self, '__createfrom__') and self.__createfrom__ == 'C':
358			libsvm.svm_free_and_destroy_model(pointer(pointer(self)))
359
360	def get_svm_type(self):
361		return libsvm.svm_get_svm_type(self)
362
363	def get_nr_class(self):
364		return libsvm.svm_get_nr_class(self)
365
366	def get_svr_probability(self):
367		return libsvm.svm_get_svr_probability(self)
368
369	def get_labels(self):
370		nr_class = self.get_nr_class()
371		labels = (c_int * nr_class)()
372		libsvm.svm_get_labels(self, labels)
373		return labels[:nr_class]
374
375	def get_sv_indices(self):
376		total_sv = self.get_nr_sv()
377		sv_indices = (c_int * total_sv)()
378		libsvm.svm_get_sv_indices(self, sv_indices)
379		return sv_indices[:total_sv]
380
381	def get_nr_sv(self):
382		return libsvm.svm_get_nr_sv(self)
383
384	def is_probability_model(self):
385		return (libsvm.svm_check_probability_model(self) == 1)
386
387	def get_sv_coef(self):
388		return [tuple(self.sv_coef[j][i] for j in range(self.nr_class - 1))
389				for i in range(self.l)]
390
391	def get_SV(self):
392		result = []
393		for sparse_sv in self.SV[:self.l]:
394			row = dict()
395
396			i = 0
397			while True:
398				if sparse_sv[i].index == -1:
399					break
400				row[sparse_sv[i].index] = sparse_sv[i].value
401				i += 1
402
403			result.append(row)
404		return result
405
406def toPyModel(model_ptr):
407	"""
408	toPyModel(model_ptr) -> svm_model
409
410	Convert a ctypes POINTER(svm_model) to a Python svm_model
411	"""
412	if bool(model_ptr) == False:
413		raise ValueError("Null pointer")
414	m = model_ptr.contents
415	m.__createfrom__ = 'C'
416	return m
417
418fillprototype(libsvm.svm_train, POINTER(svm_model), [POINTER(svm_problem), POINTER(svm_parameter)])
419fillprototype(libsvm.svm_cross_validation, None, [POINTER(svm_problem), POINTER(svm_parameter), c_int, POINTER(c_double)])
420
421fillprototype(libsvm.svm_save_model, c_int, [c_char_p, POINTER(svm_model)])
422fillprototype(libsvm.svm_load_model, POINTER(svm_model), [c_char_p])
423
424fillprototype(libsvm.svm_get_svm_type, c_int, [POINTER(svm_model)])
425fillprototype(libsvm.svm_get_nr_class, c_int, [POINTER(svm_model)])
426fillprototype(libsvm.svm_get_labels, None, [POINTER(svm_model), POINTER(c_int)])
427fillprototype(libsvm.svm_get_sv_indices, None, [POINTER(svm_model), POINTER(c_int)])
428fillprototype(libsvm.svm_get_nr_sv, c_int, [POINTER(svm_model)])
429fillprototype(libsvm.svm_get_svr_probability, c_double, [POINTER(svm_model)])
430
431fillprototype(libsvm.svm_predict_values, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)])
432fillprototype(libsvm.svm_predict, c_double, [POINTER(svm_model), POINTER(svm_node)])
433fillprototype(libsvm.svm_predict_probability, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)])
434
435fillprototype(libsvm.svm_free_model_content, None, [POINTER(svm_model)])
436fillprototype(libsvm.svm_free_and_destroy_model, None, [POINTER(POINTER(svm_model))])
437fillprototype(libsvm.svm_destroy_param, None, [POINTER(svm_parameter)])
438
439fillprototype(libsvm.svm_check_parameter, c_char_p, [POINTER(svm_problem), POINTER(svm_parameter)])
440fillprototype(libsvm.svm_check_probability_model, c_int, [POINTER(svm_model)])
441fillprototype(libsvm.svm_set_print_string_function, None, [PRINT_STRING_FUN])
442