1import chainer
2from chainer import backend
3from chainer import function_node
4from chainer import utils
5from chainer.utils import type_check
6
7
8class CrossCovariance(function_node.FunctionNode):
9
10    """Cross-covariance loss."""
11
12    def __init__(self, reduce='half_squared_sum'):
13        self.y_centered = None
14        self.z_centered = None
15        self.covariance = None
16
17        if reduce not in ('half_squared_sum', 'no'):
18            raise ValueError(
19                'Only \'half_squared_sum\' and \'no\' are valid '
20                'for \'reduce\', but \'%s\' is given' % reduce)
21        self.reduce = reduce
22
23    def check_type_forward(self, in_types):
24        type_check._argname(in_types, ('y', 'z'))
25        y_type, z_type = in_types
26
27        type_check.expect(
28            y_type.dtype.kind == 'f',
29            y_type.dtype == z_type.dtype,
30            y_type.ndim == 2,
31            z_type.ndim == 2,
32            y_type.shape[0] == z_type.shape[0]
33        )
34
35    def forward(self, inputs):
36        y, z = inputs
37        self.retain_inputs((0, 1))
38
39        y_centered = y - y.mean(axis=0, keepdims=True)
40        z_centered = z - z.mean(axis=0, keepdims=True)
41        covariance = y_centered.T.dot(z_centered)
42        covariance /= len(y)
43
44        if self.reduce == 'half_squared_sum':
45            xp = backend.get_array_module(*inputs)
46            cost = xp.vdot(covariance, covariance)
47            cost *= y.dtype.type(0.5)
48            return utils.force_array(cost),
49        else:
50            return covariance,
51
52    def backward(self, indexes, grad_outputs):
53        y, z = self.get_retained_inputs()
54        gcost, = grad_outputs
55
56        y_mean = chainer.functions.mean(y, axis=0, keepdims=True)
57        z_mean = chainer.functions.mean(z, axis=0, keepdims=True)
58        y_centered = y - chainer.functions.broadcast_to(y_mean, y.shape)
59        z_centered = z - chainer.functions.broadcast_to(z_mean, z.shape)
60        gcost_div_n = gcost / gcost.dtype.type(len(y))
61
62        ret = []
63        if self.reduce == 'half_squared_sum':
64            covariance = chainer.functions.matmul(y_centered.T, z_centered)
65            covariance /= len(y)
66            if 0 in indexes:
67                gy = chainer.functions.matmul(z_centered, covariance.T)
68                gy *= chainer.functions.broadcast_to(gcost_div_n, gy.shape)
69                ret.append(gy)
70            if 1 in indexes:
71                gz = chainer.functions.matmul(y_centered, covariance)
72                gz *= chainer.functions.broadcast_to(gcost_div_n, gz.shape)
73                ret.append(gz)
74        else:
75            if 0 in indexes:
76                gy = chainer.functions.matmul(z_centered, gcost_div_n.T)
77                ret.append(gy)
78            if 1 in indexes:
79                gz = chainer.functions.matmul(y_centered, gcost_div_n)
80                ret.append(gz)
81        return ret
82
83
84def cross_covariance(y, z, reduce='half_squared_sum'):
85    """Computes the sum-squared cross-covariance penalty between ``y`` and ``z``
86
87    The output is a variable whose value depends on the value of
88    the option ``reduce``. If it is ``'no'``, it holds the covariant
89    matrix that has as many rows (resp. columns) as the dimension of
90    ``y`` (resp.z).
91    If it is ``'half_squared_sum'``, it holds the half of the
92    Frobenius norm (i.e. L2 norm of a matrix flattened to a vector)
93    of the covarianct matrix.
94
95    Args:
96        y (:class:`~chainer.Variable` or :ref:`ndarray`):
97            Variable holding a matrix where the first dimension
98            corresponds to the batches.
99        z (:class:`~chainer.Variable` or :ref:`ndarray`):
100            Variable holding a matrix where the first dimension
101            corresponds to the batches.
102        reduce (str): Reduction option. Its value must be either
103            ``'half_squared_sum'`` or ``'no'``.
104            Otherwise, :class:`ValueError` is raised.
105
106    Returns:
107        ~chainer.Variable:
108            A variable holding the cross covariance loss.
109            If ``reduce`` is ``'no'``, the output variable holds
110            2-dimensional array matrix of shape ``(M, N)`` where
111            ``M`` (resp. ``N``) is the number of columns of ``y``
112            (resp. ``z``).
113            If it is ``'half_squared_sum'``, the output variable
114            holds a scalar value.
115
116    .. note::
117
118       This cost can be used to disentangle variables.
119       See https://arxiv.org/abs/1412.6583v3 for details.
120
121    """
122    return CrossCovariance(reduce).apply((y, z))[0]
123