1import numpy 2import six 3 4import chainer 5from chainer import backend 6from chainer import function_node 7from chainer import utils 8from chainer.utils import type_check 9import chainerx 10 11 12class Sum(function_node.FunctionNode): 13 """Sum of array elements over a given axis.""" 14 15 keepdims = False 16 17 def __init__(self, axis=None, keepdims=False): 18 if axis is None: 19 self.axis = None 20 elif isinstance(axis, six.integer_types): 21 self.axis = (axis,) 22 elif isinstance(axis, tuple) and all( 23 isinstance(a, six.integer_types) for a in axis): 24 if len(set(axis)) != len(axis): 25 raise ValueError('duplicate value in axis: ({})'.format( 26 ', '.join(map(str, axis)))) 27 self.axis = axis 28 else: 29 raise TypeError('None, int or tuple of int are required') 30 31 self.keepdims = keepdims 32 33 def check_type_forward(self, in_types): 34 type_check._argname(in_types, ('x',)) 35 type_check.expect(in_types[0].dtype.kind == 'f') 36 37 if self.axis is not None: 38 for axis in self.axis: 39 if axis >= 0: 40 type_check.expect( 41 axis < in_types[0].ndim, 42 ) 43 else: 44 type_check.expect( 45 -axis - 1 < in_types[0].ndim, 46 ) 47 48 def forward_chainerx(self, inputs): 49 x, = inputs 50 return chainerx.sum(x, axis=self.axis, keepdims=self.keepdims), 51 52 def forward(self, inputs): 53 x, = inputs 54 ret = x.sum(axis=self.axis, keepdims=self.keepdims) 55 if backend.get_array_module(x) is numpy: 56 ret = numpy.asarray(ret) 57 return ret, 58 59 def backward(self, indexes, grad_outputs): 60 gy, = grad_outputs 61 ndim = len(self.inputs[0].shape) 62 if not (ndim == 0 or self.axis is None or self.keepdims): 63 actual_axis = [ 64 axis if axis >= 0 else axis + ndim 65 for axis in self.axis] 66 shape = list(gy.shape) 67 for axis in sorted(actual_axis): 68 shape.insert(axis, 1) 69 gy = chainer.functions.reshape(gy, shape) 70 return chainer.functions.broadcast_to(gy, self.inputs[0].shape), 71 72 73def sum(x, axis=None, keepdims=False): 74 """Sum of array elements over a given axis. 75 76 Args: 77 x (:class:`~chainer.Variable` or :ref:`ndarray`): Elements to sum. 78 A :math:`(s_1, s_2, ..., s_N)` -shaped float array. 79 axis (None, int, or tuple of int): Axis along which a sum is performed. 80 The default (axis = None) is perform a sum over all the dimensions 81 of the input array. 82 keepdims (bool): If ``True``, the specified axes are remained as axes 83 of length one. 84 85 Returns: 86 ~chainer.Variable: Output variable. 87 88 .. admonition:: Example 89 90 >>> x = np.arange(6).reshape(2,3).astype(np.float32) 91 >>> x 92 array([[0., 1., 2.], 93 [3., 4., 5.]], dtype=float32) 94 >>> y = F.sum(x) 95 >>> y.shape 96 () 97 >>> y.array 98 array(15., dtype=float32) 99 >>> y = F.sum(x, axis=1) 100 >>> y.shape 101 (2,) 102 >>> y.array 103 array([ 3., 12.], dtype=float32) 104 >>> y = F.sum(x, keepdims=True) 105 >>> y.shape 106 (1, 1) 107 >>> y.array 108 array([[15.]], dtype=float32) 109 110 """ 111 y, = Sum(axis, keepdims).apply((x,)) 112 return y 113 114 115class SumTo(function_node.FunctionNode): 116 117 """Sum axes to output an array of a given shape.""" 118 119 def __init__(self, shape): 120 self._shape = shape 121 122 def forward(self, inputs): 123 x, = inputs 124 return utils.sum_to(x, self._shape), 125 126 def backward(self, indexes, grad_outputs): 127 gy, = grad_outputs 128 x_node, = self.inputs 129 return chainer.functions.broadcast_to(gy, x_node.shape), 130 131 132def sum_to(x, shape): 133 """Sum elements along axes to output an array of a given shape. 134 135 Args: 136 x (:class:`~chainer.Variable` or :ref:`ndarray`): Input variable. 137 shape (tuple of int): The target shape. 138 139 Returns: 140 ~chainer.Variable: Output variable of shape ``shape``. 141 142 .. admonition:: Example 143 144 >>> x = np.array([[1., 2., 3.], [4., 5., 6.]]) 145 >>> x 146 array([[1., 2., 3.], 147 [4., 5., 6.]]) 148 >>> y = F.sum_to(x, (1, 3)) 149 >>> y 150 variable([[5., 7., 9.]]) 151 >>> z = F.sum_to(x, (2, 1)) 152 >>> z 153 variable([[ 6.], 154 [15.]]) 155 156 """ 157 if x.shape == shape: 158 return chainer.as_variable(x) 159 y, = SumTo(shape).apply((x,)) 160 return y 161