1# coding: utf-8
2# pylint: disable=invalid-name, too-many-statements, no-self-use
3# pylint: disable=too-many-arguments
4"""Training Library containing training routines."""
5from abc import ABC
6import collections
7import os
8import pickle
9from typing import Callable, List, Optional, Union, Dict, Tuple
10import numpy
11
12from . import rabit
13from .core import EarlyStopException, CallbackEnv, Booster, XGBoostError
14from .compat import STRING_TYPES
15
16
17def _get_callback_context(env):
18    """return whether the current callback context is cv or train"""
19    if env.model is not None and env.cvfolds is None:
20        context = 'train'
21    elif env.model is None and env.cvfolds is not None:
22        context = 'cv'
23    else:
24        raise ValueError("Unexpected input with both model and cvfolds.")
25    return context
26
27
28def _fmt_metric(value, show_stdv=True):
29    """format metric string"""
30    if len(value) == 2:
31        return f"{value[0]}:{value[1]:.5f}"
32    if len(value) == 3:
33        if show_stdv:
34            return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}"
35        return f"{value[0]}:{value[1]:.5f}"
36    raise ValueError("wrong metric value", value)
37
38
39def print_evaluation(period=1, show_stdv=True):
40    """Create a callback that print evaluation result.
41
42    We print the evaluation results every **period** iterations
43    and on the first and the last iterations.
44
45    Parameters
46    ----------
47    period : int
48        The period to log the evaluation results
49
50    show_stdv : bool, optional
51         Whether show stdv if provided
52
53    Returns
54    -------
55    callback : function
56        A callback that print evaluation every period iterations.
57    """
58    def callback(env):
59        """internal function"""
60        if env.rank != 0 or (not env.evaluation_result_list) or period is False or period == 0:
61            return
62        i = env.iteration
63        if i % period == 0 or i + 1 == env.begin_iteration or i + 1 == env.end_iteration:
64            msg = '\t'.join([_fmt_metric(x, show_stdv) for x in env.evaluation_result_list])
65            rabit.tracker_print(f"{i}\t{msg}\n")
66    return callback
67
68
69def record_evaluation(eval_result):
70    """Create a call back that records the evaluation history into **eval_result**.
71
72    Parameters
73    ----------
74    eval_result : dict
75       A dictionary to store the evaluation results.
76
77    Returns
78    -------
79    callback : function
80        The requested callback function.
81    """
82    if not isinstance(eval_result, dict):
83        raise TypeError('eval_result has to be a dictionary')
84    eval_result.clear()
85
86    def init(env):
87        """internal function"""
88        for k, _ in env.evaluation_result_list:
89            pos = k.index('-')
90            key = k[:pos]
91            metric = k[pos + 1:]
92            if key not in eval_result:
93                eval_result[key] = {}
94            if metric not in eval_result[key]:
95                eval_result[key][metric] = []
96
97    def callback(env):
98        """internal function"""
99        if not eval_result:
100            init(env)
101        for k, v in env.evaluation_result_list:
102            pos = k.index('-')
103            key = k[:pos]
104            metric = k[pos + 1:]
105            eval_result[key][metric].append(v)
106    return callback
107
108
109def reset_learning_rate(learning_rates):
110    """Reset learning rate after iteration 1
111
112    NOTE: the initial learning rate will still take in-effect on first iteration.
113
114    Parameters
115    ----------
116    learning_rates: list or function
117        List of learning rate for each boosting round
118        or a customized function that calculates eta in terms of
119        current number of round and the total number of boosting round (e.g.
120        yields learning rate decay)
121
122        * list ``l``: ``eta = l[boosting_round]``
123        * function ``f``: ``eta = f(boosting_round, num_boost_round)``
124
125    Returns
126    -------
127    callback : function
128        The requested callback function.
129    """
130    def get_learning_rate(i, n, learning_rates):
131        """helper providing the learning rate"""
132        if isinstance(learning_rates, list):
133            if len(learning_rates) != n:
134                raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
135            new_learning_rate = learning_rates[i]
136        else:
137            new_learning_rate = learning_rates(i, n)
138        return new_learning_rate
139
140    def callback(env):
141        """internal function"""
142        context = _get_callback_context(env)
143
144        if context == 'train':
145            bst, i, n = env.model, env.iteration, env.end_iteration
146            bst.set_param(
147                'learning_rate', get_learning_rate(i, n, learning_rates))
148        elif context == 'cv':
149            i, n = env.iteration, env.end_iteration
150            for cvpack in env.cvfolds:
151                bst = cvpack.bst
152                bst.set_param(
153                    'learning_rate', get_learning_rate(i, n, learning_rates))
154
155    callback.before_iteration = False
156    return callback
157
158
159def early_stop(stopping_rounds, maximize=False, verbose=True):
160    """Create a callback that activates early stoppping.
161
162    Validation error needs to decrease at least
163    every **stopping_rounds** round(s) to continue training.
164    Requires at least one item in **evals**.
165    If there's more than one, will use the last.
166    Returns the model from the last iteration (not the best one).
167    If early stopping occurs, the model will have three additional fields:
168    ``bst.best_score``, ``bst.best_iteration``.
169
170    Parameters
171    ----------
172    stopping_rounds : int
173       The stopping rounds before the trend occur.
174
175    maximize : bool
176        Whether to maximize evaluation metric.
177
178    verbose : optional, bool
179        Whether to print message about early stopping information.
180
181    Returns
182    -------
183    callback : function
184        The requested callback function.
185    """
186    state = {}
187
188    def init(env):
189        """internal function"""
190        bst = env.model
191
192        if not env.evaluation_result_list:
193            raise ValueError('For early stopping you need at least one set in evals.')
194        if len(env.evaluation_result_list) > 1 and verbose:
195            msg = ("Multiple eval metrics have been passed: "
196                   "'{0}' will be used for early stopping.\n\n")
197            rabit.tracker_print(msg.format(env.evaluation_result_list[-1][0]))
198        maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg')
199        maximize_at_n_metrics = ('auc@', 'aucpr@', 'map@', 'ndcg@')
200        maximize_score = maximize
201        metric_label = env.evaluation_result_list[-1][0]
202        metric = metric_label.split('-', 1)[-1]
203
204        if any(metric.startswith(x) for x in maximize_at_n_metrics):
205            maximize_score = True
206
207        if any(metric.split(":")[0] == x for x in maximize_metrics):
208            maximize_score = True
209
210        if verbose and env.rank == 0:
211            msg = "Will train until {} hasn't improved in {} rounds.\n"
212            rabit.tracker_print(msg.format(metric_label, stopping_rounds))
213
214        state['maximize_score'] = maximize_score
215        state['best_iteration'] = 0
216        if maximize_score:
217            state['best_score'] = float('-inf')
218        else:
219            state['best_score'] = float('inf')
220        # pylint: disable=consider-using-f-string
221        msg = '[%d]\t%s' % (
222            env.iteration,
223            '\t'.join([_fmt_metric(x) for x in env.evaluation_result_list])
224        )
225        state['best_msg'] = msg
226
227        if bst is not None:
228            if bst.attr('best_score') is not None:
229                state['best_score'] = float(bst.attr('best_score'))
230                state['best_iteration'] = int(bst.attr('best_iteration'))
231                state['best_msg'] = bst.attr('best_msg')
232            else:
233                bst.set_attr(best_iteration=str(state['best_iteration']))
234                bst.set_attr(best_score=str(state['best_score']))
235        else:
236            assert env.cvfolds is not None
237
238    def callback(env):
239        """internal function"""
240        if not state:
241            init(env)
242        score = env.evaluation_result_list[-1][1]
243        best_score = state['best_score']
244        best_iteration = state['best_iteration']
245        maximize_score = state['maximize_score']
246        if (maximize_score and score > best_score) or \
247                (not maximize_score and score < best_score):
248            # pylint: disable=consider-using-f-string
249            msg = '[%d]\t%s' % (
250                env.iteration,
251                '\t'.join([_fmt_metric(x) for x in env.evaluation_result_list]))
252            state['best_msg'] = msg
253            state['best_score'] = score
254            state['best_iteration'] = env.iteration
255            # save the property to attributes, so they will occur in checkpoint.
256            if env.model is not None:
257                env.model.set_attr(best_score=str(state['best_score']),
258                                   best_iteration=str(state['best_iteration']),
259                                   best_msg=state['best_msg'])
260        elif env.iteration - best_iteration >= stopping_rounds:
261            best_msg = state['best_msg']
262            if verbose and env.rank == 0:
263                msg = "Stopping. Best iteration:\n{}\n\n"
264                rabit.tracker_print(msg.format(best_msg))
265            raise EarlyStopException(best_iteration)
266    return callback
267
268
269# The new implementation of callback functions.
270# Breaking:
271# - reset learning rate no longer accepts total boosting rounds
272
273# pylint: disable=unused-argument
274class TrainingCallback(ABC):
275    '''Interface for training callback.
276
277    .. versionadded:: 1.3.0
278
279    '''
280
281    EvalsLog = Dict[str, Dict[str, Union[List[float], List[Tuple[float, float]]]]]
282
283    def __init__(self):
284        pass
285
286    def before_training(self, model):
287        '''Run before training starts.'''
288        return model
289
290    def after_training(self, model):
291        '''Run after training is finished.'''
292        return model
293
294    def before_iteration(self, model, epoch: int, evals_log: EvalsLog) -> bool:
295        '''Run before each iteration.  Return True when training should stop.'''
296        return False
297
298    def after_iteration(self, model, epoch: int, evals_log: EvalsLog) -> bool:
299        '''Run after each iteration.  Return True when training should stop.'''
300        return False
301
302
303def _aggcv(rlist):
304    # pylint: disable=invalid-name
305    """Aggregate cross-validation results.
306
307    """
308    cvmap = {}
309    idx = rlist[0].split()[0]
310    for line in rlist:
311        arr = line.split()
312        assert idx == arr[0]
313        for metric_idx, it in enumerate(arr[1:]):
314            if not isinstance(it, STRING_TYPES):
315                it = it.decode()
316            k, v = it.split(':')
317            if (metric_idx, k) not in cvmap:
318                cvmap[(metric_idx, k)] = []
319            cvmap[(metric_idx, k)].append(float(v))
320    msg = idx
321    results = []
322    for (metric_idx, k), v in sorted(cvmap.items(), key=lambda x: x[0][0]):
323        v = numpy.array(v)
324        if not isinstance(msg, STRING_TYPES):
325            msg = msg.decode()
326        mean, std = numpy.mean(v), numpy.std(v)
327        results.extend([(k, mean, std)])
328    return results
329
330
331def _allreduce_metric(score):
332    '''Helper function for computing customized metric in distributed
333    environment.  Not strictly correct as many functions don't use mean value
334    as final result.
335
336    '''
337    world = rabit.get_world_size()
338    assert world != 0
339    if world == 1:
340        return score
341    if isinstance(score, tuple):  # has mean and stdv
342        raise ValueError(
343            'xgboost.cv function should not be used in distributed environment.')
344    score = numpy.array([score])
345    score = rabit.allreduce(score, rabit.Op.SUM) / world
346    return score[0]
347
348
349class CallbackContainer:
350    '''A special callback for invoking a list of other callbacks.
351
352    .. versionadded:: 1.3.0
353
354    '''
355
356    EvalsLog = TrainingCallback.EvalsLog
357
358    def __init__(self,
359                 callbacks: List[TrainingCallback],
360                 metric: Callable = None,
361                 is_cv: bool = False):
362        self.callbacks = set(callbacks)
363        if metric is not None:
364            msg = 'metric must be callable object for monitoring.  For ' + \
365                'builtin metrics, passing them in training parameter' + \
366                ' will invoke monitor automatically.'
367            assert callable(metric), msg
368        self.metric = metric
369        self.history: TrainingCallback.EvalsLog = collections.OrderedDict()
370        self.is_cv = is_cv
371
372        if self.is_cv:
373            self.aggregated_cv = None
374
375    def before_training(self, model):
376        '''Function called before training.'''
377        for c in self.callbacks:
378            model = c.before_training(model=model)
379            msg = 'before_training should return the model'
380            if self.is_cv:
381                assert isinstance(model.cvfolds, list), msg
382            else:
383                assert isinstance(model, Booster), msg
384        return model
385
386    def after_training(self, model):
387        '''Function called after training.'''
388        for c in self.callbacks:
389            model = c.after_training(model=model)
390            msg = 'after_training should return the model'
391            if self.is_cv:
392                assert isinstance(model.cvfolds, list), msg
393            else:
394                assert isinstance(model, Booster), msg
395        return model
396
397    def before_iteration(self, model, epoch, dtrain, evals) -> bool:
398        '''Function called before training iteration.'''
399        return any(c.before_iteration(model, epoch, self.history)
400                   for c in self.callbacks)
401
402    def _update_history(self, score, epoch):
403        for d in score:
404            name, s = d[0], float(d[1])
405            if self.is_cv:
406                std = float(d[2])
407                s = (s, std)
408            splited_names = name.split('-')
409            data_name = splited_names[0]
410            metric_name = '-'.join(splited_names[1:])
411            s = _allreduce_metric(s)
412            if data_name in self.history:
413                data_history = self.history[data_name]
414                if metric_name in data_history:
415                    data_history[metric_name].append(s)
416                else:
417                    data_history[metric_name] = [s]
418            else:
419                self.history[data_name] = collections.OrderedDict()
420                self.history[data_name][metric_name] = [s]
421        return False
422
423    def after_iteration(self, model, epoch, dtrain, evals) -> bool:
424        '''Function called after training iteration.'''
425        if self.is_cv:
426            scores = model.eval(epoch, self.metric)
427            scores = _aggcv(scores)
428            self.aggregated_cv = scores
429            self._update_history(scores, epoch)
430        else:
431            evals = [] if evals is None else evals
432            for _, name in evals:
433                assert name.find('-') == -1, 'Dataset name should not contain `-`'
434            score = model.eval_set(evals, epoch, self.metric)
435            score = score.split()[1:]  # into datasets
436            # split up `test-error:0.1234`
437            score = [tuple(s.split(':')) for s in score]
438            self._update_history(score, epoch)
439        ret = any(c.after_iteration(model, epoch, self.history)
440                  for c in self.callbacks)
441        return ret
442
443
444class LearningRateScheduler(TrainingCallback):
445    '''Callback function for scheduling learning rate.
446
447    .. versionadded:: 1.3.0
448
449    Parameters
450    ----------
451
452    learning_rates : callable/collections.Sequence
453        If it's a callable object, then it should accept an integer parameter
454        `epoch` and returns the corresponding learning rate.  Otherwise it
455        should be a sequence like list or tuple with the same size of boosting
456        rounds.
457
458    '''
459    def __init__(self, learning_rates) -> None:
460        assert callable(learning_rates) or \
461            isinstance(learning_rates, collections.abc.Sequence)
462        if callable(learning_rates):
463            self.learning_rates = learning_rates
464        else:
465            self.learning_rates = lambda epoch: learning_rates[epoch]
466        super().__init__()
467
468    def after_iteration(self, model, epoch, evals_log) -> bool:
469        model.set_param('learning_rate', self.learning_rates(epoch))
470        return False
471
472
473# pylint: disable=too-many-instance-attributes
474class EarlyStopping(TrainingCallback):
475    """Callback function for early stopping
476
477    .. versionadded:: 1.3.0
478
479    Parameters
480    ----------
481    rounds
482        Early stopping rounds.
483    metric_name
484        Name of metric that is used for early stopping.
485    data_name
486        Name of dataset that is used for early stopping.
487    maximize
488        Whether to maximize evaluation metric.  None means auto (discouraged).
489    save_best
490        Whether training should return the best model or the last model.
491    min_delta
492        Minimum absolute change in score to be qualified as an improvement.
493
494        .. versionadded:: 1.5.0
495
496        .. code-block:: python
497
498            clf = xgboost.XGBClassifier(tree_method="gpu_hist")
499            es = xgboost.callback.EarlyStopping(
500                rounds=2,
501                abs_tol=1e-3,
502                save_best=True,
503                maximize=False,
504                data_name="validation_0",
505                metric_name="mlogloss",
506            )
507
508            X, y = load_digits(return_X_y=True)
509            clf.fit(X, y, eval_set=[(X, y)], callbacks=[es])
510    """
511    def __init__(
512        self,
513        rounds: int,
514        metric_name: Optional[str] = None,
515        data_name: Optional[str] = None,
516        maximize: Optional[bool] = None,
517        save_best: Optional[bool] = False,
518        min_delta: float = 0.0
519    ) -> None:
520        self.data = data_name
521        self.metric_name = metric_name
522        self.rounds = rounds
523        self.save_best = save_best
524        self.maximize = maximize
525        self.stopping_history: TrainingCallback.EvalsLog = {}
526        self._min_delta = min_delta
527        if self._min_delta < 0:
528            raise ValueError("min_delta must be greater or equal to 0.")
529
530        self.improve_op = None
531
532        self.current_rounds: int = 0
533        self.best_scores: dict = {}
534        self.starting_round: int = 0
535        super().__init__()
536
537    def before_training(self, model):
538        self.starting_round = model.num_boosted_rounds()
539        return model
540
541    def _update_rounds(self, score, name, metric, model, epoch) -> bool:
542        def get_s(x):
543            """get score if it's cross validation history."""
544            return x[0] if isinstance(x, tuple) else x
545
546        def maximize(new, best):
547            """New score should be greater than the old one."""
548            return numpy.greater(get_s(new) - self._min_delta, get_s(best))
549
550        def minimize(new, best):
551            """New score should be smaller than the old one."""
552            return numpy.greater(get_s(best) - self._min_delta, get_s(new))
553
554        if self.maximize is None:
555            # Just to be compatibility with old behavior before 1.3.  We should let
556            # user to decide.
557            maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg', 'auc@',
558                                'aucpr@', 'map@', 'ndcg@')
559            if metric != 'mape' and any(metric.startswith(x) for x in maximize_metrics):
560                self.maximize = True
561            else:
562                self.maximize = False
563
564        if self.maximize:
565            self.improve_op = maximize
566        else:
567            self.improve_op = minimize
568
569        assert self.improve_op
570
571        if not self.stopping_history:  # First round
572            self.current_rounds = 0
573            self.stopping_history[name] = {}
574            self.stopping_history[name][metric] = [score]
575            self.best_scores[name] = {}
576            self.best_scores[name][metric] = [score]
577            model.set_attr(best_score=str(score), best_iteration=str(epoch))
578        elif not self.improve_op(score, self.best_scores[name][metric][-1]):
579            # Not improved
580            self.stopping_history[name][metric].append(score)
581            self.current_rounds += 1
582        else:  # Improved
583            self.stopping_history[name][metric].append(score)
584            self.best_scores[name][metric].append(score)
585            record = self.stopping_history[name][metric][-1]
586            model.set_attr(best_score=str(record), best_iteration=str(epoch))
587            self.current_rounds = 0  # reset
588
589        if self.current_rounds >= self.rounds:
590            # Should stop
591            return True
592        return False
593
594    def after_iteration(self, model, epoch: int,
595                        evals_log: TrainingCallback.EvalsLog) -> bool:
596        epoch += self.starting_round  # training continuation
597        msg = 'Must have at least 1 validation dataset for early stopping.'
598        assert len(evals_log.keys()) >= 1, msg
599        data_name = ''
600        if self.data:
601            for d, _ in evals_log.items():
602                if d == self.data:
603                    data_name = d
604            if not data_name:
605                raise ValueError('No dataset named:', self.data)
606        else:
607            # Use the last one as default.
608            data_name = list(evals_log.keys())[-1]
609        assert isinstance(data_name, str) and data_name
610        data_log = evals_log[data_name]
611
612        # Filter out scores that can not be used for early stopping.
613        if self.metric_name:
614            metric_name = self.metric_name
615        else:
616            # Use last metric by default.
617            assert isinstance(data_log, collections.OrderedDict)
618            metric_name = list(data_log.keys())[-1]
619        score = data_log[metric_name][-1]
620        return self._update_rounds(score, data_name, metric_name, model, epoch)
621
622    def after_training(self, model):
623        try:
624            if self.save_best:
625                model = model[: int(model.attr("best_iteration")) + 1]
626        except XGBoostError as e:
627            raise XGBoostError(
628                "`save_best` is not applicable to current booster"
629            ) from e
630        return model
631
632
633class EvaluationMonitor(TrainingCallback):
634    '''Print the evaluation result at each iteration.
635
636    .. versionadded:: 1.3.0
637
638    Parameters
639    ----------
640
641    metric : callable
642        Extra user defined metric.
643    rank : int
644        Which worker should be used for printing the result.
645    period : int
646        How many epoches between printing.
647    show_stdv : bool
648        Used in cv to show standard deviation.  Users should not specify it.
649    '''
650    def __init__(self, rank=0, period=1, show_stdv=False) -> None:
651        self.printer_rank = rank
652        self.show_stdv = show_stdv
653        self.period = period
654        assert period > 0
655        # last error message, useful when early stopping and period are used together.
656        self._latest: Optional[str] = None
657        super().__init__()
658
659    def _fmt_metric(
660        self, data: str, metric: str, score: float, std: Optional[float]
661    ) -> str:
662        if std is not None and self.show_stdv:
663            msg = f"\t{data + '-' + metric}:{score:.5f}+{std:.5f}"
664        else:
665            msg = f"\t{data + '-' + metric}:{score:.5f}"
666        return msg
667
668    def after_iteration(self, model, epoch: int,
669                        evals_log: TrainingCallback.EvalsLog) -> bool:
670        if not evals_log:
671            return False
672
673        msg: str = f'[{epoch}]'
674        if rabit.get_rank() == self.printer_rank:
675            for data, metric in evals_log.items():
676                for metric_name, log in metric.items():
677                    stdv: Optional[float] = None
678                    if isinstance(log[-1], tuple):
679                        score = log[-1][0]
680                        stdv = log[-1][1]
681                    else:
682                        score = log[-1]
683                    msg += self._fmt_metric(data, metric_name, score, stdv)
684            msg += '\n'
685
686            if (epoch % self.period) == 0 or self.period == 1:
687                rabit.tracker_print(msg)
688                self._latest = None
689            else:
690                # There is skipped message
691                self._latest = msg
692        return False
693
694    def after_training(self, model):
695        if rabit.get_rank() == self.printer_rank and self._latest is not None:
696            rabit.tracker_print(self._latest)
697        return model
698
699
700class TrainingCheckPoint(TrainingCallback):
701    '''Checkpointing operation.
702
703    .. versionadded:: 1.3.0
704
705    Parameters
706    ----------
707
708    directory : os.PathLike
709        Output model directory.
710    name : str
711        pattern of output model file.  Models will be saved as name_0.json, name_1.json,
712        name_2.json ....
713    as_pickle : boolean
714        When set to Ture, all training parameters will be saved in pickle format, instead
715        of saving only the model.
716    iterations : int
717        Interval of checkpointing.  Checkpointing is slow so setting a larger number can
718        reduce performance hit.
719
720    '''
721    def __init__(self, directory: os.PathLike, name: str = 'model',
722                 as_pickle=False, iterations: int = 100):
723        self._path = directory
724        self._name = name
725        self._as_pickle = as_pickle
726        self._iterations = iterations
727        self._epoch = 0
728        super().__init__()
729
730    def after_iteration(self, model, epoch: int,
731                        evals_log: TrainingCallback.EvalsLog) -> bool:
732        if self._epoch == self._iterations:
733            path = os.path.join(self._path, self._name + '_' + str(epoch) +
734                                ('.pkl' if self._as_pickle else '.json'))
735            self._epoch = 0
736            if rabit.get_rank() == 0:
737                if self._as_pickle:
738                    with open(path, 'wb') as fd:
739                        pickle.dump(model, fd)
740                else:
741                    model.save_model(path)
742        self._epoch += 1
743        return False
744
745
746class LegacyCallbacks:
747    '''Adapter for legacy callback functions.
748
749    .. versionadded:: 1.3.0
750
751    Parameters
752    ----------
753
754    callbacks : Sequence
755        A sequence of legacy callbacks (callbacks that are not instance of
756        TrainingCallback)
757    start_iteration : int
758        Begining iteration.
759    end_iteration : int
760        End iteration, normally is the number of boosting rounds.
761    evals : Sequence
762        Sequence of evaluation dataset tuples.
763    feval : Custom evaluation metric.
764    '''
765    def __init__(self, callbacks, start_iteration, end_iteration,
766                 feval, cvfolds=None):
767        self.callbacks_before_iter = [
768            cb for cb in callbacks
769            if cb.__dict__.get('before_iteration', False)]
770        self.callbacks_after_iter = [
771            cb for cb in callbacks
772            if not cb.__dict__.get('before_iteration', False)]
773
774        self.start_iteration = start_iteration
775        self.end_iteration = end_iteration
776        self.cvfolds = cvfolds
777
778        self.feval = feval
779        assert self.feval is None or callable(self.feval)
780
781        if cvfolds is not None:
782            self.aggregated_cv = None
783
784        super().__init__()
785
786    def before_training(self, model):
787        '''Nothing to do for legacy callbacks'''
788        return model
789
790    def after_training(self, model):
791        '''Nothing to do for legacy callbacks'''
792        return model
793
794    def before_iteration(self, model, epoch, dtrain, evals):
795        '''Called before each iteration.'''
796        for cb in self.callbacks_before_iter:
797            rank = rabit.get_rank()
798            cb(CallbackEnv(model=None if self.cvfolds is not None else model,
799                           cvfolds=self.cvfolds,
800                           iteration=epoch,
801                           begin_iteration=self.start_iteration,
802                           end_iteration=self.end_iteration,
803                           rank=rank,
804                           evaluation_result_list=None))
805        return False
806
807    def after_iteration(self, model, epoch, dtrain, evals):
808        '''Called after each iteration.'''
809        evaluation_result_list = []
810        if self.cvfolds is not None:
811            # dtrain is not used here.
812            scores = model.eval(epoch, self.feval)
813            self.aggregated_cv = _aggcv(scores)
814            evaluation_result_list = self.aggregated_cv
815
816        if evals:
817            # When cv is used, evals are embedded into folds.
818            assert self.cvfolds is None
819            bst_eval_set = model.eval_set(evals, epoch, self.feval)
820            if isinstance(bst_eval_set, STRING_TYPES):
821                msg = bst_eval_set
822            else:
823                msg = bst_eval_set.decode()
824            res = [x.split(':') for x in msg.split()]
825            evaluation_result_list = [(k, float(v)) for k, v in res[1:]]
826
827        try:
828            for cb in self.callbacks_after_iter:
829                rank = rabit.get_rank()
830                cb(CallbackEnv(model=None if self.cvfolds is not None else model,
831                               cvfolds=self.cvfolds,
832                               iteration=epoch,
833                               begin_iteration=self.start_iteration,
834                               end_iteration=self.end_iteration,
835                               rank=rank,
836                               evaluation_result_list=evaluation_result_list))
837        except EarlyStopException:
838            return True
839
840        return False
841