1# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module
2# pylint: disable=missing-class-docstring, invalid-name
3# pylint: disable=too-many-lines, fixme
4# pylint: disable=too-few-public-methods
5# pylint: disable=import-error
6"""Dask extensions for distributed training. See
7https://xgboost.readthedocs.io/en/latest/tutorials/dask.html for simple
8tutorial.  Also xgboost/demo/dask for some examples.
9
10There are two sets of APIs in this module, one is the functional API including
11``train`` and ``predict`` methods.  Another is stateful Scikit-Learner wrapper
12inherited from single-node Scikit-Learn interface.
13
14The implementation is heavily influenced by dask_xgboost:
15https://github.com/dask/dask-xgboost
16
17"""
18import platform
19import logging
20from contextlib import contextmanager
21from collections import defaultdict
22from collections.abc import Sequence
23from threading import Thread
24from functools import partial, update_wrapper
25from typing import TYPE_CHECKING, List, Tuple, Callable, Optional, Any, Union, Dict, Set
26from typing import Awaitable, Generator, TypeVar
27
28import numpy
29
30from . import rabit, config
31
32from .callback import TrainingCallback
33
34from .compat import LazyLoader
35from .compat import sparse, scipy_sparse
36from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
37from .compat import lazy_isinstance
38
39from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter
40from .core import Objective, Metric
41from .core import _deprecate_positional_args
42from .training import train as worker_train
43from .tracker import RabitTracker, get_host_ip
44from .sklearn import XGBModel, XGBClassifier, XGBRegressorBase, XGBClassifierBase
45from .sklearn import _wrap_evaluation_matrices, _objective_decorator, _check_rf_callback
46from .sklearn import XGBRankerMixIn
47from .sklearn import xgboost_model_doc
48from .sklearn import _cls_predict_proba
49from .sklearn import XGBRanker
50
51if TYPE_CHECKING:
52    from dask import dataframe as dd
53    from dask import array as da
54    import dask
55    import distributed
56else:
57    dd = LazyLoader('dd', globals(), 'dask.dataframe')
58    da = LazyLoader('da', globals(), 'dask.array')
59    dask = LazyLoader('dask', globals(), 'dask')
60    distributed = LazyLoader('distributed', globals(), 'dask.distributed')
61
62_DaskCollection = Union["da.Array", "dd.DataFrame", "dd.Series"]
63
64try:
65    from mypy_extensions import TypedDict
66    TrainReturnT = TypedDict('TrainReturnT', {
67        'booster': Booster,
68        'history': Dict,
69    })
70except ImportError:
71    TrainReturnT = Dict[str, Any]  # type:ignore
72
73__all__ = [
74    "RabitContext",
75    "DaskDMatrix",
76    "DaskDeviceQuantileDMatrix",
77    "DaskXGBRegressor",
78    "DaskXGBClassifier",
79    "DaskXGBRanker",
80    "DaskXGBRFRegressor",
81    "DaskXGBRFClassifier",
82    "train",
83    "predict",
84    "inplace_predict",
85]
86
87# TODOs:
88#   - CV
89#
90# Note for developers:
91#
92#   As of writing asyncio is still a new feature of Python and in depth documentation is
93#   rare.  Best examples of various asyncio tricks are in dask (luckily).  Classes like
94#   Client, Worker are awaitable.  Some general rules for the implementation here:
95#
96#     - Synchronous world is different from asynchronous one, and they don't mix well.
97#     - Write everything with async, then use distributed Client sync function to do the
98#       switch.
99#     - Use Any for type hint when the return value can be union of Awaitable and plain
100#       value.  This is caused by Client.sync can return both types depending on context.
101#       Right now there's no good way to silent:
102#
103#         await train(...)
104#
105#       if train returns an Union type.
106
107
108LOGGER = logging.getLogger('[xgboost.dask]')
109
110
111def _multi_lock() -> Any:
112    """MultiLock is only available on latest distributed.  See:
113
114    https://github.com/dask/distributed/pull/4503
115
116"""
117    try:
118        from distributed import MultiLock
119    except ImportError:
120        class MultiLock:        # type:ignore
121            def __init__(self, *args: Any, **kwargs: Any) -> None:
122                pass
123
124            def __enter__(self) -> "MultiLock":
125                return self
126
127            def __exit__(self, *args: Any, **kwargs: Any) -> None:
128                return
129
130            async def __aenter__(self) -> "MultiLock":
131                return self
132
133            async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
134                return
135
136    return MultiLock
137
138
139def _start_tracker(n_workers: int) -> Dict[str, Any]:
140    """Start Rabit tracker """
141    env = {'DMLC_NUM_WORKER': n_workers}
142    host = get_host_ip('auto')
143    rabit_context = RabitTracker(hostIP=host, nslave=n_workers, use_logger=False)
144    env.update(rabit_context.slave_envs())
145
146    rabit_context.start(n_workers)
147    thread = Thread(target=rabit_context.join)
148    thread.daemon = True
149    thread.start()
150    return env
151
152
153def _assert_dask_support() -> None:
154    try:
155        import dask  # pylint: disable=W0621,W0611
156    except ImportError as e:
157        raise ImportError(
158            "Dask needs to be installed in order to use this module"
159        ) from e
160
161    if platform.system() == "Windows":
162        msg = "Windows is not officially supported for dask/xgboost,"
163        msg += " contribution are welcomed."
164        LOGGER.warning(msg)
165
166
167class RabitContext:
168    '''A context controling rabit initialization and finalization.'''
169    def __init__(self, args: List[bytes]) -> None:
170        self.args = args
171        worker = distributed.get_worker()
172        self.args.append(
173            ('DMLC_TASK_ID=[xgboost.dask]:' + str(worker.address)).encode())
174
175    def __enter__(self) -> None:
176        rabit.init(self.args)
177        LOGGER.debug('-------------- rabit say hello ------------------')
178
179    def __exit__(self, *args: List) -> None:
180        rabit.finalize()
181        LOGGER.debug('--------------- rabit say bye ------------------')
182
183
184def concat(value: Any) -> Any:  # pylint: disable=too-many-return-statements
185    '''To be replaced with dask builtin.'''
186    if isinstance(value[0], numpy.ndarray):
187        return numpy.concatenate(value, axis=0)
188    if scipy_sparse and isinstance(value[0], scipy_sparse.spmatrix):
189        return scipy_sparse.vstack(value, format='csr')
190    if sparse and isinstance(value[0], sparse.SparseArray):
191        return sparse.concatenate(value, axis=0)
192    if PANDAS_INSTALLED and isinstance(value[0], (DataFrame, Series)):
193        return pandas_concat(value, axis=0)
194    if lazy_isinstance(value[0], 'cudf.core.dataframe', 'DataFrame') or \
195       lazy_isinstance(value[0], 'cudf.core.series', 'Series'):
196        from cudf import concat as CUDF_concat  # pylint: disable=import-error
197        return CUDF_concat(value, axis=0)
198    if lazy_isinstance(value[0], 'cupy._core.core', 'ndarray'):
199        import cupy
200        # pylint: disable=c-extension-no-member,no-member
201        d = cupy.cuda.runtime.getDevice()
202        for v in value:
203            d_v = v.device.id
204            assert d_v == d, 'Concatenating arrays on different devices.'
205        return cupy.concatenate(value, axis=0)
206    return dd.multi.concat(list(value), axis=0)
207
208
209def _xgb_get_client(client: Optional["distributed.Client"]) -> "distributed.Client":
210    '''Simple wrapper around testing None.'''
211    if not isinstance(client, (type(distributed.get_client()), type(None))):
212        raise TypeError(
213            _expect([type(distributed.get_client()), type(None)], type(client)))
214    ret = distributed.get_client() if client is None else client
215    return ret
216
217# From the implementation point of view, DaskDMatrix complicates a lots of
218# things.  A large portion of the code base is about syncing and extracting
219# stuffs from DaskDMatrix.  But having an independent data structure gives us a
220# chance to perform some specialized optimizations, like building histogram
221# index directly.
222
223
224class DaskDMatrix:
225    # pylint: disable=missing-docstring, too-many-instance-attributes
226    '''DMatrix holding on references to Dask DataFrame or Dask Array.  Constructing a
227    `DaskDMatrix` forces all lazy computation to be carried out.  Wait for the input data
228    explicitly if you want to see actual computation of constructing `DaskDMatrix`.
229
230    See doc for :py:obj:`xgboost.DMatrix` constructor for other parameters.  DaskDMatrix
231    accepts only dask collection.
232
233    .. note::
234
235        DaskDMatrix does not repartition or move data between workers.  It's
236        the caller's responsibility to balance the data.
237
238    .. versionadded:: 1.0.0
239
240    Parameters
241    ----------
242    client :
243        Specify the dask client used for training.  Use default client returned from dask
244        if it's set to None.
245
246    '''
247
248    @_deprecate_positional_args
249    def __init__(
250        self,
251        client: "distributed.Client",
252        data: _DaskCollection,
253        label: Optional[_DaskCollection] = None,
254        *,
255        weight: Optional[_DaskCollection] = None,
256        base_margin: Optional[_DaskCollection] = None,
257        missing: float = None,
258        silent: bool = False,   # pylint: disable=unused-argument
259        feature_names: Optional[Union[str, List[str]]] = None,
260        feature_types: Optional[Union[Any, List[Any]]] = None,
261        group: Optional[_DaskCollection] = None,
262        qid: Optional[_DaskCollection] = None,
263        label_lower_bound: Optional[_DaskCollection] = None,
264        label_upper_bound: Optional[_DaskCollection] = None,
265        feature_weights: Optional[_DaskCollection] = None,
266        enable_categorical: bool = False
267    ) -> None:
268        _assert_dask_support()
269        client = _xgb_get_client(client)
270
271        self.feature_names = feature_names
272        self.feature_types = feature_types
273        self.missing = missing
274        self.enable_categorical = enable_categorical
275
276        if qid is not None and weight is not None:
277            raise NotImplementedError("per-group weight is not implemented.")
278        if group is not None:
279            raise NotImplementedError(
280                "group structure is not implemented, use qid instead."
281            )
282
283        if len(data.shape) != 2:
284            raise ValueError(
285                f"Expecting 2 dimensional input, got: {data.shape}"
286            )
287
288        if not isinstance(data, (dd.DataFrame, da.Array)):
289            raise TypeError(_expect((dd.DataFrame, da.Array), type(data)))
290        if not isinstance(label, (dd.DataFrame, da.Array, dd.Series, type(None))):
291            raise TypeError(_expect((dd.DataFrame, da.Array, dd.Series), type(label)))
292
293        self._n_cols = data.shape[1]
294        assert isinstance(self._n_cols, int)
295        self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list)
296        self.is_quantile: bool = False
297
298        self._init = client.sync(
299            self._map_local_data,
300            client,
301            data,
302            label=label,
303            weights=weight,
304            base_margin=base_margin,
305            qid=qid,
306            feature_weights=feature_weights,
307            label_lower_bound=label_lower_bound,
308            label_upper_bound=label_upper_bound,
309        )
310
311    def __await__(self) -> Generator:
312        return self._init.__await__()
313
314    async def _map_local_data(
315        self,
316        client: "distributed.Client",
317        data: _DaskCollection,
318        label: Optional[_DaskCollection] = None,
319        weights: Optional[_DaskCollection] = None,
320        base_margin: Optional[_DaskCollection] = None,
321        qid: Optional[_DaskCollection] = None,
322        feature_weights: Optional[_DaskCollection] = None,
323        label_lower_bound: Optional[_DaskCollection] = None,
324        label_upper_bound: Optional[_DaskCollection] = None,
325    ) -> "DaskDMatrix":
326        '''Obtain references to local data.'''
327
328        def inconsistent(
329            left: List[Any], left_name: str, right: List[Any], right_name: str
330        ) -> str:
331            msg = (f"Partitions between {left_name} and {right_name} are not "
332                   f"consistent: {len(left)} != {len(right)}.  "
333                   f"Please try to repartition/rechunk your data.")
334            return msg
335
336        def check_columns(parts: Any) -> None:
337            # x is required to be 2 dim in __init__
338            assert parts.ndim == 1 or parts.shape[1], 'Data should be' \
339                ' partitioned by row. To avoid this specify the number' \
340                ' of columns for your dask Array explicitly. e.g.' \
341                ' chunks=(partition_size, X.shape[1])'
342
343        data = client.persist(data)
344        for meta in [label, weights, base_margin, label_lower_bound,
345                     label_upper_bound]:
346            if meta is not None:
347                meta = client.persist(meta)
348        # Breaking data into partitions, a trick borrowed from dask_xgboost.
349
350        # `to_delayed` downgrades high-level objects into numpy or pandas
351        # equivalents.
352        X_parts = data.to_delayed()
353        if isinstance(X_parts, numpy.ndarray):
354            check_columns(X_parts)
355            X_parts = X_parts.flatten().tolist()
356
357        def flatten_meta(
358            meta: Optional[_DaskCollection]
359        ) -> "Optional[List[dask.delayed.Delayed]]":
360            if meta is not None:
361                meta_parts = meta.to_delayed()
362                if isinstance(meta_parts, numpy.ndarray):
363                    check_columns(meta_parts)
364                    meta_parts = meta_parts.flatten().tolist()
365                return meta_parts
366            return None
367
368        y_parts = flatten_meta(label)
369        w_parts = flatten_meta(weights)
370        margin_parts = flatten_meta(base_margin)
371        qid_parts = flatten_meta(qid)
372        ll_parts = flatten_meta(label_lower_bound)
373        lu_parts = flatten_meta(label_upper_bound)
374
375        parts = [X_parts]
376        meta_names = []
377
378        def append_meta(
379            m_parts: Optional[List["dask.delayed.delayed"]], name: str
380        ) -> None:
381            if m_parts is not None:
382                assert len(X_parts) == len(
383                    m_parts), inconsistent(X_parts, 'X', m_parts, name)
384                parts.append(m_parts)
385                meta_names.append(name)
386
387        append_meta(y_parts, 'labels')
388        append_meta(w_parts, 'weights')
389        append_meta(margin_parts, 'base_margin')
390        append_meta(qid_parts, 'qid')
391        append_meta(ll_parts, 'label_lower_bound')
392        append_meta(lu_parts, 'label_upper_bound')
393        # At this point, `parts` looks like:
394        # [(x0, x1, ..), (y0, y1, ..), ..] in delayed form
395
396        # delay the zipped result
397        parts = list(map(dask.delayed, zip(*parts)))  # pylint: disable=no-member
398        # At this point, the mental model should look like:
399        # [(x0, y0, ..), (x1, y1, ..), ..] in delayed form
400
401        parts = client.compute(parts)
402        await distributed.wait(parts)  # async wait for parts to be computed
403
404        for part in parts:
405            assert part.status == 'finished', part.status
406
407        # Preserving the partition order for prediction.
408        self.partition_order = {}
409        for i, part in enumerate(parts):
410            self.partition_order[part.key] = i
411
412        key_to_partition = {part.key: part for part in parts}
413        who_has = await client.scheduler.who_has(keys=[part.key for part in parts])
414
415        worker_map: Dict[str, "distributed.Future"] = defaultdict(list)
416
417        for key, workers in who_has.items():
418            worker_map[next(iter(workers))].append(key_to_partition[key])
419
420        self.worker_map = worker_map
421        self.meta_names = meta_names
422
423        if feature_weights is None:
424            self.feature_weights = None
425        else:
426            self.feature_weights = await client.compute(feature_weights).result()
427
428        return self
429
430    def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
431        '''Create a dictionary of objects that can be pickled for function
432        arguments.
433
434        '''
435        return {'feature_names': self.feature_names,
436                'feature_types': self.feature_types,
437                'feature_weights': self.feature_weights,
438                'meta_names': self.meta_names,
439                'missing': self.missing,
440                'enable_categorical': self.enable_categorical,
441                'parts': self.worker_map.get(worker_addr, None),
442                'is_quantile': self.is_quantile}
443
444    def num_col(self) -> int:
445        return self._n_cols
446
447
448_DataParts = List[Tuple[Any, Optional[Any], Optional[Any], Optional[Any], Optional[Any],
449                        Optional[Any], Optional[Any]]]
450
451
452def _get_worker_parts_ordered(
453    meta_names: List[str], list_of_parts: _DataParts
454) -> _DataParts:
455    # List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved.
456    assert isinstance(list_of_parts, list)
457
458    result = []
459
460    for i, _ in enumerate(list_of_parts):
461        data = list_of_parts[i][0]
462        labels = None
463        weights = None
464        base_margin = None
465        qid = None
466        label_lower_bound = None
467        label_upper_bound = None
468        # Iterate through all possible meta info, brings small overhead as in xgboost
469        # there are constant number of meta info available.
470        for j, blob in enumerate(list_of_parts[i][1:]):
471            if meta_names[j] == 'labels':
472                labels = blob
473            elif meta_names[j] == 'weights':
474                weights = blob
475            elif meta_names[j] == 'base_margin':
476                base_margin = blob
477            elif meta_names[j] == 'qid':
478                qid = blob
479            elif meta_names[j] == 'label_lower_bound':
480                label_lower_bound = blob
481            elif meta_names[j] == 'label_upper_bound':
482                label_upper_bound = blob
483            else:
484                raise ValueError('Unknown metainfo:', meta_names[j])
485        result.append((data, labels, weights, base_margin, qid, label_lower_bound,
486                       label_upper_bound))
487    return result
488
489
490def _unzip(list_of_parts: _DataParts) -> List[Tuple[Any, ...]]:
491    return list(zip(*list_of_parts))
492
493
494def _get_worker_parts(
495    list_of_parts: _DataParts, meta_names: List[str]
496) -> List[Tuple[Any, ...]]:
497    partitions = _get_worker_parts_ordered(meta_names, list_of_parts)
498    partitions_unzipped = _unzip(partitions)
499    return partitions_unzipped
500
501
502class DaskPartitionIter(DataIter):  # pylint: disable=R0902
503    """A data iterator for `DaskDeviceQuantileDMatrix`."""
504
505    def __init__(
506        self,
507        data: Tuple[Any, ...],
508        label: Optional[Tuple[Any, ...]] = None,
509        weight: Optional[Tuple[Any, ...]] = None,
510        base_margin: Optional[Tuple[Any, ...]] = None,
511        qid: Optional[Tuple[Any, ...]] = None,
512        label_lower_bound: Optional[Tuple[Any, ...]] = None,
513        label_upper_bound: Optional[Tuple[Any, ...]] = None,
514        feature_names: Optional[Union[str, List[str]]] = None,
515        feature_types: Optional[Union[Any, List[Any]]] = None
516    ) -> None:
517        self._data = data
518        self._labels = label
519        self._weights = weight
520        self._base_margin = base_margin
521        self._qid = qid
522        self._label_lower_bound = label_lower_bound
523        self._label_upper_bound = label_upper_bound
524        self._feature_names = feature_names
525        self._feature_types = feature_types
526
527        assert isinstance(self._data, Sequence)
528
529        types = (Sequence, type(None))
530        assert isinstance(self._labels, types)
531        assert isinstance(self._weights, types)
532        assert isinstance(self._base_margin, types)
533        assert isinstance(self._label_lower_bound, types)
534        assert isinstance(self._label_upper_bound, types)
535
536        self._iter = 0             # set iterator to 0
537        super().__init__()
538
539    def data(self) -> Any:
540        '''Utility function for obtaining current batch of data.'''
541        return self._data[self._iter]
542
543    def labels(self) -> Any:
544        '''Utility function for obtaining current batch of label.'''
545        if self._labels is not None:
546            return self._labels[self._iter]
547        return None
548
549    def weights(self) -> Any:
550        '''Utility function for obtaining current batch of label.'''
551        if self._weights is not None:
552            return self._weights[self._iter]
553        return None
554
555    def qids(self) -> Any:
556        '''Utility function for obtaining current batch of query id.'''
557        if self._qid is not None:
558            return self._qid[self._iter]
559        return None
560
561    def base_margins(self) -> Any:
562        '''Utility function for obtaining current batch of base_margin.'''
563        if self._base_margin is not None:
564            return self._base_margin[self._iter]
565        return None
566
567    def label_lower_bounds(self) -> Any:
568        '''Utility function for obtaining current batch of label_lower_bound.
569        '''
570        if self._label_lower_bound is not None:
571            return self._label_lower_bound[self._iter]
572        return None
573
574    def label_upper_bounds(self) -> Any:
575        '''Utility function for obtaining current batch of label_upper_bound.
576        '''
577        if self._label_upper_bound is not None:
578            return self._label_upper_bound[self._iter]
579        return None
580
581    def reset(self) -> None:
582        '''Reset the iterator'''
583        self._iter = 0
584
585    def next(self, input_data: Callable) -> int:
586        '''Yield next batch of data'''
587        if self._iter == len(self._data):
588            # Return 0 when there's no more batch.
589            return 0
590        feature_names: Optional[Union[List[str], str]] = None
591        if self._feature_names:
592            feature_names = self._feature_names
593        else:
594            if hasattr(self.data(), 'columns'):
595                feature_names = self.data().columns.format()
596            else:
597                feature_names = None
598        input_data(data=self.data(), label=self.labels(),
599                   weight=self.weights(), group=None,
600                   qid=self.qids(),
601                   label_lower_bound=self.label_lower_bounds(),
602                   label_upper_bound=self.label_upper_bounds(),
603                   feature_names=feature_names,
604                   feature_types=self._feature_types)
605        self._iter += 1
606        return 1
607
608
609class DaskDeviceQuantileDMatrix(DaskDMatrix):
610    '''Specialized data type for `gpu_hist` tree method.  This class is used to reduce the
611    memory usage by eliminating data copies.  Internally the all partitions/chunks of data
612    are merged by weighted GK sketching.  So the number of partitions from dask may affect
613    training accuracy as GK generates bounded error for each merge.  See doc string for
614    :py:obj:`xgboost.DeviceQuantileDMatrix` and :py:obj:`xgboost.DMatrix` for other
615    parameters.
616
617    .. versionadded:: 1.2.0
618
619    Parameters
620    ----------
621    max_bin : Number of bins for histogram construction.
622
623    '''
624    @_deprecate_positional_args
625    def __init__(
626        self,
627        client: "distributed.Client",
628        data: _DaskCollection,
629        label: Optional[_DaskCollection] = None,
630        *,
631        weight: Optional[_DaskCollection] = None,
632        base_margin: Optional[_DaskCollection] = None,
633        missing: float = None,
634        silent: bool = False,   # disable=unused-argument
635        feature_names: Optional[Union[str, List[str]]] = None,
636        feature_types: Optional[Union[Any, List[Any]]] = None,
637        max_bin: int = 256,
638        group: Optional[_DaskCollection] = None,
639        qid: Optional[_DaskCollection] = None,
640        label_lower_bound: Optional[_DaskCollection] = None,
641        label_upper_bound: Optional[_DaskCollection] = None,
642        feature_weights: Optional[_DaskCollection] = None,
643        enable_categorical: bool = False,
644    ) -> None:
645        super().__init__(
646            client=client,
647            data=data,
648            label=label,
649            weight=weight,
650            base_margin=base_margin,
651            group=group,
652            qid=qid,
653            label_lower_bound=label_lower_bound,
654            label_upper_bound=label_upper_bound,
655            missing=missing,
656            silent=silent,
657            feature_weights=feature_weights,
658            feature_names=feature_names,
659            feature_types=feature_types,
660            enable_categorical=enable_categorical,
661        )
662        self.max_bin = max_bin
663        self.is_quantile = True
664
665    def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
666        args = super()._create_fn_args(worker_addr)
667        args["max_bin"] = self.max_bin
668        return args
669
670
671def _create_device_quantile_dmatrix(
672    feature_names: Optional[Union[str, List[str]]],
673    feature_types: Optional[Union[Any, List[Any]]],
674    feature_weights: Optional[Any],
675    meta_names: List[str],
676    missing: float,
677    parts: Optional[_DataParts],
678    max_bin: int,
679    enable_categorical: bool,
680) -> DeviceQuantileDMatrix:
681    worker = distributed.get_worker()
682    if parts is None:
683        msg = f"worker {worker.address} has an empty DMatrix."
684        LOGGER.warning(msg)
685        import cupy
686
687        d = DeviceQuantileDMatrix(
688            cupy.zeros((0, 0)),
689            feature_names=feature_names,
690            feature_types=feature_types,
691            max_bin=max_bin,
692            enable_categorical=enable_categorical,
693        )
694        return d
695
696    (
697        data,
698        labels,
699        weights,
700        base_margin,
701        qid,
702        label_lower_bound,
703        label_upper_bound,
704    ) = _get_worker_parts(parts, meta_names)
705    it = DaskPartitionIter(
706        data=data,
707        label=labels,
708        weight=weights,
709        base_margin=base_margin,
710        qid=qid,
711        label_lower_bound=label_lower_bound,
712        label_upper_bound=label_upper_bound,
713    )
714
715    dmatrix = DeviceQuantileDMatrix(
716        it,
717        missing=missing,
718        feature_names=feature_names,
719        feature_types=feature_types,
720        nthread=worker.nthreads,
721        max_bin=max_bin,
722        enable_categorical=enable_categorical,
723    )
724    dmatrix.set_info(feature_weights=feature_weights)
725    return dmatrix
726
727
728def _create_dmatrix(
729    feature_names: Optional[Union[str, List[str]]],
730    feature_types: Optional[Union[Any, List[Any]]],
731    feature_weights: Optional[Any],
732    meta_names: List[str],
733    missing: float,
734    enable_categorical: bool,
735    parts: Optional[_DataParts]
736) -> DMatrix:
737    '''Get data that local to worker from DaskDMatrix.
738
739      Returns
740      -------
741      A DMatrix object.
742
743    '''
744    worker = distributed.get_worker()
745    list_of_parts = parts
746    if list_of_parts is None:
747        msg = f"worker {worker.address} has an empty DMatrix."
748        LOGGER.warning(msg)
749        d = DMatrix(
750            numpy.empty((0, 0)),
751            feature_names=feature_names,
752            feature_types=feature_types,
753            enable_categorical=enable_categorical,
754        )
755        return d
756
757    T = TypeVar('T')
758
759    def concat_or_none(data: Tuple[Optional[T], ...]) -> Optional[T]:
760        if any(part is None for part in data):
761            return None
762        return concat(data)
763
764    (data, labels, weights, base_margin, qid,
765     label_lower_bound, label_upper_bound) = _get_worker_parts(list_of_parts, meta_names)
766
767    _labels = concat_or_none(labels)
768    _weights = concat_or_none(weights)
769    _base_margin = concat_or_none(base_margin)
770    _qid = concat_or_none(qid)
771    _label_lower_bound = concat_or_none(label_lower_bound)
772    _label_upper_bound = concat_or_none(label_upper_bound)
773
774    _data = concat(data)
775    dmatrix = DMatrix(
776        _data,
777        _labels,
778        missing=missing,
779        feature_names=feature_names,
780        feature_types=feature_types,
781        nthread=worker.nthreads,
782        enable_categorical=enable_categorical,
783    )
784    dmatrix.set_info(
785        base_margin=_base_margin,
786        qid=_qid,
787        weight=_weights,
788        label_lower_bound=_label_lower_bound,
789        label_upper_bound=_label_upper_bound,
790        feature_weights=feature_weights,
791    )
792    return dmatrix
793
794
795def _dmatrix_from_list_of_parts(
796    is_quantile: bool, **kwargs: Any
797) -> Union[DMatrix, DeviceQuantileDMatrix]:
798    if is_quantile:
799        return _create_device_quantile_dmatrix(**kwargs)
800    return _create_dmatrix(**kwargs)
801
802
803async def _get_rabit_args(n_workers: int, client: "distributed.Client") -> List[bytes]:
804    '''Get rabit context arguments from data distribution in DaskDMatrix.'''
805    env = await client.run_on_scheduler(_start_tracker, n_workers)
806    rabit_args = [f"{k}={v}".encode() for k, v in env.items()]
807    return rabit_args
808
809# train and predict methods are supposed to be "functional", which meets the
810# dask paradigm.  But as a side effect, the `evals_result` in single-node API
811# is no longer supported since it mutates the input parameter, and it's not
812# intuitive to sync the mutation result.  Therefore, a dictionary containing
813# evaluation history is instead returned.
814
815
816def _get_workers_from_data(
817    dtrain: DaskDMatrix,
818    evals: Optional[List[Tuple[DaskDMatrix, str]]]
819) -> List[str]:
820    X_worker_map: Set[str] = set(dtrain.worker_map.keys())
821    if evals:
822        for e in evals:
823            assert len(e) == 2
824            assert isinstance(e[0], DaskDMatrix) and isinstance(e[1], str)
825            if e[0] is dtrain:
826                continue
827            worker_map = set(e[0].worker_map.keys())
828            X_worker_map = X_worker_map.union(worker_map)
829    return list(X_worker_map)
830
831
832async def _train_async(
833    client: "distributed.Client",
834    global_config: Dict[str, Any],
835    params: Dict[str, Any],
836    dtrain: DaskDMatrix,
837    num_boost_round: int,
838    evals: Optional[List[Tuple[DaskDMatrix, str]]],
839    obj: Optional[Objective],
840    feval: Optional[Metric],
841    early_stopping_rounds: Optional[int],
842    verbose_eval: Union[int, bool],
843    xgb_model: Optional[Booster],
844    callbacks: Optional[List[TrainingCallback]],
845) -> Optional[TrainReturnT]:
846    workers = _get_workers_from_data(dtrain, evals)
847    _rabit_args = await _get_rabit_args(len(workers), client)
848
849    if params.get("booster", None) == "gblinear":
850        raise NotImplementedError(
851            f"booster `{params['booster']}` is not yet supported for dask."
852        )
853
854    def dispatched_train(
855        worker_addr: str,
856        rabit_args: List[bytes],
857        dtrain_ref: Dict,
858        dtrain_idt: int,
859        evals_ref: Dict
860    ) -> Optional[Dict[str, Union[Booster, Dict]]]:
861        '''Perform training on a single worker.  A local function prevents pickling.
862
863        '''
864        LOGGER.debug('Training on %s', str(worker_addr))
865        worker = distributed.get_worker()
866        with RabitContext(rabit_args), config.config_context(**global_config):
867            local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref)
868            local_evals = []
869            if evals_ref:
870                for ref, name, idt in evals_ref:
871                    if idt == dtrain_idt:
872                        local_evals.append((local_dtrain, name))
873                        continue
874                    local_evals.append((_dmatrix_from_list_of_parts(**ref), name))
875
876            local_history: Dict = {}
877            local_param = params.copy()  # just to be consistent
878            msg = 'Overriding `nthreads` defined in dask worker.'
879            override = ['nthread', 'n_jobs']
880            for p in override:
881                val = local_param.get(p, None)
882                if val is not None and val != worker.nthreads:
883                    LOGGER.info(msg)
884                else:
885                    local_param[p] = worker.nthreads
886            bst = worker_train(params=local_param,
887                               dtrain=local_dtrain,
888                               num_boost_round=num_boost_round,
889                               evals_result=local_history,
890                               evals=local_evals,
891                               obj=obj,
892                               feval=feval,
893                               early_stopping_rounds=early_stopping_rounds,
894                               verbose_eval=verbose_eval,
895                               xgb_model=xgb_model,
896                               callbacks=callbacks)
897            ret: Optional[Dict[str, Union[Booster, Dict]]] = {
898                'booster': bst, 'history': local_history}
899            if local_dtrain.num_row() == 0:
900                ret = None
901            return ret
902
903    # Note for function purity:
904    # XGBoost is deterministic in most of the cases, which means train function is
905    # supposed to be idempotent.  One known exception is gblinear with shotgun updater.
906    # We haven't been able to do a full verification so here we keep pure to be False.
907    async with _multi_lock()(workers, client):
908        futures = []
909        for worker_addr in workers:
910            if evals:
911                # pylint: disable=protected-access
912                evals_per_worker = [
913                    (e._create_fn_args(worker_addr), name, id(e)) for e, name in evals
914                ]
915            else:
916                evals_per_worker = []
917            f = client.submit(
918                dispatched_train,
919                worker_addr,
920                _rabit_args,
921                # pylint: disable=protected-access
922                dtrain._create_fn_args(worker_addr),
923                id(dtrain),
924                evals_per_worker,
925                pure=False,
926                workers=[worker_addr],
927                allow_other_workers=False
928            )
929            futures.append(f)
930
931        results = await client.gather(futures, asynchronous=True)
932
933        return list(filter(lambda ret: ret is not None, results))[0]
934
935
936def train(                      # pylint: disable=unused-argument
937    client: "distributed.Client",
938    params: Dict[str, Any],
939    dtrain: DaskDMatrix,
940    num_boost_round: int = 10,
941    evals: Optional[List[Tuple[DaskDMatrix, str]]] = None,
942    obj: Optional[Objective] = None,
943    feval: Optional[Metric] = None,
944    early_stopping_rounds: Optional[int] = None,
945    xgb_model: Optional[Booster] = None,
946    verbose_eval: Union[int, bool] = True,
947    callbacks: Optional[List[TrainingCallback]] = None,
948) -> Any:
949    """Train XGBoost model.
950
951    .. versionadded:: 1.0.0
952
953    .. note::
954
955        Other parameters are the same as :py:func:`xgboost.train` except for
956        `evals_result`, which is returned as part of function return value instead of
957        argument.
958
959    Parameters
960    ----------
961    client :
962        Specify the dask client used for training.  Use default client returned from dask
963        if it's set to None.
964
965    Returns
966    -------
967    results: dict
968        A dictionary containing trained booster and evaluation history.  `history` field
969        is the same as `eval_result` from `xgboost.train`.
970
971        .. code-block:: python
972
973            {'booster': xgboost.Booster,
974             'history': {'train': {'logloss': ['0.48253', '0.35953']},
975                         'eval': {'logloss': ['0.480385', '0.357756']}}}
976
977    """
978    _assert_dask_support()
979    client = _xgb_get_client(client)
980    args = locals()
981    return client.sync(_train_async, global_config=config.get_config(), **args)
982
983
984def _can_output_df(is_df: bool, output_shape: Tuple) -> bool:
985    return is_df and len(output_shape) <= 2
986
987
988def _maybe_dataframe(
989    data: Any, prediction: Any, columns: List[int], is_df: bool
990) -> Any:
991    """Return dataframe for prediction when applicable."""
992    if _can_output_df(is_df, prediction.shape):
993        # Need to preserve the index for dataframe.
994        # See issue: https://github.com/dmlc/xgboost/issues/6939
995        # In older versions of dask, the partition is actually a numpy array when input is
996        # dataframe.
997        index = getattr(data, "index", None)
998        if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"):
999            import cudf
1000
1001            prediction = cudf.DataFrame(
1002                prediction, columns=columns, dtype=numpy.float32, index=index
1003            )
1004        else:
1005            prediction = DataFrame(
1006                prediction, columns=columns, dtype=numpy.float32, index=index
1007            )
1008    return prediction
1009
1010
1011async def _direct_predict_impl(  # pylint: disable=too-many-branches
1012    mapped_predict: Callable,
1013    booster: "distributed.Future",
1014    data: _DaskCollection,
1015    base_margin: Optional[_DaskCollection],
1016    output_shape: Tuple[int, ...],
1017    meta: Dict[int, str],
1018) -> _DaskCollection:
1019    columns = tuple(meta.keys())
1020    if len(output_shape) >= 3 and isinstance(data, dd.DataFrame):
1021        # Without this check, dask will finish the prediction silently even if output
1022        # dimension is greater than 3.  But during map_partitions, dask passes a
1023        # `dd.DataFrame` as local input to xgboost, which is converted to csr_matrix by
1024        # `_convert_unknown_data` since dd.DataFrame is not known to xgboost native
1025        # binding.
1026        raise ValueError(
1027            "Use `da.Array` or `DaskDMatrix` when output has more than 2 dimensions."
1028        )
1029    if _can_output_df(isinstance(data, dd.DataFrame), output_shape):
1030        if base_margin is not None and isinstance(base_margin, da.Array):
1031            # Easier for map_partitions
1032            base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe()
1033        else:
1034            base_margin_df = base_margin
1035        predictions = dd.map_partitions(
1036            mapped_predict,
1037            booster,
1038            data,
1039            True,
1040            columns,
1041            base_margin_df,
1042            meta=dd.utils.make_meta(meta),
1043        )
1044        # classification can return a dataframe, drop 1 dim when it's reg/binary
1045        if len(output_shape) == 1:
1046            predictions = predictions.iloc[:, 0]
1047    else:
1048        if base_margin is not None and isinstance(
1049            base_margin, (dd.Series, dd.DataFrame)
1050        ):
1051            # Easier for map_blocks
1052            base_margin_array: Optional[da.Array] = base_margin.to_dask_array()
1053        else:
1054            base_margin_array = base_margin
1055        # Input data is 2-dim array, output can be 1(reg, binary)/2(multi-class,
1056        # contrib)/3(contrib, interaction)/4(interaction) dims.
1057        if len(output_shape) == 1:
1058            drop_axis: Union[int, List[int]] = [1]  # drop from 2 to 1 dim.
1059            new_axis: Union[int, List[int]] = []
1060        else:
1061            drop_axis = []
1062            if isinstance(data, dd.DataFrame):
1063                new_axis = list(range(len(output_shape) - 2))
1064            else:
1065                new_axis = [i + 2 for i in range(len(output_shape) - 2)]
1066        if len(output_shape) == 2:
1067            # Somehow dask fail to infer output shape change for 2-dim prediction, and
1068            #  `chunks = (None, output_shape[1])` doesn't work due to None is not
1069            #  supported in map_blocks.
1070            chunks: Optional[List[Tuple]] = list(data.chunks)
1071            assert isinstance(chunks, list)
1072            chunks[1] = (output_shape[1], )
1073        else:
1074            chunks = None
1075        predictions = da.map_blocks(
1076            mapped_predict,
1077            booster,
1078            data,
1079            False,
1080            columns,
1081            base_margin_array,
1082
1083            chunks=chunks,
1084            drop_axis=drop_axis,
1085            new_axis=new_axis,
1086            dtype=numpy.float32,
1087        )
1088    return predictions
1089
1090
1091def _infer_predict_output(
1092    booster: Booster, features: int, is_df: bool, inplace: bool, **kwargs: Any
1093) -> Tuple[Tuple[int, ...], Dict[int, str]]:
1094    """Create a dummy test sample to infer output shape for prediction."""
1095    assert isinstance(features, int)
1096    rng = numpy.random.RandomState(1994)
1097    test_sample = rng.randn(1, features)
1098    if inplace:
1099        kwargs = kwargs.copy()
1100        if kwargs.pop("predict_type") == "margin":
1101            kwargs["output_margin"] = True
1102    m = DMatrix(test_sample)
1103    # generated DMatrix doesn't have feature name, so no validation.
1104    test_predt = booster.predict(m, validate_features=False, **kwargs)
1105    n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1
1106    meta: Dict[int, str] = {}
1107    if _can_output_df(is_df, test_predt.shape):
1108        for i in range(n_columns):
1109            meta[i] = "f4"
1110    return test_predt.shape, meta
1111
1112
1113async def _get_model_future(
1114    client: "distributed.Client", model: Union[Booster, Dict, "distributed.Future"]
1115) -> "distributed.Future":
1116    if isinstance(model, Booster):
1117        booster = await client.scatter(model, broadcast=True)
1118    elif isinstance(model, dict):
1119        booster = await client.scatter(model["booster"], broadcast=True)
1120    elif isinstance(model, distributed.Future):
1121        booster = model
1122        if booster.type is not Booster:
1123            raise TypeError(
1124                f"Underlying type of model future should be `Booster`, got {booster.type}"
1125            )
1126    else:
1127        raise TypeError(_expect([Booster, dict, distributed.Future], type(model)))
1128    return booster
1129
1130
1131# pylint: disable=too-many-statements
1132async def _predict_async(
1133    client: "distributed.Client",
1134    global_config: Dict[str, Any],
1135    model: Union[Booster, Dict, "distributed.Future"],
1136    data: _DaskCollection,
1137    output_margin: bool,
1138    missing: float,
1139    pred_leaf: bool,
1140    pred_contribs: bool,
1141    approx_contribs: bool,
1142    pred_interactions: bool,
1143    validate_features: bool,
1144    iteration_range: Tuple[int, int],
1145    strict_shape: bool,
1146) -> _DaskCollection:
1147    _booster = await _get_model_future(client, model)
1148    if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
1149        raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data)))
1150
1151    def mapped_predict(
1152        booster: Booster, partition: Any, is_df: bool, columns: List[int], _: Any
1153    ) -> Any:
1154        with config.config_context(**global_config):
1155            m = DMatrix(data=partition, missing=missing)
1156            predt = booster.predict(
1157                data=m,
1158                output_margin=output_margin,
1159                pred_leaf=pred_leaf,
1160                pred_contribs=pred_contribs,
1161                approx_contribs=approx_contribs,
1162                pred_interactions=pred_interactions,
1163                validate_features=validate_features,
1164                iteration_range=iteration_range,
1165                strict_shape=strict_shape,
1166            )
1167            predt = _maybe_dataframe(partition, predt, columns, is_df)
1168            return predt
1169
1170    # Predict on dask collection directly.
1171    if isinstance(data, (da.Array, dd.DataFrame)):
1172        _output_shape, meta = await client.compute(
1173            client.submit(
1174                _infer_predict_output,
1175                _booster,
1176                features=data.shape[1],
1177                is_df=isinstance(data, dd.DataFrame),
1178                inplace=False,
1179                output_margin=output_margin,
1180                pred_leaf=pred_leaf,
1181                pred_contribs=pred_contribs,
1182                approx_contribs=approx_contribs,
1183                pred_interactions=pred_interactions,
1184                strict_shape=strict_shape,
1185            )
1186        )
1187        return await _direct_predict_impl(
1188            mapped_predict, _booster, data, None, _output_shape, meta
1189        )
1190
1191    output_shape, _ = await client.compute(
1192        client.submit(
1193            _infer_predict_output,
1194            booster=_booster,
1195            features=data.num_col(),
1196            is_df=False,
1197            inplace=False,
1198            output_margin=output_margin,
1199            pred_leaf=pred_leaf,
1200            pred_contribs=pred_contribs,
1201            approx_contribs=approx_contribs,
1202            pred_interactions=pred_interactions,
1203            strict_shape=strict_shape,
1204        )
1205    )
1206    # Prediction on dask DMatrix.
1207    partition_order = data.partition_order
1208    feature_names = data.feature_names
1209    feature_types = data.feature_types
1210    missing = data.missing
1211    meta_names = data.meta_names
1212
1213    def dispatched_predict(booster: Booster, part: Tuple) -> numpy.ndarray:
1214        data = part[0]
1215        assert isinstance(part, tuple), type(part)
1216        base_margin = None
1217        for i, blob in enumerate(part[1:]):
1218            if meta_names[i] == "base_margin":
1219                base_margin = blob
1220        with config.config_context(**global_config):
1221            m = DMatrix(
1222                data,
1223                missing=missing,
1224                base_margin=base_margin,
1225                feature_names=feature_names,
1226                feature_types=feature_types,
1227            )
1228            predt = booster.predict(
1229                m,
1230                output_margin=output_margin,
1231                pred_leaf=pred_leaf,
1232                pred_contribs=pred_contribs,
1233                approx_contribs=approx_contribs,
1234                pred_interactions=pred_interactions,
1235                validate_features=validate_features,
1236                iteration_range=iteration_range,
1237                strict_shape=strict_shape,
1238            )
1239            return predt
1240
1241    all_parts = []
1242    all_orders = []
1243    all_shapes = []
1244    all_workers: List[str] = []
1245    workers_address = list(data.worker_map.keys())
1246    for worker_addr in workers_address:
1247        list_of_parts = data.worker_map[worker_addr]
1248        all_parts.extend(list_of_parts)
1249        all_workers.extend(len(list_of_parts) * [worker_addr])
1250        all_orders.extend([partition_order[part.key] for part in list_of_parts])
1251    for w, part in zip(all_workers, all_parts):
1252        s = client.submit(lambda part: part[0].shape[0], part, workers=[w])
1253        all_shapes.append(s)
1254    all_shapes = await client.gather(all_shapes)
1255
1256    parts_with_order = list(zip(all_parts, all_shapes, all_orders))
1257    parts_with_order = sorted(parts_with_order, key=lambda p: p[2])
1258    all_parts = [part for part, shape, order in parts_with_order]
1259    all_shapes = [shape for part, shape, order in parts_with_order]
1260
1261    futures = []
1262    for w, part in zip(all_workers, all_parts):
1263        f = client.submit(dispatched_predict, _booster, part, workers=[w])
1264        futures.append(f)
1265
1266    # Constructing a dask array from list of numpy arrays
1267    # See https://docs.dask.org/en/latest/array-creation.html
1268    arrays = []
1269    for i, rows in enumerate(all_shapes):
1270        arrays.append(
1271            da.from_delayed(
1272                futures[i], shape=(rows,) + output_shape[1:], dtype=numpy.float32
1273            )
1274        )
1275    predictions = da.concatenate(arrays, axis=0)
1276    return predictions
1277
1278
1279def predict(                    # pylint: disable=unused-argument
1280    client: "distributed.Client",
1281    model: Union[TrainReturnT, Booster, "distributed.Future"],
1282    data: Union[DaskDMatrix, _DaskCollection],
1283    output_margin: bool = False,
1284    missing: float = numpy.nan,
1285    pred_leaf: bool = False,
1286    pred_contribs: bool = False,
1287    approx_contribs: bool = False,
1288    pred_interactions: bool = False,
1289    validate_features: bool = True,
1290    iteration_range: Tuple[int, int] = (0, 0),
1291    strict_shape: bool = False,
1292) -> Any:
1293    '''Run prediction with a trained booster.
1294
1295    .. note::
1296
1297        Using ``inplace_predict`` might be faster when some features are not needed.  See
1298        :py:meth:`xgboost.Booster.predict` for details on various parameters.  When output
1299        has more than 2 dimensions (shap value, leaf with strict_shape), input should be
1300        ``da.Array`` or ``DaskDMatrix``.
1301
1302    .. versionadded:: 1.0.0
1303
1304    Parameters
1305    ----------
1306    client:
1307        Specify the dask client used for training.  Use default client
1308        returned from dask if it's set to None.
1309    model:
1310        The trained model.  It can be a distributed.Future so user can
1311        pre-scatter it onto all workers.
1312    data:
1313        Input data used for prediction.  When input is a dataframe object,
1314        prediction output is a series.
1315    missing:
1316        Used when input data is not DaskDMatrix.  Specify the value
1317        considered as missing.
1318
1319    Returns
1320    -------
1321    prediction: dask.array.Array/dask.dataframe.Series
1322        When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an
1323        array, when input data is ``dask.dataframe.DataFrame``, return value can be
1324        ``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output
1325        shape.
1326
1327    '''
1328    _assert_dask_support()
1329    client = _xgb_get_client(client)
1330    return client.sync(_predict_async, global_config=config.get_config(), **locals())
1331
1332
1333async def _inplace_predict_async(  # pylint: disable=too-many-branches
1334    client: "distributed.Client",
1335    global_config: Dict[str, Any],
1336    model: Union[Booster, Dict, "distributed.Future"],
1337    data: _DaskCollection,
1338    iteration_range: Tuple[int, int],
1339    predict_type: str,
1340    missing: float,
1341    validate_features: bool,
1342    base_margin: Optional[_DaskCollection],
1343    strict_shape: bool,
1344) -> _DaskCollection:
1345    client = _xgb_get_client(client)
1346    booster = await _get_model_future(client, model)
1347    if not isinstance(data, (da.Array, dd.DataFrame)):
1348        raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
1349    if base_margin is not None and not isinstance(
1350        data, (da.Array, dd.DataFrame, dd.Series)
1351    ):
1352        raise TypeError(_expect([da.Array, dd.DataFrame, dd.Series], type(base_margin)))
1353
1354    def mapped_predict(
1355        booster: Booster, partition: Any, is_df: bool, columns: List[int], base_margin: Any
1356    ) -> Any:
1357        with config.config_context(**global_config):
1358            prediction = booster.inplace_predict(
1359                partition,
1360                iteration_range=iteration_range,
1361                predict_type=predict_type,
1362                missing=missing,
1363                base_margin=base_margin,
1364                validate_features=validate_features,
1365                strict_shape=strict_shape,
1366            )
1367        prediction = _maybe_dataframe(partition, prediction, columns, is_df)
1368        return prediction
1369
1370    # await turns future into value.
1371    shape, meta = await client.compute(
1372        client.submit(
1373            _infer_predict_output,
1374            booster,
1375            features=data.shape[1],
1376            is_df=isinstance(data, dd.DataFrame),
1377            inplace=True,
1378            predict_type=predict_type,
1379            iteration_range=iteration_range,
1380            strict_shape=strict_shape,
1381        )
1382    )
1383    return await _direct_predict_impl(
1384        mapped_predict, booster, data, base_margin, shape, meta
1385    )
1386
1387
1388def inplace_predict(  # pylint: disable=unused-argument
1389    client: "distributed.Client",
1390    model: Union[TrainReturnT, Booster, "distributed.Future"],
1391    data: _DaskCollection,
1392    iteration_range: Tuple[int, int] = (0, 0),
1393    predict_type: str = "value",
1394    missing: float = numpy.nan,
1395    validate_features: bool = True,
1396    base_margin: Optional[_DaskCollection] = None,
1397    strict_shape: bool = False,
1398) -> Any:
1399    """Inplace prediction. See doc in :py:meth:`xgboost.Booster.inplace_predict` for details.
1400
1401    .. versionadded:: 1.1.0
1402
1403    Parameters
1404    ----------
1405    client:
1406        Specify the dask client used for training.  Use default client
1407        returned from dask if it's set to None.
1408    model:
1409        See :py:func:`xgboost.dask.predict` for details.
1410    data :
1411        dask collection.
1412    iteration_range:
1413        See :py:meth:`xgboost.Booster.predict` for details.
1414    predict_type:
1415        See :py:meth:`xgboost.Booster.inplace_predict` for details.
1416    missing:
1417        Value in the input data which needs to be present as a missing
1418        value. If None, defaults to np.nan.
1419    base_margin:
1420        See :py:obj:`xgboost.DMatrix` for details. Right now classifier is not well
1421        supported with base_margin as it requires the size of base margin to be `n_classes
1422        * n_samples`.
1423
1424        .. versionadded:: 1.4.0
1425
1426    strict_shape:
1427        See :py:meth:`xgboost.Booster.predict` for details.
1428
1429        .. versionadded:: 1.4.0
1430
1431    Returns
1432    -------
1433    prediction :
1434        When input data is ``dask.array.Array``, the return value is an array, when input
1435        data is ``dask.dataframe.DataFrame``, return value can be
1436        ``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output
1437        shape.
1438
1439    """
1440    _assert_dask_support()
1441    client = _xgb_get_client(client)
1442    # When used in asynchronous environment, the `client` object should have
1443    # `asynchronous` attribute as True.  When invoked by the skl interface, it's
1444    # responsible for setting up the client.
1445    return client.sync(
1446        _inplace_predict_async, global_config=config.get_config(), **locals()
1447    )
1448
1449
1450async def _async_wrap_evaluation_matrices(
1451    client: "distributed.Client", **kwargs: Any
1452) -> Tuple[DaskDMatrix, Optional[List[Tuple[DaskDMatrix, str]]]]:
1453    """A switch function for async environment."""
1454
1455    def _inner(**kwargs: Any) -> DaskDMatrix:
1456        m = DaskDMatrix(client=client, **kwargs)
1457        return m
1458
1459    train_dmatrix, evals = _wrap_evaluation_matrices(create_dmatrix=_inner, **kwargs)
1460    train_dmatrix = await train_dmatrix
1461    if evals is None:
1462        return train_dmatrix, evals
1463    awaited = []
1464    for e in evals:
1465        if e[0] is train_dmatrix:  # already awaited
1466            awaited.append(e)
1467            continue
1468        awaited.append((await e[0], e[1]))
1469    return train_dmatrix, awaited
1470
1471
1472@contextmanager
1473def _set_worker_client(
1474    model: "DaskScikitLearnBase", client: "distributed.Client"
1475) -> Generator:
1476    """Temporarily set the client for sklearn model."""
1477    try:
1478        model.client = client
1479        yield model
1480    finally:
1481        model.client = None
1482
1483
1484class DaskScikitLearnBase(XGBModel):
1485    """Base class for implementing scikit-learn interface with Dask"""
1486
1487    _client = None
1488
1489    async def _predict_async(
1490        self,
1491        data: _DaskCollection,
1492        output_margin: bool,
1493        validate_features: bool,
1494        base_margin: Optional[_DaskCollection],
1495        iteration_range: Optional[Tuple[int, int]],
1496    ) -> Any:
1497        iteration_range = self._get_iteration_range(iteration_range)
1498        if self._can_use_inplace_predict():
1499            predts = await inplace_predict(
1500                client=self.client,
1501                model=self.get_booster(),
1502                data=data,
1503                iteration_range=iteration_range,
1504                predict_type="margin" if output_margin else "value",
1505                missing=self.missing,
1506                base_margin=base_margin,
1507                validate_features=validate_features,
1508            )
1509            if isinstance(predts, dd.DataFrame):
1510                predts = predts.to_dask_array()
1511        else:
1512            test_dmatrix = await DaskDMatrix(
1513                self.client, data=data, base_margin=base_margin, missing=self.missing
1514            )
1515            predts = await predict(
1516                self.client,
1517                model=self.get_booster(),
1518                data=test_dmatrix,
1519                output_margin=output_margin,
1520                validate_features=validate_features,
1521                iteration_range=iteration_range,
1522            )
1523        return predts
1524
1525    def predict(
1526        self,
1527        X: _DaskCollection,
1528        output_margin: bool = False,
1529        ntree_limit: Optional[int] = None,
1530        validate_features: bool = True,
1531        base_margin: Optional[_DaskCollection] = None,
1532        iteration_range: Optional[Tuple[int, int]] = None,
1533    ) -> Any:
1534        _assert_dask_support()
1535        msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead."
1536        assert ntree_limit is None, msg
1537        return self.client.sync(
1538            self._predict_async,
1539            X,
1540            output_margin=output_margin,
1541            validate_features=validate_features,
1542            base_margin=base_margin,
1543            iteration_range=iteration_range,
1544        )
1545
1546    async def _apply_async(
1547        self,
1548        X: _DaskCollection,
1549        iteration_range: Optional[Tuple[int, int]] = None,
1550    ) -> Any:
1551        iteration_range = self._get_iteration_range(iteration_range)
1552        test_dmatrix = await DaskDMatrix(self.client, data=X, missing=self.missing)
1553        predts = await predict(
1554            self.client,
1555            model=self.get_booster(),
1556            data=test_dmatrix,
1557            pred_leaf=True,
1558            iteration_range=iteration_range,
1559        )
1560        return predts
1561
1562    def apply(
1563        self,
1564        X: _DaskCollection,
1565        ntree_limit: Optional[int] = None,
1566        iteration_range: Optional[Tuple[int, int]] = None,
1567    ) -> Any:
1568        _assert_dask_support()
1569        msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead."
1570        assert ntree_limit is None, msg
1571        return self.client.sync(self._apply_async, X, iteration_range=iteration_range)
1572
1573    def __await__(self) -> Awaitable[Any]:
1574        # Generate a coroutine wrapper to make this class awaitable.
1575        async def _() -> Awaitable[Any]:
1576            return self
1577
1578        return self._client_sync(_).__await__()
1579
1580    def __getstate__(self) -> Dict:
1581        this = self.__dict__.copy()
1582        if "_client" in this.keys():
1583            del this["_client"]
1584        return this
1585
1586    @property
1587    def client(self) -> "distributed.Client":
1588        """The dask client used in this model.  The `Client` object can not be serialized for
1589        transmission, so if task is launched from a worker instead of directly from the
1590        client process, this attribute needs to be set at that worker.
1591
1592        """
1593
1594        client = _xgb_get_client(self._client)
1595        return client
1596
1597    @client.setter
1598    def client(self, clt: "distributed.Client") -> None:
1599        # calling `worker_client' doesn't return the correct `asynchronous` attribute, so
1600        # we have to pass it ourselves.
1601        self._asynchronous = clt.asynchronous if clt is not None else False
1602        self._client = clt
1603
1604    def _client_sync(self, func: Callable, **kwargs: Any) -> Any:
1605        """Get the correct client, when method is invoked inside a worker we
1606        should use `worker_client' instead of default client.
1607
1608        """
1609        asynchronous = getattr(self, "_asynchronous", False)
1610        if self._client is None:
1611            try:
1612                distributed.get_worker()
1613                in_worker = True
1614            except ValueError:
1615                in_worker = False
1616            if in_worker:
1617                with distributed.worker_client() as client:
1618                    with _set_worker_client(self, client) as this:
1619                        ret = this.client.sync(func, **kwargs, asynchronous=asynchronous)
1620                        return ret
1621                    return ret
1622
1623        return self.client.sync(func, **kwargs, asynchronous=asynchronous)
1624
1625
1626@xgboost_model_doc(
1627    """Implementation of the Scikit-Learn API for XGBoost.""", ["estimators", "model"]
1628)
1629class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
1630    # pylint: disable=missing-class-docstring
1631    async def _fit_async(
1632        self,
1633        X: _DaskCollection,
1634        y: _DaskCollection,
1635        sample_weight: Optional[_DaskCollection],
1636        base_margin: Optional[_DaskCollection],
1637        eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
1638        eval_metric: Optional[Union[str, List[str], Metric]],
1639        sample_weight_eval_set: Optional[List[_DaskCollection]],
1640        base_margin_eval_set: Optional[List[_DaskCollection]],
1641        early_stopping_rounds: int,
1642        verbose: bool,
1643        xgb_model: Optional[Union[Booster, XGBModel]],
1644        feature_weights: Optional[_DaskCollection],
1645        callbacks: Optional[List[TrainingCallback]],
1646    ) -> _DaskCollection:
1647        params = self.get_xgb_params()
1648        dtrain, evals = await _async_wrap_evaluation_matrices(
1649            client=self.client,
1650            X=X,
1651            y=y,
1652            group=None,
1653            qid=None,
1654            sample_weight=sample_weight,
1655            base_margin=base_margin,
1656            feature_weights=feature_weights,
1657            eval_set=eval_set,
1658            sample_weight_eval_set=sample_weight_eval_set,
1659            base_margin_eval_set=base_margin_eval_set,
1660            eval_group=None,
1661            eval_qid=None,
1662            missing=self.missing,
1663            enable_categorical=self.enable_categorical,
1664        )
1665
1666        if callable(self.objective):
1667            obj: Optional[Callable] = _objective_decorator(self.objective)
1668        else:
1669            obj = None
1670        model, metric, params = self._configure_fit(
1671            booster=xgb_model, eval_metric=eval_metric, params=params
1672        )
1673        results = await self.client.sync(
1674            _train_async,
1675            asynchronous=True,
1676            client=self.client,
1677            global_config=config.get_config(),
1678            params=params,
1679            dtrain=dtrain,
1680            num_boost_round=self.get_num_boosting_rounds(),
1681            evals=evals,
1682            obj=obj,
1683            feval=metric,
1684            verbose_eval=verbose,
1685            early_stopping_rounds=early_stopping_rounds,
1686            callbacks=callbacks,
1687            xgb_model=model,
1688        )
1689        self._Booster = results["booster"]
1690        self._set_evaluation_result(results["history"])
1691        return self
1692
1693    # pylint: disable=missing-docstring, disable=unused-argument
1694    @_deprecate_positional_args
1695    def fit(
1696        self,
1697        X: _DaskCollection,
1698        y: _DaskCollection,
1699        *,
1700        sample_weight: Optional[_DaskCollection] = None,
1701        base_margin: Optional[_DaskCollection] = None,
1702        eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
1703        eval_metric: Optional[Union[str, List[str], Metric]] = None,
1704        early_stopping_rounds: Optional[int] = None,
1705        verbose: bool = True,
1706        xgb_model: Optional[Union[Booster, XGBModel]] = None,
1707        sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
1708        base_margin_eval_set: Optional[List[_DaskCollection]] = None,
1709        feature_weights: Optional[_DaskCollection] = None,
1710        callbacks: Optional[List[TrainingCallback]] = None,
1711    ) -> "DaskXGBRegressor":
1712        _assert_dask_support()
1713        args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
1714        return self._client_sync(self._fit_async, **args)
1715
1716
1717@xgboost_model_doc(
1718    'Implementation of the scikit-learn API for XGBoost classification.',
1719    ['estimators', 'model'])
1720class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
1721    # pylint: disable=missing-class-docstring
1722    async def _fit_async(
1723        self, X: _DaskCollection, y: _DaskCollection,
1724        sample_weight: Optional[_DaskCollection],
1725        base_margin: Optional[_DaskCollection],
1726        eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
1727        eval_metric: Optional[Union[str, List[str], Metric]],
1728        sample_weight_eval_set: Optional[List[_DaskCollection]],
1729        base_margin_eval_set: Optional[List[_DaskCollection]],
1730        early_stopping_rounds: int,
1731        verbose: bool,
1732        xgb_model: Optional[Union[Booster, XGBModel]],
1733        feature_weights: Optional[_DaskCollection],
1734        callbacks: Optional[List[TrainingCallback]]
1735    ) -> "DaskXGBClassifier":
1736        params = self.get_xgb_params()
1737        dtrain, evals = await _async_wrap_evaluation_matrices(
1738            self.client,
1739            X=X,
1740            y=y,
1741            group=None,
1742            qid=None,
1743            sample_weight=sample_weight,
1744            base_margin=base_margin,
1745            feature_weights=feature_weights,
1746            eval_set=eval_set,
1747            sample_weight_eval_set=sample_weight_eval_set,
1748            base_margin_eval_set=base_margin_eval_set,
1749            eval_group=None,
1750            eval_qid=None,
1751            missing=self.missing,
1752            enable_categorical=self.enable_categorical,
1753        )
1754
1755        # pylint: disable=attribute-defined-outside-init
1756        if isinstance(y, (da.Array)):
1757            self.classes_ = await self.client.compute(da.unique(y))
1758        else:
1759            self.classes_ = await self.client.compute(y.drop_duplicates())
1760        self.n_classes_ = len(self.classes_)
1761
1762        if self.n_classes_ > 2:
1763            params["objective"] = "multi:softprob"
1764            params['num_class'] = self.n_classes_
1765        else:
1766            params["objective"] = "binary:logistic"
1767
1768        if callable(self.objective):
1769            obj: Optional[Callable] = _objective_decorator(self.objective)
1770        else:
1771            obj = None
1772        model, metric, params = self._configure_fit(
1773            booster=xgb_model, eval_metric=eval_metric, params=params
1774        )
1775        results = await self.client.sync(
1776            _train_async,
1777            asynchronous=True,
1778            client=self.client,
1779            global_config=config.get_config(),
1780            params=params,
1781            dtrain=dtrain,
1782            num_boost_round=self.get_num_boosting_rounds(),
1783            evals=evals,
1784            obj=obj,
1785            feval=metric,
1786            verbose_eval=verbose,
1787            early_stopping_rounds=early_stopping_rounds,
1788            callbacks=callbacks,
1789            xgb_model=model,
1790        )
1791        self._Booster = results['booster']
1792        if not callable(self.objective):
1793            self.objective = params["objective"]
1794        self._set_evaluation_result(results["history"])
1795        return self
1796
1797    # pylint: disable=unused-argument
1798    def fit(
1799        self,
1800        X: _DaskCollection,
1801        y: _DaskCollection,
1802        *,
1803        sample_weight: Optional[_DaskCollection] = None,
1804        base_margin: Optional[_DaskCollection] = None,
1805        eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
1806        eval_metric: Optional[Union[str, List[str], Metric]] = None,
1807        early_stopping_rounds: Optional[int] = None,
1808        verbose: bool = True,
1809        xgb_model: Optional[Union[Booster, XGBModel]] = None,
1810        sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
1811        base_margin_eval_set: Optional[List[_DaskCollection]] = None,
1812        feature_weights: Optional[_DaskCollection] = None,
1813        callbacks: Optional[List[TrainingCallback]] = None
1814    ) -> "DaskXGBClassifier":
1815        _assert_dask_support()
1816        args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
1817        return self._client_sync(self._fit_async, **args)
1818
1819    async def _predict_proba_async(
1820        self,
1821        X: _DaskCollection,
1822        validate_features: bool,
1823        base_margin: Optional[_DaskCollection],
1824        iteration_range: Optional[Tuple[int, int]],
1825    ) -> _DaskCollection:
1826        predts = await super()._predict_async(
1827            data=X,
1828            output_margin=self.objective == "multi:softmax",
1829            validate_features=validate_features,
1830            base_margin=base_margin,
1831            iteration_range=iteration_range,
1832        )
1833        vstack = update_wrapper(
1834            partial(da.vstack, allow_unknown_chunksizes=True), da.vstack
1835        )
1836        return _cls_predict_proba(getattr(self, "n_classes_", None), predts, vstack)
1837
1838    # pylint: disable=missing-function-docstring
1839    def predict_proba(
1840        self,
1841        X: _DaskCollection,
1842        ntree_limit: Optional[int] = None,
1843        validate_features: bool = True,
1844        base_margin: Optional[_DaskCollection] = None,
1845        iteration_range: Optional[Tuple[int, int]] = None,
1846    ) -> Any:
1847        _assert_dask_support()
1848        msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead."
1849        assert ntree_limit is None, msg
1850        return self._client_sync(
1851            self._predict_proba_async,
1852            X=X,
1853            validate_features=validate_features,
1854            base_margin=base_margin,
1855            iteration_range=iteration_range,
1856        )
1857
1858    predict_proba.__doc__ = XGBClassifier.predict_proba.__doc__
1859
1860    async def _predict_async(
1861        self,
1862        data: _DaskCollection,
1863        output_margin: bool,
1864        validate_features: bool,
1865        base_margin: Optional[_DaskCollection],
1866        iteration_range: Optional[Tuple[int, int]],
1867    ) -> _DaskCollection:
1868        pred_probs = await super()._predict_async(
1869            data, output_margin, validate_features, base_margin, iteration_range
1870        )
1871        if output_margin:
1872            return pred_probs
1873
1874        if len(pred_probs.shape) == 1:
1875            preds = (pred_probs > 0.5).astype(int)
1876        else:
1877            assert len(pred_probs.shape) == 2
1878            assert isinstance(pred_probs, da.Array)
1879            # when using da.argmax directly, dask will construct a numpy based return
1880            # array, which runs into error when computing GPU based prediction.
1881
1882            def _argmax(x: Any) -> Any:
1883                return x.argmax(axis=1)
1884
1885            preds = da.map_blocks(_argmax, pred_probs, drop_axis=1)
1886        return preds
1887
1888
1889@xgboost_model_doc(
1890    """Implementation of the Scikit-Learn API for XGBoost Ranking.
1891
1892    .. versionadded:: 1.4.0
1893
1894""",
1895    ["estimators", "model"],
1896    end_note="""
1897        Note
1898        ----
1899        For dask implementation, group is not supported, use qid instead.
1900""",
1901)
1902class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
1903    @_deprecate_positional_args
1904    def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any):
1905        if callable(objective):
1906            raise ValueError("Custom objective function not supported by XGBRanker.")
1907        super().__init__(objective=objective, kwargs=kwargs)
1908
1909    async def _fit_async(
1910        self,
1911        X: _DaskCollection,
1912        y: _DaskCollection,
1913        group: Optional[_DaskCollection],
1914        qid: Optional[_DaskCollection],
1915        sample_weight: Optional[_DaskCollection],
1916        base_margin: Optional[_DaskCollection],
1917        eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
1918        sample_weight_eval_set: Optional[List[_DaskCollection]],
1919        base_margin_eval_set: Optional[List[_DaskCollection]],
1920        eval_group: Optional[List[_DaskCollection]],
1921        eval_qid: Optional[List[_DaskCollection]],
1922        eval_metric: Optional[Union[str, List[str], Metric]],
1923        early_stopping_rounds: int,
1924        verbose: bool,
1925        xgb_model: Optional[Union[XGBModel, Booster]],
1926        feature_weights: Optional[_DaskCollection],
1927        callbacks: Optional[List[TrainingCallback]],
1928    ) -> "DaskXGBRanker":
1929        msg = "Use `qid` instead of `group` on dask interface."
1930        if not (group is None and eval_group is None):
1931            raise ValueError(msg)
1932        if qid is None:
1933            raise ValueError("`qid` is required for ranking.")
1934        params = self.get_xgb_params()
1935        dtrain, evals = await _async_wrap_evaluation_matrices(
1936            self.client,
1937            X=X,
1938            y=y,
1939            group=None,
1940            qid=qid,
1941            sample_weight=sample_weight,
1942            base_margin=base_margin,
1943            feature_weights=feature_weights,
1944            eval_set=eval_set,
1945            sample_weight_eval_set=sample_weight_eval_set,
1946            base_margin_eval_set=base_margin_eval_set,
1947            eval_group=None,
1948            eval_qid=eval_qid,
1949            missing=self.missing,
1950            enable_categorical=self.enable_categorical,
1951        )
1952        if eval_metric is not None:
1953            if callable(eval_metric):
1954                raise ValueError(
1955                    "Custom evaluation metric is not yet supported for XGBRanker."
1956                )
1957        model, metric, params = self._configure_fit(
1958            booster=xgb_model, eval_metric=eval_metric, params=params
1959        )
1960        results = await self.client.sync(
1961            _train_async,
1962            asynchronous=True,
1963            client=self.client,
1964            global_config=config.get_config(),
1965            params=params,
1966            dtrain=dtrain,
1967            num_boost_round=self.get_num_boosting_rounds(),
1968            evals=evals,
1969            obj=None,
1970            feval=metric,
1971            verbose_eval=verbose,
1972            early_stopping_rounds=early_stopping_rounds,
1973            callbacks=callbacks,
1974            xgb_model=model,
1975        )
1976        self._Booster = results["booster"]
1977        self.evals_result_ = results["history"]
1978        return self
1979
1980    # pylint: disable=unused-argument, arguments-differ
1981    @_deprecate_positional_args
1982    def fit(
1983        self,
1984        X: _DaskCollection,
1985        y: _DaskCollection,
1986        *,
1987        group: Optional[_DaskCollection] = None,
1988        qid: Optional[_DaskCollection] = None,
1989        sample_weight: Optional[_DaskCollection] = None,
1990        base_margin: Optional[_DaskCollection] = None,
1991        eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
1992        eval_group: Optional[List[_DaskCollection]] = None,
1993        eval_qid: Optional[List[_DaskCollection]] = None,
1994        eval_metric: Optional[Union[str, List[str], Metric]] = None,
1995        early_stopping_rounds: int = None,
1996        verbose: bool = False,
1997        xgb_model: Optional[Union[XGBModel, Booster]] = None,
1998        sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
1999        base_margin_eval_set: Optional[List[_DaskCollection]] = None,
2000        feature_weights: Optional[_DaskCollection] = None,
2001        callbacks: Optional[List[TrainingCallback]] = None
2002    ) -> "DaskXGBRanker":
2003        _assert_dask_support()
2004        args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
2005        return self._client_sync(self._fit_async, **args)
2006
2007    # FIXME(trivialfis): arguments differ due to additional parameters like group and qid.
2008    fit.__doc__ = XGBRanker.fit.__doc__
2009
2010
2011@xgboost_model_doc(
2012    """Implementation of the Scikit-Learn API for XGBoost Random Forest Regressor.
2013
2014    .. versionadded:: 1.4.0
2015
2016""",
2017    ["model", "objective"],
2018    extra_parameters="""
2019    n_estimators : int
2020        Number of trees in random forest to fit.
2021""",
2022)
2023class DaskXGBRFRegressor(DaskXGBRegressor):
2024    @_deprecate_positional_args
2025    def __init__(
2026        self,
2027        *,
2028        learning_rate: Optional[float] = 1,
2029        subsample: Optional[float] = 0.8,
2030        colsample_bynode: Optional[float] = 0.8,
2031        reg_lambda: Optional[float] = 1e-5,
2032        **kwargs: Any
2033    ) -> None:
2034        super().__init__(
2035            learning_rate=learning_rate,
2036            subsample=subsample,
2037            colsample_bynode=colsample_bynode,
2038            reg_lambda=reg_lambda,
2039            **kwargs
2040        )
2041
2042    def get_xgb_params(self) -> Dict[str, Any]:
2043        params = super().get_xgb_params()
2044        params["num_parallel_tree"] = self.n_estimators
2045        return params
2046
2047    def get_num_boosting_rounds(self) -> int:
2048        return 1
2049
2050    # pylint: disable=unused-argument
2051    def fit(
2052        self,
2053        X: _DaskCollection,
2054        y: _DaskCollection,
2055        *,
2056        sample_weight: Optional[_DaskCollection] = None,
2057        base_margin: Optional[_DaskCollection] = None,
2058        eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
2059        eval_metric: Optional[Union[str, List[str], Metric]] = None,
2060        early_stopping_rounds: Optional[int] = None,
2061        verbose: bool = True,
2062        xgb_model: Optional[Union[Booster, XGBModel]] = None,
2063        sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
2064        base_margin_eval_set: Optional[List[_DaskCollection]] = None,
2065        feature_weights: Optional[_DaskCollection] = None,
2066        callbacks: Optional[List[TrainingCallback]] = None
2067    ) -> "DaskXGBRFRegressor":
2068        _assert_dask_support()
2069        args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
2070        _check_rf_callback(early_stopping_rounds, callbacks)
2071        super().fit(**args)
2072        return self
2073
2074
2075@xgboost_model_doc(
2076    """Implementation of the Scikit-Learn API for XGBoost Random Forest Classifier.
2077
2078    .. versionadded:: 1.4.0
2079
2080""",
2081    ["model", "objective"],
2082    extra_parameters="""
2083    n_estimators : int
2084        Number of trees in random forest to fit.
2085""",
2086)
2087class DaskXGBRFClassifier(DaskXGBClassifier):
2088    @_deprecate_positional_args
2089    def __init__(
2090        self,
2091        *,
2092        learning_rate: Optional[float] = 1,
2093        subsample: Optional[float] = 0.8,
2094        colsample_bynode: Optional[float] = 0.8,
2095        reg_lambda: Optional[float] = 1e-5,
2096        **kwargs: Any
2097    ) -> None:
2098        super().__init__(
2099            learning_rate=learning_rate,
2100            subsample=subsample,
2101            colsample_bynode=colsample_bynode,
2102            reg_lambda=reg_lambda,
2103            **kwargs
2104        )
2105
2106    def get_xgb_params(self) -> Dict[str, Any]:
2107        params = super().get_xgb_params()
2108        params["num_parallel_tree"] = self.n_estimators
2109        return params
2110
2111    def get_num_boosting_rounds(self) -> int:
2112        return 1
2113
2114    # pylint: disable=unused-argument
2115    def fit(
2116        self,
2117        X: _DaskCollection,
2118        y: _DaskCollection,
2119        *,
2120        sample_weight: Optional[_DaskCollection] = None,
2121        base_margin: Optional[_DaskCollection] = None,
2122        eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
2123        eval_metric: Optional[Union[str, List[str], Metric]] = None,
2124        early_stopping_rounds: Optional[int] = None,
2125        verbose: bool = True,
2126        xgb_model: Optional[Union[Booster, XGBModel]] = None,
2127        sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
2128        base_margin_eval_set: Optional[List[_DaskCollection]] = None,
2129        feature_weights: Optional[_DaskCollection] = None,
2130        callbacks: Optional[List[TrainingCallback]] = None
2131    ) -> "DaskXGBRFClassifier":
2132        _assert_dask_support()
2133        args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
2134        _check_rf_callback(early_stopping_rounds, callbacks)
2135        super().fit(**args)
2136        return self
2137