1from typing import Any
2
3import optuna
4
5
6with optuna._imports.try_import() as _imports:
7    from skorch.callbacks import Callback
8    from skorch.net import NeuralNet
9
10if not _imports.is_successful():
11    Callback = object  # NOQA
12
13
14class SkorchPruningCallback(Callback):
15    """Skorch callback to prune unpromising trials.
16
17    .. versionadded:: 2.1.0
18
19    Args:
20        trial:
21            A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
22            objective function.
23        monitor:
24            An evaluation metric for pruning, e.g. ``val_loss`` or
25            ``val_acc``. The metrics are obtained from the returned dictionaries,
26            i.e., ``net.histroy``. The names thus depend on how this dictionary
27            is formatted.
28    """
29
30    def __init__(self, trial: optuna.trial.Trial, monitor: str) -> None:
31
32        _imports.check()
33
34        super().__init__()
35        self._trial = trial
36        self._monitor = monitor
37
38    def on_epoch_end(self, net: "NeuralNet", **kwargs: Any) -> None:
39        history = net.history
40        if not history:
41            return
42        epoch = len(history) - 1
43        current_score = history[-1, self._monitor]
44        self._trial.report(current_score, epoch)
45        if self._trial.should_prune():
46            message = "Trial was pruned at epoch {}.".format(epoch)
47            raise optuna.TrialPruned(message)
48