1from chainer import backend 2from chainer.backends import cuda 3from chainer import function_node 4import chainer.functions 5from chainer import utils 6from chainer.utils import type_check 7 8 9class Minimum(function_node.FunctionNode): 10 """Element-wise minimum of input variables.""" 11 12 def check_type_forward(self, in_types): 13 type_check._argname(in_types, ('x1', 'x2')) 14 type_check.expect( 15 in_types[0].dtype.kind == 'f', 16 in_types[0].dtype == in_types[1].dtype, 17 ) 18 type_check.expect_broadcast_shapes( 19 in_types[0].shape, in_types[1].shape) 20 21 def forward(self, inputs): 22 # may broadcast 23 self.retain_inputs((0, 1)) 24 x1, x2 = inputs 25 xp = backend.get_array_module(x1, x2) 26 return utils.force_array(xp.minimum(x1, x2)), 27 28 def backward(self, indexes, grad_outputs): 29 x1, x2 = self.get_retained_inputs() 30 return MinimumGrad(x1.data, x2.data).apply((grad_outputs[0],)) 31 32 33class MinimumGrad(function_node.FunctionNode): 34 35 def __init__(self, x1, x2): 36 self.x1 = x1 37 self.x2 = x2 38 39 def forward_cpu(self, inputs): 40 gy, = inputs 41 x1, x2 = self.x1, self.x2 42 gx1 = utils.force_array(gy * (x1 <= x2)) 43 gx2 = utils.force_array(gy * (x1 > x2)) 44 return utils.sum_to(gx1, x1.shape), utils.sum_to(gx2, x2.shape) 45 46 def forward_gpu(self, inputs): 47 gy, = inputs 48 x1, x2 = self.x1, self.x2 49 gx1 = cuda.elementwise( 50 'T x1, T x2, T gy', 'T gx1', 51 'gx1 = (x1 <= x2) ? gy : (T)0.0', 52 'minimum_bwd1')(x1, x2, gy) 53 gx2 = cuda.elementwise( 54 'T x1, T x2, T gy', 'T gx1', 55 'gx1 = (x1 > x2) ? gy : (T)0.0', 56 'minimum_bwd2')(x1, x2, gy) 57 return utils.sum_to(gx1, x1.shape), utils.sum_to(gx2, x2.shape) 58 59 def backward(self, indexes, grad_outputs): 60 x1, x2 = self.x1, self.x2 61 cond = utils.force_array(x1 <= x2) 62 ggy = chainer.functions.where(cond, grad_outputs[0], grad_outputs[1]) 63 return ggy, 64 65 66def minimum(x1, x2): 67 """Element-wise minimum of input variables. 68 69 Args: 70 x1 (:class:`~chainer.Variable` or :ref:`ndarray`): 71 Input variables to be compared. 72 x2 (:class:`~chainer.Variable` or :ref:`ndarray`): 73 Input variables to be compared. 74 75 Returns: 76 ~chainer.Variable: Output variable. 77 """ 78 return Minimum().apply((x1, x2))[0] 79