1from chainer import backend
2from chainer.backends import cuda
3from chainer import function_node
4import chainer.functions
5from chainer import utils
6from chainer.utils import type_check
7
8
9class Minimum(function_node.FunctionNode):
10    """Element-wise minimum of input variables."""
11
12    def check_type_forward(self, in_types):
13        type_check._argname(in_types, ('x1', 'x2'))
14        type_check.expect(
15            in_types[0].dtype.kind == 'f',
16            in_types[0].dtype == in_types[1].dtype,
17        )
18        type_check.expect_broadcast_shapes(
19            in_types[0].shape, in_types[1].shape)
20
21    def forward(self, inputs):
22        # may broadcast
23        self.retain_inputs((0, 1))
24        x1, x2 = inputs
25        xp = backend.get_array_module(x1, x2)
26        return utils.force_array(xp.minimum(x1, x2)),
27
28    def backward(self, indexes, grad_outputs):
29        x1, x2 = self.get_retained_inputs()
30        return MinimumGrad(x1.data, x2.data).apply((grad_outputs[0],))
31
32
33class MinimumGrad(function_node.FunctionNode):
34
35    def __init__(self, x1, x2):
36        self.x1 = x1
37        self.x2 = x2
38
39    def forward_cpu(self, inputs):
40        gy, = inputs
41        x1, x2 = self.x1, self.x2
42        gx1 = utils.force_array(gy * (x1 <= x2))
43        gx2 = utils.force_array(gy * (x1 > x2))
44        return utils.sum_to(gx1, x1.shape), utils.sum_to(gx2, x2.shape)
45
46    def forward_gpu(self, inputs):
47        gy, = inputs
48        x1, x2 = self.x1, self.x2
49        gx1 = cuda.elementwise(
50            'T x1, T x2, T gy', 'T gx1',
51            'gx1 = (x1 <= x2) ? gy : (T)0.0',
52            'minimum_bwd1')(x1, x2, gy)
53        gx2 = cuda.elementwise(
54            'T x1, T x2, T gy', 'T gx1',
55            'gx1 = (x1 > x2) ? gy : (T)0.0',
56            'minimum_bwd2')(x1, x2, gy)
57        return utils.sum_to(gx1, x1.shape), utils.sum_to(gx2, x2.shape)
58
59    def backward(self, indexes, grad_outputs):
60        x1, x2 = self.x1, self.x2
61        cond = utils.force_array(x1 <= x2)
62        ggy = chainer.functions.where(cond, grad_outputs[0], grad_outputs[1])
63        return ggy,
64
65
66def minimum(x1, x2):
67    """Element-wise minimum of input variables.
68
69    Args:
70        x1 (:class:`~chainer.Variable` or :ref:`ndarray`):
71            Input variables to be compared.
72        x2 (:class:`~chainer.Variable` or :ref:`ndarray`):
73            Input variables to be compared.
74
75    Returns:
76        ~chainer.Variable: Output variable.
77    """
78    return Minimum().apply((x1, x2))[0]
79