1import optuna 2from optuna.trial import Trial 3 4 5with optuna._imports.try_import() as _imports: 6 from ignite.engine import Engine 7 8 9class PyTorchIgnitePruningHandler(object): 10 """PyTorch Ignite handler to prune unpromising trials. 11 12 See `the example <https://github.com/optuna/optuna-examples/blob/main/ 13 pytorch/pytorch_ignite_simple.py>`__ 14 if you want to add a pruning handler which observes validation accuracy. 15 16 Args: 17 trial: 18 A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the 19 objective function. 20 metric: 21 A name of metric for pruning, e.g., ``accuracy`` and ``loss``. 22 trainer: 23 A trainer engine of PyTorch Ignite. Please refer to `ignite.engine.Engine reference 24 <https://pytorch.org/ignite/engine.html#ignite.engine.Engine>`_ for further details. 25 """ 26 27 def __init__(self, trial: Trial, metric: str, trainer: "Engine") -> None: 28 29 _imports.check() 30 31 self._trial = trial 32 self._metric = metric 33 self._trainer = trainer 34 35 def __call__(self, engine: "Engine") -> None: 36 37 score = engine.state.metrics[self._metric] 38 self._trial.report(score, self._trainer.state.epoch) 39 if self._trial.should_prune(): 40 message = "Trial was pruned at {} epoch.".format(self._trainer.state.epoch) 41 raise optuna.TrialPruned(message) 42