1import warnings
2
3import optuna
4
5
6with optuna._imports.try_import() as _imports:
7    from pytorch_lightning import LightningModule
8    from pytorch_lightning import Trainer
9    from pytorch_lightning.callbacks import Callback
10
11if not _imports.is_successful():
12    Callback = object  # type: ignore # NOQA
13    LightningModule = object  # type: ignore # NOQA
14    Trainer = object  # type: ignore # NOQA
15
16
17class PyTorchLightningPruningCallback(Callback):
18    """PyTorch Lightning callback to prune unpromising trials.
19
20    See `the example <https://github.com/optuna/optuna-examples/blob/
21    main/pytorch/pytorch_lightning_simple.py>`__
22    if you want to add a pruning callback which observes accuracy.
23
24    Args:
25        trial:
26            A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
27            objective function.
28        monitor:
29            An evaluation metric for pruning, e.g., ``val_loss`` or
30            ``val_acc``. The metrics are obtained from the returned dictionaries from e.g.
31            ``pytorch_lightning.LightningModule.training_step`` or
32            ``pytorch_lightning.LightningModule.validation_epoch_end`` and the names thus depend on
33            how this dictionary is formatted.
34    """
35
36    def __init__(self, trial: optuna.trial.Trial, monitor: str) -> None:
37        _imports.check()
38        super().__init__()
39
40        self._trial = trial
41        self.monitor = monitor
42
43    def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
44        epoch = pl_module.current_epoch
45
46        current_score = trainer.callback_metrics.get(self.monitor)
47        if current_score is None:
48            message = (
49                "The metric '{}' is not in the evaluation logs for pruning. "
50                "Please make sure you set the correct metric name.".format(self.monitor)
51            )
52            warnings.warn(message)
53            return
54
55        self._trial.report(current_score, step=epoch)
56        if self._trial.should_prune():
57            message = "Trial was pruned at epoch {}.".format(epoch)
58            raise optuna.TrialPruned(message)
59