1from __future__ import absolute_import 2import collections 3import heapq 4import warnings 5 6import six 7 8import chainer 9from chainer import _backprop_utils 10from chainer import backend 11from chainer.utils import argument 12import chainerx 13 14 15def backward(outputs, grad_outputs=None, **kwargs): 16 """backward(outputs, grad_outputs=None, *, enable_double_backprop=False) 17 18 Runs backpropagation from variables simultaneously. 19 20 .. warning:: 21 22 This feature is experimental. The interface can change in the future. 23 24 Args: 25 outputs (tuple or list of :class:`~chainer.Variable`): 26 A sequence of output variables from which backprop starts. 27 grad_outputs (None or tuple or list of :class:`~chainer.Variable`): 28 A sequence of variables that gives the initial value of each output 29 gradient. 30 If this argument is ``None``, backprop uses 31 :attr:`~chainer.Variable.grad_var` of ``outputs``. 32 enable_double_backprop (bool): If ``True``, 33 computational trace of the whole backpropagation procedure is 34 recorded to the computational graph so that one can further do 35 backpropagation from the resulting gradients. Note that 36 enabling it results in larger memory consumption needed to 37 store the gradients w.r.t intermediate variables that are 38 required for the second gradient computation. 39 40 .. seealso:: 41 :meth:`chainer.Variable.backward` 42 :func:`chainer.grad` 43 44 """ 45 enable_double_backprop, = argument.parse_kwargs( 46 kwargs, ('enable_double_backprop', False), 47 retain_grad='semantics for retain_grad=True is under discussion', 48 loss_scale='chainer.backward does not support loss_scale option', 49 ) 50 if not isinstance(outputs, (tuple, list)): 51 raise TypeError( 52 'outputs must be a tuple or a list, not {}.'.format(type(outputs))) 53 for v in outputs: 54 if not isinstance(v, chainer.Variable): 55 raise TypeError( 56 'each output must be a Variable, not {}'.format(type(v))) 57 if grad_outputs is not None: 58 if not isinstance(grad_outputs, (tuple, list)): 59 raise TypeError( 60 'grad_outputs must be None, a tuple, or a list, not {}.' 61 .format(type(grad_outputs))) 62 if len(outputs) != len(grad_outputs): 63 raise ValueError( 64 'grad_outputs must be of the same length as outputs.\n' 65 'len(outputs) = {}, len(grad_outputs) = {}' 66 .format(len(outputs), len(grad_outputs))) 67 68 is_chainerx = [v._has_chainerx_array for v in outputs] 69 70 if any(is_chainerx): 71 if not all(is_chainerx): 72 # The restriction is required as soon as the workarounds below 73 # are removed. 74 raise ValueError('cannot mix chainerx and other backends') 75 76 # Cannot use chainerx.backward directly, because it does not follow 77 # retain_grad=False 78 # TODO(kataoka): Fix chainerx.backward and remove this workaround 79 if grad_outputs is None: 80 grad_outputs = [] 81 for y in outputs: 82 grad_outputs.append(y.grad_var) 83 y.grad_var = None 84 85 # The check is required because chainerx.backward sets default grads. 86 # TODO(kataoka): Fix chainerx.backward and remove this workaround 87 indices = [i for i, gy in enumerate(grad_outputs) if gy is not None] 88 outputs = [outputs[i] for i in indices] 89 grad_outputs = [grad_outputs[i] for i in indices] 90 91 # Use new variables to start backprop 92 # TODO(kataoka): Implement chainerx.backward(output, grad_outputs) 93 # and remove this workaround. 94 outputs = chainer.functions.identity(*outputs) 95 if not isinstance(outputs, tuple): 96 outputs = outputs, 97 grad_outputs = chainer.functions.identity(*grad_outputs) 98 if not isinstance(grad_outputs, tuple): 99 grad_outputs = grad_outputs, 100 101 # TODO(kataoka): Even after F.identity, non-float grad cannot be set. 102 # Move the check to elsewhere and remove this workaround. 103 outputs_ = [] 104 for y, gy in zip(outputs, grad_outputs): 105 if not y.requires_grad and gy is not None: 106 warnings.warn( 107 'Some of grads are ignored by chainer.backward.\n' 108 'backend: ChainerX, ' 109 'output.dtype: {}, grad_output.dtype: {}'.format( 110 y.dtype, gy.dtype), 111 RuntimeWarning) 112 continue 113 y.grad_var = gy 114 outputs_.append(y) 115 outputs = outputs_ 116 del outputs_ 117 118 # See also the ChainerX case of Variable.backward 119 arrs = [] 120 for y in outputs: 121 arr = y._data[0] 122 assert isinstance(arr, chainerx.ndarray) 123 arrs.append(arr) 124 chainerx.backward( 125 arrs, enable_double_backprop=enable_double_backprop) 126 return 127 128 if grad_outputs is None: 129 grad_outputs = [] 130 for y in outputs: 131 grad_var = y.grad_var 132 if grad_var is None: 133 warnings.warn( 134 'outputs contains a Variable without grad, or ' 135 'duplicate outputs. Note that ' 136 'chainer.backward does not set default grad.', 137 RuntimeWarning) 138 y.grad_var = None 139 grad_outputs.append(grad_var) 140 outputs = [ 141 (y.node, gy) for y, gy in zip(outputs, grad_outputs) if gy is not None] 142 with chainer.using_config('enable_backprop', enable_double_backprop): 143 _backprop_to_all(outputs, False, None) 144 145 146def _backprop_to_all(outputs, retain_grad, loss_scale): 147 """Backprop to all input variables 148 149 Args: 150 outputs (list of tuple): each tuple is (y_node, y_grad_var). 151 y_grad_var should not be None. 152 retain_grad (bool): see docstring of Variable.backward 153 loss_scale (float): see docstring of Variable.backward 154 155 """ 156 OrderedDict = chainer.utils._collections.OrderedDict # fix py2 memory leak 157 158 cand_funcs = [] 159 seen_set = set() 160 161 def add_cand(cand): 162 if cand not in seen_set: 163 # Negate since heapq is min-heap 164 heapq.heappush(cand_funcs, (-cand.rank, len(seen_set), cand)) 165 seen_set.add(cand) 166 167 grads = _backprop_utils.GradTable(accumulate_grad_inputs=True) 168 169 leaf_nodes = set() 170 171 for y, gy in outputs: 172 grads.accumulate(y, gy) 173 174 func = y.creator_node 175 if func is None: # leaf 176 leaf_nodes.add(y) 177 else: 178 add_cand(func) 179 180 # Fix F812 (Python 2) 181 y = None 182 del y 183 184 is_debug = chainer.is_debug() 185 base_hooks = chainer.get_function_hooks().values() 186 while cand_funcs: 187 _, _, func = heapq.heappop(cand_funcs) 188 inputs = func.inputs 189 target_input_indexes = tuple([ 190 i for i, x in enumerate(inputs) if x.requires_grad 191 ]) 192 outputs = [y() for y in func.outputs] # access via weak ref 193 out_grad = tuple([grads.pop(y) 194 if y is not None and y.creator_node is not None 195 else None 196 for y in outputs]) 197 if not target_input_indexes: 198 continue 199 200 in_data = [x.data for x in inputs] 201 out_grad_array = [None if g is None else g.raw_array for g in out_grad] 202 if func._n_local_function_hooks != 0: 203 local_hooks = collections.OrderedDict(chainer.get_function_hooks()) 204 local_hooks.update(func.local_function_hooks) 205 hooks = local_hooks.values() # avoid six for performance 206 else: 207 hooks = base_hooks 208 209 with chainer.using_device( 210 backend.get_device_from_array(*(in_data + out_grad_array))): 211 for hook in hooks: 212 hook.backward_preprocess( 213 func, tuple(in_data), tuple(out_grad_array)) 214 215 # Collect the current input gradients. 216 target_inputs = [inputs[i] for i in target_input_indexes] 217 # Keep the order for the portability, rather than 218 # in_grad = {x: grads.get_as_list(x) 219 # for x in set(target_inputs)} 220 in_grad = OrderedDict() 221 for x in target_inputs: 222 if x not in in_grad: 223 in_grad[x] = grads.get_as_list(x) 224 225 _backprop_utils.backprop_step( 226 func, target_input_indexes, out_grad, in_grad, is_debug) 227 228 for hook in hooks: 229 hook.backward_postprocess( 230 func, tuple(in_data), tuple(out_grad_array)) 231 232 if retain_grad: 233 # The gradients of the outputs of `func` are final. Store them if 234 # retain_grad=True. 235 for y, gy in six.moves.zip(outputs, out_grad): 236 if y is not None: 237 y._set_grad_var_if_available(gy) 238 del gy # to reduce memory usage 239 del out_grad # to reduce memory usage 240 241 for x, gx in in_grad.items(): 242 if not gx: # gradient == None 243 continue 244 245 for gx_elem in gx: 246 if gx_elem is not None: 247 chainer.variable._check_grad_type( 248 func, x, True, gx_elem.raw_array) 249 del gx_elem # to reduce memory usage 250 251 if x.creator_node is None: # leaf 252 leaf_nodes.add(x) 253 else: 254 add_cand(x.creator_node) 255 del gx, in_grad # to reduce memory usage 256 257 for x in leaf_nodes: 258 x_var = x.get_variable_or_none() 259 gx = grads.pop(x) 260 if x_var is not None: 261 x_var._set_grad_var_without_check(gx) 262 x_var._loss_scale = loss_scale 263 grads.assert_no_grads() 264