1import chainer 2from chainer import backend 3from chainer import function_node 4from chainer import utils 5from chainer.utils import type_check 6 7 8class CrossCovariance(function_node.FunctionNode): 9 10 """Cross-covariance loss.""" 11 12 def __init__(self, reduce='half_squared_sum'): 13 self.y_centered = None 14 self.z_centered = None 15 self.covariance = None 16 17 if reduce not in ('half_squared_sum', 'no'): 18 raise ValueError( 19 'Only \'half_squared_sum\' and \'no\' are valid ' 20 'for \'reduce\', but \'%s\' is given' % reduce) 21 self.reduce = reduce 22 23 def check_type_forward(self, in_types): 24 type_check._argname(in_types, ('y', 'z')) 25 y_type, z_type = in_types 26 27 type_check.expect( 28 y_type.dtype.kind == 'f', 29 y_type.dtype == z_type.dtype, 30 y_type.ndim == 2, 31 z_type.ndim == 2, 32 y_type.shape[0] == z_type.shape[0] 33 ) 34 35 def forward(self, inputs): 36 y, z = inputs 37 self.retain_inputs((0, 1)) 38 39 y_centered = y - y.mean(axis=0, keepdims=True) 40 z_centered = z - z.mean(axis=0, keepdims=True) 41 covariance = y_centered.T.dot(z_centered) 42 covariance /= len(y) 43 44 if self.reduce == 'half_squared_sum': 45 xp = backend.get_array_module(*inputs) 46 cost = xp.vdot(covariance, covariance) 47 cost *= y.dtype.type(0.5) 48 return utils.force_array(cost), 49 else: 50 return covariance, 51 52 def backward(self, indexes, grad_outputs): 53 y, z = self.get_retained_inputs() 54 gcost, = grad_outputs 55 56 y_mean = chainer.functions.mean(y, axis=0, keepdims=True) 57 z_mean = chainer.functions.mean(z, axis=0, keepdims=True) 58 y_centered = y - chainer.functions.broadcast_to(y_mean, y.shape) 59 z_centered = z - chainer.functions.broadcast_to(z_mean, z.shape) 60 gcost_div_n = gcost / gcost.dtype.type(len(y)) 61 62 ret = [] 63 if self.reduce == 'half_squared_sum': 64 covariance = chainer.functions.matmul(y_centered.T, z_centered) 65 covariance /= len(y) 66 if 0 in indexes: 67 gy = chainer.functions.matmul(z_centered, covariance.T) 68 gy *= chainer.functions.broadcast_to(gcost_div_n, gy.shape) 69 ret.append(gy) 70 if 1 in indexes: 71 gz = chainer.functions.matmul(y_centered, covariance) 72 gz *= chainer.functions.broadcast_to(gcost_div_n, gz.shape) 73 ret.append(gz) 74 else: 75 if 0 in indexes: 76 gy = chainer.functions.matmul(z_centered, gcost_div_n.T) 77 ret.append(gy) 78 if 1 in indexes: 79 gz = chainer.functions.matmul(y_centered, gcost_div_n) 80 ret.append(gz) 81 return ret 82 83 84def cross_covariance(y, z, reduce='half_squared_sum'): 85 """Computes the sum-squared cross-covariance penalty between ``y`` and ``z`` 86 87 The output is a variable whose value depends on the value of 88 the option ``reduce``. If it is ``'no'``, it holds the covariant 89 matrix that has as many rows (resp. columns) as the dimension of 90 ``y`` (resp.z). 91 If it is ``'half_squared_sum'``, it holds the half of the 92 Frobenius norm (i.e. L2 norm of a matrix flattened to a vector) 93 of the covarianct matrix. 94 95 Args: 96 y (:class:`~chainer.Variable` or :ref:`ndarray`): 97 Variable holding a matrix where the first dimension 98 corresponds to the batches. 99 z (:class:`~chainer.Variable` or :ref:`ndarray`): 100 Variable holding a matrix where the first dimension 101 corresponds to the batches. 102 reduce (str): Reduction option. Its value must be either 103 ``'half_squared_sum'`` or ``'no'``. 104 Otherwise, :class:`ValueError` is raised. 105 106 Returns: 107 ~chainer.Variable: 108 A variable holding the cross covariance loss. 109 If ``reduce`` is ``'no'``, the output variable holds 110 2-dimensional array matrix of shape ``(M, N)`` where 111 ``M`` (resp. ``N``) is the number of columns of ``y`` 112 (resp. ``z``). 113 If it is ``'half_squared_sum'``, the output variable 114 holds a scalar value. 115 116 .. note:: 117 118 This cost can be used to disentangle variables. 119 See https://arxiv.org/abs/1412.6583v3 for details. 120 121 """ 122 return CrossCovariance(reduce).apply((y, z))[0] 123