1import six
2
3import chainer
4from chainer.backends import cuda
5from chainer import function_node
6from chainer.functions.math import sum as _sum
7from chainer.utils import type_check
8
9
10class BatchL2NormSquared(function_node.FunctionNode):
11
12    def check_type_forward(self, in_types):
13        type_check._argname(in_types, ('x',))
14        x_type, = in_types
15
16        type_check.expect(
17            x_type.dtype.kind == 'f',
18            x_type.ndim >= 2,
19        )
20
21    def forward_cpu(self, inputs):
22        self.retain_inputs((0,))
23        x = inputs[0].reshape(len(inputs[0]), -1)
24        return (x * x).sum(axis=1),
25
26    def forward_gpu(self, inputs):
27        self.retain_inputs((0,))
28        x = inputs[0].reshape(len(inputs[0]), -1)
29        l2normsquared_kernel = cuda.reduce(
30            'T x', 'T y', 'x * x', 'a + b', 'y = a', '0', 'l2normsquared'
31        )
32        return l2normsquared_kernel(x, axis=1),
33
34    def backward(self, indexes, gy):
35        x = self.get_retained_inputs()
36        return BatchL2NormSquaredGrad().apply((x[0], gy[0]))
37
38
39class BatchL2NormSquaredGrad(function_node.FunctionNode):
40
41    def forward_cpu(self, inputs):
42        self.retain_inputs((0, 1))
43        x, gy0 = inputs
44        gy0 = gy0.reshape(-1, *((1,) * (x.ndim - 1)))
45        gx = 2 * x * gy0
46        return gx,
47
48    def forward_gpu(self, inputs):
49        self.retain_inputs((0, 1))
50        x, gy0 = inputs
51        gy0 = gy0.reshape(-1, *((1,) * (x.ndim - 1)))
52        kernel = cuda.elementwise(
53            'T x, T gy', 'T gx', 'gx = 2 * x * gy',
54            'l2normsquared_bwd')
55        gx = kernel(x, gy0)
56        return gx,
57
58    def backward(self, indexes, grad_outputs):
59        x, gy0 = self.get_retained_inputs()
60        gy0 = gy0.reshape(-1, *((1,) * (x.ndim - 1)))
61        gy0 = chainer.functions.broadcast_to(gy0, x.shape)
62        ggx2 = 2 * grad_outputs[0]
63        gx = ggx2 * gy0
64        ggy0 = ggx2 * x
65        return gx, _sum.sum(ggy0, axis=tuple(six.moves.range(1, ggy0.ndim)))
66
67
68def batch_l2_norm_squared(x):
69    """L2 norm (a.k.a.\\  Euclidean norm) squared.
70
71    This function implements the square of L2 norm on a vector. No reduction
72    along batch axis is done.
73
74    Args:
75        x (:class:`~chainer.Variable` or :ref:`ndarray`): Input variable.
76            The first dimension is assumed to be the *minibatch dimension*.
77            If ``x`` has more than two dimensions all but the first dimension
78            are flattened to one dimension.
79
80    Returns:
81        ~chainer.Variable: Two dimensional output variable.
82
83    """
84    return BatchL2NormSquared().apply((x,))[0]
85