1""" 2Gradually warm-up(increasing) learning rate for pytorch's optimizer. 3Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 4Code adapted from https://github.com/ildoonet/pytorch-gradual-warmup-lr 5""" 6# pylint: disable=missing-function-docstring, line-too-long, inconsistent-return-statements 7from torch.optim.lr_scheduler import _LRScheduler 8from torch.optim.lr_scheduler import ReduceLROnPlateau 9 10 11class GradualWarmupScheduler(_LRScheduler): 12 """ Gradually warm-up(increasing) learning rate in optimizer. 13 Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 14 15 Args: 16 optimizer (Optimizer): Wrapped optimizer. 17 multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 18 total_epoch: target learning rate is reached at total_epoch, gradually 19 after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 20 """ 21 22 def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 23 self.multiplier = multiplier 24 if self.multiplier < 1.: 25 raise ValueError('multiplier should be greater thant or equal to 1.') 26 self.total_epoch = total_epoch 27 self.after_scheduler = after_scheduler 28 self.finished = False 29 super(GradualWarmupScheduler, self).__init__(optimizer) 30 31 def get_lr(self): 32 if self.last_epoch > self.total_epoch: 33 if self.after_scheduler: 34 if not self.finished: 35 self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 36 self.finished = True 37 return self.after_scheduler.get_last_lr() 38 return [base_lr * self.multiplier for base_lr in self.base_lrs] 39 40 if self.multiplier == 1.0: 41 return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 42 else: 43 return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 44 45 def step_ReduceLROnPlateau(self, metrics, epoch=None): 46 # pylint: disable=access-member-before-definition 47 if epoch is None: 48 epoch = self.last_epoch + 1 49 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 50 self.last_epoch = epoch if epoch != 0 else 1 51 if self.last_epoch <= self.total_epoch: 52 warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 53 for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 54 param_group['lr'] = lr 55 else: 56 if epoch is None: 57 self.after_scheduler.step(metrics, None) 58 else: 59 self.after_scheduler.step(metrics, epoch - self.total_epoch) 60 61 def step(self, epoch=None, metrics=None): 62 if not isinstance(self.after_scheduler, ReduceLROnPlateau): 63 if self.finished and self.after_scheduler: 64 if epoch is None: 65 self.after_scheduler.step(None) 66 else: 67 self.after_scheduler.step(epoch - self.total_epoch) 68 self._last_lr = self.after_scheduler.get_last_lr() 69 else: 70 return super(GradualWarmupScheduler, self).step(epoch) 71 else: 72 self.step_ReduceLROnPlateau(metrics, epoch) 73