1# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module 2# pylint: disable=missing-class-docstring, invalid-name 3# pylint: disable=too-many-lines, fixme 4# pylint: disable=too-few-public-methods 5# pylint: disable=import-error 6"""Dask extensions for distributed training. See 7https://xgboost.readthedocs.io/en/latest/tutorials/dask.html for simple 8tutorial. Also xgboost/demo/dask for some examples. 9 10There are two sets of APIs in this module, one is the functional API including 11``train`` and ``predict`` methods. Another is stateful Scikit-Learner wrapper 12inherited from single-node Scikit-Learn interface. 13 14The implementation is heavily influenced by dask_xgboost: 15https://github.com/dask/dask-xgboost 16 17""" 18import platform 19import logging 20from contextlib import contextmanager 21from collections import defaultdict 22from collections.abc import Sequence 23from threading import Thread 24from functools import partial, update_wrapper 25from typing import TYPE_CHECKING, List, Tuple, Callable, Optional, Any, Union, Dict, Set 26from typing import Awaitable, Generator, TypeVar 27 28import numpy 29 30from . import rabit, config 31 32from .callback import TrainingCallback 33 34from .compat import LazyLoader 35from .compat import sparse, scipy_sparse 36from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat 37from .compat import lazy_isinstance 38 39from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter 40from .core import Objective, Metric 41from .core import _deprecate_positional_args 42from .training import train as worker_train 43from .tracker import RabitTracker, get_host_ip 44from .sklearn import XGBModel, XGBClassifier, XGBRegressorBase, XGBClassifierBase 45from .sklearn import _wrap_evaluation_matrices, _objective_decorator, _check_rf_callback 46from .sklearn import XGBRankerMixIn 47from .sklearn import xgboost_model_doc 48from .sklearn import _cls_predict_proba 49from .sklearn import XGBRanker 50 51if TYPE_CHECKING: 52 from dask import dataframe as dd 53 from dask import array as da 54 import dask 55 import distributed 56else: 57 dd = LazyLoader('dd', globals(), 'dask.dataframe') 58 da = LazyLoader('da', globals(), 'dask.array') 59 dask = LazyLoader('dask', globals(), 'dask') 60 distributed = LazyLoader('distributed', globals(), 'dask.distributed') 61 62_DaskCollection = Union["da.Array", "dd.DataFrame", "dd.Series"] 63 64try: 65 from mypy_extensions import TypedDict 66 TrainReturnT = TypedDict('TrainReturnT', { 67 'booster': Booster, 68 'history': Dict, 69 }) 70except ImportError: 71 TrainReturnT = Dict[str, Any] # type:ignore 72 73__all__ = [ 74 "RabitContext", 75 "DaskDMatrix", 76 "DaskDeviceQuantileDMatrix", 77 "DaskXGBRegressor", 78 "DaskXGBClassifier", 79 "DaskXGBRanker", 80 "DaskXGBRFRegressor", 81 "DaskXGBRFClassifier", 82 "train", 83 "predict", 84 "inplace_predict", 85] 86 87# TODOs: 88# - CV 89# 90# Note for developers: 91# 92# As of writing asyncio is still a new feature of Python and in depth documentation is 93# rare. Best examples of various asyncio tricks are in dask (luckily). Classes like 94# Client, Worker are awaitable. Some general rules for the implementation here: 95# 96# - Synchronous world is different from asynchronous one, and they don't mix well. 97# - Write everything with async, then use distributed Client sync function to do the 98# switch. 99# - Use Any for type hint when the return value can be union of Awaitable and plain 100# value. This is caused by Client.sync can return both types depending on context. 101# Right now there's no good way to silent: 102# 103# await train(...) 104# 105# if train returns an Union type. 106 107 108LOGGER = logging.getLogger('[xgboost.dask]') 109 110 111def _multi_lock() -> Any: 112 """MultiLock is only available on latest distributed. See: 113 114 https://github.com/dask/distributed/pull/4503 115 116""" 117 try: 118 from distributed import MultiLock 119 except ImportError: 120 class MultiLock: # type:ignore 121 def __init__(self, *args: Any, **kwargs: Any) -> None: 122 pass 123 124 def __enter__(self) -> "MultiLock": 125 return self 126 127 def __exit__(self, *args: Any, **kwargs: Any) -> None: 128 return 129 130 async def __aenter__(self) -> "MultiLock": 131 return self 132 133 async def __aexit__(self, *args: Any, **kwargs: Any) -> None: 134 return 135 136 return MultiLock 137 138 139def _start_tracker(n_workers: int) -> Dict[str, Any]: 140 """Start Rabit tracker """ 141 env = {'DMLC_NUM_WORKER': n_workers} 142 host = get_host_ip('auto') 143 rabit_context = RabitTracker(hostIP=host, nslave=n_workers, use_logger=False) 144 env.update(rabit_context.slave_envs()) 145 146 rabit_context.start(n_workers) 147 thread = Thread(target=rabit_context.join) 148 thread.daemon = True 149 thread.start() 150 return env 151 152 153def _assert_dask_support() -> None: 154 try: 155 import dask # pylint: disable=W0621,W0611 156 except ImportError as e: 157 raise ImportError( 158 "Dask needs to be installed in order to use this module" 159 ) from e 160 161 if platform.system() == "Windows": 162 msg = "Windows is not officially supported for dask/xgboost," 163 msg += " contribution are welcomed." 164 LOGGER.warning(msg) 165 166 167class RabitContext: 168 '''A context controling rabit initialization and finalization.''' 169 def __init__(self, args: List[bytes]) -> None: 170 self.args = args 171 worker = distributed.get_worker() 172 self.args.append( 173 ('DMLC_TASK_ID=[xgboost.dask]:' + str(worker.address)).encode()) 174 175 def __enter__(self) -> None: 176 rabit.init(self.args) 177 LOGGER.debug('-------------- rabit say hello ------------------') 178 179 def __exit__(self, *args: List) -> None: 180 rabit.finalize() 181 LOGGER.debug('--------------- rabit say bye ------------------') 182 183 184def concat(value: Any) -> Any: # pylint: disable=too-many-return-statements 185 '''To be replaced with dask builtin.''' 186 if isinstance(value[0], numpy.ndarray): 187 return numpy.concatenate(value, axis=0) 188 if scipy_sparse and isinstance(value[0], scipy_sparse.spmatrix): 189 return scipy_sparse.vstack(value, format='csr') 190 if sparse and isinstance(value[0], sparse.SparseArray): 191 return sparse.concatenate(value, axis=0) 192 if PANDAS_INSTALLED and isinstance(value[0], (DataFrame, Series)): 193 return pandas_concat(value, axis=0) 194 if lazy_isinstance(value[0], 'cudf.core.dataframe', 'DataFrame') or \ 195 lazy_isinstance(value[0], 'cudf.core.series', 'Series'): 196 from cudf import concat as CUDF_concat # pylint: disable=import-error 197 return CUDF_concat(value, axis=0) 198 if lazy_isinstance(value[0], 'cupy._core.core', 'ndarray'): 199 import cupy 200 # pylint: disable=c-extension-no-member,no-member 201 d = cupy.cuda.runtime.getDevice() 202 for v in value: 203 d_v = v.device.id 204 assert d_v == d, 'Concatenating arrays on different devices.' 205 return cupy.concatenate(value, axis=0) 206 return dd.multi.concat(list(value), axis=0) 207 208 209def _xgb_get_client(client: Optional["distributed.Client"]) -> "distributed.Client": 210 '''Simple wrapper around testing None.''' 211 if not isinstance(client, (type(distributed.get_client()), type(None))): 212 raise TypeError( 213 _expect([type(distributed.get_client()), type(None)], type(client))) 214 ret = distributed.get_client() if client is None else client 215 return ret 216 217# From the implementation point of view, DaskDMatrix complicates a lots of 218# things. A large portion of the code base is about syncing and extracting 219# stuffs from DaskDMatrix. But having an independent data structure gives us a 220# chance to perform some specialized optimizations, like building histogram 221# index directly. 222 223 224class DaskDMatrix: 225 # pylint: disable=missing-docstring, too-many-instance-attributes 226 '''DMatrix holding on references to Dask DataFrame or Dask Array. Constructing a 227 `DaskDMatrix` forces all lazy computation to be carried out. Wait for the input data 228 explicitly if you want to see actual computation of constructing `DaskDMatrix`. 229 230 See doc for :py:obj:`xgboost.DMatrix` constructor for other parameters. DaskDMatrix 231 accepts only dask collection. 232 233 .. note:: 234 235 DaskDMatrix does not repartition or move data between workers. It's 236 the caller's responsibility to balance the data. 237 238 .. versionadded:: 1.0.0 239 240 Parameters 241 ---------- 242 client : 243 Specify the dask client used for training. Use default client returned from dask 244 if it's set to None. 245 246 ''' 247 248 @_deprecate_positional_args 249 def __init__( 250 self, 251 client: "distributed.Client", 252 data: _DaskCollection, 253 label: Optional[_DaskCollection] = None, 254 *, 255 weight: Optional[_DaskCollection] = None, 256 base_margin: Optional[_DaskCollection] = None, 257 missing: float = None, 258 silent: bool = False, # pylint: disable=unused-argument 259 feature_names: Optional[Union[str, List[str]]] = None, 260 feature_types: Optional[Union[Any, List[Any]]] = None, 261 group: Optional[_DaskCollection] = None, 262 qid: Optional[_DaskCollection] = None, 263 label_lower_bound: Optional[_DaskCollection] = None, 264 label_upper_bound: Optional[_DaskCollection] = None, 265 feature_weights: Optional[_DaskCollection] = None, 266 enable_categorical: bool = False 267 ) -> None: 268 _assert_dask_support() 269 client = _xgb_get_client(client) 270 271 self.feature_names = feature_names 272 self.feature_types = feature_types 273 self.missing = missing 274 self.enable_categorical = enable_categorical 275 276 if qid is not None and weight is not None: 277 raise NotImplementedError("per-group weight is not implemented.") 278 if group is not None: 279 raise NotImplementedError( 280 "group structure is not implemented, use qid instead." 281 ) 282 283 if len(data.shape) != 2: 284 raise ValueError( 285 f"Expecting 2 dimensional input, got: {data.shape}" 286 ) 287 288 if not isinstance(data, (dd.DataFrame, da.Array)): 289 raise TypeError(_expect((dd.DataFrame, da.Array), type(data))) 290 if not isinstance(label, (dd.DataFrame, da.Array, dd.Series, type(None))): 291 raise TypeError(_expect((dd.DataFrame, da.Array, dd.Series), type(label))) 292 293 self._n_cols = data.shape[1] 294 assert isinstance(self._n_cols, int) 295 self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list) 296 self.is_quantile: bool = False 297 298 self._init = client.sync( 299 self._map_local_data, 300 client, 301 data, 302 label=label, 303 weights=weight, 304 base_margin=base_margin, 305 qid=qid, 306 feature_weights=feature_weights, 307 label_lower_bound=label_lower_bound, 308 label_upper_bound=label_upper_bound, 309 ) 310 311 def __await__(self) -> Generator: 312 return self._init.__await__() 313 314 async def _map_local_data( 315 self, 316 client: "distributed.Client", 317 data: _DaskCollection, 318 label: Optional[_DaskCollection] = None, 319 weights: Optional[_DaskCollection] = None, 320 base_margin: Optional[_DaskCollection] = None, 321 qid: Optional[_DaskCollection] = None, 322 feature_weights: Optional[_DaskCollection] = None, 323 label_lower_bound: Optional[_DaskCollection] = None, 324 label_upper_bound: Optional[_DaskCollection] = None, 325 ) -> "DaskDMatrix": 326 '''Obtain references to local data.''' 327 328 def inconsistent( 329 left: List[Any], left_name: str, right: List[Any], right_name: str 330 ) -> str: 331 msg = (f"Partitions between {left_name} and {right_name} are not " 332 f"consistent: {len(left)} != {len(right)}. " 333 f"Please try to repartition/rechunk your data.") 334 return msg 335 336 def check_columns(parts: Any) -> None: 337 # x is required to be 2 dim in __init__ 338 assert parts.ndim == 1 or parts.shape[1], 'Data should be' \ 339 ' partitioned by row. To avoid this specify the number' \ 340 ' of columns for your dask Array explicitly. e.g.' \ 341 ' chunks=(partition_size, X.shape[1])' 342 343 data = client.persist(data) 344 for meta in [label, weights, base_margin, label_lower_bound, 345 label_upper_bound]: 346 if meta is not None: 347 meta = client.persist(meta) 348 # Breaking data into partitions, a trick borrowed from dask_xgboost. 349 350 # `to_delayed` downgrades high-level objects into numpy or pandas 351 # equivalents. 352 X_parts = data.to_delayed() 353 if isinstance(X_parts, numpy.ndarray): 354 check_columns(X_parts) 355 X_parts = X_parts.flatten().tolist() 356 357 def flatten_meta( 358 meta: Optional[_DaskCollection] 359 ) -> "Optional[List[dask.delayed.Delayed]]": 360 if meta is not None: 361 meta_parts = meta.to_delayed() 362 if isinstance(meta_parts, numpy.ndarray): 363 check_columns(meta_parts) 364 meta_parts = meta_parts.flatten().tolist() 365 return meta_parts 366 return None 367 368 y_parts = flatten_meta(label) 369 w_parts = flatten_meta(weights) 370 margin_parts = flatten_meta(base_margin) 371 qid_parts = flatten_meta(qid) 372 ll_parts = flatten_meta(label_lower_bound) 373 lu_parts = flatten_meta(label_upper_bound) 374 375 parts = [X_parts] 376 meta_names = [] 377 378 def append_meta( 379 m_parts: Optional[List["dask.delayed.delayed"]], name: str 380 ) -> None: 381 if m_parts is not None: 382 assert len(X_parts) == len( 383 m_parts), inconsistent(X_parts, 'X', m_parts, name) 384 parts.append(m_parts) 385 meta_names.append(name) 386 387 append_meta(y_parts, 'labels') 388 append_meta(w_parts, 'weights') 389 append_meta(margin_parts, 'base_margin') 390 append_meta(qid_parts, 'qid') 391 append_meta(ll_parts, 'label_lower_bound') 392 append_meta(lu_parts, 'label_upper_bound') 393 # At this point, `parts` looks like: 394 # [(x0, x1, ..), (y0, y1, ..), ..] in delayed form 395 396 # delay the zipped result 397 parts = list(map(dask.delayed, zip(*parts))) # pylint: disable=no-member 398 # At this point, the mental model should look like: 399 # [(x0, y0, ..), (x1, y1, ..), ..] in delayed form 400 401 parts = client.compute(parts) 402 await distributed.wait(parts) # async wait for parts to be computed 403 404 for part in parts: 405 assert part.status == 'finished', part.status 406 407 # Preserving the partition order for prediction. 408 self.partition_order = {} 409 for i, part in enumerate(parts): 410 self.partition_order[part.key] = i 411 412 key_to_partition = {part.key: part for part in parts} 413 who_has = await client.scheduler.who_has(keys=[part.key for part in parts]) 414 415 worker_map: Dict[str, "distributed.Future"] = defaultdict(list) 416 417 for key, workers in who_has.items(): 418 worker_map[next(iter(workers))].append(key_to_partition[key]) 419 420 self.worker_map = worker_map 421 self.meta_names = meta_names 422 423 if feature_weights is None: 424 self.feature_weights = None 425 else: 426 self.feature_weights = await client.compute(feature_weights).result() 427 428 return self 429 430 def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]: 431 '''Create a dictionary of objects that can be pickled for function 432 arguments. 433 434 ''' 435 return {'feature_names': self.feature_names, 436 'feature_types': self.feature_types, 437 'feature_weights': self.feature_weights, 438 'meta_names': self.meta_names, 439 'missing': self.missing, 440 'enable_categorical': self.enable_categorical, 441 'parts': self.worker_map.get(worker_addr, None), 442 'is_quantile': self.is_quantile} 443 444 def num_col(self) -> int: 445 return self._n_cols 446 447 448_DataParts = List[Tuple[Any, Optional[Any], Optional[Any], Optional[Any], Optional[Any], 449 Optional[Any], Optional[Any]]] 450 451 452def _get_worker_parts_ordered( 453 meta_names: List[str], list_of_parts: _DataParts 454) -> _DataParts: 455 # List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved. 456 assert isinstance(list_of_parts, list) 457 458 result = [] 459 460 for i, _ in enumerate(list_of_parts): 461 data = list_of_parts[i][0] 462 labels = None 463 weights = None 464 base_margin = None 465 qid = None 466 label_lower_bound = None 467 label_upper_bound = None 468 # Iterate through all possible meta info, brings small overhead as in xgboost 469 # there are constant number of meta info available. 470 for j, blob in enumerate(list_of_parts[i][1:]): 471 if meta_names[j] == 'labels': 472 labels = blob 473 elif meta_names[j] == 'weights': 474 weights = blob 475 elif meta_names[j] == 'base_margin': 476 base_margin = blob 477 elif meta_names[j] == 'qid': 478 qid = blob 479 elif meta_names[j] == 'label_lower_bound': 480 label_lower_bound = blob 481 elif meta_names[j] == 'label_upper_bound': 482 label_upper_bound = blob 483 else: 484 raise ValueError('Unknown metainfo:', meta_names[j]) 485 result.append((data, labels, weights, base_margin, qid, label_lower_bound, 486 label_upper_bound)) 487 return result 488 489 490def _unzip(list_of_parts: _DataParts) -> List[Tuple[Any, ...]]: 491 return list(zip(*list_of_parts)) 492 493 494def _get_worker_parts( 495 list_of_parts: _DataParts, meta_names: List[str] 496) -> List[Tuple[Any, ...]]: 497 partitions = _get_worker_parts_ordered(meta_names, list_of_parts) 498 partitions_unzipped = _unzip(partitions) 499 return partitions_unzipped 500 501 502class DaskPartitionIter(DataIter): # pylint: disable=R0902 503 """A data iterator for `DaskDeviceQuantileDMatrix`.""" 504 505 def __init__( 506 self, 507 data: Tuple[Any, ...], 508 label: Optional[Tuple[Any, ...]] = None, 509 weight: Optional[Tuple[Any, ...]] = None, 510 base_margin: Optional[Tuple[Any, ...]] = None, 511 qid: Optional[Tuple[Any, ...]] = None, 512 label_lower_bound: Optional[Tuple[Any, ...]] = None, 513 label_upper_bound: Optional[Tuple[Any, ...]] = None, 514 feature_names: Optional[Union[str, List[str]]] = None, 515 feature_types: Optional[Union[Any, List[Any]]] = None 516 ) -> None: 517 self._data = data 518 self._labels = label 519 self._weights = weight 520 self._base_margin = base_margin 521 self._qid = qid 522 self._label_lower_bound = label_lower_bound 523 self._label_upper_bound = label_upper_bound 524 self._feature_names = feature_names 525 self._feature_types = feature_types 526 527 assert isinstance(self._data, Sequence) 528 529 types = (Sequence, type(None)) 530 assert isinstance(self._labels, types) 531 assert isinstance(self._weights, types) 532 assert isinstance(self._base_margin, types) 533 assert isinstance(self._label_lower_bound, types) 534 assert isinstance(self._label_upper_bound, types) 535 536 self._iter = 0 # set iterator to 0 537 super().__init__() 538 539 def data(self) -> Any: 540 '''Utility function for obtaining current batch of data.''' 541 return self._data[self._iter] 542 543 def labels(self) -> Any: 544 '''Utility function for obtaining current batch of label.''' 545 if self._labels is not None: 546 return self._labels[self._iter] 547 return None 548 549 def weights(self) -> Any: 550 '''Utility function for obtaining current batch of label.''' 551 if self._weights is not None: 552 return self._weights[self._iter] 553 return None 554 555 def qids(self) -> Any: 556 '''Utility function for obtaining current batch of query id.''' 557 if self._qid is not None: 558 return self._qid[self._iter] 559 return None 560 561 def base_margins(self) -> Any: 562 '''Utility function for obtaining current batch of base_margin.''' 563 if self._base_margin is not None: 564 return self._base_margin[self._iter] 565 return None 566 567 def label_lower_bounds(self) -> Any: 568 '''Utility function for obtaining current batch of label_lower_bound. 569 ''' 570 if self._label_lower_bound is not None: 571 return self._label_lower_bound[self._iter] 572 return None 573 574 def label_upper_bounds(self) -> Any: 575 '''Utility function for obtaining current batch of label_upper_bound. 576 ''' 577 if self._label_upper_bound is not None: 578 return self._label_upper_bound[self._iter] 579 return None 580 581 def reset(self) -> None: 582 '''Reset the iterator''' 583 self._iter = 0 584 585 def next(self, input_data: Callable) -> int: 586 '''Yield next batch of data''' 587 if self._iter == len(self._data): 588 # Return 0 when there's no more batch. 589 return 0 590 feature_names: Optional[Union[List[str], str]] = None 591 if self._feature_names: 592 feature_names = self._feature_names 593 else: 594 if hasattr(self.data(), 'columns'): 595 feature_names = self.data().columns.format() 596 else: 597 feature_names = None 598 input_data(data=self.data(), label=self.labels(), 599 weight=self.weights(), group=None, 600 qid=self.qids(), 601 label_lower_bound=self.label_lower_bounds(), 602 label_upper_bound=self.label_upper_bounds(), 603 feature_names=feature_names, 604 feature_types=self._feature_types) 605 self._iter += 1 606 return 1 607 608 609class DaskDeviceQuantileDMatrix(DaskDMatrix): 610 '''Specialized data type for `gpu_hist` tree method. This class is used to reduce the 611 memory usage by eliminating data copies. Internally the all partitions/chunks of data 612 are merged by weighted GK sketching. So the number of partitions from dask may affect 613 training accuracy as GK generates bounded error for each merge. See doc string for 614 :py:obj:`xgboost.DeviceQuantileDMatrix` and :py:obj:`xgboost.DMatrix` for other 615 parameters. 616 617 .. versionadded:: 1.2.0 618 619 Parameters 620 ---------- 621 max_bin : Number of bins for histogram construction. 622 623 ''' 624 @_deprecate_positional_args 625 def __init__( 626 self, 627 client: "distributed.Client", 628 data: _DaskCollection, 629 label: Optional[_DaskCollection] = None, 630 *, 631 weight: Optional[_DaskCollection] = None, 632 base_margin: Optional[_DaskCollection] = None, 633 missing: float = None, 634 silent: bool = False, # disable=unused-argument 635 feature_names: Optional[Union[str, List[str]]] = None, 636 feature_types: Optional[Union[Any, List[Any]]] = None, 637 max_bin: int = 256, 638 group: Optional[_DaskCollection] = None, 639 qid: Optional[_DaskCollection] = None, 640 label_lower_bound: Optional[_DaskCollection] = None, 641 label_upper_bound: Optional[_DaskCollection] = None, 642 feature_weights: Optional[_DaskCollection] = None, 643 enable_categorical: bool = False, 644 ) -> None: 645 super().__init__( 646 client=client, 647 data=data, 648 label=label, 649 weight=weight, 650 base_margin=base_margin, 651 group=group, 652 qid=qid, 653 label_lower_bound=label_lower_bound, 654 label_upper_bound=label_upper_bound, 655 missing=missing, 656 silent=silent, 657 feature_weights=feature_weights, 658 feature_names=feature_names, 659 feature_types=feature_types, 660 enable_categorical=enable_categorical, 661 ) 662 self.max_bin = max_bin 663 self.is_quantile = True 664 665 def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]: 666 args = super()._create_fn_args(worker_addr) 667 args["max_bin"] = self.max_bin 668 return args 669 670 671def _create_device_quantile_dmatrix( 672 feature_names: Optional[Union[str, List[str]]], 673 feature_types: Optional[Union[Any, List[Any]]], 674 feature_weights: Optional[Any], 675 meta_names: List[str], 676 missing: float, 677 parts: Optional[_DataParts], 678 max_bin: int, 679 enable_categorical: bool, 680) -> DeviceQuantileDMatrix: 681 worker = distributed.get_worker() 682 if parts is None: 683 msg = f"worker {worker.address} has an empty DMatrix." 684 LOGGER.warning(msg) 685 import cupy 686 687 d = DeviceQuantileDMatrix( 688 cupy.zeros((0, 0)), 689 feature_names=feature_names, 690 feature_types=feature_types, 691 max_bin=max_bin, 692 enable_categorical=enable_categorical, 693 ) 694 return d 695 696 ( 697 data, 698 labels, 699 weights, 700 base_margin, 701 qid, 702 label_lower_bound, 703 label_upper_bound, 704 ) = _get_worker_parts(parts, meta_names) 705 it = DaskPartitionIter( 706 data=data, 707 label=labels, 708 weight=weights, 709 base_margin=base_margin, 710 qid=qid, 711 label_lower_bound=label_lower_bound, 712 label_upper_bound=label_upper_bound, 713 ) 714 715 dmatrix = DeviceQuantileDMatrix( 716 it, 717 missing=missing, 718 feature_names=feature_names, 719 feature_types=feature_types, 720 nthread=worker.nthreads, 721 max_bin=max_bin, 722 enable_categorical=enable_categorical, 723 ) 724 dmatrix.set_info(feature_weights=feature_weights) 725 return dmatrix 726 727 728def _create_dmatrix( 729 feature_names: Optional[Union[str, List[str]]], 730 feature_types: Optional[Union[Any, List[Any]]], 731 feature_weights: Optional[Any], 732 meta_names: List[str], 733 missing: float, 734 enable_categorical: bool, 735 parts: Optional[_DataParts] 736) -> DMatrix: 737 '''Get data that local to worker from DaskDMatrix. 738 739 Returns 740 ------- 741 A DMatrix object. 742 743 ''' 744 worker = distributed.get_worker() 745 list_of_parts = parts 746 if list_of_parts is None: 747 msg = f"worker {worker.address} has an empty DMatrix." 748 LOGGER.warning(msg) 749 d = DMatrix( 750 numpy.empty((0, 0)), 751 feature_names=feature_names, 752 feature_types=feature_types, 753 enable_categorical=enable_categorical, 754 ) 755 return d 756 757 T = TypeVar('T') 758 759 def concat_or_none(data: Tuple[Optional[T], ...]) -> Optional[T]: 760 if any(part is None for part in data): 761 return None 762 return concat(data) 763 764 (data, labels, weights, base_margin, qid, 765 label_lower_bound, label_upper_bound) = _get_worker_parts(list_of_parts, meta_names) 766 767 _labels = concat_or_none(labels) 768 _weights = concat_or_none(weights) 769 _base_margin = concat_or_none(base_margin) 770 _qid = concat_or_none(qid) 771 _label_lower_bound = concat_or_none(label_lower_bound) 772 _label_upper_bound = concat_or_none(label_upper_bound) 773 774 _data = concat(data) 775 dmatrix = DMatrix( 776 _data, 777 _labels, 778 missing=missing, 779 feature_names=feature_names, 780 feature_types=feature_types, 781 nthread=worker.nthreads, 782 enable_categorical=enable_categorical, 783 ) 784 dmatrix.set_info( 785 base_margin=_base_margin, 786 qid=_qid, 787 weight=_weights, 788 label_lower_bound=_label_lower_bound, 789 label_upper_bound=_label_upper_bound, 790 feature_weights=feature_weights, 791 ) 792 return dmatrix 793 794 795def _dmatrix_from_list_of_parts( 796 is_quantile: bool, **kwargs: Any 797) -> Union[DMatrix, DeviceQuantileDMatrix]: 798 if is_quantile: 799 return _create_device_quantile_dmatrix(**kwargs) 800 return _create_dmatrix(**kwargs) 801 802 803async def _get_rabit_args(n_workers: int, client: "distributed.Client") -> List[bytes]: 804 '''Get rabit context arguments from data distribution in DaskDMatrix.''' 805 env = await client.run_on_scheduler(_start_tracker, n_workers) 806 rabit_args = [f"{k}={v}".encode() for k, v in env.items()] 807 return rabit_args 808 809# train and predict methods are supposed to be "functional", which meets the 810# dask paradigm. But as a side effect, the `evals_result` in single-node API 811# is no longer supported since it mutates the input parameter, and it's not 812# intuitive to sync the mutation result. Therefore, a dictionary containing 813# evaluation history is instead returned. 814 815 816def _get_workers_from_data( 817 dtrain: DaskDMatrix, 818 evals: Optional[List[Tuple[DaskDMatrix, str]]] 819) -> List[str]: 820 X_worker_map: Set[str] = set(dtrain.worker_map.keys()) 821 if evals: 822 for e in evals: 823 assert len(e) == 2 824 assert isinstance(e[0], DaskDMatrix) and isinstance(e[1], str) 825 if e[0] is dtrain: 826 continue 827 worker_map = set(e[0].worker_map.keys()) 828 X_worker_map = X_worker_map.union(worker_map) 829 return list(X_worker_map) 830 831 832async def _train_async( 833 client: "distributed.Client", 834 global_config: Dict[str, Any], 835 params: Dict[str, Any], 836 dtrain: DaskDMatrix, 837 num_boost_round: int, 838 evals: Optional[List[Tuple[DaskDMatrix, str]]], 839 obj: Optional[Objective], 840 feval: Optional[Metric], 841 early_stopping_rounds: Optional[int], 842 verbose_eval: Union[int, bool], 843 xgb_model: Optional[Booster], 844 callbacks: Optional[List[TrainingCallback]], 845) -> Optional[TrainReturnT]: 846 workers = _get_workers_from_data(dtrain, evals) 847 _rabit_args = await _get_rabit_args(len(workers), client) 848 849 if params.get("booster", None) == "gblinear": 850 raise NotImplementedError( 851 f"booster `{params['booster']}` is not yet supported for dask." 852 ) 853 854 def dispatched_train( 855 worker_addr: str, 856 rabit_args: List[bytes], 857 dtrain_ref: Dict, 858 dtrain_idt: int, 859 evals_ref: Dict 860 ) -> Optional[Dict[str, Union[Booster, Dict]]]: 861 '''Perform training on a single worker. A local function prevents pickling. 862 863 ''' 864 LOGGER.debug('Training on %s', str(worker_addr)) 865 worker = distributed.get_worker() 866 with RabitContext(rabit_args), config.config_context(**global_config): 867 local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref) 868 local_evals = [] 869 if evals_ref: 870 for ref, name, idt in evals_ref: 871 if idt == dtrain_idt: 872 local_evals.append((local_dtrain, name)) 873 continue 874 local_evals.append((_dmatrix_from_list_of_parts(**ref), name)) 875 876 local_history: Dict = {} 877 local_param = params.copy() # just to be consistent 878 msg = 'Overriding `nthreads` defined in dask worker.' 879 override = ['nthread', 'n_jobs'] 880 for p in override: 881 val = local_param.get(p, None) 882 if val is not None and val != worker.nthreads: 883 LOGGER.info(msg) 884 else: 885 local_param[p] = worker.nthreads 886 bst = worker_train(params=local_param, 887 dtrain=local_dtrain, 888 num_boost_round=num_boost_round, 889 evals_result=local_history, 890 evals=local_evals, 891 obj=obj, 892 feval=feval, 893 early_stopping_rounds=early_stopping_rounds, 894 verbose_eval=verbose_eval, 895 xgb_model=xgb_model, 896 callbacks=callbacks) 897 ret: Optional[Dict[str, Union[Booster, Dict]]] = { 898 'booster': bst, 'history': local_history} 899 if local_dtrain.num_row() == 0: 900 ret = None 901 return ret 902 903 # Note for function purity: 904 # XGBoost is deterministic in most of the cases, which means train function is 905 # supposed to be idempotent. One known exception is gblinear with shotgun updater. 906 # We haven't been able to do a full verification so here we keep pure to be False. 907 async with _multi_lock()(workers, client): 908 futures = [] 909 for worker_addr in workers: 910 if evals: 911 # pylint: disable=protected-access 912 evals_per_worker = [ 913 (e._create_fn_args(worker_addr), name, id(e)) for e, name in evals 914 ] 915 else: 916 evals_per_worker = [] 917 f = client.submit( 918 dispatched_train, 919 worker_addr, 920 _rabit_args, 921 # pylint: disable=protected-access 922 dtrain._create_fn_args(worker_addr), 923 id(dtrain), 924 evals_per_worker, 925 pure=False, 926 workers=[worker_addr], 927 allow_other_workers=False 928 ) 929 futures.append(f) 930 931 results = await client.gather(futures, asynchronous=True) 932 933 return list(filter(lambda ret: ret is not None, results))[0] 934 935 936def train( # pylint: disable=unused-argument 937 client: "distributed.Client", 938 params: Dict[str, Any], 939 dtrain: DaskDMatrix, 940 num_boost_round: int = 10, 941 evals: Optional[List[Tuple[DaskDMatrix, str]]] = None, 942 obj: Optional[Objective] = None, 943 feval: Optional[Metric] = None, 944 early_stopping_rounds: Optional[int] = None, 945 xgb_model: Optional[Booster] = None, 946 verbose_eval: Union[int, bool] = True, 947 callbacks: Optional[List[TrainingCallback]] = None, 948) -> Any: 949 """Train XGBoost model. 950 951 .. versionadded:: 1.0.0 952 953 .. note:: 954 955 Other parameters are the same as :py:func:`xgboost.train` except for 956 `evals_result`, which is returned as part of function return value instead of 957 argument. 958 959 Parameters 960 ---------- 961 client : 962 Specify the dask client used for training. Use default client returned from dask 963 if it's set to None. 964 965 Returns 966 ------- 967 results: dict 968 A dictionary containing trained booster and evaluation history. `history` field 969 is the same as `eval_result` from `xgboost.train`. 970 971 .. code-block:: python 972 973 {'booster': xgboost.Booster, 974 'history': {'train': {'logloss': ['0.48253', '0.35953']}, 975 'eval': {'logloss': ['0.480385', '0.357756']}}} 976 977 """ 978 _assert_dask_support() 979 client = _xgb_get_client(client) 980 args = locals() 981 return client.sync(_train_async, global_config=config.get_config(), **args) 982 983 984def _can_output_df(is_df: bool, output_shape: Tuple) -> bool: 985 return is_df and len(output_shape) <= 2 986 987 988def _maybe_dataframe( 989 data: Any, prediction: Any, columns: List[int], is_df: bool 990) -> Any: 991 """Return dataframe for prediction when applicable.""" 992 if _can_output_df(is_df, prediction.shape): 993 # Need to preserve the index for dataframe. 994 # See issue: https://github.com/dmlc/xgboost/issues/6939 995 # In older versions of dask, the partition is actually a numpy array when input is 996 # dataframe. 997 index = getattr(data, "index", None) 998 if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"): 999 import cudf 1000 1001 prediction = cudf.DataFrame( 1002 prediction, columns=columns, dtype=numpy.float32, index=index 1003 ) 1004 else: 1005 prediction = DataFrame( 1006 prediction, columns=columns, dtype=numpy.float32, index=index 1007 ) 1008 return prediction 1009 1010 1011async def _direct_predict_impl( # pylint: disable=too-many-branches 1012 mapped_predict: Callable, 1013 booster: "distributed.Future", 1014 data: _DaskCollection, 1015 base_margin: Optional[_DaskCollection], 1016 output_shape: Tuple[int, ...], 1017 meta: Dict[int, str], 1018) -> _DaskCollection: 1019 columns = tuple(meta.keys()) 1020 if len(output_shape) >= 3 and isinstance(data, dd.DataFrame): 1021 # Without this check, dask will finish the prediction silently even if output 1022 # dimension is greater than 3. But during map_partitions, dask passes a 1023 # `dd.DataFrame` as local input to xgboost, which is converted to csr_matrix by 1024 # `_convert_unknown_data` since dd.DataFrame is not known to xgboost native 1025 # binding. 1026 raise ValueError( 1027 "Use `da.Array` or `DaskDMatrix` when output has more than 2 dimensions." 1028 ) 1029 if _can_output_df(isinstance(data, dd.DataFrame), output_shape): 1030 if base_margin is not None and isinstance(base_margin, da.Array): 1031 # Easier for map_partitions 1032 base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe() 1033 else: 1034 base_margin_df = base_margin 1035 predictions = dd.map_partitions( 1036 mapped_predict, 1037 booster, 1038 data, 1039 True, 1040 columns, 1041 base_margin_df, 1042 meta=dd.utils.make_meta(meta), 1043 ) 1044 # classification can return a dataframe, drop 1 dim when it's reg/binary 1045 if len(output_shape) == 1: 1046 predictions = predictions.iloc[:, 0] 1047 else: 1048 if base_margin is not None and isinstance( 1049 base_margin, (dd.Series, dd.DataFrame) 1050 ): 1051 # Easier for map_blocks 1052 base_margin_array: Optional[da.Array] = base_margin.to_dask_array() 1053 else: 1054 base_margin_array = base_margin 1055 # Input data is 2-dim array, output can be 1(reg, binary)/2(multi-class, 1056 # contrib)/3(contrib, interaction)/4(interaction) dims. 1057 if len(output_shape) == 1: 1058 drop_axis: Union[int, List[int]] = [1] # drop from 2 to 1 dim. 1059 new_axis: Union[int, List[int]] = [] 1060 else: 1061 drop_axis = [] 1062 if isinstance(data, dd.DataFrame): 1063 new_axis = list(range(len(output_shape) - 2)) 1064 else: 1065 new_axis = [i + 2 for i in range(len(output_shape) - 2)] 1066 if len(output_shape) == 2: 1067 # Somehow dask fail to infer output shape change for 2-dim prediction, and 1068 # `chunks = (None, output_shape[1])` doesn't work due to None is not 1069 # supported in map_blocks. 1070 chunks: Optional[List[Tuple]] = list(data.chunks) 1071 assert isinstance(chunks, list) 1072 chunks[1] = (output_shape[1], ) 1073 else: 1074 chunks = None 1075 predictions = da.map_blocks( 1076 mapped_predict, 1077 booster, 1078 data, 1079 False, 1080 columns, 1081 base_margin_array, 1082 1083 chunks=chunks, 1084 drop_axis=drop_axis, 1085 new_axis=new_axis, 1086 dtype=numpy.float32, 1087 ) 1088 return predictions 1089 1090 1091def _infer_predict_output( 1092 booster: Booster, features: int, is_df: bool, inplace: bool, **kwargs: Any 1093) -> Tuple[Tuple[int, ...], Dict[int, str]]: 1094 """Create a dummy test sample to infer output shape for prediction.""" 1095 assert isinstance(features, int) 1096 rng = numpy.random.RandomState(1994) 1097 test_sample = rng.randn(1, features) 1098 if inplace: 1099 kwargs = kwargs.copy() 1100 if kwargs.pop("predict_type") == "margin": 1101 kwargs["output_margin"] = True 1102 m = DMatrix(test_sample) 1103 # generated DMatrix doesn't have feature name, so no validation. 1104 test_predt = booster.predict(m, validate_features=False, **kwargs) 1105 n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1 1106 meta: Dict[int, str] = {} 1107 if _can_output_df(is_df, test_predt.shape): 1108 for i in range(n_columns): 1109 meta[i] = "f4" 1110 return test_predt.shape, meta 1111 1112 1113async def _get_model_future( 1114 client: "distributed.Client", model: Union[Booster, Dict, "distributed.Future"] 1115) -> "distributed.Future": 1116 if isinstance(model, Booster): 1117 booster = await client.scatter(model, broadcast=True) 1118 elif isinstance(model, dict): 1119 booster = await client.scatter(model["booster"], broadcast=True) 1120 elif isinstance(model, distributed.Future): 1121 booster = model 1122 if booster.type is not Booster: 1123 raise TypeError( 1124 f"Underlying type of model future should be `Booster`, got {booster.type}" 1125 ) 1126 else: 1127 raise TypeError(_expect([Booster, dict, distributed.Future], type(model))) 1128 return booster 1129 1130 1131# pylint: disable=too-many-statements 1132async def _predict_async( 1133 client: "distributed.Client", 1134 global_config: Dict[str, Any], 1135 model: Union[Booster, Dict, "distributed.Future"], 1136 data: _DaskCollection, 1137 output_margin: bool, 1138 missing: float, 1139 pred_leaf: bool, 1140 pred_contribs: bool, 1141 approx_contribs: bool, 1142 pred_interactions: bool, 1143 validate_features: bool, 1144 iteration_range: Tuple[int, int], 1145 strict_shape: bool, 1146) -> _DaskCollection: 1147 _booster = await _get_model_future(client, model) 1148 if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)): 1149 raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data))) 1150 1151 def mapped_predict( 1152 booster: Booster, partition: Any, is_df: bool, columns: List[int], _: Any 1153 ) -> Any: 1154 with config.config_context(**global_config): 1155 m = DMatrix(data=partition, missing=missing) 1156 predt = booster.predict( 1157 data=m, 1158 output_margin=output_margin, 1159 pred_leaf=pred_leaf, 1160 pred_contribs=pred_contribs, 1161 approx_contribs=approx_contribs, 1162 pred_interactions=pred_interactions, 1163 validate_features=validate_features, 1164 iteration_range=iteration_range, 1165 strict_shape=strict_shape, 1166 ) 1167 predt = _maybe_dataframe(partition, predt, columns, is_df) 1168 return predt 1169 1170 # Predict on dask collection directly. 1171 if isinstance(data, (da.Array, dd.DataFrame)): 1172 _output_shape, meta = await client.compute( 1173 client.submit( 1174 _infer_predict_output, 1175 _booster, 1176 features=data.shape[1], 1177 is_df=isinstance(data, dd.DataFrame), 1178 inplace=False, 1179 output_margin=output_margin, 1180 pred_leaf=pred_leaf, 1181 pred_contribs=pred_contribs, 1182 approx_contribs=approx_contribs, 1183 pred_interactions=pred_interactions, 1184 strict_shape=strict_shape, 1185 ) 1186 ) 1187 return await _direct_predict_impl( 1188 mapped_predict, _booster, data, None, _output_shape, meta 1189 ) 1190 1191 output_shape, _ = await client.compute( 1192 client.submit( 1193 _infer_predict_output, 1194 booster=_booster, 1195 features=data.num_col(), 1196 is_df=False, 1197 inplace=False, 1198 output_margin=output_margin, 1199 pred_leaf=pred_leaf, 1200 pred_contribs=pred_contribs, 1201 approx_contribs=approx_contribs, 1202 pred_interactions=pred_interactions, 1203 strict_shape=strict_shape, 1204 ) 1205 ) 1206 # Prediction on dask DMatrix. 1207 partition_order = data.partition_order 1208 feature_names = data.feature_names 1209 feature_types = data.feature_types 1210 missing = data.missing 1211 meta_names = data.meta_names 1212 1213 def dispatched_predict(booster: Booster, part: Tuple) -> numpy.ndarray: 1214 data = part[0] 1215 assert isinstance(part, tuple), type(part) 1216 base_margin = None 1217 for i, blob in enumerate(part[1:]): 1218 if meta_names[i] == "base_margin": 1219 base_margin = blob 1220 with config.config_context(**global_config): 1221 m = DMatrix( 1222 data, 1223 missing=missing, 1224 base_margin=base_margin, 1225 feature_names=feature_names, 1226 feature_types=feature_types, 1227 ) 1228 predt = booster.predict( 1229 m, 1230 output_margin=output_margin, 1231 pred_leaf=pred_leaf, 1232 pred_contribs=pred_contribs, 1233 approx_contribs=approx_contribs, 1234 pred_interactions=pred_interactions, 1235 validate_features=validate_features, 1236 iteration_range=iteration_range, 1237 strict_shape=strict_shape, 1238 ) 1239 return predt 1240 1241 all_parts = [] 1242 all_orders = [] 1243 all_shapes = [] 1244 all_workers: List[str] = [] 1245 workers_address = list(data.worker_map.keys()) 1246 for worker_addr in workers_address: 1247 list_of_parts = data.worker_map[worker_addr] 1248 all_parts.extend(list_of_parts) 1249 all_workers.extend(len(list_of_parts) * [worker_addr]) 1250 all_orders.extend([partition_order[part.key] for part in list_of_parts]) 1251 for w, part in zip(all_workers, all_parts): 1252 s = client.submit(lambda part: part[0].shape[0], part, workers=[w]) 1253 all_shapes.append(s) 1254 all_shapes = await client.gather(all_shapes) 1255 1256 parts_with_order = list(zip(all_parts, all_shapes, all_orders)) 1257 parts_with_order = sorted(parts_with_order, key=lambda p: p[2]) 1258 all_parts = [part for part, shape, order in parts_with_order] 1259 all_shapes = [shape for part, shape, order in parts_with_order] 1260 1261 futures = [] 1262 for w, part in zip(all_workers, all_parts): 1263 f = client.submit(dispatched_predict, _booster, part, workers=[w]) 1264 futures.append(f) 1265 1266 # Constructing a dask array from list of numpy arrays 1267 # See https://docs.dask.org/en/latest/array-creation.html 1268 arrays = [] 1269 for i, rows in enumerate(all_shapes): 1270 arrays.append( 1271 da.from_delayed( 1272 futures[i], shape=(rows,) + output_shape[1:], dtype=numpy.float32 1273 ) 1274 ) 1275 predictions = da.concatenate(arrays, axis=0) 1276 return predictions 1277 1278 1279def predict( # pylint: disable=unused-argument 1280 client: "distributed.Client", 1281 model: Union[TrainReturnT, Booster, "distributed.Future"], 1282 data: Union[DaskDMatrix, _DaskCollection], 1283 output_margin: bool = False, 1284 missing: float = numpy.nan, 1285 pred_leaf: bool = False, 1286 pred_contribs: bool = False, 1287 approx_contribs: bool = False, 1288 pred_interactions: bool = False, 1289 validate_features: bool = True, 1290 iteration_range: Tuple[int, int] = (0, 0), 1291 strict_shape: bool = False, 1292) -> Any: 1293 '''Run prediction with a trained booster. 1294 1295 .. note:: 1296 1297 Using ``inplace_predict`` might be faster when some features are not needed. See 1298 :py:meth:`xgboost.Booster.predict` for details on various parameters. When output 1299 has more than 2 dimensions (shap value, leaf with strict_shape), input should be 1300 ``da.Array`` or ``DaskDMatrix``. 1301 1302 .. versionadded:: 1.0.0 1303 1304 Parameters 1305 ---------- 1306 client: 1307 Specify the dask client used for training. Use default client 1308 returned from dask if it's set to None. 1309 model: 1310 The trained model. It can be a distributed.Future so user can 1311 pre-scatter it onto all workers. 1312 data: 1313 Input data used for prediction. When input is a dataframe object, 1314 prediction output is a series. 1315 missing: 1316 Used when input data is not DaskDMatrix. Specify the value 1317 considered as missing. 1318 1319 Returns 1320 ------- 1321 prediction: dask.array.Array/dask.dataframe.Series 1322 When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an 1323 array, when input data is ``dask.dataframe.DataFrame``, return value can be 1324 ``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output 1325 shape. 1326 1327 ''' 1328 _assert_dask_support() 1329 client = _xgb_get_client(client) 1330 return client.sync(_predict_async, global_config=config.get_config(), **locals()) 1331 1332 1333async def _inplace_predict_async( # pylint: disable=too-many-branches 1334 client: "distributed.Client", 1335 global_config: Dict[str, Any], 1336 model: Union[Booster, Dict, "distributed.Future"], 1337 data: _DaskCollection, 1338 iteration_range: Tuple[int, int], 1339 predict_type: str, 1340 missing: float, 1341 validate_features: bool, 1342 base_margin: Optional[_DaskCollection], 1343 strict_shape: bool, 1344) -> _DaskCollection: 1345 client = _xgb_get_client(client) 1346 booster = await _get_model_future(client, model) 1347 if not isinstance(data, (da.Array, dd.DataFrame)): 1348 raise TypeError(_expect([da.Array, dd.DataFrame], type(data))) 1349 if base_margin is not None and not isinstance( 1350 data, (da.Array, dd.DataFrame, dd.Series) 1351 ): 1352 raise TypeError(_expect([da.Array, dd.DataFrame, dd.Series], type(base_margin))) 1353 1354 def mapped_predict( 1355 booster: Booster, partition: Any, is_df: bool, columns: List[int], base_margin: Any 1356 ) -> Any: 1357 with config.config_context(**global_config): 1358 prediction = booster.inplace_predict( 1359 partition, 1360 iteration_range=iteration_range, 1361 predict_type=predict_type, 1362 missing=missing, 1363 base_margin=base_margin, 1364 validate_features=validate_features, 1365 strict_shape=strict_shape, 1366 ) 1367 prediction = _maybe_dataframe(partition, prediction, columns, is_df) 1368 return prediction 1369 1370 # await turns future into value. 1371 shape, meta = await client.compute( 1372 client.submit( 1373 _infer_predict_output, 1374 booster, 1375 features=data.shape[1], 1376 is_df=isinstance(data, dd.DataFrame), 1377 inplace=True, 1378 predict_type=predict_type, 1379 iteration_range=iteration_range, 1380 strict_shape=strict_shape, 1381 ) 1382 ) 1383 return await _direct_predict_impl( 1384 mapped_predict, booster, data, base_margin, shape, meta 1385 ) 1386 1387 1388def inplace_predict( # pylint: disable=unused-argument 1389 client: "distributed.Client", 1390 model: Union[TrainReturnT, Booster, "distributed.Future"], 1391 data: _DaskCollection, 1392 iteration_range: Tuple[int, int] = (0, 0), 1393 predict_type: str = "value", 1394 missing: float = numpy.nan, 1395 validate_features: bool = True, 1396 base_margin: Optional[_DaskCollection] = None, 1397 strict_shape: bool = False, 1398) -> Any: 1399 """Inplace prediction. See doc in :py:meth:`xgboost.Booster.inplace_predict` for details. 1400 1401 .. versionadded:: 1.1.0 1402 1403 Parameters 1404 ---------- 1405 client: 1406 Specify the dask client used for training. Use default client 1407 returned from dask if it's set to None. 1408 model: 1409 See :py:func:`xgboost.dask.predict` for details. 1410 data : 1411 dask collection. 1412 iteration_range: 1413 See :py:meth:`xgboost.Booster.predict` for details. 1414 predict_type: 1415 See :py:meth:`xgboost.Booster.inplace_predict` for details. 1416 missing: 1417 Value in the input data which needs to be present as a missing 1418 value. If None, defaults to np.nan. 1419 base_margin: 1420 See :py:obj:`xgboost.DMatrix` for details. Right now classifier is not well 1421 supported with base_margin as it requires the size of base margin to be `n_classes 1422 * n_samples`. 1423 1424 .. versionadded:: 1.4.0 1425 1426 strict_shape: 1427 See :py:meth:`xgboost.Booster.predict` for details. 1428 1429 .. versionadded:: 1.4.0 1430 1431 Returns 1432 ------- 1433 prediction : 1434 When input data is ``dask.array.Array``, the return value is an array, when input 1435 data is ``dask.dataframe.DataFrame``, return value can be 1436 ``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output 1437 shape. 1438 1439 """ 1440 _assert_dask_support() 1441 client = _xgb_get_client(client) 1442 # When used in asynchronous environment, the `client` object should have 1443 # `asynchronous` attribute as True. When invoked by the skl interface, it's 1444 # responsible for setting up the client. 1445 return client.sync( 1446 _inplace_predict_async, global_config=config.get_config(), **locals() 1447 ) 1448 1449 1450async def _async_wrap_evaluation_matrices( 1451 client: "distributed.Client", **kwargs: Any 1452) -> Tuple[DaskDMatrix, Optional[List[Tuple[DaskDMatrix, str]]]]: 1453 """A switch function for async environment.""" 1454 1455 def _inner(**kwargs: Any) -> DaskDMatrix: 1456 m = DaskDMatrix(client=client, **kwargs) 1457 return m 1458 1459 train_dmatrix, evals = _wrap_evaluation_matrices(create_dmatrix=_inner, **kwargs) 1460 train_dmatrix = await train_dmatrix 1461 if evals is None: 1462 return train_dmatrix, evals 1463 awaited = [] 1464 for e in evals: 1465 if e[0] is train_dmatrix: # already awaited 1466 awaited.append(e) 1467 continue 1468 awaited.append((await e[0], e[1])) 1469 return train_dmatrix, awaited 1470 1471 1472@contextmanager 1473def _set_worker_client( 1474 model: "DaskScikitLearnBase", client: "distributed.Client" 1475) -> Generator: 1476 """Temporarily set the client for sklearn model.""" 1477 try: 1478 model.client = client 1479 yield model 1480 finally: 1481 model.client = None 1482 1483 1484class DaskScikitLearnBase(XGBModel): 1485 """Base class for implementing scikit-learn interface with Dask""" 1486 1487 _client = None 1488 1489 async def _predict_async( 1490 self, 1491 data: _DaskCollection, 1492 output_margin: bool, 1493 validate_features: bool, 1494 base_margin: Optional[_DaskCollection], 1495 iteration_range: Optional[Tuple[int, int]], 1496 ) -> Any: 1497 iteration_range = self._get_iteration_range(iteration_range) 1498 if self._can_use_inplace_predict(): 1499 predts = await inplace_predict( 1500 client=self.client, 1501 model=self.get_booster(), 1502 data=data, 1503 iteration_range=iteration_range, 1504 predict_type="margin" if output_margin else "value", 1505 missing=self.missing, 1506 base_margin=base_margin, 1507 validate_features=validate_features, 1508 ) 1509 if isinstance(predts, dd.DataFrame): 1510 predts = predts.to_dask_array() 1511 else: 1512 test_dmatrix = await DaskDMatrix( 1513 self.client, data=data, base_margin=base_margin, missing=self.missing 1514 ) 1515 predts = await predict( 1516 self.client, 1517 model=self.get_booster(), 1518 data=test_dmatrix, 1519 output_margin=output_margin, 1520 validate_features=validate_features, 1521 iteration_range=iteration_range, 1522 ) 1523 return predts 1524 1525 def predict( 1526 self, 1527 X: _DaskCollection, 1528 output_margin: bool = False, 1529 ntree_limit: Optional[int] = None, 1530 validate_features: bool = True, 1531 base_margin: Optional[_DaskCollection] = None, 1532 iteration_range: Optional[Tuple[int, int]] = None, 1533 ) -> Any: 1534 _assert_dask_support() 1535 msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead." 1536 assert ntree_limit is None, msg 1537 return self.client.sync( 1538 self._predict_async, 1539 X, 1540 output_margin=output_margin, 1541 validate_features=validate_features, 1542 base_margin=base_margin, 1543 iteration_range=iteration_range, 1544 ) 1545 1546 async def _apply_async( 1547 self, 1548 X: _DaskCollection, 1549 iteration_range: Optional[Tuple[int, int]] = None, 1550 ) -> Any: 1551 iteration_range = self._get_iteration_range(iteration_range) 1552 test_dmatrix = await DaskDMatrix(self.client, data=X, missing=self.missing) 1553 predts = await predict( 1554 self.client, 1555 model=self.get_booster(), 1556 data=test_dmatrix, 1557 pred_leaf=True, 1558 iteration_range=iteration_range, 1559 ) 1560 return predts 1561 1562 def apply( 1563 self, 1564 X: _DaskCollection, 1565 ntree_limit: Optional[int] = None, 1566 iteration_range: Optional[Tuple[int, int]] = None, 1567 ) -> Any: 1568 _assert_dask_support() 1569 msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead." 1570 assert ntree_limit is None, msg 1571 return self.client.sync(self._apply_async, X, iteration_range=iteration_range) 1572 1573 def __await__(self) -> Awaitable[Any]: 1574 # Generate a coroutine wrapper to make this class awaitable. 1575 async def _() -> Awaitable[Any]: 1576 return self 1577 1578 return self._client_sync(_).__await__() 1579 1580 def __getstate__(self) -> Dict: 1581 this = self.__dict__.copy() 1582 if "_client" in this.keys(): 1583 del this["_client"] 1584 return this 1585 1586 @property 1587 def client(self) -> "distributed.Client": 1588 """The dask client used in this model. The `Client` object can not be serialized for 1589 transmission, so if task is launched from a worker instead of directly from the 1590 client process, this attribute needs to be set at that worker. 1591 1592 """ 1593 1594 client = _xgb_get_client(self._client) 1595 return client 1596 1597 @client.setter 1598 def client(self, clt: "distributed.Client") -> None: 1599 # calling `worker_client' doesn't return the correct `asynchronous` attribute, so 1600 # we have to pass it ourselves. 1601 self._asynchronous = clt.asynchronous if clt is not None else False 1602 self._client = clt 1603 1604 def _client_sync(self, func: Callable, **kwargs: Any) -> Any: 1605 """Get the correct client, when method is invoked inside a worker we 1606 should use `worker_client' instead of default client. 1607 1608 """ 1609 asynchronous = getattr(self, "_asynchronous", False) 1610 if self._client is None: 1611 try: 1612 distributed.get_worker() 1613 in_worker = True 1614 except ValueError: 1615 in_worker = False 1616 if in_worker: 1617 with distributed.worker_client() as client: 1618 with _set_worker_client(self, client) as this: 1619 ret = this.client.sync(func, **kwargs, asynchronous=asynchronous) 1620 return ret 1621 return ret 1622 1623 return self.client.sync(func, **kwargs, asynchronous=asynchronous) 1624 1625 1626@xgboost_model_doc( 1627 """Implementation of the Scikit-Learn API for XGBoost.""", ["estimators", "model"] 1628) 1629class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase): 1630 # pylint: disable=missing-class-docstring 1631 async def _fit_async( 1632 self, 1633 X: _DaskCollection, 1634 y: _DaskCollection, 1635 sample_weight: Optional[_DaskCollection], 1636 base_margin: Optional[_DaskCollection], 1637 eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]], 1638 eval_metric: Optional[Union[str, List[str], Metric]], 1639 sample_weight_eval_set: Optional[List[_DaskCollection]], 1640 base_margin_eval_set: Optional[List[_DaskCollection]], 1641 early_stopping_rounds: int, 1642 verbose: bool, 1643 xgb_model: Optional[Union[Booster, XGBModel]], 1644 feature_weights: Optional[_DaskCollection], 1645 callbacks: Optional[List[TrainingCallback]], 1646 ) -> _DaskCollection: 1647 params = self.get_xgb_params() 1648 dtrain, evals = await _async_wrap_evaluation_matrices( 1649 client=self.client, 1650 X=X, 1651 y=y, 1652 group=None, 1653 qid=None, 1654 sample_weight=sample_weight, 1655 base_margin=base_margin, 1656 feature_weights=feature_weights, 1657 eval_set=eval_set, 1658 sample_weight_eval_set=sample_weight_eval_set, 1659 base_margin_eval_set=base_margin_eval_set, 1660 eval_group=None, 1661 eval_qid=None, 1662 missing=self.missing, 1663 enable_categorical=self.enable_categorical, 1664 ) 1665 1666 if callable(self.objective): 1667 obj: Optional[Callable] = _objective_decorator(self.objective) 1668 else: 1669 obj = None 1670 model, metric, params = self._configure_fit( 1671 booster=xgb_model, eval_metric=eval_metric, params=params 1672 ) 1673 results = await self.client.sync( 1674 _train_async, 1675 asynchronous=True, 1676 client=self.client, 1677 global_config=config.get_config(), 1678 params=params, 1679 dtrain=dtrain, 1680 num_boost_round=self.get_num_boosting_rounds(), 1681 evals=evals, 1682 obj=obj, 1683 feval=metric, 1684 verbose_eval=verbose, 1685 early_stopping_rounds=early_stopping_rounds, 1686 callbacks=callbacks, 1687 xgb_model=model, 1688 ) 1689 self._Booster = results["booster"] 1690 self._set_evaluation_result(results["history"]) 1691 return self 1692 1693 # pylint: disable=missing-docstring, disable=unused-argument 1694 @_deprecate_positional_args 1695 def fit( 1696 self, 1697 X: _DaskCollection, 1698 y: _DaskCollection, 1699 *, 1700 sample_weight: Optional[_DaskCollection] = None, 1701 base_margin: Optional[_DaskCollection] = None, 1702 eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None, 1703 eval_metric: Optional[Union[str, List[str], Metric]] = None, 1704 early_stopping_rounds: Optional[int] = None, 1705 verbose: bool = True, 1706 xgb_model: Optional[Union[Booster, XGBModel]] = None, 1707 sample_weight_eval_set: Optional[List[_DaskCollection]] = None, 1708 base_margin_eval_set: Optional[List[_DaskCollection]] = None, 1709 feature_weights: Optional[_DaskCollection] = None, 1710 callbacks: Optional[List[TrainingCallback]] = None, 1711 ) -> "DaskXGBRegressor": 1712 _assert_dask_support() 1713 args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} 1714 return self._client_sync(self._fit_async, **args) 1715 1716 1717@xgboost_model_doc( 1718 'Implementation of the scikit-learn API for XGBoost classification.', 1719 ['estimators', 'model']) 1720class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): 1721 # pylint: disable=missing-class-docstring 1722 async def _fit_async( 1723 self, X: _DaskCollection, y: _DaskCollection, 1724 sample_weight: Optional[_DaskCollection], 1725 base_margin: Optional[_DaskCollection], 1726 eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]], 1727 eval_metric: Optional[Union[str, List[str], Metric]], 1728 sample_weight_eval_set: Optional[List[_DaskCollection]], 1729 base_margin_eval_set: Optional[List[_DaskCollection]], 1730 early_stopping_rounds: int, 1731 verbose: bool, 1732 xgb_model: Optional[Union[Booster, XGBModel]], 1733 feature_weights: Optional[_DaskCollection], 1734 callbacks: Optional[List[TrainingCallback]] 1735 ) -> "DaskXGBClassifier": 1736 params = self.get_xgb_params() 1737 dtrain, evals = await _async_wrap_evaluation_matrices( 1738 self.client, 1739 X=X, 1740 y=y, 1741 group=None, 1742 qid=None, 1743 sample_weight=sample_weight, 1744 base_margin=base_margin, 1745 feature_weights=feature_weights, 1746 eval_set=eval_set, 1747 sample_weight_eval_set=sample_weight_eval_set, 1748 base_margin_eval_set=base_margin_eval_set, 1749 eval_group=None, 1750 eval_qid=None, 1751 missing=self.missing, 1752 enable_categorical=self.enable_categorical, 1753 ) 1754 1755 # pylint: disable=attribute-defined-outside-init 1756 if isinstance(y, (da.Array)): 1757 self.classes_ = await self.client.compute(da.unique(y)) 1758 else: 1759 self.classes_ = await self.client.compute(y.drop_duplicates()) 1760 self.n_classes_ = len(self.classes_) 1761 1762 if self.n_classes_ > 2: 1763 params["objective"] = "multi:softprob" 1764 params['num_class'] = self.n_classes_ 1765 else: 1766 params["objective"] = "binary:logistic" 1767 1768 if callable(self.objective): 1769 obj: Optional[Callable] = _objective_decorator(self.objective) 1770 else: 1771 obj = None 1772 model, metric, params = self._configure_fit( 1773 booster=xgb_model, eval_metric=eval_metric, params=params 1774 ) 1775 results = await self.client.sync( 1776 _train_async, 1777 asynchronous=True, 1778 client=self.client, 1779 global_config=config.get_config(), 1780 params=params, 1781 dtrain=dtrain, 1782 num_boost_round=self.get_num_boosting_rounds(), 1783 evals=evals, 1784 obj=obj, 1785 feval=metric, 1786 verbose_eval=verbose, 1787 early_stopping_rounds=early_stopping_rounds, 1788 callbacks=callbacks, 1789 xgb_model=model, 1790 ) 1791 self._Booster = results['booster'] 1792 if not callable(self.objective): 1793 self.objective = params["objective"] 1794 self._set_evaluation_result(results["history"]) 1795 return self 1796 1797 # pylint: disable=unused-argument 1798 def fit( 1799 self, 1800 X: _DaskCollection, 1801 y: _DaskCollection, 1802 *, 1803 sample_weight: Optional[_DaskCollection] = None, 1804 base_margin: Optional[_DaskCollection] = None, 1805 eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None, 1806 eval_metric: Optional[Union[str, List[str], Metric]] = None, 1807 early_stopping_rounds: Optional[int] = None, 1808 verbose: bool = True, 1809 xgb_model: Optional[Union[Booster, XGBModel]] = None, 1810 sample_weight_eval_set: Optional[List[_DaskCollection]] = None, 1811 base_margin_eval_set: Optional[List[_DaskCollection]] = None, 1812 feature_weights: Optional[_DaskCollection] = None, 1813 callbacks: Optional[List[TrainingCallback]] = None 1814 ) -> "DaskXGBClassifier": 1815 _assert_dask_support() 1816 args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} 1817 return self._client_sync(self._fit_async, **args) 1818 1819 async def _predict_proba_async( 1820 self, 1821 X: _DaskCollection, 1822 validate_features: bool, 1823 base_margin: Optional[_DaskCollection], 1824 iteration_range: Optional[Tuple[int, int]], 1825 ) -> _DaskCollection: 1826 predts = await super()._predict_async( 1827 data=X, 1828 output_margin=self.objective == "multi:softmax", 1829 validate_features=validate_features, 1830 base_margin=base_margin, 1831 iteration_range=iteration_range, 1832 ) 1833 vstack = update_wrapper( 1834 partial(da.vstack, allow_unknown_chunksizes=True), da.vstack 1835 ) 1836 return _cls_predict_proba(getattr(self, "n_classes_", None), predts, vstack) 1837 1838 # pylint: disable=missing-function-docstring 1839 def predict_proba( 1840 self, 1841 X: _DaskCollection, 1842 ntree_limit: Optional[int] = None, 1843 validate_features: bool = True, 1844 base_margin: Optional[_DaskCollection] = None, 1845 iteration_range: Optional[Tuple[int, int]] = None, 1846 ) -> Any: 1847 _assert_dask_support() 1848 msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead." 1849 assert ntree_limit is None, msg 1850 return self._client_sync( 1851 self._predict_proba_async, 1852 X=X, 1853 validate_features=validate_features, 1854 base_margin=base_margin, 1855 iteration_range=iteration_range, 1856 ) 1857 1858 predict_proba.__doc__ = XGBClassifier.predict_proba.__doc__ 1859 1860 async def _predict_async( 1861 self, 1862 data: _DaskCollection, 1863 output_margin: bool, 1864 validate_features: bool, 1865 base_margin: Optional[_DaskCollection], 1866 iteration_range: Optional[Tuple[int, int]], 1867 ) -> _DaskCollection: 1868 pred_probs = await super()._predict_async( 1869 data, output_margin, validate_features, base_margin, iteration_range 1870 ) 1871 if output_margin: 1872 return pred_probs 1873 1874 if len(pred_probs.shape) == 1: 1875 preds = (pred_probs > 0.5).astype(int) 1876 else: 1877 assert len(pred_probs.shape) == 2 1878 assert isinstance(pred_probs, da.Array) 1879 # when using da.argmax directly, dask will construct a numpy based return 1880 # array, which runs into error when computing GPU based prediction. 1881 1882 def _argmax(x: Any) -> Any: 1883 return x.argmax(axis=1) 1884 1885 preds = da.map_blocks(_argmax, pred_probs, drop_axis=1) 1886 return preds 1887 1888 1889@xgboost_model_doc( 1890 """Implementation of the Scikit-Learn API for XGBoost Ranking. 1891 1892 .. versionadded:: 1.4.0 1893 1894""", 1895 ["estimators", "model"], 1896 end_note=""" 1897 Note 1898 ---- 1899 For dask implementation, group is not supported, use qid instead. 1900""", 1901) 1902class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn): 1903 @_deprecate_positional_args 1904 def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any): 1905 if callable(objective): 1906 raise ValueError("Custom objective function not supported by XGBRanker.") 1907 super().__init__(objective=objective, kwargs=kwargs) 1908 1909 async def _fit_async( 1910 self, 1911 X: _DaskCollection, 1912 y: _DaskCollection, 1913 group: Optional[_DaskCollection], 1914 qid: Optional[_DaskCollection], 1915 sample_weight: Optional[_DaskCollection], 1916 base_margin: Optional[_DaskCollection], 1917 eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]], 1918 sample_weight_eval_set: Optional[List[_DaskCollection]], 1919 base_margin_eval_set: Optional[List[_DaskCollection]], 1920 eval_group: Optional[List[_DaskCollection]], 1921 eval_qid: Optional[List[_DaskCollection]], 1922 eval_metric: Optional[Union[str, List[str], Metric]], 1923 early_stopping_rounds: int, 1924 verbose: bool, 1925 xgb_model: Optional[Union[XGBModel, Booster]], 1926 feature_weights: Optional[_DaskCollection], 1927 callbacks: Optional[List[TrainingCallback]], 1928 ) -> "DaskXGBRanker": 1929 msg = "Use `qid` instead of `group` on dask interface." 1930 if not (group is None and eval_group is None): 1931 raise ValueError(msg) 1932 if qid is None: 1933 raise ValueError("`qid` is required for ranking.") 1934 params = self.get_xgb_params() 1935 dtrain, evals = await _async_wrap_evaluation_matrices( 1936 self.client, 1937 X=X, 1938 y=y, 1939 group=None, 1940 qid=qid, 1941 sample_weight=sample_weight, 1942 base_margin=base_margin, 1943 feature_weights=feature_weights, 1944 eval_set=eval_set, 1945 sample_weight_eval_set=sample_weight_eval_set, 1946 base_margin_eval_set=base_margin_eval_set, 1947 eval_group=None, 1948 eval_qid=eval_qid, 1949 missing=self.missing, 1950 enable_categorical=self.enable_categorical, 1951 ) 1952 if eval_metric is not None: 1953 if callable(eval_metric): 1954 raise ValueError( 1955 "Custom evaluation metric is not yet supported for XGBRanker." 1956 ) 1957 model, metric, params = self._configure_fit( 1958 booster=xgb_model, eval_metric=eval_metric, params=params 1959 ) 1960 results = await self.client.sync( 1961 _train_async, 1962 asynchronous=True, 1963 client=self.client, 1964 global_config=config.get_config(), 1965 params=params, 1966 dtrain=dtrain, 1967 num_boost_round=self.get_num_boosting_rounds(), 1968 evals=evals, 1969 obj=None, 1970 feval=metric, 1971 verbose_eval=verbose, 1972 early_stopping_rounds=early_stopping_rounds, 1973 callbacks=callbacks, 1974 xgb_model=model, 1975 ) 1976 self._Booster = results["booster"] 1977 self.evals_result_ = results["history"] 1978 return self 1979 1980 # pylint: disable=unused-argument, arguments-differ 1981 @_deprecate_positional_args 1982 def fit( 1983 self, 1984 X: _DaskCollection, 1985 y: _DaskCollection, 1986 *, 1987 group: Optional[_DaskCollection] = None, 1988 qid: Optional[_DaskCollection] = None, 1989 sample_weight: Optional[_DaskCollection] = None, 1990 base_margin: Optional[_DaskCollection] = None, 1991 eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None, 1992 eval_group: Optional[List[_DaskCollection]] = None, 1993 eval_qid: Optional[List[_DaskCollection]] = None, 1994 eval_metric: Optional[Union[str, List[str], Metric]] = None, 1995 early_stopping_rounds: int = None, 1996 verbose: bool = False, 1997 xgb_model: Optional[Union[XGBModel, Booster]] = None, 1998 sample_weight_eval_set: Optional[List[_DaskCollection]] = None, 1999 base_margin_eval_set: Optional[List[_DaskCollection]] = None, 2000 feature_weights: Optional[_DaskCollection] = None, 2001 callbacks: Optional[List[TrainingCallback]] = None 2002 ) -> "DaskXGBRanker": 2003 _assert_dask_support() 2004 args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} 2005 return self._client_sync(self._fit_async, **args) 2006 2007 # FIXME(trivialfis): arguments differ due to additional parameters like group and qid. 2008 fit.__doc__ = XGBRanker.fit.__doc__ 2009 2010 2011@xgboost_model_doc( 2012 """Implementation of the Scikit-Learn API for XGBoost Random Forest Regressor. 2013 2014 .. versionadded:: 1.4.0 2015 2016""", 2017 ["model", "objective"], 2018 extra_parameters=""" 2019 n_estimators : int 2020 Number of trees in random forest to fit. 2021""", 2022) 2023class DaskXGBRFRegressor(DaskXGBRegressor): 2024 @_deprecate_positional_args 2025 def __init__( 2026 self, 2027 *, 2028 learning_rate: Optional[float] = 1, 2029 subsample: Optional[float] = 0.8, 2030 colsample_bynode: Optional[float] = 0.8, 2031 reg_lambda: Optional[float] = 1e-5, 2032 **kwargs: Any 2033 ) -> None: 2034 super().__init__( 2035 learning_rate=learning_rate, 2036 subsample=subsample, 2037 colsample_bynode=colsample_bynode, 2038 reg_lambda=reg_lambda, 2039 **kwargs 2040 ) 2041 2042 def get_xgb_params(self) -> Dict[str, Any]: 2043 params = super().get_xgb_params() 2044 params["num_parallel_tree"] = self.n_estimators 2045 return params 2046 2047 def get_num_boosting_rounds(self) -> int: 2048 return 1 2049 2050 # pylint: disable=unused-argument 2051 def fit( 2052 self, 2053 X: _DaskCollection, 2054 y: _DaskCollection, 2055 *, 2056 sample_weight: Optional[_DaskCollection] = None, 2057 base_margin: Optional[_DaskCollection] = None, 2058 eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None, 2059 eval_metric: Optional[Union[str, List[str], Metric]] = None, 2060 early_stopping_rounds: Optional[int] = None, 2061 verbose: bool = True, 2062 xgb_model: Optional[Union[Booster, XGBModel]] = None, 2063 sample_weight_eval_set: Optional[List[_DaskCollection]] = None, 2064 base_margin_eval_set: Optional[List[_DaskCollection]] = None, 2065 feature_weights: Optional[_DaskCollection] = None, 2066 callbacks: Optional[List[TrainingCallback]] = None 2067 ) -> "DaskXGBRFRegressor": 2068 _assert_dask_support() 2069 args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} 2070 _check_rf_callback(early_stopping_rounds, callbacks) 2071 super().fit(**args) 2072 return self 2073 2074 2075@xgboost_model_doc( 2076 """Implementation of the Scikit-Learn API for XGBoost Random Forest Classifier. 2077 2078 .. versionadded:: 1.4.0 2079 2080""", 2081 ["model", "objective"], 2082 extra_parameters=""" 2083 n_estimators : int 2084 Number of trees in random forest to fit. 2085""", 2086) 2087class DaskXGBRFClassifier(DaskXGBClassifier): 2088 @_deprecate_positional_args 2089 def __init__( 2090 self, 2091 *, 2092 learning_rate: Optional[float] = 1, 2093 subsample: Optional[float] = 0.8, 2094 colsample_bynode: Optional[float] = 0.8, 2095 reg_lambda: Optional[float] = 1e-5, 2096 **kwargs: Any 2097 ) -> None: 2098 super().__init__( 2099 learning_rate=learning_rate, 2100 subsample=subsample, 2101 colsample_bynode=colsample_bynode, 2102 reg_lambda=reg_lambda, 2103 **kwargs 2104 ) 2105 2106 def get_xgb_params(self) -> Dict[str, Any]: 2107 params = super().get_xgb_params() 2108 params["num_parallel_tree"] = self.n_estimators 2109 return params 2110 2111 def get_num_boosting_rounds(self) -> int: 2112 return 1 2113 2114 # pylint: disable=unused-argument 2115 def fit( 2116 self, 2117 X: _DaskCollection, 2118 y: _DaskCollection, 2119 *, 2120 sample_weight: Optional[_DaskCollection] = None, 2121 base_margin: Optional[_DaskCollection] = None, 2122 eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None, 2123 eval_metric: Optional[Union[str, List[str], Metric]] = None, 2124 early_stopping_rounds: Optional[int] = None, 2125 verbose: bool = True, 2126 xgb_model: Optional[Union[Booster, XGBModel]] = None, 2127 sample_weight_eval_set: Optional[List[_DaskCollection]] = None, 2128 base_margin_eval_set: Optional[List[_DaskCollection]] = None, 2129 feature_weights: Optional[_DaskCollection] = None, 2130 callbacks: Optional[List[TrainingCallback]] = None 2131 ) -> "DaskXGBRFClassifier": 2132 _assert_dask_support() 2133 args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} 2134 _check_rf_callback(early_stopping_rounds, callbacks) 2135 super().fit(**args) 2136 return self 2137