1from AnyQt.QtCore import Qt
2
3from Orange.base import Learner
4from Orange.data import Table
5from Orange.modelling import SklAdaBoostLearner, SklTreeLearner
6from Orange.widgets import gui
7from Orange.widgets.settings import Setting
8from Orange.widgets.utils.owlearnerwidget import OWBaseLearner
9from Orange.widgets.utils.widgetpreview import WidgetPreview
10from Orange.widgets.widget import Msg, Input
11
12
13class OWAdaBoost(OWBaseLearner):
14    name = "AdaBoost"
15    description = "An ensemble meta-algorithm that combines weak learners " \
16                  "and adapts to the 'hardness' of each training sample. "
17    icon = "icons/AdaBoost.svg"
18    replaces = [
19        "Orange.widgets.classify.owadaboost.OWAdaBoostClassification",
20        "Orange.widgets.regression.owadaboostregression.OWAdaBoostRegression",
21    ]
22    priority = 80
23    keywords = ["boost"]
24
25    LEARNER = SklAdaBoostLearner
26
27    class Inputs(OWBaseLearner.Inputs):
28        learner = Input("Learner", Learner)
29
30    #: Algorithms for classification problems
31    algorithms = ["SAMME", "SAMME.R"]
32    #: Losses for regression problems
33    losses = ["Linear", "Square", "Exponential"]
34
35    n_estimators = Setting(50)
36    learning_rate = Setting(1.)
37    algorithm_index = Setting(1)
38    loss_index = Setting(0)
39    use_random_seed = Setting(False)
40    random_seed = Setting(0)
41
42    DEFAULT_BASE_ESTIMATOR = SklTreeLearner()
43
44    class Error(OWBaseLearner.Error):
45        no_weight_support = Msg('The base learner does not support weights.')
46
47    def add_main_layout(self):
48        # this is part of init, pylint: disable=attribute-defined-outside-init
49        box = gui.widgetBox(self.controlArea, "Parameters")
50        self.base_estimator = self.DEFAULT_BASE_ESTIMATOR
51        self.base_label = gui.label(
52            box, self, "Base estimator: " + self.base_estimator.name.title())
53
54        self.n_estimators_spin = gui.spin(
55            box, self, "n_estimators", 1, 10000, label="Number of estimators:",
56            alignment=Qt.AlignRight, controlWidth=80,
57            callback=self.settings_changed)
58        self.learning_rate_spin = gui.doubleSpin(
59            box, self, "learning_rate", 1e-5, 1.0, 1e-5,
60            label="Learning rate:", decimals=5, alignment=Qt.AlignRight,
61            controlWidth=80, callback=self.settings_changed)
62        self.random_seed_spin = gui.spin(
63            box, self, "random_seed", 0, 2 ** 31 - 1, controlWidth=80,
64            label="Fixed seed for random generator:", alignment=Qt.AlignRight,
65            callback=self.settings_changed, checked="use_random_seed",
66            checkCallback=self.settings_changed)
67
68        # Algorithms
69        box = gui.widgetBox(self.controlArea, "Boosting method")
70        self.cls_algorithm_combo = gui.comboBox(
71            box, self, "algorithm_index", label="Classification algorithm:",
72            items=self.algorithms,
73            orientation=Qt.Horizontal, callback=self.settings_changed)
74        self.reg_algorithm_combo = gui.comboBox(
75            box, self, "loss_index", label="Regression loss function:",
76            items=self.losses,
77            orientation=Qt.Horizontal, callback=self.settings_changed)
78
79    def create_learner(self):
80        if self.base_estimator is None:
81            return None
82        return self.LEARNER(
83            base_estimator=self.base_estimator,
84            n_estimators=self.n_estimators,
85            learning_rate=self.learning_rate,
86            random_state=self.random_seed,
87            preprocessors=self.preprocessors,
88            algorithm=self.algorithms[self.algorithm_index],
89            loss=self.losses[self.loss_index].lower())
90
91    @Inputs.learner
92    def set_base_learner(self, learner):
93        # base_estimator is defined in add_main_layout
94        # pylint: disable=attribute-defined-outside-init
95        self.Error.no_weight_support.clear()
96        if learner and not learner.supports_weights:
97            # Clear the error and reset to default base learner
98            self.Error.no_weight_support()
99            self.base_estimator = None
100            self.base_label.setText("Base estimator: INVALID")
101        else:
102            self.base_estimator = learner or self.DEFAULT_BASE_ESTIMATOR
103            self.base_label.setText(
104                "Base estimator: %s" % self.base_estimator.name.title())
105        if self.auto_apply:
106            self.apply()
107
108    def get_learner_parameters(self):
109        return (("Base estimator", self.base_estimator),
110                ("Number of estimators", self.n_estimators),
111                ("Algorithm (classification)", self.algorithms[
112                    self.algorithm_index].capitalize()),
113                ("Loss (regression)", self.losses[
114                    self.loss_index].capitalize()))
115
116
117if __name__ == "__main__":  # pragma: no cover
118    WidgetPreview(OWAdaBoost).run(Table("iris"))
119