1import six 2 3import chainer 4from chainer.backends import cuda 5from chainer import function_node 6from chainer.functions.math import sum as _sum 7from chainer.utils import type_check 8 9 10class BatchL2NormSquared(function_node.FunctionNode): 11 12 def check_type_forward(self, in_types): 13 type_check._argname(in_types, ('x',)) 14 x_type, = in_types 15 16 type_check.expect( 17 x_type.dtype.kind == 'f', 18 x_type.ndim >= 2, 19 ) 20 21 def forward_cpu(self, inputs): 22 self.retain_inputs((0,)) 23 x = inputs[0].reshape(len(inputs[0]), -1) 24 return (x * x).sum(axis=1), 25 26 def forward_gpu(self, inputs): 27 self.retain_inputs((0,)) 28 x = inputs[0].reshape(len(inputs[0]), -1) 29 l2normsquared_kernel = cuda.reduce( 30 'T x', 'T y', 'x * x', 'a + b', 'y = a', '0', 'l2normsquared' 31 ) 32 return l2normsquared_kernel(x, axis=1), 33 34 def backward(self, indexes, gy): 35 x = self.get_retained_inputs() 36 return BatchL2NormSquaredGrad().apply((x[0], gy[0])) 37 38 39class BatchL2NormSquaredGrad(function_node.FunctionNode): 40 41 def forward_cpu(self, inputs): 42 self.retain_inputs((0, 1)) 43 x, gy0 = inputs 44 gy0 = gy0.reshape(-1, *((1,) * (x.ndim - 1))) 45 gx = 2 * x * gy0 46 return gx, 47 48 def forward_gpu(self, inputs): 49 self.retain_inputs((0, 1)) 50 x, gy0 = inputs 51 gy0 = gy0.reshape(-1, *((1,) * (x.ndim - 1))) 52 kernel = cuda.elementwise( 53 'T x, T gy', 'T gx', 'gx = 2 * x * gy', 54 'l2normsquared_bwd') 55 gx = kernel(x, gy0) 56 return gx, 57 58 def backward(self, indexes, grad_outputs): 59 x, gy0 = self.get_retained_inputs() 60 gy0 = gy0.reshape(-1, *((1,) * (x.ndim - 1))) 61 gy0 = chainer.functions.broadcast_to(gy0, x.shape) 62 ggx2 = 2 * grad_outputs[0] 63 gx = ggx2 * gy0 64 ggy0 = ggx2 * x 65 return gx, _sum.sum(ggy0, axis=tuple(six.moves.range(1, ggy0.ndim))) 66 67 68def batch_l2_norm_squared(x): 69 """L2 norm (a.k.a.\\ Euclidean norm) squared. 70 71 This function implements the square of L2 norm on a vector. No reduction 72 along batch axis is done. 73 74 Args: 75 x (:class:`~chainer.Variable` or :ref:`ndarray`): Input variable. 76 The first dimension is assumed to be the *minibatch dimension*. 77 If ``x`` has more than two dimensions all but the first dimension 78 are flattened to one dimension. 79 80 Returns: 81 ~chainer.Variable: Two dimensional output variable. 82 83 """ 84 return BatchL2NormSquared().apply((x,))[0] 85