1import warnings 2 3import optuna 4 5 6with optuna._imports.try_import() as _imports: 7 from pytorch_lightning import LightningModule 8 from pytorch_lightning import Trainer 9 from pytorch_lightning.callbacks import Callback 10 11if not _imports.is_successful(): 12 Callback = object # type: ignore # NOQA 13 LightningModule = object # type: ignore # NOQA 14 Trainer = object # type: ignore # NOQA 15 16 17class PyTorchLightningPruningCallback(Callback): 18 """PyTorch Lightning callback to prune unpromising trials. 19 20 See `the example <https://github.com/optuna/optuna-examples/blob/ 21 main/pytorch/pytorch_lightning_simple.py>`__ 22 if you want to add a pruning callback which observes accuracy. 23 24 Args: 25 trial: 26 A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the 27 objective function. 28 monitor: 29 An evaluation metric for pruning, e.g., ``val_loss`` or 30 ``val_acc``. The metrics are obtained from the returned dictionaries from e.g. 31 ``pytorch_lightning.LightningModule.training_step`` or 32 ``pytorch_lightning.LightningModule.validation_epoch_end`` and the names thus depend on 33 how this dictionary is formatted. 34 """ 35 36 def __init__(self, trial: optuna.trial.Trial, monitor: str) -> None: 37 _imports.check() 38 super().__init__() 39 40 self._trial = trial 41 self.monitor = monitor 42 43 def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None: 44 epoch = pl_module.current_epoch 45 46 current_score = trainer.callback_metrics.get(self.monitor) 47 if current_score is None: 48 message = ( 49 "The metric '{}' is not in the evaluation logs for pruning. " 50 "Please make sure you set the correct metric name.".format(self.monitor) 51 ) 52 warnings.warn(message) 53 return 54 55 self._trial.report(current_score, step=epoch) 56 if self._trial.should_prune(): 57 message = "Trial was pruned at epoch {}.".format(epoch) 58 raise optuna.TrialPruned(message) 59