1from chainer import backend
2from chainer import function_node
3import chainer.functions
4from chainer import utils
5from chainer.utils import type_check
6
7
8class Fmod(function_node.FunctionNode):
9
10    @property
11    def label(self):
12        return 'fmod'
13
14    def check_type_forward(self, in_types):
15        type_check._argname(in_types, ('x', 'divisor'))
16        type_check.expect(
17            in_types[0].dtype == in_types[1].dtype,
18            in_types[0].dtype.kind == 'f',
19            in_types[1].dtype.kind == 'f',
20        )
21
22    def forward(self, inputs):
23        self.retain_inputs((0, 1))
24        xp = backend.get_array_module(*inputs)
25        x, divisor = inputs
26        m = xp.fmod(x, divisor)
27        return utils.force_array(m, x.dtype),
28
29    def backward(self, indexes, grad_outputs):
30        x, divisor = self.get_retained_inputs()
31        gw, = grad_outputs
32        return gw, - chainer.functions.fix(x / divisor) * gw
33
34
35def fmod(x, divisor):
36    """Elementwise mod function.
37
38    .. math::
39       y_i = x_i \\bmod \\mathrm{divisor}.
40
41    Args:
42        x (:class:`~chainer.Variable` or :ref:`ndarray`): Input variable.
43        divisor (:class:`~chainer.Variable` or :ref:`ndarray`): Input divisor.
44    Returns:
45        ~chainer.Variable: Output variable.
46    """
47    return Fmod().apply((x, divisor))[0]
48