1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# coding: utf-8
19# pylint: disable=line-too-long
20"""Parameter optimizer."""
21__all__ = ['Trainer']
22
23from .. import optimizer as opt
24from ..model import _create_kvstore, _create_sparse_kvstore
25from .parameter import ParameterDict, Parameter
26from ..kvstore import KVStore
27
28class Trainer(object):
29    """Applies an `Optimizer` on a set of Parameters. Trainer should
30    be used together with `autograd`.
31
32    .. note::
33
34        For the following cases, updates will always happen on kvstore,
35        i.e., you cannot set update_on_kvstore=False.
36
37        - dist kvstore with sparse weights or sparse gradients
38        - dist async kvstore
39        - `optimizer.lr_scheduler` is not None
40
41    Parameters
42    ----------
43    params : ParameterDict
44        The set of parameters to optimize.
45    optimizer : str or Optimizer
46        The optimizer to use. See
47        `help <https://mxnet.apache.org/api/python/docs/api/optimizer/index.html#mxnet.optimizer.Optimizer>`_
48        on Optimizer for a list of available optimizers.
49    optimizer_params : dict
50        Key-word arguments to be passed to optimizer constructor. For example,
51        `{'learning_rate': 0.1}`. All optimizers accept learning_rate, wd (weight decay),
52        clip_gradient, and lr_scheduler. See each optimizer's
53        constructor for a list of additional supported arguments.
54    kvstore : str or KVStore
55        kvstore type for multi-gpu and distributed training. See help on
56        :any:`mxnet.kvstore.create` for more information.
57    compression_params : dict
58        Specifies type of gradient compression and additional arguments depending
59        on the type of compression being used. For example, 2bit compression requires a threshold.
60        Arguments would then be {'type':'2bit', 'threshold':0.5}
61        See mxnet.KVStore.set_gradient_compression method for more details on gradient compression.
62    update_on_kvstore : bool, default None
63        Whether to perform parameter updates on kvstore. If None, then trainer will choose the more
64        suitable option depending on the type of kvstore. If the `update_on_kvstore` argument is
65        provided, environment variable `MXNET_UPDATE_ON_KVSTORE` will be ignored.
66
67    Properties
68    ----------
69    learning_rate : float
70        The current learning rate of the optimizer. Given an Optimizer object
71        optimizer, its learning rate can be accessed as optimizer.learning_rate.
72    """
73    def __init__(self, params, optimizer, optimizer_params=None, kvstore='device',
74                 compression_params=None, update_on_kvstore=None):
75        param_list = []
76        if isinstance(params, (dict, ParameterDict)):
77            for key in sorted(list(params.keys())):
78                param_list.append(params[key])
79            params = param_list
80        if not isinstance(params, (list, tuple)):
81            raise ValueError(
82                "First argument must be a list or dict of Parameters, " \
83                "got %s."%(type(params)))
84        self._params = []
85        # parameters to initialize on the kvstore
86        self._contains_sparse_weight = False
87        self._contains_sparse_grad = False
88        self._param2idx = {}
89        for i, param in enumerate(params):
90            if not isinstance(param, Parameter):
91                raise ValueError(
92                    "First argument must be a list or dict of Parameters, " \
93                    "got list of %s."%(type(param)))
94            self._param2idx[param.name] = i
95            self._params.append(param)
96            param._set_trainer(self)
97            if param._stype != 'default':
98                self._contains_sparse_weight = True
99            if param._grad_stype != 'default':
100                self._contains_sparse_grad = True
101        self._compression_params = compression_params
102        self._contexts = self._check_contexts()
103        optimizer_params = optimizer_params if optimizer_params else {}
104        self._init_optimizer(optimizer, optimizer_params)
105        self._scale = self._optimizer.rescale_grad
106        self._kvstore_params = {'kvstore': kvstore, 'update_on_kvstore': update_on_kvstore}
107        self._kv_initialized = False
108        self._kvstore = None
109        self._update_on_kvstore = None
110        self._distributed = None
111        self._params_to_init = []
112        self._reset_kvstore()
113
114    def _check_contexts(self):
115        contexts = None
116        for param in self._params:
117            ctx = param.list_ctx()
118            assert contexts is None or contexts == ctx, \
119                "All Parameters must be initialized on the same set of contexts, " \
120                "but Parameter %s is initialized on %s while previous Parameters " \
121                "are initialized on %s."%(param.name, str(ctx), str(contexts))
122            contexts = ctx
123        return contexts
124
125    def _init_optimizer(self, optimizer, optimizer_params):
126        param_dict = {i: param for i, param in enumerate(self._params)}
127        if isinstance(optimizer, opt.Optimizer):
128            assert not optimizer_params, \
129                "optimizer_params must be None if optimizer is an instance of " \
130                "Optimizer instead of str"
131            self._optimizer = optimizer
132            # param_dict must not be deep copied, so that if user mutate the lr_mult
133            # or wd_mult of some parameters, it takes effect.
134            self._optimizer.param_dict = param_dict
135        else:
136            self._optimizer = opt.create(optimizer, param_dict=param_dict,
137                                         **optimizer_params)
138        self._updaters = [opt.get_updater(self._optimizer) \
139                            for _ in self._contexts]
140
141    def _init_params(self):
142        """Initialize parameters in the KVStore.
143
144        Parameters with incomplete initialization are ignored.
145
146        """
147        assert self._kv_initialized, "Cannot initialize parameters in KVStore " \
148                                     "when KVStore is not initialized."
149        params_to_init = []
150        if self._kvstore:
151            for param in self._params_to_init:
152                if param._deferred_init:
153                    params_to_init.append(param)
154                else:
155                    param_arrays = param._check_and_get(param._data, list)
156                    idx = self._param2idx[param.name]
157                    if param._stype != 'default':
158                        self._kvstore.init(idx, param_arrays[0])
159                    else:
160                        self._kvstore.broadcast(idx, param_arrays[0], param_arrays)
161
162        self._params_to_init = params_to_init
163
164    def _reset_kvstore(self):
165        """Reset kvstore."""
166        if self._kvstore and 'dist' in self._kvstore.type:
167            raise RuntimeError("Cannot reset distributed KVStore.")
168        self._kv_initialized = False
169        self._kvstore = None
170        self._distributed = None
171        self._update_on_kvstore = None
172        self._params_to_init = [param for param in self._params]
173
174    def _init_kvstore(self):
175        """Create kvstore."""
176        config = self._kvstore_params
177        # configure kvstore, update_on_kvstore and self._distributed on three cases:
178        if self._contains_sparse_weight:
179            # If weight is sparse, kvstore must be present and the weight must be updated on kvstore.
180            # The training loop is the following:
181            #    - row_sparse_pull(sparse_weight)
182            #    - forward()
183            #    - backward()
184            #    - push_and_update(grad)
185            #    - pull(weight)
186            kvstore, update_on_kvstore = _create_sparse_kvstore(config['kvstore'])
187            self._distributed = 'dist' in kvstore.type
188            # raise err if user provides unsupported configs
189            if config['update_on_kvstore'] is False:
190                raise ValueError("Cannot set update_on_kvstore=False when sparse weights "
191                                 "are present.")
192
193        elif self._contains_sparse_grad:
194            # For single node training with dense weight and sparse grad,
195            # we prefer update_on_kvstore=False because this is usually faster.
196            # This means we push and pull sparse gradients, and we do not store weight in kvstore.
197            # The training loop is the following:
198            #    - forward()
199            #    - backward()
200            #    - push(grad)
201            #    - pull(grad)
202            #    - update(grad, weight)
203            #
204            # For multi-node training with dense weight and sparse grad,
205            # only update_on_kvstore=True is supported, due to the fact that
206            # kv.row_sparse_pull(grad) is not implemented.
207            # Therefore, we push sparse gradients and pull dense weights.
208            # The training loop contains:
209            #    - forward()
210            #    - backward()
211            #    - push_and_update(grad)
212            #    - pull(weight)
213            arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params}
214            kvstore, _ = _create_kvstore(config['kvstore'], len(self._contexts), arg_arrays)
215            self._distributed = 'dist' in kvstore.type if kvstore else False
216            update_on_kvstore = self._distributed
217            # raise err if user provides unsupported configs
218            if config['update_on_kvstore'] is not None:
219                if config['update_on_kvstore'] is False and self._distributed:
220                    raise ValueError("Cannot set update_on_kvstore=False on dist kvstore "
221                                     "when sparse gradients are present.")
222                update_on_kvstore = config['update_on_kvstore']
223            # raise err if a custom kvstore is used for sparse training
224            if kvstore is not None and not isinstance(kvstore, KVStore):
225                raise ValueError("Cannot use {} for multi-device training with sparse gradients"
226                                 .format(type(kvstore)))
227
228        else:
229            # Training with dense weight and dense gradients.
230            # The only unsupported mode is async with update_on_kvstore=False
231            arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params}
232            kvstore, update_on_kvstore = _create_kvstore(config['kvstore'], len(self._contexts),
233                                                         arg_arrays)
234            self._distributed = 'dist' in kvstore.type if kvstore else False
235            if self._distributed and 'async' in kvstore.type:
236                update_on_kvstore = True
237                # raise err if user provides unsupported configs
238                if config['update_on_kvstore'] is False:
239                    raise ValueError("Please set update_on_kvstore=True "
240                                     "when training in async mode.")
241            if config['update_on_kvstore'] is not None:
242                update_on_kvstore = config['update_on_kvstore']
243            # raise err if update_on_kvstore is set to True with kvstores that do not support optimizers
244            if update_on_kvstore and not kvstore.is_capable('optimizer'):
245                if config['update_on_kvstore']:
246                    raise ValueError("Please set update_on_kvstore=False "
247                                     "when training with {}".format(type(kvstore)))
248                update_on_kvstore = False
249
250        # set grad compression and optimizers
251        if kvstore:
252            if self._compression_params:
253                kvstore.set_gradient_compression(self._compression_params)
254            if update_on_kvstore:
255                # optimizer preferably needs to be set before init for multiprecision
256                kvstore.set_optimizer(self._optimizer)
257            self._kvstore = kvstore
258            self._update_on_kvstore = update_on_kvstore
259        else:
260            self._kvstore = None
261            self._update_on_kvstore = None
262
263        self._kv_initialized = True
264
265    @property
266    def learning_rate(self):
267        if not isinstance(self._optimizer, opt.Optimizer):
268            raise UserWarning("Optimizer has to be defined before its learning "
269                              "rate can be accessed.")
270
271        return self._optimizer.learning_rate
272
273    @property
274    def optimizer(self):
275        if isinstance(self._optimizer, opt.Optimizer):
276            return self._optimizer
277        else:
278            raise UserWarning("Optimizer has not been initialized yet")
279
280    def set_learning_rate(self, lr):
281        """Sets a new learning rate of the optimizer.
282
283        Parameters
284        ----------
285        lr : float
286            The new learning rate of the optimizer.
287        """
288        if not isinstance(self._optimizer, opt.Optimizer):
289            raise UserWarning("Optimizer has to be defined before its learning "
290                              "rate is mutated.")
291
292        self._optimizer.set_learning_rate(lr)
293
294    def _row_sparse_pull(self, parameter, out, row_id, full_idx=False):
295        """Internal method to invoke pull operations on KVStore. If `full_idx` is set to True,
296        `kv.pull` is preferred instead of `kv.row_sparse_pull`.
297        """
298        # initialize kv and params if not already
299        if not self._kv_initialized:
300            self._init_kvstore()
301        if self._params_to_init:
302            self._init_params()
303        idx = self._param2idx[parameter.name]
304        if full_idx and 'dist' not in self._kvstore.type:
305            assert row_id.size == out.shape[0]
306            self._kvstore.pull(idx, out=out, priority=-idx, ignore_sparse=False)
307        else:
308            self._kvstore.row_sparse_pull(idx, out=out, row_ids=row_id, priority=-idx)
309
310    def _check_and_rescale_grad(self, scale):
311        if self._update_on_kvstore and self._distributed and self._kv_initialized:
312            if self._optimizer.rescale_grad != scale:
313                raise UserWarning('Possible change in the `batch_size` from previous '
314                                  '`step` detected. Optimizer gradient normalizing '
315                                  'factor will not change w.r.t new batch_size when '
316                                  'update_on_kvstore=True and when distributed kvstore '
317                                  'is used.')
318        self._optimizer.rescale_grad = scale
319
320    def step(self, batch_size, ignore_stale_grad=False):
321        """Makes one step of parameter update. Should be called after
322        `autograd.backward()` and outside of `record()` scope.
323
324        For normal parameter updates, `step()` should be used, which internally calls
325        `allreduce_grads()` and then `update()`. However, if you need to get the reduced
326        gradients to perform certain transformation, such as in gradient clipping, then
327        you may want to manually call `allreduce_grads()` and `update()` separately.
328
329        Parameters
330        ----------
331        batch_size : int
332            Batch size of data processed. Gradient will be normalized by `1/batch_size`.
333            Set this to 1 if you normalized loss manually with `loss = mean(loss)`.
334        ignore_stale_grad : bool, optional, default=False
335            If true, ignores Parameters with stale gradient (gradient that has not
336            been updated by `backward` after last step) and skip update.
337        """
338        rescale_grad = self._scale / batch_size
339        self._check_and_rescale_grad(rescale_grad)
340
341        if not self._kv_initialized:
342            self._init_kvstore()
343        if self._params_to_init:
344            self._init_params()
345
346        self._allreduce_grads()
347        self._update(ignore_stale_grad)
348
349    def allreduce_grads(self):
350        """For each parameter, reduce the gradients from different contexts.
351
352        Should be called after `autograd.backward()`, outside of `record()` scope,
353        and before `trainer.update()`.
354
355        For normal parameter updates, `step()` should be used, which internally calls
356        `allreduce_grads()` and then `update()`. However, if you need to get the reduced
357        gradients to perform certain transformation, such as in gradient clipping, then
358        you may want to manually call `allreduce_grads()` and `update()` separately.
359        """
360        if not self._kv_initialized:
361            self._init_kvstore()
362        if self._params_to_init:
363            self._init_params()
364        assert not (self._kvstore and self._update_on_kvstore), \
365                'allreduce_grads() when parameters are updated on kvstore ' \
366                'is not supported. Try setting `update_on_kvstore` ' \
367                'to False when creating trainer.'
368
369        self._allreduce_grads()
370
371    def _allreduce_grads(self):
372        # nothing to reduce
373        if not self._kvstore:
374            return
375        for i, param in enumerate(self._params):
376            if param.grad_req != 'null':
377
378                grad_list = param.list_grad()
379                # sparse gradients, call push and pull separately
380                if grad_list[0].stype != 'default':
381                    self._kvstore.push(i, grad_list, priority=-i)
382                    if param._stype == 'default':
383                        if self._update_on_kvstore:
384                            pull_list = param.list_data()
385                        else:
386                            pull_list = param.list_grad()
387                        self._kvstore.pull(i, pull_list, priority=-i,
388                                           ignore_sparse=self._distributed)
389                else:
390                    # allreduce dense gradients if not update_on_kvstore,
391                    # otherwise push dense gradients, pull dense weights
392                    if self._update_on_kvstore:
393                        self._kvstore.pushpull(i, grad_list, out=param.list_data(), priority=-i)
394                    else:
395                        self._kvstore.pushpull(i, grad_list, priority=-i)
396
397    def update(self, batch_size, ignore_stale_grad=False):
398        """Makes one step of parameter update.
399
400        Should be called after `autograd.backward()` and outside of `record()` scope,
401        and after `trainer.update()`.
402
403
404        For normal parameter updates, `step()` should be used, which internally calls
405        `allreduce_grads()` and then `update()`. However, if you need to get the reduced
406        gradients to perform certain transformation, such as in gradient clipping, then
407        you may want to manually call `allreduce_grads()` and `update()` separately.
408
409        Parameters
410        ----------
411        batch_size : int
412            Batch size of data processed. Gradient will be normalized by `1/batch_size`.
413            Set this to 1 if you normalized loss manually with `loss = mean(loss)`.
414        ignore_stale_grad : bool, optional, default=False
415            If true, ignores Parameters with stale gradient (gradient that has not
416            been updated by `backward` after last step) and skip update.
417        """
418        if not self._kv_initialized:
419            self._init_kvstore()
420        if self._params_to_init:
421            self._init_params()
422        assert not (self._kvstore and self._update_on_kvstore), \
423                'update() when parameters are updated on kvstore ' \
424                'is not supported. Try setting `update_on_kvstore` ' \
425                'to False when creating trainer.'
426
427        self._check_and_rescale_grad(self._scale / batch_size)
428        self._update(ignore_stale_grad)
429
430    def _update(self, ignore_stale_grad=False):
431        loss_scaler = getattr(self, '_amp_loss_scaler', None)
432        if loss_scaler is not None:
433            if loss_scaler.has_overflow(self._params):
434                return  # skip on overflow
435
436        updates = [[] for _ in self._updaters]
437
438        for i, param in enumerate(self._params):
439            if param.grad_req == 'null':
440                continue
441
442            if not ignore_stale_grad:
443                for data in param._check_and_get(param._data, list):
444                    if not data._fresh_grad:
445                        raise UserWarning(
446                            "Gradient of Parameter `%s` on context %s has not been updated "
447                            "by backward since last `step`. This could mean a bug in your "
448                            "model that made it only use a subset of the Parameters (Blocks) "
449                            "for this iteration. If you are intentionally only using a subset, "
450                            "call step with ignore_stale_grad=True to suppress this "
451                            "warning and skip updating of Parameters with stale gradient" \
452                            %(param.name, str(data.context)))
453
454            if self._kvstore and self._update_on_kvstore:
455                continue
456
457            for upd, arr, grad in zip(updates, param.list_data(), param.list_grad()):
458                if not ignore_stale_grad or arr._fresh_grad:
459                    upd.append((i, grad, arr))
460                    arr._fresh_grad = False
461
462        if not (self._kvstore and self._update_on_kvstore):
463            for updater, upd in zip(self._updaters, updates):
464                if upd:
465                    i, w, g = zip(*upd)
466                    updater(i, w, g)
467
468    def save_states(self, fname):
469        """Saves trainer states (e.g. optimizer, momentum) to a file.
470
471
472        Parameters
473        ----------
474        fname : str
475            Path to output states file.
476
477        Note
478        ----
479        `optimizer.param_dict`, which contains Parameter information (such as
480        `lr_mult` and `wd_mult`) will not be saved.
481        """
482        assert self._optimizer is not None
483
484        if not self._kv_initialized:
485            self._init_kvstore()
486        if self._params_to_init:
487            self._init_params()
488
489        if self._update_on_kvstore:
490            assert not self._params_to_init, "Cannot save trainer states when some " \
491                                             "parameters are not yet initialized in kvstore."
492            self._kvstore.save_optimizer_states(fname, dump_optimizer=True)
493        else:
494            with open(fname, 'wb') as fout:
495                fout.write(self._updaters[0].get_states(dump_optimizer=True))
496
497    def load_states(self, fname):
498        """Loads trainer states (e.g. optimizer, momentum) from a file.
499
500        Parameters
501        ----------
502        fname : str
503            Path to input states file.
504
505        Note
506        ----
507        `optimizer.param_dict`, which contains Parameter information (such as
508        `lr_mult` and `wd_mult`) will not be loaded from the file, but rather set
509        based on current Trainer's parameters.
510        """
511        if not self._kv_initialized:
512            self._init_kvstore()
513        if self._params_to_init:
514            self._init_params()
515
516        if self._update_on_kvstore:
517            self._kvstore.load_optimizer_states(fname)
518            self._optimizer = self._kvstore._updater.optimizer
519        else:
520            with open(fname, 'rb') as f:
521                states = f.read()
522            for updater in self._updaters:
523                updater.set_states(states)
524                updater.optimizer = self._updaters[0].optimizer
525            self._optimizer = self._updaters[0].optimizer
526        param_dict = {i: param for i, param in enumerate(self._params)}
527        self._optimizer.param_dict = param_dict
528