1import six 2 3from chainer import backend 4from chainer import function_node 5from chainer.functions.array import flip 6from chainer.utils import type_check 7 8 9class Cumsum(function_node.FunctionNode): 10 """Cumulative sum of array elements over a given axis.""" 11 12 def __init__(self, axis=None): 13 if isinstance(axis, six.integer_types) or axis is None: 14 self.axis = axis 15 else: 16 raise TypeError('axis must be int or None') 17 18 def check_type_forward(self, in_types): 19 type_check._argname(in_types, ('x',)) 20 type_check.expect(in_types[0].dtype.kind == 'f') 21 22 if self.axis is not None: 23 if self.axis >= 0: 24 type_check.expect(self.axis < in_types[0].ndim) 25 else: 26 type_check.expect(-self.axis - 1 < in_types[0].ndim) 27 28 def forward(self, inputs): 29 x, = inputs 30 self._in_shape = x.shape 31 xp = backend.get_array_module(x) 32 return xp.cumsum(x, axis=self.axis), 33 34 def backward(self, indexes, grad_outputs): 35 gy = grad_outputs[0] 36 axis = self.axis 37 38 if axis is not None: 39 gx = flip.flip(cumsum(flip.flip(gy, axis), axis), axis) 40 else: 41 gx = flip.flip(cumsum(flip.flip(gy, 0), 0), 0) 42 gx = gx.reshape(self._in_shape) 43 44 return gx, 45 46 47def cumsum(x, axis=None): 48 """Cumulative sum of array elements over a given axis. 49 50 Args: 51 x (:class:`~chainer.Variable` or :ref:`ndarray`): 52 Elements to calculate the cumulative sum. 53 axis (int or None): 54 Axis along which the cumulative sum is taken. 55 If it is not specified, the input is flattened. 56 57 Returns: 58 ~chainer.Variable: Output variable. 59 60 """ 61 return Cumsum(axis).apply((x,))[0] 62