1from chainer import backend
2from chainer import function
3from chainer import utils
4from chainer.utils import type_check
5
6
7class DeCov(function.Function):
8
9    """DeCov loss (https://arxiv.org/abs/1511.06068)"""
10
11    def __init__(self, reduce='half_squared_sum'):
12        self.h_centered = None
13        self.covariance = None
14        if reduce not in ('half_squared_sum', 'no'):
15            raise ValueError(
16                'only \'half_squared_sum\' and \'no\' are valid '
17                'for \'reduce\', but \'%s\' is given' % reduce)
18        self.reduce = reduce
19
20    def check_type_forward(self, in_types):
21        type_check._argname(in_types, ('h',))
22        h_type, = in_types
23
24        type_check.expect(
25            h_type.dtype.kind == 'f',
26            h_type.ndim == 2,
27        )
28
29    def forward(self, inputs):
30        xp = backend.get_array_module(*inputs)
31        h, = inputs
32
33        self.h_centered = h - h.mean(axis=0, keepdims=True)
34        self.covariance = self.h_centered.T.dot(self.h_centered)
35        xp.fill_diagonal(self.covariance, 0.0)
36        self.covariance /= len(h)
37        if self.reduce == 'half_squared_sum':
38            cost = xp.vdot(self.covariance, self.covariance)
39            cost *= h.dtype.type(0.5)
40            return utils.force_array(cost),
41        else:
42            return self.covariance,
43
44    def backward(self, inputs, grad_outputs):
45        xp = backend.get_array_module(*inputs)
46        h, = inputs
47        gcost, = grad_outputs
48        gcost_div_n = gcost / gcost.dtype.type(len(h))
49        if self.reduce == 'half_squared_sum':
50            gh = 2.0 * self.h_centered.dot(self.covariance)
51            gh *= gcost_div_n
52        else:
53            xp.fill_diagonal(gcost_div_n, 0.0)
54            gh = self.h_centered.dot(gcost_div_n + gcost_div_n.T)
55        return gh,
56
57
58def decov(h, reduce='half_squared_sum'):
59    """Computes the DeCov loss of ``h``
60
61    The output is a variable whose value depends on the value of
62    the option ``reduce``. If it is ``'no'``, it holds a matrix
63    whose size is same as the number of columns of ``y``.
64    If it is ``'half_squared_sum'``, it holds the half of the
65    squared Frobenius norm (i.e. squared of the L2 norm of a matrix flattened
66    to a vector) of the matrix.
67
68    Args:
69        h (:class:`~chainer.Variable` or :ref:`ndarray`):
70            Variable holding a matrix where the first dimension
71            corresponds to the batches.
72        reduce (str): Reduction option. Its value must be either
73            ``'half_squared_sum'`` or ``'no'``.
74            Otherwise, :class:`ValueError` is raised.
75
76    Returns:
77        ~chainer.Variable:
78            A variable holding a scalar of the DeCov loss.
79            If ``reduce`` is ``'no'``, the output variable holds
80            2-dimensional array matrix of shape ``(N, N)`` where
81            ``N`` is the number of columns of ``y``.
82            If it is ``'half_squared_sum'``, the output variable
83            holds a scalar value.
84
85    .. note::
86
87       See https://arxiv.org/abs/1511.06068 for details.
88
89    """
90    return DeCov(reduce)(h)
91