1from __future__ import division 2 3from chainer.training import extension 4 5 6class MultistepShift(extension.Extension): 7 8 """Trainer extension to shift an optimizer attribute in several steps. 9 10 This extension changes an optimizer attribute in several steps, every step 11 the attribute will multiply a factor ``gamma``. 12 13 For example, suppose that this extension is called at every iteration, 14 and ``init = x``, ``gamma = y``, ``step_value = [s1, s2, s3]``. 15 Then during the iterations from 0 to (s1 - 1), the attr will be ``x``. 16 During the iterations from s1 to (s2 - 1), the attr will be ``x * y``. 17 During the iterations from s2 to (s3 - 1), the attr will be ``x * y * y``. 18 During the iterations after s3, the attr will be ``x * y * y * y``. 19 20 This extension is also called before the training loop starts by default. 21 22 Args: 23 attr (str): Name of the attribute to shift. 24 init (float): Initial value of the attribute. If it is ``None``, the 25 extension extracts the attribute at the first call and uses it as 26 the initial value. 27 gamma (float): The factor which the attr will mutiply at the beginning 28 of each step. 29 step_value (tuple): The first iterations of each step. 30 optimizer (~chainer.Optimizer): Target optimizer to adjust the 31 attribute. If it is ``None``, the main optimizer of the updater is 32 used. 33 34 """ 35 36 def __init__(self, attr, gamma, step_value, init, optimizer=None): 37 self._attr = attr 38 self._gamma = gamma 39 self._step_value = step_value 40 self._init = init 41 self._optimizer = optimizer 42 self._stepvalue_size = len(step_value) 43 self._current_step = 0 44 self._t = 0 45 46 def initialize(self, trainer): 47 optimizer = self._optimizer or trainer.updater.get_optimizer('main') 48 if self._init is None: 49 self._init = getattr(optimizer, self._attr) 50 else: 51 setattr(optimizer, self._attr, self._init) 52 53 def __call__(self, trainer): 54 self._t += 1 55 optimizer = self._optimizer or trainer.updater.get_optimizer('main') 56 if (self._current_step < self._stepvalue_size and 57 self._t >= self._step_value[self._current_step]): 58 self._current_step += 1 59 value = self._init * pow(self._gamma, self._current_step) 60 setattr(optimizer, self._attr, value) 61 62 def serialize(self, serializer): 63 self._t = serializer('_t', self._t) 64 self._current_step = serializer('_current_step', self._current_step) 65