1from chainer import backend 2from chainer import function 3from chainer import utils 4from chainer.utils import type_check 5 6 7class DeCov(function.Function): 8 9 """DeCov loss (https://arxiv.org/abs/1511.06068)""" 10 11 def __init__(self, reduce='half_squared_sum'): 12 self.h_centered = None 13 self.covariance = None 14 if reduce not in ('half_squared_sum', 'no'): 15 raise ValueError( 16 'only \'half_squared_sum\' and \'no\' are valid ' 17 'for \'reduce\', but \'%s\' is given' % reduce) 18 self.reduce = reduce 19 20 def check_type_forward(self, in_types): 21 type_check._argname(in_types, ('h',)) 22 h_type, = in_types 23 24 type_check.expect( 25 h_type.dtype.kind == 'f', 26 h_type.ndim == 2, 27 ) 28 29 def forward(self, inputs): 30 xp = backend.get_array_module(*inputs) 31 h, = inputs 32 33 self.h_centered = h - h.mean(axis=0, keepdims=True) 34 self.covariance = self.h_centered.T.dot(self.h_centered) 35 xp.fill_diagonal(self.covariance, 0.0) 36 self.covariance /= len(h) 37 if self.reduce == 'half_squared_sum': 38 cost = xp.vdot(self.covariance, self.covariance) 39 cost *= h.dtype.type(0.5) 40 return utils.force_array(cost), 41 else: 42 return self.covariance, 43 44 def backward(self, inputs, grad_outputs): 45 xp = backend.get_array_module(*inputs) 46 h, = inputs 47 gcost, = grad_outputs 48 gcost_div_n = gcost / gcost.dtype.type(len(h)) 49 if self.reduce == 'half_squared_sum': 50 gh = 2.0 * self.h_centered.dot(self.covariance) 51 gh *= gcost_div_n 52 else: 53 xp.fill_diagonal(gcost_div_n, 0.0) 54 gh = self.h_centered.dot(gcost_div_n + gcost_div_n.T) 55 return gh, 56 57 58def decov(h, reduce='half_squared_sum'): 59 """Computes the DeCov loss of ``h`` 60 61 The output is a variable whose value depends on the value of 62 the option ``reduce``. If it is ``'no'``, it holds a matrix 63 whose size is same as the number of columns of ``y``. 64 If it is ``'half_squared_sum'``, it holds the half of the 65 squared Frobenius norm (i.e. squared of the L2 norm of a matrix flattened 66 to a vector) of the matrix. 67 68 Args: 69 h (:class:`~chainer.Variable` or :ref:`ndarray`): 70 Variable holding a matrix where the first dimension 71 corresponds to the batches. 72 reduce (str): Reduction option. Its value must be either 73 ``'half_squared_sum'`` or ``'no'``. 74 Otherwise, :class:`ValueError` is raised. 75 76 Returns: 77 ~chainer.Variable: 78 A variable holding a scalar of the DeCov loss. 79 If ``reduce`` is ``'no'``, the output variable holds 80 2-dimensional array matrix of shape ``(N, N)`` where 81 ``N`` is the number of columns of ``y``. 82 If it is ``'half_squared_sum'``, the output variable 83 holds a scalar value. 84 85 .. note:: 86 87 See https://arxiv.org/abs/1511.06068 for details. 88 89 """ 90 return DeCov(reduce)(h) 91