1from typing import Any 2 3import optuna 4 5 6with optuna._imports.try_import() as _imports: 7 from skorch.callbacks import Callback 8 from skorch.net import NeuralNet 9 10if not _imports.is_successful(): 11 Callback = object # NOQA 12 13 14class SkorchPruningCallback(Callback): 15 """Skorch callback to prune unpromising trials. 16 17 .. versionadded:: 2.1.0 18 19 Args: 20 trial: 21 A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the 22 objective function. 23 monitor: 24 An evaluation metric for pruning, e.g. ``val_loss`` or 25 ``val_acc``. The metrics are obtained from the returned dictionaries, 26 i.e., ``net.histroy``. The names thus depend on how this dictionary 27 is formatted. 28 """ 29 30 def __init__(self, trial: optuna.trial.Trial, monitor: str) -> None: 31 32 _imports.check() 33 34 super().__init__() 35 self._trial = trial 36 self._monitor = monitor 37 38 def on_epoch_end(self, net: "NeuralNet", **kwargs: Any) -> None: 39 history = net.history 40 if not history: 41 return 42 epoch = len(history) - 1 43 current_score = history[-1, self._monitor] 44 self._trial.report(current_score, epoch) 45 if self._trial.should_prune(): 46 message = "Trial was pruned at epoch {}.".format(epoch) 47 raise optuna.TrialPruned(message) 48