1# coding: utf-8 2# pylint: disable=invalid-name, too-many-statements, no-self-use 3# pylint: disable=too-many-arguments 4"""Training Library containing training routines.""" 5from abc import ABC 6import collections 7import os 8import pickle 9from typing import Callable, List, Optional, Union, Dict, Tuple 10import numpy 11 12from . import rabit 13from .core import EarlyStopException, CallbackEnv, Booster, XGBoostError 14from .compat import STRING_TYPES 15 16 17def _get_callback_context(env): 18 """return whether the current callback context is cv or train""" 19 if env.model is not None and env.cvfolds is None: 20 context = 'train' 21 elif env.model is None and env.cvfolds is not None: 22 context = 'cv' 23 else: 24 raise ValueError("Unexpected input with both model and cvfolds.") 25 return context 26 27 28def _fmt_metric(value, show_stdv=True): 29 """format metric string""" 30 if len(value) == 2: 31 return f"{value[0]}:{value[1]:.5f}" 32 if len(value) == 3: 33 if show_stdv: 34 return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}" 35 return f"{value[0]}:{value[1]:.5f}" 36 raise ValueError("wrong metric value", value) 37 38 39def print_evaluation(period=1, show_stdv=True): 40 """Create a callback that print evaluation result. 41 42 We print the evaluation results every **period** iterations 43 and on the first and the last iterations. 44 45 Parameters 46 ---------- 47 period : int 48 The period to log the evaluation results 49 50 show_stdv : bool, optional 51 Whether show stdv if provided 52 53 Returns 54 ------- 55 callback : function 56 A callback that print evaluation every period iterations. 57 """ 58 def callback(env): 59 """internal function""" 60 if env.rank != 0 or (not env.evaluation_result_list) or period is False or period == 0: 61 return 62 i = env.iteration 63 if i % period == 0 or i + 1 == env.begin_iteration or i + 1 == env.end_iteration: 64 msg = '\t'.join([_fmt_metric(x, show_stdv) for x in env.evaluation_result_list]) 65 rabit.tracker_print(f"{i}\t{msg}\n") 66 return callback 67 68 69def record_evaluation(eval_result): 70 """Create a call back that records the evaluation history into **eval_result**. 71 72 Parameters 73 ---------- 74 eval_result : dict 75 A dictionary to store the evaluation results. 76 77 Returns 78 ------- 79 callback : function 80 The requested callback function. 81 """ 82 if not isinstance(eval_result, dict): 83 raise TypeError('eval_result has to be a dictionary') 84 eval_result.clear() 85 86 def init(env): 87 """internal function""" 88 for k, _ in env.evaluation_result_list: 89 pos = k.index('-') 90 key = k[:pos] 91 metric = k[pos + 1:] 92 if key not in eval_result: 93 eval_result[key] = {} 94 if metric not in eval_result[key]: 95 eval_result[key][metric] = [] 96 97 def callback(env): 98 """internal function""" 99 if not eval_result: 100 init(env) 101 for k, v in env.evaluation_result_list: 102 pos = k.index('-') 103 key = k[:pos] 104 metric = k[pos + 1:] 105 eval_result[key][metric].append(v) 106 return callback 107 108 109def reset_learning_rate(learning_rates): 110 """Reset learning rate after iteration 1 111 112 NOTE: the initial learning rate will still take in-effect on first iteration. 113 114 Parameters 115 ---------- 116 learning_rates: list or function 117 List of learning rate for each boosting round 118 or a customized function that calculates eta in terms of 119 current number of round and the total number of boosting round (e.g. 120 yields learning rate decay) 121 122 * list ``l``: ``eta = l[boosting_round]`` 123 * function ``f``: ``eta = f(boosting_round, num_boost_round)`` 124 125 Returns 126 ------- 127 callback : function 128 The requested callback function. 129 """ 130 def get_learning_rate(i, n, learning_rates): 131 """helper providing the learning rate""" 132 if isinstance(learning_rates, list): 133 if len(learning_rates) != n: 134 raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.") 135 new_learning_rate = learning_rates[i] 136 else: 137 new_learning_rate = learning_rates(i, n) 138 return new_learning_rate 139 140 def callback(env): 141 """internal function""" 142 context = _get_callback_context(env) 143 144 if context == 'train': 145 bst, i, n = env.model, env.iteration, env.end_iteration 146 bst.set_param( 147 'learning_rate', get_learning_rate(i, n, learning_rates)) 148 elif context == 'cv': 149 i, n = env.iteration, env.end_iteration 150 for cvpack in env.cvfolds: 151 bst = cvpack.bst 152 bst.set_param( 153 'learning_rate', get_learning_rate(i, n, learning_rates)) 154 155 callback.before_iteration = False 156 return callback 157 158 159def early_stop(stopping_rounds, maximize=False, verbose=True): 160 """Create a callback that activates early stoppping. 161 162 Validation error needs to decrease at least 163 every **stopping_rounds** round(s) to continue training. 164 Requires at least one item in **evals**. 165 If there's more than one, will use the last. 166 Returns the model from the last iteration (not the best one). 167 If early stopping occurs, the model will have three additional fields: 168 ``bst.best_score``, ``bst.best_iteration``. 169 170 Parameters 171 ---------- 172 stopping_rounds : int 173 The stopping rounds before the trend occur. 174 175 maximize : bool 176 Whether to maximize evaluation metric. 177 178 verbose : optional, bool 179 Whether to print message about early stopping information. 180 181 Returns 182 ------- 183 callback : function 184 The requested callback function. 185 """ 186 state = {} 187 188 def init(env): 189 """internal function""" 190 bst = env.model 191 192 if not env.evaluation_result_list: 193 raise ValueError('For early stopping you need at least one set in evals.') 194 if len(env.evaluation_result_list) > 1 and verbose: 195 msg = ("Multiple eval metrics have been passed: " 196 "'{0}' will be used for early stopping.\n\n") 197 rabit.tracker_print(msg.format(env.evaluation_result_list[-1][0])) 198 maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg') 199 maximize_at_n_metrics = ('auc@', 'aucpr@', 'map@', 'ndcg@') 200 maximize_score = maximize 201 metric_label = env.evaluation_result_list[-1][0] 202 metric = metric_label.split('-', 1)[-1] 203 204 if any(metric.startswith(x) for x in maximize_at_n_metrics): 205 maximize_score = True 206 207 if any(metric.split(":")[0] == x for x in maximize_metrics): 208 maximize_score = True 209 210 if verbose and env.rank == 0: 211 msg = "Will train until {} hasn't improved in {} rounds.\n" 212 rabit.tracker_print(msg.format(metric_label, stopping_rounds)) 213 214 state['maximize_score'] = maximize_score 215 state['best_iteration'] = 0 216 if maximize_score: 217 state['best_score'] = float('-inf') 218 else: 219 state['best_score'] = float('inf') 220 # pylint: disable=consider-using-f-string 221 msg = '[%d]\t%s' % ( 222 env.iteration, 223 '\t'.join([_fmt_metric(x) for x in env.evaluation_result_list]) 224 ) 225 state['best_msg'] = msg 226 227 if bst is not None: 228 if bst.attr('best_score') is not None: 229 state['best_score'] = float(bst.attr('best_score')) 230 state['best_iteration'] = int(bst.attr('best_iteration')) 231 state['best_msg'] = bst.attr('best_msg') 232 else: 233 bst.set_attr(best_iteration=str(state['best_iteration'])) 234 bst.set_attr(best_score=str(state['best_score'])) 235 else: 236 assert env.cvfolds is not None 237 238 def callback(env): 239 """internal function""" 240 if not state: 241 init(env) 242 score = env.evaluation_result_list[-1][1] 243 best_score = state['best_score'] 244 best_iteration = state['best_iteration'] 245 maximize_score = state['maximize_score'] 246 if (maximize_score and score > best_score) or \ 247 (not maximize_score and score < best_score): 248 # pylint: disable=consider-using-f-string 249 msg = '[%d]\t%s' % ( 250 env.iteration, 251 '\t'.join([_fmt_metric(x) for x in env.evaluation_result_list])) 252 state['best_msg'] = msg 253 state['best_score'] = score 254 state['best_iteration'] = env.iteration 255 # save the property to attributes, so they will occur in checkpoint. 256 if env.model is not None: 257 env.model.set_attr(best_score=str(state['best_score']), 258 best_iteration=str(state['best_iteration']), 259 best_msg=state['best_msg']) 260 elif env.iteration - best_iteration >= stopping_rounds: 261 best_msg = state['best_msg'] 262 if verbose and env.rank == 0: 263 msg = "Stopping. Best iteration:\n{}\n\n" 264 rabit.tracker_print(msg.format(best_msg)) 265 raise EarlyStopException(best_iteration) 266 return callback 267 268 269# The new implementation of callback functions. 270# Breaking: 271# - reset learning rate no longer accepts total boosting rounds 272 273# pylint: disable=unused-argument 274class TrainingCallback(ABC): 275 '''Interface for training callback. 276 277 .. versionadded:: 1.3.0 278 279 ''' 280 281 EvalsLog = Dict[str, Dict[str, Union[List[float], List[Tuple[float, float]]]]] 282 283 def __init__(self): 284 pass 285 286 def before_training(self, model): 287 '''Run before training starts.''' 288 return model 289 290 def after_training(self, model): 291 '''Run after training is finished.''' 292 return model 293 294 def before_iteration(self, model, epoch: int, evals_log: EvalsLog) -> bool: 295 '''Run before each iteration. Return True when training should stop.''' 296 return False 297 298 def after_iteration(self, model, epoch: int, evals_log: EvalsLog) -> bool: 299 '''Run after each iteration. Return True when training should stop.''' 300 return False 301 302 303def _aggcv(rlist): 304 # pylint: disable=invalid-name 305 """Aggregate cross-validation results. 306 307 """ 308 cvmap = {} 309 idx = rlist[0].split()[0] 310 for line in rlist: 311 arr = line.split() 312 assert idx == arr[0] 313 for metric_idx, it in enumerate(arr[1:]): 314 if not isinstance(it, STRING_TYPES): 315 it = it.decode() 316 k, v = it.split(':') 317 if (metric_idx, k) not in cvmap: 318 cvmap[(metric_idx, k)] = [] 319 cvmap[(metric_idx, k)].append(float(v)) 320 msg = idx 321 results = [] 322 for (metric_idx, k), v in sorted(cvmap.items(), key=lambda x: x[0][0]): 323 v = numpy.array(v) 324 if not isinstance(msg, STRING_TYPES): 325 msg = msg.decode() 326 mean, std = numpy.mean(v), numpy.std(v) 327 results.extend([(k, mean, std)]) 328 return results 329 330 331def _allreduce_metric(score): 332 '''Helper function for computing customized metric in distributed 333 environment. Not strictly correct as many functions don't use mean value 334 as final result. 335 336 ''' 337 world = rabit.get_world_size() 338 assert world != 0 339 if world == 1: 340 return score 341 if isinstance(score, tuple): # has mean and stdv 342 raise ValueError( 343 'xgboost.cv function should not be used in distributed environment.') 344 score = numpy.array([score]) 345 score = rabit.allreduce(score, rabit.Op.SUM) / world 346 return score[0] 347 348 349class CallbackContainer: 350 '''A special callback for invoking a list of other callbacks. 351 352 .. versionadded:: 1.3.0 353 354 ''' 355 356 EvalsLog = TrainingCallback.EvalsLog 357 358 def __init__(self, 359 callbacks: List[TrainingCallback], 360 metric: Callable = None, 361 is_cv: bool = False): 362 self.callbacks = set(callbacks) 363 if metric is not None: 364 msg = 'metric must be callable object for monitoring. For ' + \ 365 'builtin metrics, passing them in training parameter' + \ 366 ' will invoke monitor automatically.' 367 assert callable(metric), msg 368 self.metric = metric 369 self.history: TrainingCallback.EvalsLog = collections.OrderedDict() 370 self.is_cv = is_cv 371 372 if self.is_cv: 373 self.aggregated_cv = None 374 375 def before_training(self, model): 376 '''Function called before training.''' 377 for c in self.callbacks: 378 model = c.before_training(model=model) 379 msg = 'before_training should return the model' 380 if self.is_cv: 381 assert isinstance(model.cvfolds, list), msg 382 else: 383 assert isinstance(model, Booster), msg 384 return model 385 386 def after_training(self, model): 387 '''Function called after training.''' 388 for c in self.callbacks: 389 model = c.after_training(model=model) 390 msg = 'after_training should return the model' 391 if self.is_cv: 392 assert isinstance(model.cvfolds, list), msg 393 else: 394 assert isinstance(model, Booster), msg 395 return model 396 397 def before_iteration(self, model, epoch, dtrain, evals) -> bool: 398 '''Function called before training iteration.''' 399 return any(c.before_iteration(model, epoch, self.history) 400 for c in self.callbacks) 401 402 def _update_history(self, score, epoch): 403 for d in score: 404 name, s = d[0], float(d[1]) 405 if self.is_cv: 406 std = float(d[2]) 407 s = (s, std) 408 splited_names = name.split('-') 409 data_name = splited_names[0] 410 metric_name = '-'.join(splited_names[1:]) 411 s = _allreduce_metric(s) 412 if data_name in self.history: 413 data_history = self.history[data_name] 414 if metric_name in data_history: 415 data_history[metric_name].append(s) 416 else: 417 data_history[metric_name] = [s] 418 else: 419 self.history[data_name] = collections.OrderedDict() 420 self.history[data_name][metric_name] = [s] 421 return False 422 423 def after_iteration(self, model, epoch, dtrain, evals) -> bool: 424 '''Function called after training iteration.''' 425 if self.is_cv: 426 scores = model.eval(epoch, self.metric) 427 scores = _aggcv(scores) 428 self.aggregated_cv = scores 429 self._update_history(scores, epoch) 430 else: 431 evals = [] if evals is None else evals 432 for _, name in evals: 433 assert name.find('-') == -1, 'Dataset name should not contain `-`' 434 score = model.eval_set(evals, epoch, self.metric) 435 score = score.split()[1:] # into datasets 436 # split up `test-error:0.1234` 437 score = [tuple(s.split(':')) for s in score] 438 self._update_history(score, epoch) 439 ret = any(c.after_iteration(model, epoch, self.history) 440 for c in self.callbacks) 441 return ret 442 443 444class LearningRateScheduler(TrainingCallback): 445 '''Callback function for scheduling learning rate. 446 447 .. versionadded:: 1.3.0 448 449 Parameters 450 ---------- 451 452 learning_rates : callable/collections.Sequence 453 If it's a callable object, then it should accept an integer parameter 454 `epoch` and returns the corresponding learning rate. Otherwise it 455 should be a sequence like list or tuple with the same size of boosting 456 rounds. 457 458 ''' 459 def __init__(self, learning_rates) -> None: 460 assert callable(learning_rates) or \ 461 isinstance(learning_rates, collections.abc.Sequence) 462 if callable(learning_rates): 463 self.learning_rates = learning_rates 464 else: 465 self.learning_rates = lambda epoch: learning_rates[epoch] 466 super().__init__() 467 468 def after_iteration(self, model, epoch, evals_log) -> bool: 469 model.set_param('learning_rate', self.learning_rates(epoch)) 470 return False 471 472 473# pylint: disable=too-many-instance-attributes 474class EarlyStopping(TrainingCallback): 475 """Callback function for early stopping 476 477 .. versionadded:: 1.3.0 478 479 Parameters 480 ---------- 481 rounds 482 Early stopping rounds. 483 metric_name 484 Name of metric that is used for early stopping. 485 data_name 486 Name of dataset that is used for early stopping. 487 maximize 488 Whether to maximize evaluation metric. None means auto (discouraged). 489 save_best 490 Whether training should return the best model or the last model. 491 min_delta 492 Minimum absolute change in score to be qualified as an improvement. 493 494 .. versionadded:: 1.5.0 495 496 .. code-block:: python 497 498 clf = xgboost.XGBClassifier(tree_method="gpu_hist") 499 es = xgboost.callback.EarlyStopping( 500 rounds=2, 501 abs_tol=1e-3, 502 save_best=True, 503 maximize=False, 504 data_name="validation_0", 505 metric_name="mlogloss", 506 ) 507 508 X, y = load_digits(return_X_y=True) 509 clf.fit(X, y, eval_set=[(X, y)], callbacks=[es]) 510 """ 511 def __init__( 512 self, 513 rounds: int, 514 metric_name: Optional[str] = None, 515 data_name: Optional[str] = None, 516 maximize: Optional[bool] = None, 517 save_best: Optional[bool] = False, 518 min_delta: float = 0.0 519 ) -> None: 520 self.data = data_name 521 self.metric_name = metric_name 522 self.rounds = rounds 523 self.save_best = save_best 524 self.maximize = maximize 525 self.stopping_history: TrainingCallback.EvalsLog = {} 526 self._min_delta = min_delta 527 if self._min_delta < 0: 528 raise ValueError("min_delta must be greater or equal to 0.") 529 530 self.improve_op = None 531 532 self.current_rounds: int = 0 533 self.best_scores: dict = {} 534 self.starting_round: int = 0 535 super().__init__() 536 537 def before_training(self, model): 538 self.starting_round = model.num_boosted_rounds() 539 return model 540 541 def _update_rounds(self, score, name, metric, model, epoch) -> bool: 542 def get_s(x): 543 """get score if it's cross validation history.""" 544 return x[0] if isinstance(x, tuple) else x 545 546 def maximize(new, best): 547 """New score should be greater than the old one.""" 548 return numpy.greater(get_s(new) - self._min_delta, get_s(best)) 549 550 def minimize(new, best): 551 """New score should be smaller than the old one.""" 552 return numpy.greater(get_s(best) - self._min_delta, get_s(new)) 553 554 if self.maximize is None: 555 # Just to be compatibility with old behavior before 1.3. We should let 556 # user to decide. 557 maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg', 'auc@', 558 'aucpr@', 'map@', 'ndcg@') 559 if metric != 'mape' and any(metric.startswith(x) for x in maximize_metrics): 560 self.maximize = True 561 else: 562 self.maximize = False 563 564 if self.maximize: 565 self.improve_op = maximize 566 else: 567 self.improve_op = minimize 568 569 assert self.improve_op 570 571 if not self.stopping_history: # First round 572 self.current_rounds = 0 573 self.stopping_history[name] = {} 574 self.stopping_history[name][metric] = [score] 575 self.best_scores[name] = {} 576 self.best_scores[name][metric] = [score] 577 model.set_attr(best_score=str(score), best_iteration=str(epoch)) 578 elif not self.improve_op(score, self.best_scores[name][metric][-1]): 579 # Not improved 580 self.stopping_history[name][metric].append(score) 581 self.current_rounds += 1 582 else: # Improved 583 self.stopping_history[name][metric].append(score) 584 self.best_scores[name][metric].append(score) 585 record = self.stopping_history[name][metric][-1] 586 model.set_attr(best_score=str(record), best_iteration=str(epoch)) 587 self.current_rounds = 0 # reset 588 589 if self.current_rounds >= self.rounds: 590 # Should stop 591 return True 592 return False 593 594 def after_iteration(self, model, epoch: int, 595 evals_log: TrainingCallback.EvalsLog) -> bool: 596 epoch += self.starting_round # training continuation 597 msg = 'Must have at least 1 validation dataset for early stopping.' 598 assert len(evals_log.keys()) >= 1, msg 599 data_name = '' 600 if self.data: 601 for d, _ in evals_log.items(): 602 if d == self.data: 603 data_name = d 604 if not data_name: 605 raise ValueError('No dataset named:', self.data) 606 else: 607 # Use the last one as default. 608 data_name = list(evals_log.keys())[-1] 609 assert isinstance(data_name, str) and data_name 610 data_log = evals_log[data_name] 611 612 # Filter out scores that can not be used for early stopping. 613 if self.metric_name: 614 metric_name = self.metric_name 615 else: 616 # Use last metric by default. 617 assert isinstance(data_log, collections.OrderedDict) 618 metric_name = list(data_log.keys())[-1] 619 score = data_log[metric_name][-1] 620 return self._update_rounds(score, data_name, metric_name, model, epoch) 621 622 def after_training(self, model): 623 try: 624 if self.save_best: 625 model = model[: int(model.attr("best_iteration")) + 1] 626 except XGBoostError as e: 627 raise XGBoostError( 628 "`save_best` is not applicable to current booster" 629 ) from e 630 return model 631 632 633class EvaluationMonitor(TrainingCallback): 634 '''Print the evaluation result at each iteration. 635 636 .. versionadded:: 1.3.0 637 638 Parameters 639 ---------- 640 641 metric : callable 642 Extra user defined metric. 643 rank : int 644 Which worker should be used for printing the result. 645 period : int 646 How many epoches between printing. 647 show_stdv : bool 648 Used in cv to show standard deviation. Users should not specify it. 649 ''' 650 def __init__(self, rank=0, period=1, show_stdv=False) -> None: 651 self.printer_rank = rank 652 self.show_stdv = show_stdv 653 self.period = period 654 assert period > 0 655 # last error message, useful when early stopping and period are used together. 656 self._latest: Optional[str] = None 657 super().__init__() 658 659 def _fmt_metric( 660 self, data: str, metric: str, score: float, std: Optional[float] 661 ) -> str: 662 if std is not None and self.show_stdv: 663 msg = f"\t{data + '-' + metric}:{score:.5f}+{std:.5f}" 664 else: 665 msg = f"\t{data + '-' + metric}:{score:.5f}" 666 return msg 667 668 def after_iteration(self, model, epoch: int, 669 evals_log: TrainingCallback.EvalsLog) -> bool: 670 if not evals_log: 671 return False 672 673 msg: str = f'[{epoch}]' 674 if rabit.get_rank() == self.printer_rank: 675 for data, metric in evals_log.items(): 676 for metric_name, log in metric.items(): 677 stdv: Optional[float] = None 678 if isinstance(log[-1], tuple): 679 score = log[-1][0] 680 stdv = log[-1][1] 681 else: 682 score = log[-1] 683 msg += self._fmt_metric(data, metric_name, score, stdv) 684 msg += '\n' 685 686 if (epoch % self.period) == 0 or self.period == 1: 687 rabit.tracker_print(msg) 688 self._latest = None 689 else: 690 # There is skipped message 691 self._latest = msg 692 return False 693 694 def after_training(self, model): 695 if rabit.get_rank() == self.printer_rank and self._latest is not None: 696 rabit.tracker_print(self._latest) 697 return model 698 699 700class TrainingCheckPoint(TrainingCallback): 701 '''Checkpointing operation. 702 703 .. versionadded:: 1.3.0 704 705 Parameters 706 ---------- 707 708 directory : os.PathLike 709 Output model directory. 710 name : str 711 pattern of output model file. Models will be saved as name_0.json, name_1.json, 712 name_2.json .... 713 as_pickle : boolean 714 When set to Ture, all training parameters will be saved in pickle format, instead 715 of saving only the model. 716 iterations : int 717 Interval of checkpointing. Checkpointing is slow so setting a larger number can 718 reduce performance hit. 719 720 ''' 721 def __init__(self, directory: os.PathLike, name: str = 'model', 722 as_pickle=False, iterations: int = 100): 723 self._path = directory 724 self._name = name 725 self._as_pickle = as_pickle 726 self._iterations = iterations 727 self._epoch = 0 728 super().__init__() 729 730 def after_iteration(self, model, epoch: int, 731 evals_log: TrainingCallback.EvalsLog) -> bool: 732 if self._epoch == self._iterations: 733 path = os.path.join(self._path, self._name + '_' + str(epoch) + 734 ('.pkl' if self._as_pickle else '.json')) 735 self._epoch = 0 736 if rabit.get_rank() == 0: 737 if self._as_pickle: 738 with open(path, 'wb') as fd: 739 pickle.dump(model, fd) 740 else: 741 model.save_model(path) 742 self._epoch += 1 743 return False 744 745 746class LegacyCallbacks: 747 '''Adapter for legacy callback functions. 748 749 .. versionadded:: 1.3.0 750 751 Parameters 752 ---------- 753 754 callbacks : Sequence 755 A sequence of legacy callbacks (callbacks that are not instance of 756 TrainingCallback) 757 start_iteration : int 758 Begining iteration. 759 end_iteration : int 760 End iteration, normally is the number of boosting rounds. 761 evals : Sequence 762 Sequence of evaluation dataset tuples. 763 feval : Custom evaluation metric. 764 ''' 765 def __init__(self, callbacks, start_iteration, end_iteration, 766 feval, cvfolds=None): 767 self.callbacks_before_iter = [ 768 cb for cb in callbacks 769 if cb.__dict__.get('before_iteration', False)] 770 self.callbacks_after_iter = [ 771 cb for cb in callbacks 772 if not cb.__dict__.get('before_iteration', False)] 773 774 self.start_iteration = start_iteration 775 self.end_iteration = end_iteration 776 self.cvfolds = cvfolds 777 778 self.feval = feval 779 assert self.feval is None or callable(self.feval) 780 781 if cvfolds is not None: 782 self.aggregated_cv = None 783 784 super().__init__() 785 786 def before_training(self, model): 787 '''Nothing to do for legacy callbacks''' 788 return model 789 790 def after_training(self, model): 791 '''Nothing to do for legacy callbacks''' 792 return model 793 794 def before_iteration(self, model, epoch, dtrain, evals): 795 '''Called before each iteration.''' 796 for cb in self.callbacks_before_iter: 797 rank = rabit.get_rank() 798 cb(CallbackEnv(model=None if self.cvfolds is not None else model, 799 cvfolds=self.cvfolds, 800 iteration=epoch, 801 begin_iteration=self.start_iteration, 802 end_iteration=self.end_iteration, 803 rank=rank, 804 evaluation_result_list=None)) 805 return False 806 807 def after_iteration(self, model, epoch, dtrain, evals): 808 '''Called after each iteration.''' 809 evaluation_result_list = [] 810 if self.cvfolds is not None: 811 # dtrain is not used here. 812 scores = model.eval(epoch, self.feval) 813 self.aggregated_cv = _aggcv(scores) 814 evaluation_result_list = self.aggregated_cv 815 816 if evals: 817 # When cv is used, evals are embedded into folds. 818 assert self.cvfolds is None 819 bst_eval_set = model.eval_set(evals, epoch, self.feval) 820 if isinstance(bst_eval_set, STRING_TYPES): 821 msg = bst_eval_set 822 else: 823 msg = bst_eval_set.decode() 824 res = [x.split(':') for x in msg.split()] 825 evaluation_result_list = [(k, float(v)) for k, v in res[1:]] 826 827 try: 828 for cb in self.callbacks_after_iter: 829 rank = rabit.get_rank() 830 cb(CallbackEnv(model=None if self.cvfolds is not None else model, 831 cvfolds=self.cvfolds, 832 iteration=epoch, 833 begin_iteration=self.start_iteration, 834 end_iteration=self.end_iteration, 835 rank=rank, 836 evaluation_result_list=evaluation_result_list)) 837 except EarlyStopException: 838 return True 839 840 return False 841