1import optuna
2from optuna.trial import Trial
3
4
5with optuna._imports.try_import() as _imports:
6    from ignite.engine import Engine
7
8
9class PyTorchIgnitePruningHandler(object):
10    """PyTorch Ignite handler to prune unpromising trials.
11
12    See `the example <https://github.com/optuna/optuna-examples/blob/main/
13    pytorch/pytorch_ignite_simple.py>`__
14    if you want to add a pruning handler which observes validation accuracy.
15
16    Args:
17        trial:
18            A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
19            objective function.
20        metric:
21            A name of metric for pruning, e.g., ``accuracy`` and ``loss``.
22        trainer:
23            A trainer engine of PyTorch Ignite. Please refer to `ignite.engine.Engine reference
24            <https://pytorch.org/ignite/engine.html#ignite.engine.Engine>`_ for further details.
25    """
26
27    def __init__(self, trial: Trial, metric: str, trainer: "Engine") -> None:
28
29        _imports.check()
30
31        self._trial = trial
32        self._metric = metric
33        self._trainer = trainer
34
35    def __call__(self, engine: "Engine") -> None:
36
37        score = engine.state.metrics[self._metric]
38        self._trial.report(score, self._trainer.state.epoch)
39        if self._trial.should_prune():
40            message = "Trial was pruned at {} epoch.".format(self._trainer.state.epoch)
41            raise optuna.TrialPruned(message)
42