1import six
2
3from chainer import backend
4from chainer import function_node
5from chainer.functions.array import flip
6from chainer.utils import type_check
7
8
9class Cumsum(function_node.FunctionNode):
10    """Cumulative sum of array elements over a given axis."""
11
12    def __init__(self, axis=None):
13        if isinstance(axis, six.integer_types) or axis is None:
14            self.axis = axis
15        else:
16            raise TypeError('axis must be int or None')
17
18    def check_type_forward(self, in_types):
19        type_check._argname(in_types, ('x',))
20        type_check.expect(in_types[0].dtype.kind == 'f')
21
22        if self.axis is not None:
23            if self.axis >= 0:
24                type_check.expect(self.axis < in_types[0].ndim)
25            else:
26                type_check.expect(-self.axis - 1 < in_types[0].ndim)
27
28    def forward(self, inputs):
29        x, = inputs
30        self._in_shape = x.shape
31        xp = backend.get_array_module(x)
32        return xp.cumsum(x, axis=self.axis),
33
34    def backward(self, indexes, grad_outputs):
35        gy = grad_outputs[0]
36        axis = self.axis
37
38        if axis is not None:
39            gx = flip.flip(cumsum(flip.flip(gy, axis), axis), axis)
40        else:
41            gx = flip.flip(cumsum(flip.flip(gy, 0), 0), 0)
42            gx = gx.reshape(self._in_shape)
43
44        return gx,
45
46
47def cumsum(x, axis=None):
48    """Cumulative sum of array elements over a given axis.
49
50    Args:
51        x (:class:`~chainer.Variable` or :ref:`ndarray`):
52            Elements to calculate the cumulative sum.
53        axis (int or None):
54            Axis along which the cumulative sum is taken.
55            If it is not specified, the input is flattened.
56
57    Returns:
58        ~chainer.Variable: Output variable.
59
60    """
61    return Cumsum(axis).apply((x,))[0]
62