1from __future__ import division
2
3from chainer.training import extension
4
5
6class MultistepShift(extension.Extension):
7
8    """Trainer extension to shift an optimizer attribute in several steps.
9
10    This extension changes an optimizer attribute in several steps, every step
11    the attribute will multiply a factor ``gamma``.
12
13    For example, suppose that this extension is called at every iteration,
14    and ``init = x``, ``gamma = y``, ``step_value = [s1, s2, s3]``.
15    Then during the iterations from 0 to (s1 - 1), the attr will be ``x``.
16    During the iterations from s1 to (s2 - 1), the attr will be ``x * y``.
17    During the iterations from s2 to (s3 - 1), the attr will be ``x * y * y``.
18    During the iterations after s3, the attr will be ``x * y * y * y``.
19
20    This extension is also called before the training loop starts by default.
21
22    Args:
23        attr (str): Name of the attribute to shift.
24        init (float): Initial value of the attribute. If it is ``None``, the
25            extension extracts the attribute at the first call and uses it as
26            the initial value.
27        gamma (float): The factor which the attr will mutiply at the beginning
28            of each step.
29        step_value (tuple): The first iterations of each step.
30        optimizer (~chainer.Optimizer): Target optimizer to adjust the
31            attribute. If it is ``None``, the main optimizer of the updater is
32            used.
33
34    """
35
36    def __init__(self, attr, gamma, step_value, init, optimizer=None):
37        self._attr = attr
38        self._gamma = gamma
39        self._step_value = step_value
40        self._init = init
41        self._optimizer = optimizer
42        self._stepvalue_size = len(step_value)
43        self._current_step = 0
44        self._t = 0
45
46    def initialize(self, trainer):
47        optimizer = self._optimizer or trainer.updater.get_optimizer('main')
48        if self._init is None:
49            self._init = getattr(optimizer, self._attr)
50        else:
51            setattr(optimizer, self._attr, self._init)
52
53    def __call__(self, trainer):
54        self._t += 1
55        optimizer = self._optimizer or trainer.updater.get_optimizer('main')
56        if (self._current_step < self._stepvalue_size and
57                self._t >= self._step_value[self._current_step]):
58            self._current_step += 1
59        value = self._init * pow(self._gamma, self._current_step)
60        setattr(optimizer, self._attr, value)
61
62    def serialize(self, serializer):
63        self._t = serializer('_t', self._t)
64        self._current_step = serializer('_current_step', self._current_step)
65