1from collections import OrderedDict
2
3from AnyQt.QtCore import Qt
4from AnyQt.QtWidgets import QLabel, QGridLayout
5import scipy.sparse as sp
6
7from Orange.data import Table
8from Orange.modelling import SVMLearner, NuSVMLearner
9from Orange.widgets import gui
10from Orange.widgets.widget import Msg
11from Orange.widgets.settings import Setting
12from Orange.widgets.utils.owlearnerwidget import OWBaseLearner
13from Orange.widgets.utils.signals import Output
14from Orange.widgets.utils.widgetpreview import WidgetPreview
15
16
17class OWSVM(OWBaseLearner):
18    name = 'SVM'
19    description = "Support Vector Machines map inputs to higher-dimensional " \
20                  "feature spaces."
21    icon = "icons/SVM.svg"
22    replaces = [
23        "Orange.widgets.classify.owsvmclassification.OWSVMClassification",
24        "Orange.widgets.regression.owsvmregression.OWSVMRegression",
25    ]
26    priority = 50
27    keywords = ["support vector machines"]
28
29    LEARNER = SVMLearner
30
31    class Outputs(OWBaseLearner.Outputs):
32        support_vectors = Output("Support Vectors", Table, explicit=True,
33                                 replaces=["Support vectors"])
34
35    class Warning(OWBaseLearner.Warning):
36        sparse_data = Msg('Input data is sparse, default preprocessing is to scale it.')
37
38    #: Different types of SVMs
39    SVM, Nu_SVM = range(2)
40    #: SVM type
41    svm_type = Setting(SVM)
42
43    C = Setting(1.)  # pylint: disable=invalid-name
44    epsilon = Setting(.1)
45    nu_C = Setting(1.)
46    nu = Setting(.5)  # pylint: disable=invalid-name
47
48    #: Kernel types
49    Linear, Poly, RBF, Sigmoid = range(4)
50    #: Selected kernel type
51    kernel_type = Setting(RBF)
52    #: kernel degree
53    degree = Setting(3)
54    #: gamma
55    gamma = Setting(0.0)
56    #: coef0 (adative constant)
57    coef0 = Setting(1.0)
58
59    #: numerical tolerance
60    tol = Setting(0.001)
61    #: whether or not to limit number of iterations
62    limit_iter = Setting(True)
63    #: maximum number of iterations
64    max_iter = Setting(100)
65
66    _default_gamma = "auto"
67    kernels = (("Linear", "x⋅y"),
68               ("Polynomial", "(g x⋅y + c)<sup>d</sup>"),
69               ("RBF", "exp(-g|x-y|²)"),
70               ("Sigmoid", "tanh(g x⋅y + c)"))
71
72    def add_main_layout(self):
73        self._add_type_box()
74        self._add_kernel_box()
75        self._add_optimization_box()
76        self._show_right_kernel()
77
78    def _add_type_box(self):
79        # this is part of init, pylint: disable=attribute-defined-outside-init
80        form = QGridLayout()
81        self.type_box = box = gui.radioButtonsInBox(
82            self.controlArea, self, "svm_type", [], box="SVM Type",
83            orientation=form, callback=self._update_type)
84
85        self.epsilon_radio = gui.appendRadioButton(
86            box, "SVM", addToLayout=False)
87        self.c_spin = gui.doubleSpin(
88            box, self, "C", 0.1, 512.0, 0.1, decimals=2,
89            alignment=Qt.AlignRight, addToLayout=False,
90            callback=self.settings_changed)
91        self.epsilon_spin = gui.doubleSpin(
92            box, self, "epsilon", 0.1, 512.0, 0.1, decimals=2,
93            alignment=Qt.AlignRight, addToLayout=False,
94            callback=self.settings_changed)
95        form.addWidget(self.epsilon_radio, 0, 0, Qt.AlignLeft)
96        form.addWidget(QLabel("Cost (C):"), 0, 1, Qt.AlignRight)
97        form.addWidget(self.c_spin, 0, 2)
98        form.addWidget(QLabel(
99            "Regression loss epsilon (ε):"), 1, 1, Qt.AlignRight)
100        form.addWidget(self.epsilon_spin, 1, 2)
101
102        self.nu_radio = gui.appendRadioButton(box, "ν-SVM", addToLayout=False)
103        self.nu_C_spin = gui.doubleSpin(
104            box, self, "nu_C", 0.1, 512.0, 0.1, decimals=2,
105            alignment=Qt.AlignRight, addToLayout=False,
106            callback=self.settings_changed)
107        self.nu_spin = gui.doubleSpin(
108            box, self, "nu", 0.05, 1.0, 0.05, decimals=2,
109            alignment=Qt.AlignRight, addToLayout=False,
110            callback=self.settings_changed)
111        form.addWidget(self.nu_radio, 2, 0, Qt.AlignLeft)
112        form.addWidget(QLabel("Regression cost (C):"), 2, 1, Qt.AlignRight)
113        form.addWidget(self.nu_C_spin, 2, 2)
114        form.addWidget(QLabel("Complexity bound (ν):"), 3, 1, Qt.AlignRight)
115        form.addWidget(self.nu_spin, 3, 2)
116
117        # Correctly enable/disable the appropriate boxes
118        self._update_type()
119
120    def _update_type(self):
121        # Enable/disable SVM type parameters depending on selected SVM type
122        if self.svm_type == self.SVM:
123            self.c_spin.setEnabled(True)
124            self.epsilon_spin.setEnabled(True)
125            self.nu_C_spin.setEnabled(False)
126            self.nu_spin.setEnabled(False)
127        else:
128            self.c_spin.setEnabled(False)
129            self.epsilon_spin.setEnabled(False)
130            self.nu_C_spin.setEnabled(True)
131            self.nu_spin.setEnabled(True)
132        self.settings_changed()
133
134    def _add_kernel_box(self):
135        # this is part of init, pylint: disable=attribute-defined-outside-init
136        # Initialize with the widest label to measure max width
137        self.kernel_eq = self.kernels[-1][1]
138
139        box = gui.hBox(self.controlArea, "Kernel")
140
141        self.kernel_box = buttonbox = gui.radioButtonsInBox(
142            box, self, "kernel_type", btnLabels=[k[0] for k in self.kernels],
143            callback=self._on_kernel_changed)
144        buttonbox.layout().setSpacing(10)
145        gui.rubber(buttonbox)
146
147        parambox = gui.vBox(box)
148        gui.label(parambox, self, "Kernel: %(kernel_eq)s")
149        common = dict(orientation=Qt.Horizontal, callback=self.settings_changed,
150                      alignment=Qt.AlignRight, controlWidth=80)
151        spbox = gui.hBox(parambox)
152        gui.rubber(spbox)
153        inbox = gui.vBox(spbox)
154        gamma = gui.doubleSpin(
155            inbox, self, "gamma", 0.0, 10.0, 0.01, label=" g: ", **common)
156        gamma.setSpecialValueText(self._default_gamma)
157        coef0 = gui.doubleSpin(
158            inbox, self, "coef0", 0.0, 10.0, 0.01, label=" c: ", **common)
159        degree = gui.doubleSpin(
160            inbox, self, "degree", 0.0, 10.0, 0.5, label=" d: ", **common)
161        self._kernel_params = [gamma, coef0, degree]
162        gui.rubber(parambox)
163
164        # This is the maximal height (all double spins are visible)
165        # and the maximal width (the label is initialized to the widest one)
166        box.layout().activate()
167        box.setFixedHeight(box.sizeHint().height())
168        box.setMinimumWidth(box.sizeHint().width())
169
170    def _add_optimization_box(self):
171        # this is part of init, pylint: disable=attribute-defined-outside-init
172        self.optimization_box = gui.vBox(
173            self.controlArea, "Optimization Parameters")
174        self.tol_spin = gui.doubleSpin(
175            self.optimization_box, self, "tol", 1e-4, 1.0, 1e-4,
176            label="Numerical tolerance: ",
177            alignment=Qt.AlignRight, controlWidth=100,
178            callback=self.settings_changed)
179        self.max_iter_spin = gui.spin(
180            self.optimization_box, self, "max_iter", 5, 1e6, 50,
181            label="Iteration limit: ", checked="limit_iter",
182            alignment=Qt.AlignRight, controlWidth=100,
183            callback=self.settings_changed,
184            checkCallback=self.settings_changed)
185
186    def _show_right_kernel(self):
187        enabled = [[False, False, False],  # linear
188                   [True, True, True],  # poly
189                   [True, False, False],  # rbf
190                   [True, True, False]]  # sigmoid
191
192        # set in _add_kernel_box, pylint: disable=attribute-defined-outside-init
193        self.kernel_eq = self.kernels[self.kernel_type][1]
194        mask = enabled[self.kernel_type]
195        for spin, enabled in zip(self._kernel_params, mask):
196            [spin.box.hide, spin.box.show][enabled]()
197
198    def update_model(self):
199        super().update_model()
200        sv = None
201        if self.model is not None:
202            sv = self.data[self.model.skl_model.support_]
203        self.Outputs.support_vectors.send(sv)
204
205    def _on_kernel_changed(self):
206        self._show_right_kernel()
207        self.settings_changed()
208
209    def set_data(self, data):
210        self.Warning.sparse_data.clear()
211        super().set_data(data)
212        if self.data and sp.issparse(self.data.X):
213            self.Warning.sparse_data()
214
215    def create_learner(self):
216        kernel = ["linear", "poly", "rbf", "sigmoid"][self.kernel_type]
217        common_args = {
218            'kernel': kernel,
219            'degree': self.degree,
220            'gamma': self.gamma or self._default_gamma,
221            'coef0': self.coef0,
222            'probability': True,
223            'tol': self.tol,
224            'max_iter': self.max_iter if self.limit_iter else -1,
225            'preprocessors': self.preprocessors
226        }
227        if self.svm_type == self.SVM:
228            return SVMLearner(C=self.C, epsilon=self.epsilon, **common_args)
229        else:
230            return NuSVMLearner(nu=self.nu, C=self.nu_C, **common_args)
231
232    def get_learner_parameters(self):
233        items = OrderedDict()
234        if self.svm_type == self.SVM:
235            items["SVM type"] = "SVM, C={}, ε={}".format(self.C, self.epsilon)
236        else:
237            items["SVM type"] = "ν-SVM, ν={}, C={}".format(self.nu, self.nu_C)
238        self._report_kernel_parameters(items)
239        items["Numerical tolerance"] = "{:.6}".format(self.tol)
240        items["Iteration limt"] = self.max_iter if self.limit_iter else "unlimited"
241        return items
242
243    def _report_kernel_parameters(self, items):
244        gamma = self.gamma or self._default_gamma
245        if self.kernel_type == 0:
246            items["Kernel"] = "Linear"
247        elif self.kernel_type == 1:
248            items["Kernel"] = \
249                "Polynomial, ({g:.4} x⋅y + {c:.4})<sup>{d}</sup>".format(
250                    g=gamma, c=self.coef0, d=self.degree)
251        elif self.kernel_type == 2:
252            items["Kernel"] = "RBF, exp(-{:.4}|x-y|²)".format(gamma)
253        else:
254            items["Kernel"] = "Sigmoid, tanh({g:.4} x⋅y + {c:.4})".format(
255                g=gamma, c=self.coef0)
256
257
258if __name__ == "__main__":  # pragma: no cover
259    WidgetPreview(OWSVM).run(Table("iris"))
260