1from chainer import backend 2from chainer import function_node 3import chainer.functions 4from chainer.utils import type_check 5 6 7class Contrastive(function_node.FunctionNode): 8 9 """Contrastive loss function.""" 10 11 def __init__(self, margin, reduce='mean'): 12 if margin <= 0: 13 raise ValueError('margin should be positive value.') 14 self.margin = margin 15 16 if reduce not in ('mean', 'no'): 17 raise ValueError( 18 'only \'mean\' and \'no\' are valid for \'reduce\', but ' 19 '\'%s\' is given' % reduce) 20 self.reduce = reduce 21 22 def check_type_forward(self, in_types): 23 type_check._argname(in_types, ('x0', 'x1', 'y')) 24 25 x0_type, x1_type, y_type = in_types 26 type_check.expect( 27 x0_type.dtype.kind == 'f', 28 x0_type.dtype == x1_type.dtype, 29 y_type.dtype.kind == 'i', 30 x0_type.shape == x1_type.shape, 31 x1_type.shape[0] == y_type.shape[0], 32 x1_type.shape[0] > 0, 33 x0_type.ndim == 2, 34 x1_type.ndim == 2, 35 y_type.ndim == 1 36 ) 37 38 def forward(self, inputs): 39 xp = backend.get_array_module(*inputs) 40 self.retain_inputs((0, 1, 2)) 41 x0, x1, y = inputs 42 43 diff = x0 - x1 44 dist_sq = xp.sum(diff ** 2, axis=1) 45 dist = xp.sqrt(dist_sq) 46 mdist = self.margin - dist 47 dist = xp.maximum(mdist, 0) 48 loss = (y * dist_sq + (1 - y) * dist * dist) * .5 49 if self.reduce == 'mean': 50 loss = xp.sum(loss) / x0.shape[0] 51 return xp.array(loss, dtype=x0.dtype), 52 53 def backward(self, indexes, grad_outputs): 54 x0, x1, y = self.get_retained_inputs() 55 gy, = grad_outputs 56 xp = backend.get_array_module(gy.data) 57 58 # Recompute intermediate variables as in forward. 59 diff = x0 - x1 60 dist_sq = chainer.functions.sum(diff ** 2, axis=1) 61 dist = chainer.functions.sqrt(dist_sq) 62 mdist = self.margin - dist 63 64 y = y.data 65 x_dim = x0.shape[1] 66 y = xp.repeat(y[:, None], x_dim, axis=1) 67 if self.reduce == 'mean': 68 alpha = gy / y.shape[0] 69 else: 70 alpha = gy[:, None] 71 alpha = chainer.functions.broadcast_to(alpha, y.shape) 72 dist = chainer.functions.repeat(dist[:, None], x_dim, axis=1) 73 # avoid division by zero, 1e-7 is not sufficiently small value because 74 # 1e7 cannot be represented in half precision. 75 eps = 5e-3 if dist.dtype == 'float16' else 1e-7 76 dist = chainer.functions.maximum( 77 dist, xp.full(dist.shape, eps, dtype=dist.dtype)) 78 # similar pair 79 gx0 = alpha * y.astype(alpha.dtype) * diff 80 # dissimilar pair 81 d = chainer.functions.repeat(mdist[:, None], x_dim, axis=1) 82 mdist = chainer.functions.maximum( 83 d, xp.zeros(shape=d.shape, dtype=d.dtype)) 84 gx0 += alpha * (1 - y) * mdist * -(diff / dist) 85 gx0 = chainer.functions.cast(gx0, x0.dtype) 86 87 return gx0, -gx0, None 88 89 90def contrastive(x0, x1, y, margin=1, reduce='mean'): 91 """Computes contrastive loss. 92 93 It takes a pair of samples and a label as inputs. 94 The label is :math:`1` when those samples are similar, 95 or :math:`0` when they are dissimilar. 96 97 Let :math:`N` and :math:`K` denote mini-batch size and the dimension 98 of input variables, respectively. The shape of both input variables 99 ``x0`` and ``x1`` should be ``(N, K)``. 100 The loss value of the :math:`n`-th sample pair :math:`L_n` is 101 102 .. math:: 103 L_n = \\frac{1}{2} \\left( y_n d_n^2 104 + (1 - y_n) \\max ({\\rm margin} - d_n, 0)^2 \\right) 105 106 where :math:`d_n = \\| {\\bf x_0}_n - {\\bf x_1}_n \\|_2`, 107 :math:`{\\bf x_0}_n` and :math:`{\\bf x_1}_n` are :math:`n`-th 108 K-dimensional vectors of ``x0`` and ``x1``. 109 110 The output is a variable whose value depends on the value of 111 the option ``reduce``. If it is ``'no'``, it holds the elementwise 112 loss values. If it is ``'mean'``, this function takes a mean of 113 loss values. 114 115 Args: 116 x0 (:class:`~chainer.Variable` or :ref:`ndarray`): The first input 117 variable. The shape should be (N, K), where N denotes the 118 mini-batch size, and K denotes the dimension of ``x0``. 119 x1 (:class:`~chainer.Variable` or :ref:`ndarray`): The second input 120 variable. The shape should be the same as ``x0``. 121 y (:class:`~chainer.Variable` or :ref:`ndarray`): Labels. All values 122 should be 0 or 1. The shape should be ``(N,)``, where N denotes the 123 mini-batch size. 124 margin (float): A parameter for contrastive loss. It should be positive 125 value. 126 reduce (str): Reduction option. Its value must be either 127 ``'mean'`` or ``'no'``. Otherwise, :class:`ValueError` is raised. 128 129 Returns: 130 ~chainer.Variable: 131 A variable holding the loss value(s) calculated by the 132 above equation. 133 If ``reduce`` is ``'no'``, the output variable holds array 134 whose shape is same as one of (hence both of) input variables. 135 If it is ``'mean'``, the output variable holds a scalar value. 136 137 .. note:: 138 This cost can be used to train siamese networks. See `Learning a 139 Similarity Metric Discriminatively, with Application to Face 140 Verification <http://yann.lecun.com/exdb/publis/pdf/chopra-05.pdf>`_ 141 for details. 142 143 .. admonition:: Example 144 145 >>> x0 = np.array([[-2.0, 3.0, 0.5], [5.0, 2.0, -0.5]]).\ 146astype(np.float32) 147 >>> x1 = np.array([[-1.0, 3.0, 1.0], [3.5, 0.5, -2.0]]).\ 148astype(np.float32) 149 >>> y = np.array([1, 0]).astype(np.int32) 150 >>> F.contrastive(x0, x1, y) 151 variable(0.3125) 152 >>> F.contrastive(x0, x1, y, margin=3.0) # harder penalty 153 variable(0.3528857) 154 >>> z = F.contrastive(x0, x1, y, reduce='no') 155 >>> z.shape 156 (2,) 157 >>> z.array 158 array([0.625, 0. ], dtype=float32) 159 160 """ 161 return Contrastive(margin, reduce).apply((x0, x1, y))[0] 162