1"""
2Gradually warm-up(increasing) learning rate for pytorch's optimizer.
3Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
4Code adapted from https://github.com/ildoonet/pytorch-gradual-warmup-lr
5"""
6# pylint: disable=missing-function-docstring, line-too-long, inconsistent-return-statements
7from torch.optim.lr_scheduler import _LRScheduler
8from torch.optim.lr_scheduler import ReduceLROnPlateau
9
10
11class GradualWarmupScheduler(_LRScheduler):
12    """ Gradually warm-up(increasing) learning rate in optimizer.
13    Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
14
15    Args:
16        optimizer (Optimizer): Wrapped optimizer.
17        multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
18        total_epoch: target learning rate is reached at total_epoch, gradually
19        after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
20    """
21
22    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
23        self.multiplier = multiplier
24        if self.multiplier < 1.:
25            raise ValueError('multiplier should be greater thant or equal to 1.')
26        self.total_epoch = total_epoch
27        self.after_scheduler = after_scheduler
28        self.finished = False
29        super(GradualWarmupScheduler, self).__init__(optimizer)
30
31    def get_lr(self):
32        if self.last_epoch > self.total_epoch:
33            if self.after_scheduler:
34                if not self.finished:
35                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
36                    self.finished = True
37                return self.after_scheduler.get_last_lr()
38            return [base_lr * self.multiplier for base_lr in self.base_lrs]
39
40        if self.multiplier == 1.0:
41            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
42        else:
43            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
44
45    def step_ReduceLROnPlateau(self, metrics, epoch=None):
46        # pylint: disable=access-member-before-definition
47        if epoch is None:
48            epoch = self.last_epoch + 1
49        # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
50        self.last_epoch = epoch if epoch != 0 else 1
51        if self.last_epoch <= self.total_epoch:
52            warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
53            for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
54                param_group['lr'] = lr
55        else:
56            if epoch is None:
57                self.after_scheduler.step(metrics, None)
58            else:
59                self.after_scheduler.step(metrics, epoch - self.total_epoch)
60
61    def step(self, epoch=None, metrics=None):
62        if not isinstance(self.after_scheduler, ReduceLROnPlateau):
63            if self.finished and self.after_scheduler:
64                if epoch is None:
65                    self.after_scheduler.step(None)
66                else:
67                    self.after_scheduler.step(epoch - self.total_epoch)
68                self._last_lr = self.after_scheduler.get_last_lr()
69            else:
70                return super(GradualWarmupScheduler, self).step(epoch)
71        else:
72            self.step_ReduceLROnPlateau(metrics, epoch)
73