1from chainer import backend
2from chainer import function_node
3import chainer.functions
4from chainer.utils import type_check
5
6
7class Contrastive(function_node.FunctionNode):
8
9    """Contrastive loss function."""
10
11    def __init__(self, margin, reduce='mean'):
12        if margin <= 0:
13            raise ValueError('margin should be positive value.')
14        self.margin = margin
15
16        if reduce not in ('mean', 'no'):
17            raise ValueError(
18                'only \'mean\' and \'no\' are valid for \'reduce\', but '
19                '\'%s\' is given' % reduce)
20        self.reduce = reduce
21
22    def check_type_forward(self, in_types):
23        type_check._argname(in_types, ('x0', 'x1', 'y'))
24
25        x0_type, x1_type, y_type = in_types
26        type_check.expect(
27            x0_type.dtype.kind == 'f',
28            x0_type.dtype == x1_type.dtype,
29            y_type.dtype.kind == 'i',
30            x0_type.shape == x1_type.shape,
31            x1_type.shape[0] == y_type.shape[0],
32            x1_type.shape[0] > 0,
33            x0_type.ndim == 2,
34            x1_type.ndim == 2,
35            y_type.ndim == 1
36        )
37
38    def forward(self, inputs):
39        xp = backend.get_array_module(*inputs)
40        self.retain_inputs((0, 1, 2))
41        x0, x1, y = inputs
42
43        diff = x0 - x1
44        dist_sq = xp.sum(diff ** 2, axis=1)
45        dist = xp.sqrt(dist_sq)
46        mdist = self.margin - dist
47        dist = xp.maximum(mdist, 0)
48        loss = (y * dist_sq + (1 - y) * dist * dist) * .5
49        if self.reduce == 'mean':
50            loss = xp.sum(loss) / x0.shape[0]
51        return xp.array(loss, dtype=x0.dtype),
52
53    def backward(self, indexes, grad_outputs):
54        x0, x1, y = self.get_retained_inputs()
55        gy, = grad_outputs
56        xp = backend.get_array_module(gy.data)
57
58        # Recompute intermediate variables as in forward.
59        diff = x0 - x1
60        dist_sq = chainer.functions.sum(diff ** 2, axis=1)
61        dist = chainer.functions.sqrt(dist_sq)
62        mdist = self.margin - dist
63
64        y = y.data
65        x_dim = x0.shape[1]
66        y = xp.repeat(y[:, None], x_dim, axis=1)
67        if self.reduce == 'mean':
68            alpha = gy / y.shape[0]
69        else:
70            alpha = gy[:, None]
71        alpha = chainer.functions.broadcast_to(alpha, y.shape)
72        dist = chainer.functions.repeat(dist[:, None], x_dim, axis=1)
73        # avoid division by zero, 1e-7 is not sufficiently small value because
74        # 1e7 cannot be represented in half precision.
75        eps = 5e-3 if dist.dtype == 'float16' else 1e-7
76        dist = chainer.functions.maximum(
77            dist, xp.full(dist.shape, eps, dtype=dist.dtype))
78        # similar pair
79        gx0 = alpha * y.astype(alpha.dtype) * diff
80        # dissimilar pair
81        d = chainer.functions.repeat(mdist[:, None], x_dim, axis=1)
82        mdist = chainer.functions.maximum(
83            d, xp.zeros(shape=d.shape, dtype=d.dtype))
84        gx0 += alpha * (1 - y) * mdist * -(diff / dist)
85        gx0 = chainer.functions.cast(gx0, x0.dtype)
86
87        return gx0, -gx0, None
88
89
90def contrastive(x0, x1, y, margin=1, reduce='mean'):
91    """Computes contrastive loss.
92
93    It takes a pair of samples and a label as inputs.
94    The label is :math:`1` when those samples are similar,
95    or :math:`0` when they are dissimilar.
96
97    Let :math:`N` and :math:`K` denote mini-batch size and the dimension
98    of input variables, respectively. The shape of both input variables
99    ``x0`` and ``x1`` should be ``(N, K)``.
100    The loss value of the :math:`n`-th sample pair :math:`L_n` is
101
102    .. math::
103        L_n = \\frac{1}{2} \\left( y_n d_n^2
104        + (1 - y_n) \\max ({\\rm margin} - d_n, 0)^2 \\right)
105
106    where :math:`d_n = \\| {\\bf x_0}_n - {\\bf x_1}_n \\|_2`,
107    :math:`{\\bf x_0}_n` and :math:`{\\bf x_1}_n` are :math:`n`-th
108    K-dimensional vectors of ``x0`` and ``x1``.
109
110    The output is a variable whose value depends on the value of
111    the option ``reduce``. If it is ``'no'``, it holds the elementwise
112    loss values. If it is ``'mean'``, this function takes a mean of
113    loss values.
114
115    Args:
116        x0 (:class:`~chainer.Variable` or :ref:`ndarray`): The first input
117            variable. The shape should be (N, K), where N denotes the
118            mini-batch size, and K denotes the dimension of ``x0``.
119        x1 (:class:`~chainer.Variable` or :ref:`ndarray`): The second input
120            variable. The shape should be the same as ``x0``.
121        y (:class:`~chainer.Variable` or :ref:`ndarray`): Labels. All values
122            should be 0 or 1. The shape should be ``(N,)``, where N denotes the
123            mini-batch size.
124        margin (float): A parameter for contrastive loss. It should be positive
125            value.
126        reduce (str): Reduction option. Its value must be either
127            ``'mean'`` or ``'no'``. Otherwise, :class:`ValueError` is raised.
128
129    Returns:
130        ~chainer.Variable:
131            A variable holding the loss value(s) calculated by the
132            above equation.
133            If ``reduce`` is ``'no'``, the output variable holds array
134            whose shape is same as one of (hence both of) input variables.
135            If it is ``'mean'``, the output variable holds a scalar value.
136
137    .. note::
138        This cost can be used to train siamese networks. See `Learning a
139        Similarity Metric Discriminatively, with Application to Face
140        Verification <http://yann.lecun.com/exdb/publis/pdf/chopra-05.pdf>`_
141        for details.
142
143    .. admonition:: Example
144
145        >>> x0 = np.array([[-2.0, 3.0, 0.5], [5.0, 2.0, -0.5]]).\
146astype(np.float32)
147        >>> x1 = np.array([[-1.0, 3.0, 1.0], [3.5, 0.5, -2.0]]).\
148astype(np.float32)
149        >>> y = np.array([1, 0]).astype(np.int32)
150        >>> F.contrastive(x0, x1, y)
151        variable(0.3125)
152        >>> F.contrastive(x0, x1, y, margin=3.0)  # harder penalty
153        variable(0.3528857)
154        >>> z = F.contrastive(x0, x1, y, reduce='no')
155        >>> z.shape
156        (2,)
157        >>> z.array
158        array([0.625, 0.   ], dtype=float32)
159
160    """
161    return Contrastive(margin, reduce).apply((x0, x1, y))[0]
162