1from __future__ import absolute_import
2import collections
3import heapq
4import warnings
5
6import six
7
8import chainer
9from chainer import _backprop_utils
10from chainer import backend
11from chainer.utils import argument
12import chainerx
13
14
15def backward(outputs, grad_outputs=None, **kwargs):
16    """backward(outputs, grad_outputs=None, *, enable_double_backprop=False)
17
18    Runs backpropagation from variables simultaneously.
19
20    .. warning::
21
22        This feature is experimental. The interface can change in the future.
23
24    Args:
25        outputs (tuple or list of :class:`~chainer.Variable`):
26            A sequence of output variables from which backprop starts.
27        grad_outputs (None or tuple or list of :class:`~chainer.Variable`):
28            A sequence of variables that gives the initial value of each output
29            gradient.
30            If this argument is ``None``, backprop uses
31            :attr:`~chainer.Variable.grad_var` of ``outputs``.
32        enable_double_backprop (bool): If ``True``,
33            computational trace of the whole backpropagation procedure is
34            recorded to the computational graph so that one can further do
35            backpropagation from the resulting gradients. Note that
36            enabling it results in larger memory consumption needed to
37            store the gradients w.r.t intermediate variables that are
38            required for the second gradient computation.
39
40    .. seealso::
41       :meth:`chainer.Variable.backward`
42       :func:`chainer.grad`
43
44    """
45    enable_double_backprop, = argument.parse_kwargs(
46        kwargs, ('enable_double_backprop', False),
47        retain_grad='semantics for retain_grad=True is under discussion',
48        loss_scale='chainer.backward does not support loss_scale option',
49    )
50    if not isinstance(outputs, (tuple, list)):
51        raise TypeError(
52            'outputs must be a tuple or a list, not {}.'.format(type(outputs)))
53    for v in outputs:
54        if not isinstance(v, chainer.Variable):
55            raise TypeError(
56                'each output must be a Variable, not {}'.format(type(v)))
57    if grad_outputs is not None:
58        if not isinstance(grad_outputs, (tuple, list)):
59            raise TypeError(
60                'grad_outputs must be None, a tuple, or a list, not {}.'
61                .format(type(grad_outputs)))
62        if len(outputs) != len(grad_outputs):
63            raise ValueError(
64                'grad_outputs must be of the same length as outputs.\n'
65                'len(outputs) = {}, len(grad_outputs) = {}'
66                .format(len(outputs), len(grad_outputs)))
67
68    is_chainerx = [v._has_chainerx_array for v in outputs]
69
70    if any(is_chainerx):
71        if not all(is_chainerx):
72            # The restriction is required as soon as the workarounds below
73            # are removed.
74            raise ValueError('cannot mix chainerx and other backends')
75
76        # Cannot use chainerx.backward directly, because it does not follow
77        # retain_grad=False
78        # TODO(kataoka): Fix chainerx.backward and remove this workaround
79        if grad_outputs is None:
80            grad_outputs = []
81            for y in outputs:
82                grad_outputs.append(y.grad_var)
83                y.grad_var = None
84
85        # The check is required because chainerx.backward sets default grads.
86        # TODO(kataoka): Fix chainerx.backward and remove this workaround
87        indices = [i for i, gy in enumerate(grad_outputs) if gy is not None]
88        outputs = [outputs[i] for i in indices]
89        grad_outputs = [grad_outputs[i] for i in indices]
90
91        # Use new variables to start backprop
92        # TODO(kataoka): Implement chainerx.backward(output, grad_outputs)
93        # and remove this workaround.
94        outputs = chainer.functions.identity(*outputs)
95        if not isinstance(outputs, tuple):
96            outputs = outputs,
97        grad_outputs = chainer.functions.identity(*grad_outputs)
98        if not isinstance(grad_outputs, tuple):
99            grad_outputs = grad_outputs,
100
101        # TODO(kataoka): Even after F.identity, non-float grad cannot be set.
102        # Move the check to elsewhere and remove this workaround.
103        outputs_ = []
104        for y, gy in zip(outputs, grad_outputs):
105            if not y.requires_grad and gy is not None:
106                warnings.warn(
107                    'Some of grads are ignored by chainer.backward.\n'
108                    'backend: ChainerX, '
109                    'output.dtype: {}, grad_output.dtype: {}'.format(
110                        y.dtype, gy.dtype),
111                    RuntimeWarning)
112                continue
113            y.grad_var = gy
114            outputs_.append(y)
115        outputs = outputs_
116        del outputs_
117
118        # See also the ChainerX case of Variable.backward
119        arrs = []
120        for y in outputs:
121            arr = y._data[0]
122            assert isinstance(arr, chainerx.ndarray)
123            arrs.append(arr)
124        chainerx.backward(
125            arrs, enable_double_backprop=enable_double_backprop)
126        return
127
128    if grad_outputs is None:
129        grad_outputs = []
130        for y in outputs:
131            grad_var = y.grad_var
132            if grad_var is None:
133                warnings.warn(
134                    'outputs contains a Variable without grad, or '
135                    'duplicate outputs. Note that '
136                    'chainer.backward does not set default grad.',
137                    RuntimeWarning)
138            y.grad_var = None
139            grad_outputs.append(grad_var)
140    outputs = [
141        (y.node, gy) for y, gy in zip(outputs, grad_outputs) if gy is not None]
142    with chainer.using_config('enable_backprop', enable_double_backprop):
143        _backprop_to_all(outputs, False, None)
144
145
146def _backprop_to_all(outputs, retain_grad, loss_scale):
147    """Backprop to all input variables
148
149    Args:
150        outputs (list of tuple): each tuple is (y_node, y_grad_var).
151            y_grad_var should not be None.
152        retain_grad (bool): see docstring of Variable.backward
153        loss_scale (float): see docstring of Variable.backward
154
155    """
156    OrderedDict = chainer.utils._collections.OrderedDict  # fix py2 memory leak
157
158    cand_funcs = []
159    seen_set = set()
160
161    def add_cand(cand):
162        if cand not in seen_set:
163            # Negate since heapq is min-heap
164            heapq.heappush(cand_funcs, (-cand.rank, len(seen_set), cand))
165            seen_set.add(cand)
166
167    grads = _backprop_utils.GradTable(accumulate_grad_inputs=True)
168
169    leaf_nodes = set()
170
171    for y, gy in outputs:
172        grads.accumulate(y, gy)
173
174        func = y.creator_node
175        if func is None:  # leaf
176            leaf_nodes.add(y)
177        else:
178            add_cand(func)
179
180    # Fix F812 (Python 2)
181    y = None
182    del y
183
184    is_debug = chainer.is_debug()
185    base_hooks = chainer.get_function_hooks().values()
186    while cand_funcs:
187        _, _, func = heapq.heappop(cand_funcs)
188        inputs = func.inputs
189        target_input_indexes = tuple([
190            i for i, x in enumerate(inputs) if x.requires_grad
191        ])
192        outputs = [y() for y in func.outputs]  # access via weak ref
193        out_grad = tuple([grads.pop(y)
194                          if y is not None and y.creator_node is not None
195                          else None
196                          for y in outputs])
197        if not target_input_indexes:
198            continue
199
200        in_data = [x.data for x in inputs]
201        out_grad_array = [None if g is None else g.raw_array for g in out_grad]
202        if func._n_local_function_hooks != 0:
203            local_hooks = collections.OrderedDict(chainer.get_function_hooks())
204            local_hooks.update(func.local_function_hooks)
205            hooks = local_hooks.values()  # avoid six for performance
206        else:
207            hooks = base_hooks
208
209        with chainer.using_device(
210                backend.get_device_from_array(*(in_data + out_grad_array))):
211            for hook in hooks:
212                hook.backward_preprocess(
213                    func, tuple(in_data), tuple(out_grad_array))
214
215            # Collect the current input gradients.
216            target_inputs = [inputs[i] for i in target_input_indexes]
217            # Keep the order for the portability, rather than
218            # in_grad = {x: grads.get_as_list(x)
219            #            for x in set(target_inputs)}
220            in_grad = OrderedDict()
221            for x in target_inputs:
222                if x not in in_grad:
223                    in_grad[x] = grads.get_as_list(x)
224
225            _backprop_utils.backprop_step(
226                func, target_input_indexes, out_grad, in_grad, is_debug)
227
228            for hook in hooks:
229                hook.backward_postprocess(
230                    func, tuple(in_data), tuple(out_grad_array))
231
232        if retain_grad:
233            # The gradients of the outputs of `func` are final. Store them if
234            # retain_grad=True.
235            for y, gy in six.moves.zip(outputs, out_grad):
236                if y is not None:
237                    y._set_grad_var_if_available(gy)
238            del gy  # to reduce memory usage
239        del out_grad  # to reduce memory usage
240
241        for x, gx in in_grad.items():
242            if not gx:  # gradient == None
243                continue
244
245            for gx_elem in gx:
246                if gx_elem is not None:
247                    chainer.variable._check_grad_type(
248                        func, x, True, gx_elem.raw_array)
249            del gx_elem  # to reduce memory usage
250
251            if x.creator_node is None:  # leaf
252                leaf_nodes.add(x)
253            else:
254                add_cand(x.creator_node)
255        del gx, in_grad  # to reduce memory usage
256
257    for x in leaf_nodes:
258        x_var = x.get_variable_or_none()
259        gx = grads.pop(x)
260        if x_var is not None:
261            x_var._set_grad_var_without_check(gx)
262            x_var._loss_scale = loss_scale
263    grads.assert_no_grads()
264