1import numpy
2import six
3
4import chainer
5from chainer import backend
6from chainer import function_node
7from chainer import utils
8from chainer.utils import type_check
9import chainerx
10
11
12class Sum(function_node.FunctionNode):
13    """Sum of array elements over a given axis."""
14
15    keepdims = False
16
17    def __init__(self, axis=None, keepdims=False):
18        if axis is None:
19            self.axis = None
20        elif isinstance(axis, six.integer_types):
21            self.axis = (axis,)
22        elif isinstance(axis, tuple) and all(
23                isinstance(a, six.integer_types) for a in axis):
24            if len(set(axis)) != len(axis):
25                raise ValueError('duplicate value in axis: ({})'.format(
26                    ', '.join(map(str, axis))))
27            self.axis = axis
28        else:
29            raise TypeError('None, int or tuple of int are required')
30
31        self.keepdims = keepdims
32
33    def check_type_forward(self, in_types):
34        type_check._argname(in_types, ('x',))
35        type_check.expect(in_types[0].dtype.kind == 'f')
36
37        if self.axis is not None:
38            for axis in self.axis:
39                if axis >= 0:
40                    type_check.expect(
41                        axis < in_types[0].ndim,
42                    )
43                else:
44                    type_check.expect(
45                        -axis - 1 < in_types[0].ndim,
46                    )
47
48    def forward_chainerx(self, inputs):
49        x, = inputs
50        return chainerx.sum(x, axis=self.axis, keepdims=self.keepdims),
51
52    def forward(self, inputs):
53        x, = inputs
54        ret = x.sum(axis=self.axis, keepdims=self.keepdims)
55        if backend.get_array_module(x) is numpy:
56            ret = numpy.asarray(ret)
57        return ret,
58
59    def backward(self, indexes, grad_outputs):
60        gy, = grad_outputs
61        ndim = len(self.inputs[0].shape)
62        if not (ndim == 0 or self.axis is None or self.keepdims):
63            actual_axis = [
64                axis if axis >= 0 else axis + ndim
65                for axis in self.axis]
66            shape = list(gy.shape)
67            for axis in sorted(actual_axis):
68                shape.insert(axis, 1)
69            gy = chainer.functions.reshape(gy, shape)
70        return chainer.functions.broadcast_to(gy, self.inputs[0].shape),
71
72
73def sum(x, axis=None, keepdims=False):
74    """Sum of array elements over a given axis.
75
76    Args:
77        x (:class:`~chainer.Variable` or :ref:`ndarray`): Elements to sum.
78            A :math:`(s_1, s_2, ..., s_N)` -shaped float array.
79        axis (None, int, or tuple of int): Axis along which a sum is performed.
80            The default (axis = None) is perform a sum over all the dimensions
81            of the input array.
82        keepdims (bool): If ``True``, the specified axes are remained as axes
83            of length one.
84
85    Returns:
86        ~chainer.Variable: Output variable.
87
88    .. admonition:: Example
89
90        >>> x = np.arange(6).reshape(2,3).astype(np.float32)
91        >>> x
92        array([[0., 1., 2.],
93               [3., 4., 5.]], dtype=float32)
94        >>> y = F.sum(x)
95        >>> y.shape
96        ()
97        >>> y.array
98        array(15., dtype=float32)
99        >>> y = F.sum(x, axis=1)
100        >>> y.shape
101        (2,)
102        >>> y.array
103        array([ 3., 12.], dtype=float32)
104        >>> y = F.sum(x, keepdims=True)
105        >>> y.shape
106        (1, 1)
107        >>> y.array
108        array([[15.]], dtype=float32)
109
110    """
111    y, = Sum(axis, keepdims).apply((x,))
112    return y
113
114
115class SumTo(function_node.FunctionNode):
116
117    """Sum axes to output an array of a given shape."""
118
119    def __init__(self, shape):
120        self._shape = shape
121
122    def forward(self, inputs):
123        x, = inputs
124        return utils.sum_to(x, self._shape),
125
126    def backward(self, indexes, grad_outputs):
127        gy, = grad_outputs
128        x_node, = self.inputs
129        return chainer.functions.broadcast_to(gy, x_node.shape),
130
131
132def sum_to(x, shape):
133    """Sum elements along axes to output an array of a given shape.
134
135    Args:
136        x (:class:`~chainer.Variable` or :ref:`ndarray`): Input variable.
137        shape (tuple of int): The target shape.
138
139    Returns:
140        ~chainer.Variable: Output variable of shape ``shape``.
141
142    .. admonition:: Example
143
144        >>> x = np.array([[1., 2., 3.], [4., 5., 6.]])
145        >>> x
146        array([[1., 2., 3.],
147               [4., 5., 6.]])
148        >>> y = F.sum_to(x, (1, 3))
149        >>> y
150        variable([[5., 7., 9.]])
151        >>> z = F.sum_to(x, (2, 1))
152        >>> z
153        variable([[ 6.],
154                  [15.]])
155
156    """
157    if x.shape == shape:
158        return chainer.as_variable(x)
159    y, = SumTo(shape).apply((x,))
160    return y
161