1from AnyQt.QtCore import Qt 2 3from Orange.base import Learner 4from Orange.data import Table 5from Orange.modelling import SklAdaBoostLearner, SklTreeLearner 6from Orange.widgets import gui 7from Orange.widgets.settings import Setting 8from Orange.widgets.utils.owlearnerwidget import OWBaseLearner 9from Orange.widgets.utils.widgetpreview import WidgetPreview 10from Orange.widgets.widget import Msg, Input 11 12 13class OWAdaBoost(OWBaseLearner): 14 name = "AdaBoost" 15 description = "An ensemble meta-algorithm that combines weak learners " \ 16 "and adapts to the 'hardness' of each training sample. " 17 icon = "icons/AdaBoost.svg" 18 replaces = [ 19 "Orange.widgets.classify.owadaboost.OWAdaBoostClassification", 20 "Orange.widgets.regression.owadaboostregression.OWAdaBoostRegression", 21 ] 22 priority = 80 23 keywords = ["boost"] 24 25 LEARNER = SklAdaBoostLearner 26 27 class Inputs(OWBaseLearner.Inputs): 28 learner = Input("Learner", Learner) 29 30 #: Algorithms for classification problems 31 algorithms = ["SAMME", "SAMME.R"] 32 #: Losses for regression problems 33 losses = ["Linear", "Square", "Exponential"] 34 35 n_estimators = Setting(50) 36 learning_rate = Setting(1.) 37 algorithm_index = Setting(1) 38 loss_index = Setting(0) 39 use_random_seed = Setting(False) 40 random_seed = Setting(0) 41 42 DEFAULT_BASE_ESTIMATOR = SklTreeLearner() 43 44 class Error(OWBaseLearner.Error): 45 no_weight_support = Msg('The base learner does not support weights.') 46 47 def add_main_layout(self): 48 # this is part of init, pylint: disable=attribute-defined-outside-init 49 box = gui.widgetBox(self.controlArea, "Parameters") 50 self.base_estimator = self.DEFAULT_BASE_ESTIMATOR 51 self.base_label = gui.label( 52 box, self, "Base estimator: " + self.base_estimator.name.title()) 53 54 self.n_estimators_spin = gui.spin( 55 box, self, "n_estimators", 1, 10000, label="Number of estimators:", 56 alignment=Qt.AlignRight, controlWidth=80, 57 callback=self.settings_changed) 58 self.learning_rate_spin = gui.doubleSpin( 59 box, self, "learning_rate", 1e-5, 1.0, 1e-5, 60 label="Learning rate:", decimals=5, alignment=Qt.AlignRight, 61 controlWidth=80, callback=self.settings_changed) 62 self.random_seed_spin = gui.spin( 63 box, self, "random_seed", 0, 2 ** 31 - 1, controlWidth=80, 64 label="Fixed seed for random generator:", alignment=Qt.AlignRight, 65 callback=self.settings_changed, checked="use_random_seed", 66 checkCallback=self.settings_changed) 67 68 # Algorithms 69 box = gui.widgetBox(self.controlArea, "Boosting method") 70 self.cls_algorithm_combo = gui.comboBox( 71 box, self, "algorithm_index", label="Classification algorithm:", 72 items=self.algorithms, 73 orientation=Qt.Horizontal, callback=self.settings_changed) 74 self.reg_algorithm_combo = gui.comboBox( 75 box, self, "loss_index", label="Regression loss function:", 76 items=self.losses, 77 orientation=Qt.Horizontal, callback=self.settings_changed) 78 79 def create_learner(self): 80 if self.base_estimator is None: 81 return None 82 return self.LEARNER( 83 base_estimator=self.base_estimator, 84 n_estimators=self.n_estimators, 85 learning_rate=self.learning_rate, 86 random_state=self.random_seed, 87 preprocessors=self.preprocessors, 88 algorithm=self.algorithms[self.algorithm_index], 89 loss=self.losses[self.loss_index].lower()) 90 91 @Inputs.learner 92 def set_base_learner(self, learner): 93 # base_estimator is defined in add_main_layout 94 # pylint: disable=attribute-defined-outside-init 95 self.Error.no_weight_support.clear() 96 if learner and not learner.supports_weights: 97 # Clear the error and reset to default base learner 98 self.Error.no_weight_support() 99 self.base_estimator = None 100 self.base_label.setText("Base estimator: INVALID") 101 else: 102 self.base_estimator = learner or self.DEFAULT_BASE_ESTIMATOR 103 self.base_label.setText( 104 "Base estimator: %s" % self.base_estimator.name.title()) 105 if self.auto_apply: 106 self.apply() 107 108 def get_learner_parameters(self): 109 return (("Base estimator", self.base_estimator), 110 ("Number of estimators", self.n_estimators), 111 ("Algorithm (classification)", self.algorithms[ 112 self.algorithm_index].capitalize()), 113 ("Loss (regression)", self.losses[ 114 self.loss_index].capitalize())) 115 116 117if __name__ == "__main__": # pragma: no cover 118 WidgetPreview(OWAdaBoost).run(Table("iris")) 119