1import inspect
2import itertools
3from collections.abc import Iterable
4import re
5import warnings
6from typing import Callable, Dict
7
8import numpy as np
9import scipy
10
11from Orange.data import Table, Storage, Instance, Value
12from Orange.data.filter import HasClass
13from Orange.data.table import DomainTransformationError
14from Orange.data.util import one_hot
15from Orange.misc.environ import cache_dir
16from Orange.misc.wrapper_meta import WrapperMeta
17from Orange.preprocess import Continuize, RemoveNaNColumns, SklImpute, Normalize
18from Orange.statistics.util import all_nan
19from Orange.util import Reprable, OrangeDeprecationWarning, wrap_callback, \
20    dummy_callback
21
22__all__ = ["Learner", "Model", "SklLearner", "SklModel",
23           "ReprableWithPreprocessors"]
24
25
26class ReprableWithPreprocessors(Reprable):
27    def _reprable_omit_param(self, name, default, value):
28        if name == "preprocessors":
29            default_cls = type(self).preprocessors
30            if value is default or value is default_cls:
31                return True
32            else:
33                try:
34                    return all(p1 is p2 for p1, p2 in
35                               itertools.zip_longest(value, default_cls))
36                except (ValueError, TypeError):
37                    return False
38        else:
39            return super()._reprable_omit_param(name, default, value)
40
41
42class Learner(ReprableWithPreprocessors):
43    """The base learner class.
44
45    Preprocessors can behave in a number of different ways, all of which are
46    described here.
47    If the user does not pass a preprocessor argument into the Learner
48    constructor, the default learner preprocessors are used. We assume the user
49    would simply like to get things done without having to worry about
50    preprocessors.
51    If the user chooses to pass in their own preprocessors, we assume they know
52    what they are doing. In this case, only the user preprocessors are used and
53    the default preprocessors are ignored.
54    In case the user would like to use the default preprocessors as well as
55    their own ones, the `use_default_preprocessors` flag should be set.
56
57    Parameters
58    ----------
59    preprocessors : Preprocessor or tuple[Preprocessor], optional
60        User defined preprocessors. If the user specifies their own
61        preprocessors, the default ones will not be used, unless the
62        `use_default_preprocessors` flag is set.
63
64    Attributes
65    ----------
66    preprocessors : tuple[Preprocessor] (default None)
67        The used defined preprocessors that will be used on any data.
68    use_default_preprocessors : bool (default False)
69        This flag indicates whether to use the default preprocessors that are
70        defined on the Learner class. Since preprocessors can be applied in a
71        number of ways
72    active_preprocessors : tuple[Preprocessor]
73        The processors that will be used when data is passed to the learner.
74        This depends on whether the user has passed in their own preprocessors
75        and whether the `use_default_preprocessors` flag is set.
76
77        This property is needed mainly because of the `Fitter` class, which can
78        not know in advance, which preprocessors it will need to use. Therefore
79        this resolves the active preprocessors using a lazy approach.
80    params : dict
81        The params that the learner is constructed with.
82
83    """
84    supports_multiclass = False
85    supports_weights = False
86    #: A sequence of data preprocessors to apply on data prior to
87    #: fitting the model
88    preprocessors = ()
89    learner_adequacy_err_msg = ''
90
91    def __init__(self, preprocessors=None):
92        self.use_default_preprocessors = False
93        if isinstance(preprocessors, Iterable):
94            self.preprocessors = tuple(preprocessors)
95        elif preprocessors:
96            self.preprocessors = (preprocessors,)
97
98    def fit(self, X, Y, W=None):
99        raise RuntimeError(
100            "Descendants of Learner must overload method fit or fit_storage")
101
102    def fit_storage(self, data):
103        """Default implementation of fit_storage defaults to calling fit.
104        Derived classes must define fit_storage or fit"""
105        X, Y, W = data.X, data.Y, data.W if data.has_weights() else None
106        return self.fit(X, Y, W)
107
108    def __call__(self, data, progress_callback=None):
109        if not self.check_learner_adequacy(data.domain):
110            raise ValueError(self.learner_adequacy_err_msg)
111
112        origdomain = data.domain
113
114        if isinstance(data, Instance):
115            data = Table(data.domain, [data])
116        origdata = data
117
118        if progress_callback is None:
119            progress_callback = dummy_callback
120        progress_callback(0, "Preprocessing...")
121        try:
122            cb = wrap_callback(progress_callback, end=0.1)
123            data = self.preprocess(data, progress_callback=cb)
124        except TypeError:
125            data = self.preprocess(data)
126            warnings.warn("A keyword argument 'progress_callback' has been "
127                          "added to the preprocess() signature. Implementing "
128                          "the method without the argument is deprecated and "
129                          "will result in an error in the future.",
130                          OrangeDeprecationWarning)
131
132        if len(data.domain.class_vars) > 1 and not self.supports_multiclass:
133            raise TypeError("%s doesn't support multiple class variables" %
134                            self.__class__.__name__)
135
136        progress_callback(0.1, "Fitting...")
137        model = self._fit_model(data)
138        model.used_vals = [np.unique(y).astype(int) for y in data.Y[:, None].T]
139        if not hasattr(model, "domain") or model.domain is None:
140            # some models set domain themself and it should be respected
141            # e.g. calibration learners set the base_learner's domain which
142            # would be wrongly overwritten if we set it here for any model
143            model.domain = data.domain
144        model.supports_multiclass = self.supports_multiclass
145        model.name = self.name
146        model.original_domain = origdomain
147        model.original_data = origdata
148        progress_callback(1)
149        return model
150
151    def _fit_model(self, data):
152        if type(self).fit is Learner.fit:
153            return self.fit_storage(data)
154        else:
155            X, Y, W = data.X, data.Y, data.W if data.has_weights() else None
156            return self.fit(X, Y, W)
157
158    def preprocess(self, data, progress_callback=None):
159        """Apply the `preprocessors` to the data"""
160        if progress_callback is None:
161            progress_callback = dummy_callback
162        n_pps = len(list(self.active_preprocessors))
163        for i, pp in enumerate(self.active_preprocessors):
164            progress_callback(i / n_pps)
165            data = pp(data)
166        progress_callback(1)
167        return data
168
169    @property
170    def active_preprocessors(self):
171        yield from self.preprocessors
172        if (self.use_default_preprocessors and
173                self.preprocessors is not type(self).preprocessors):
174            yield from type(self).preprocessors
175
176    def check_learner_adequacy(self, _):
177        return True
178
179    @property
180    def name(self):
181        """Return a short name derived from Learner type name"""
182        try:
183            return self.__name
184        except AttributeError:
185            name = self.__class__.__name__
186            if name.endswith('Learner'):
187                name = name[:-len('Learner')]
188            if name.endswith('Fitter'):
189                name = name[:-len('Fitter')]
190            if isinstance(self, SklLearner) and name.startswith('Skl'):
191                name = name[len('Skl'):]
192            name = name or 'learner'
193            # From http://stackoverflow.com/a/1176023/1090455 <3
194            self.name = re.sub(r'([a-z0-9])([A-Z])', r'\1 \2',
195                               re.sub(r'(.)([A-Z][a-z]+)', r'\1 \2', name)).lower()
196            return self.name
197
198    @name.setter
199    def name(self, value):
200        self.__name = value
201
202    def __str__(self):
203        return self.name
204
205
206class Model(Reprable):
207    supports_multiclass = False
208    supports_weights = False
209    Value = 0
210    Probs = 1
211    ValueProbs = 2
212
213    def __init__(self, domain=None, original_domain=None):
214        self.domain = domain
215        if original_domain is not None:
216            self.original_domain = original_domain
217        else:
218            self.original_domain = domain
219        self.used_vals = None
220
221    def predict(self, X):
222        if type(self).predict_storage is Model.predict_storage:
223            raise TypeError("Descendants of Model must overload method predict")
224        else:
225            Y = np.zeros((len(X), len(self.domain.class_vars)))
226            Y[:] = np.nan
227            table = Table(self.domain, X, Y)
228            return self.predict_storage(table)
229
230    def predict_storage(self, data):
231        if isinstance(data, Storage):
232            return self.predict(data.X)
233        elif isinstance(data, Instance):
234            return self.predict(np.atleast_2d(data.x))
235        raise TypeError("Unrecognized argument (instance of '{}')"
236                        .format(type(data).__name__))
237
238    def get_backmappers(self, data):
239        backmappers = []
240        n_values = []
241
242        dataclasses = data.domain.class_vars
243        modelclasses = self.domain.class_vars
244        if not (modelclasses and dataclasses):
245            return None, []  # classless model or data; don't touch
246        if len(dataclasses) != len(modelclasses):
247            raise DomainTransformationError(
248                "Mismatching number of model's classes and data classes")
249        for dataclass, modelclass in zip(dataclasses, modelclasses):
250            if dataclass != modelclass:
251                if dataclass.name != modelclass.name:
252                    raise DomainTransformationError(
253                        f"Model for '{modelclass.name}' "
254                        f"cannot predict '{dataclass.name}'")
255                else:
256                    raise DomainTransformationError(
257                        f"Variables '{modelclass.name}' in the model is "
258                        "incompatible with the variable of the same name "
259                        "in the data.")
260            n_values.append(dataclass.is_discrete and len(dataclass.values))
261            if dataclass is not modelclass and dataclass.is_discrete:
262                backmappers.append(dataclass.get_mapper_from(modelclass))
263            else:
264                backmappers.append(None)
265        if all(x is None for x in backmappers):
266            backmappers = None
267        return backmappers, n_values
268
269    def backmap_value(self, value, mapped_probs, n_values, backmappers):
270        if backmappers is None:
271            return value
272
273        if value.ndim == 2:  # For multitarget, recursive call by columns
274            new_value = np.zeros(value.shape)
275            for i, n_value, backmapper in zip(
276                    itertools.count(), n_values, backmappers):
277                new_value[:, i] = self.backmap_value(
278                    value[:, i], mapped_probs[:, i, :], [n_value], [backmapper])
279            return new_value
280
281        backmapper = backmappers[0]
282        if backmapper is None:
283            return value
284
285        value = backmapper(value)
286        nans = np.isnan(value)
287        if not np.any(nans) or n_values[0] < 2:
288            return value
289        if mapped_probs is not None:
290            value[nans] = np.argmax(mapped_probs[nans], axis=1)
291        else:
292            value[nans] = np.random.RandomState(0).choice(
293                backmapper(np.arange(0, n_values[0] - 1)),
294                (np.sum(nans), ))
295        return value
296
297    def backmap_probs(self, probs, n_values, backmappers):
298        if backmappers is None:
299            return probs
300
301        if probs.ndim == 3:
302            new_probs = np.zeros((len(probs), len(n_values), max(n_values)),
303                                 dtype=probs.dtype)
304            for i, n_value, backmapper in zip(
305                    itertools.count(), n_values, backmappers):
306                new_probs[:, i, :n_value] = self.backmap_probs(
307                    probs[:, i, :], [n_value], [backmapper])
308            return new_probs
309
310        backmapper = backmappers[0]
311        if backmapper is None:
312            return probs
313        n_value = n_values[0]
314        new_probs = np.zeros((len(probs), n_value), dtype=probs.dtype)
315        for col in range(probs.shape[1]):
316            target = backmapper(col)
317            if not np.isnan(target):
318                new_probs[:, int(target)] = probs[:, col]
319        tots = np.sum(new_probs, axis=1)
320        zero_sum = tots == 0
321        new_probs[zero_sum] = 1
322        tots[zero_sum] = n_value
323        new_probs = new_probs / tots[:, None]
324        return new_probs
325
326    def data_to_model_domain(
327            self, data: Table, progress_callback: Callable = dummy_callback
328    ) -> Table:
329        """
330        Transforms data to the model domain if possible.
331
332        Parameters
333        ----------
334        data
335            Data to be transformed to the model domain
336        progress_callback
337            Callback - callable - to report the progress
338
339        Returns
340        -------
341        Transformed data table
342
343        Raises
344        ------
345        DomainTransformationError
346            Error indicates that transformation is not possible since domains
347            are not compatible
348        """
349        if data.domain == self.domain:
350            return data
351
352        progress_callback(0)
353        if self.original_domain.attributes != data.domain.attributes \
354                and data.X.size \
355                and not all_nan(data.X):
356            progress_callback(0.5)
357            new_data = data.transform(self.original_domain)
358            if all_nan(new_data.X):
359                raise DomainTransformationError(
360                    "domain transformation produced no defined values")
361            progress_callback(0.75)
362            data = new_data.transform(self.domain)
363            progress_callback(1)
364            return data
365
366        progress_callback(0.5)
367        data = data.transform(self.domain)
368        progress_callback(1)
369        return data
370
371    def __call__(self, data, ret=Value):
372        multitarget = len(self.domain.class_vars) > 1
373
374        def one_hot_probs(value):
375            if not multitarget:
376                return one_hot(
377                    value,
378                    dim=len(self.domain.class_var.values)
379                    if self.domain is not None else None
380                )
381
382            max_card = max(len(c.values) for c in self.domain.class_vars)
383            probs = np.zeros(value.shape + (max_card,), float)
384            for i in range(len(self.domain.class_vars)):
385                probs[:, i, :] = one_hot(value[:, i])
386            return probs
387
388        def extend_probabilities(probs):
389            """
390            Since SklModels and models implementing `fit` and not `fit_storage`
391            do not guarantee correct prediction dimensionality, extend
392            dimensionality of probabilities when it does not match the number
393            of values in the domain.
394            """
395            class_vars = self.domain.class_vars
396            max_values = max(len(cv.values) for cv in class_vars)
397            if max_values == probs.shape[-1]:
398                return probs
399
400            if not self.supports_multiclass:
401                probs = probs[:, np.newaxis, :]
402
403            probs_ext = np.zeros((len(probs), len(class_vars), max_values))
404            for c, used_vals in enumerate(self.used_vals):
405                for i, cv in enumerate(used_vals):
406                    probs_ext[:, c, cv] = probs[:, c, i]
407
408            if not self.supports_multiclass:
409                probs_ext = probs_ext[:, 0, :]
410            return probs_ext
411
412        def fix_dim(x):
413            return x[0] if one_d else x
414
415        if not 0 <= ret <= 2:
416            raise ValueError("invalid value of argument 'ret'")
417        if ret > 0 and any(v.is_continuous for v in self.domain.class_vars):
418            raise ValueError("cannot predict continuous distributions")
419
420        # Convert 1d structures to 2d and remember doing it
421        one_d = True
422        if isinstance(data, Instance):
423            data = Table.from_list(data.domain, [data])
424        elif isinstance(data, (list, tuple)) \
425                and not isinstance(data[0], (list, tuple)):
426            data = [data]
427        elif isinstance(data, np.ndarray) and data.ndim == 1:
428            data = np.atleast_2d(data)
429        else:
430            one_d = False
431
432        # if sparse convert to csr_matrix
433        if scipy.sparse.issparse(data):
434            data = data.tocsr()
435
436        # Call the predictor
437        backmappers = None
438        n_values = []
439        if isinstance(data, (np.ndarray, scipy.sparse.csr.csr_matrix)):
440            prediction = self.predict(data)
441        elif isinstance(data, Table):
442            backmappers, n_values = self.get_backmappers(data)
443            data = self.data_to_model_domain(data)
444            prediction = self.predict_storage(data)
445        elif isinstance(data, (list, tuple)):
446            data = Table.from_list(self.original_domain, data)
447            data = data.transform(self.domain)
448            prediction = self.predict_storage(data)
449        else:
450            raise TypeError("Unrecognized argument (instance of '{}')"
451                            .format(type(data).__name__))
452
453        # Parse the result into value and probs
454        if isinstance(prediction, tuple):
455            value, probs = prediction
456        elif prediction.ndim == 1 + multitarget:
457            value, probs = prediction, None
458        elif prediction.ndim == 2 + multitarget:
459            value, probs = None, prediction
460        else:
461            raise TypeError("model returned a %i-dimensional array",
462                            prediction.ndim)
463
464        # Ensure that we have what we need to return; backmapp everything
465        if probs is None and (ret != Model.Value or backmappers is not None):
466            probs = one_hot_probs(value)
467        if probs is not None:
468            probs = extend_probabilities(probs)
469            probs = self.backmap_probs(probs, n_values, backmappers)
470        if ret != Model.Probs:
471            if value is None:
472                value = np.argmax(probs, axis=-1)
473                # probs are already backmapped
474            else:
475                value = self.backmap_value(value, probs, n_values, backmappers)
476
477        # Return what we need to
478        if ret == Model.Probs:
479            return fix_dim(probs)
480        if isinstance(data, Instance) and not multitarget:
481            value = [Value(self.domain.class_var, value[0])]
482        if ret == Model.Value:
483            return fix_dim(value)
484        else:  # ret == Model.ValueProbs
485            return fix_dim(value), fix_dim(probs)
486
487    def __getstate__(self):
488        """Skip (possibly large) data when pickling models"""
489        state = self.__dict__
490        if 'original_data' in state:
491            state = state.copy()
492            state['original_data'] = None
493        return state
494
495
496class SklModel(Model, metaclass=WrapperMeta):
497    used_vals = None
498
499    def __init__(self, skl_model):
500        self.skl_model = skl_model
501
502    def predict(self, X):
503        value = self.skl_model.predict(X)
504        # SVM has probability attribute which defines if method compute probs
505        has_prob_attr = hasattr(self.skl_model, "probability")
506        if (has_prob_attr and self.skl_model.probability
507                or not has_prob_attr
508                and hasattr(self.skl_model, "predict_proba")):
509            probs = self.skl_model.predict_proba(X)
510            return value, probs
511        return value
512
513    def __repr__(self):
514        # Params represented as a comment because not passed into constructor
515        return super().__repr__() + '  # params=' + repr(self.params)
516
517
518class SklLearner(Learner, metaclass=WrapperMeta):
519    """
520    ${skldoc}
521    Additional Orange parameters
522
523    preprocessors : list, optional
524        An ordered list of preprocessors applied to data before
525        training or testing.
526        Defaults to
527        `[RemoveNaNClasses(), Continuize(), SklImpute(), RemoveNaNColumns()]`
528    """
529    __wraps__ = None
530    __returns__ = SklModel
531    _params = {}
532
533    preprocessors = default_preprocessors = [
534        HasClass(),
535        Continuize(),
536        RemoveNaNColumns(),
537        SklImpute()]
538
539    @property
540    def params(self):
541        return self._params
542
543    @params.setter
544    def params(self, value):
545        self._params = self._get_sklparams(value)
546
547    def _get_sklparams(self, values):
548        skllearner = self.__wraps__
549        if skllearner is not None:
550            spec = list(
551                inspect.signature(skllearner.__init__).parameters.keys()
552            )
553            # first argument is 'self'
554            assert spec[0] == "self"
555            params = {
556                name: values[name] for name in spec[1:] if name in values
557            }
558        else:
559            raise TypeError("Wrapper does not define '__wraps__'")
560        return params
561
562    def preprocess(self, data, progress_callback=None):
563        data = super().preprocess(data, progress_callback)
564
565        if any(v.is_discrete and len(v.values) > 2
566               for v in data.domain.attributes):
567            raise ValueError("Wrapped scikit-learn methods do not support " +
568                             "multinomial variables.")
569
570        return data
571
572    def __call__(self, data, progress_callback=None):
573        m = super().__call__(data, progress_callback)
574        m.params = self.params
575        return m
576
577    def _initialize_wrapped(self):
578        # pylint: disable=not-callable
579        return self.__wraps__(**self.params)
580
581    def fit(self, X, Y, W=None):
582        clf = self._initialize_wrapped()
583        Y = Y.reshape(-1)
584        if W is None or not self.supports_weights:
585            return self.__returns__(clf.fit(X, Y))
586        return self.__returns__(clf.fit(X, Y, sample_weight=W.reshape(-1)))
587
588    @property
589    def supports_weights(self):
590        """Indicates whether this learner supports weighted instances.
591        """
592        return 'sample_weight' in self.__wraps__.fit.__code__.co_varnames
593
594    def __getattr__(self, item):
595        try:
596            return self.params[item]
597        except (KeyError, AttributeError):
598            raise AttributeError(item) from None
599
600    # TODO: Disallow (or mirror) __setattr__ for keys in params?
601
602    def __dir__(self):
603        dd = super().__dir__()
604        return list(sorted(set(dd) | set(self.params.keys())))
605
606
607class TreeModel(Model):
608    pass
609
610
611class RandomForestModel(Model):
612    """Interface for random forest models
613    """
614
615    @property
616    def trees(self):
617        """Return a list of Trees in the forest
618
619        Returns
620        -------
621        List[Tree]
622        """
623
624
625class KNNBase:
626    """Base class for KNN (classification and regression) learners
627    """
628
629    # pylint: disable=unused-argument
630    def __init__(self, n_neighbors=5, metric="euclidean", weights="uniform",
631                 algorithm='auto', metric_params=None,
632                 preprocessors=None):
633        super().__init__(preprocessors=preprocessors)
634        self.params = vars()
635
636    def fit(self, X, Y, W=None):
637        if self.params["metric_params"] is None and \
638                        self.params.get("metric") == "mahalanobis":
639            self.params["metric_params"] = {"V": np.cov(X.T)}
640        return super().fit(X, Y, W)
641
642
643class NNBase:
644    """Base class for neural network (classification and regression) learners
645    """
646    preprocessors = SklLearner.preprocessors + [Normalize()]
647
648    # pylint: disable=unused-argument,too-many-arguments
649    def __init__(self, hidden_layer_sizes=(100,), activation='relu',
650                 solver='adam', alpha=0.0001, batch_size='auto',
651                 learning_rate='constant', learning_rate_init=0.001,
652                 power_t=0.5, max_iter=200, shuffle=True, random_state=None,
653                 tol=0.0001, verbose=False, warm_start=False, momentum=0.9,
654                 nesterovs_momentum=True, early_stopping=False,
655                 validation_fraction=0.1, beta_1=0.9, beta_2=0.999,
656                 epsilon=1e-08, preprocessors=None):
657        super().__init__(preprocessors=preprocessors)
658        self.params = vars()
659
660
661class CatGBModel(Model, metaclass=WrapperMeta):
662    def __init__(self, cat_model, cat_features, domain):
663        super().__init__(domain)
664        self.cat_model = cat_model
665        self.cat_features = cat_features
666
667    def predict(self, X):
668        if self.cat_features:
669            X = X.astype(str)
670        value = self.cat_model.predict(X).flatten()
671        if hasattr(self.cat_model, "predict_proba"):
672            probs = self.cat_model.predict_proba(X)
673            return value, probs
674        return value
675
676    def __repr__(self):
677        # Params represented as a comment because not passed into constructor
678        return super().__repr__() + '  # params=' + repr(self.params)
679
680
681class CatGBBaseLearner(Learner, metaclass=WrapperMeta):
682    """
683    ${skldoc}
684    Additional Orange parameters
685
686    preprocessors : list, optional
687        An ordered list of preprocessors applied to data before
688        training or testing.
689        Defaults to
690        `[RemoveNaNClasses(), RemoveNaNColumns()]`
691    """
692    supports_weights = True
693    __wraps__ = None
694    __returns__ = CatGBModel
695    _params = {}
696    preprocessors = default_preprocessors = [
697        HasClass(),
698        RemoveNaNColumns(),
699    ]
700
701    # pylint: disable=unused-argument,too-many-arguments,too-many-locals
702    def __init__(self,
703                 iterations=None,
704                 learning_rate=None,
705                 depth=None,
706                 l2_leaf_reg=None,
707                 model_size_reg=None,
708                 rsm=None,
709                 loss_function=None,
710                 border_count=None,
711                 feature_border_type=None,
712                 per_float_feature_quantization=None,
713                 input_borders=None,
714                 output_borders=None,
715                 fold_permutation_block=None,
716                 od_pval=None,
717                 od_wait=None,
718                 od_type=None,
719                 nan_mode=None,
720                 counter_calc_method=None,
721                 leaf_estimation_iterations=None,
722                 leaf_estimation_method=None,
723                 thread_count=None,
724                 random_seed=None,
725                 use_best_model=None,
726                 verbose=False,
727                 logging_level=None,
728                 metric_period=None,
729                 ctr_leaf_count_limit=None,
730                 store_all_simple_ctr=None,
731                 max_ctr_complexity=None,
732                 has_time=None,
733                 allow_const_label=None,
734                 classes_count=None,
735                 class_weights=None,
736                 one_hot_max_size=None,
737                 random_strength=None,
738                 name=None,
739                 ignored_features=None,
740                 train_dir=cache_dir(),
741                 custom_loss=None,
742                 custom_metric=None,
743                 eval_metric=None,
744                 bagging_temperature=None,
745                 save_snapshot=None,
746                 snapshot_file=None,
747                 snapshot_interval=None,
748                 fold_len_multiplier=None,
749                 used_ram_limit=None,
750                 gpu_ram_part=None,
751                 allow_writing_files=False,
752                 final_ctr_computation_mode=None,
753                 approx_on_full_history=None,
754                 boosting_type=None,
755                 simple_ctr=None,
756                 combinations_ctr=None,
757                 per_feature_ctr=None,
758                 task_type=None,
759                 device_config=None,
760                 devices=None,
761                 bootstrap_type=None,
762                 subsample=None,
763                 sampling_unit=None,
764                 dev_score_calc_obj_block_size=None,
765                 max_depth=None,
766                 n_estimators=None,
767                 num_boost_round=None,
768                 num_trees=None,
769                 colsample_bylevel=None,
770                 random_state=None,
771                 reg_lambda=None,
772                 objective=None,
773                 eta=None,
774                 max_bin=None,
775                 scale_pos_weight=None,
776                 gpu_cat_features_storage=None,
777                 data_partition=None,
778                 metadata=None,
779                 early_stopping_rounds=None,
780                 cat_features=None,
781                 grow_policy=None,
782                 min_data_in_leaf=None,
783                 min_child_samples=None,
784                 max_leaves=None,
785                 num_leaves=None,
786                 score_function=None,
787                 leaf_estimation_backtracking=None,
788                 ctr_history_unit=None,
789                 monotone_constraints=None,
790                 feature_weights=None,
791                 penalties_coefficient=None,
792                 first_feature_use_penalties=None,
793                 model_shrink_rate=None,
794                 model_shrink_mode=None,
795                 langevin=None,
796                 diffusion_temperature=None,
797                 posterior_sampling=None,
798                 boost_from_average=None,
799                 text_features=None,
800                 tokenizers=None,
801                 dictionaries=None,
802                 feature_calcers=None,
803                 text_processing=None,
804                 preprocessors=None):
805        super().__init__(preprocessors=preprocessors)
806        self.params = vars()
807
808    @property
809    def params(self):
810        return self._params
811
812    @params.setter
813    def params(self, value):
814        self._params = self._get_wrapper_params(value)
815
816    def _get_wrapper_params(self, values):
817        spec = list(inspect.signature(
818            self.__wraps__.__init__).parameters.keys())
819        return {name: values[name] for name in spec[1:] if name in values}
820
821    def __call__(self, data, progress_callback=None):
822        m = super().__call__(data, progress_callback)
823        m.params = self.params
824        return m
825
826    def fit_storage(self, data: Table):
827        domain, X, Y, W = data.domain, data.X, data.Y.reshape(-1), None
828        if self.supports_weights and data.has_weights():
829            W = data.W.reshape(-1)
830        # pylint: disable=not-callable
831        clf = self.__wraps__(**self.params)
832        cat_features = [i for i, attr in enumerate(domain.attributes)
833                        if attr.is_discrete]
834        if cat_features:
835            X = X.astype(str)
836        cat_model = clf.fit(X, Y, cat_features=cat_features, sample_weight=W)
837        return self.__returns__(cat_model, cat_features, domain)
838
839    def __getattr__(self, item):
840        try:
841            return self.params[item]
842        except (KeyError, AttributeError):
843            raise AttributeError(item) from None
844
845    def __dir__(self):
846        dd = super().__dir__()
847        return list(sorted(set(dd) | set(self.params.keys())))
848
849
850class XGBBase(SklLearner):
851    """Base class for xgboost (classification and regression) learners """
852    preprocessors = default_preprocessors = [
853        HasClass(),
854        Continuize(),
855        RemoveNaNColumns(),
856    ]
857
858    def __init__(self, preprocessors=None, **kwargs):
859        super().__init__(preprocessors=preprocessors)
860        self.params = kwargs
861
862    @SklLearner.params.setter
863    def params(self, values: Dict):
864        self._params = values
865