1import numpy
2
3from chainer import backend
4from chainer import function_node
5from chainer.utils import type_check
6
7
8class Pad(function_node.FunctionNode):
9
10    """Padding of an array."""
11
12    def __init__(self, pad_width, mode, **keywords):
13        self.mode = mode
14        self.keywords = keywords
15        self.pad_width = pad_width
16        self.pad_bw = numpy.asarray(pad_width)
17        if self.pad_bw.size == 1:
18            self.pad_bw = numpy.repeat(self.pad_bw, 2)
19
20    def check_type_forward(self, in_types):
21        # Depending on the arguments, pad_width and keywords, the input value
22        # may be inappropriate. In that case, numpy.pad or cupy.pad will raise
23        # errors, so that only check the size and the dtype in this function.
24        type_check._argname(in_types, ('x',))
25        x_type = in_types[0]
26        type_check.expect(x_type.dtype.kind == 'f')
27
28    def forward(self, inputs):
29        xp = backend.get_array_module(*inputs)
30        return xp.pad(inputs[0], self.pad_width, mode=self.mode,
31                      **self.keywords),
32
33    def backward(self, inputs, grad_outputs):
34        gy, = grad_outputs
35        in_shape = self.inputs[0].shape
36        if self.pad_bw.ndim == 1:
37            self.pad_bw = numpy.tile(self.pad_bw, (len(in_shape), 1))
38        input_idxs = tuple(
39            slice(p[0], p[0] + dim) for dim, p in zip(in_shape, self.pad_bw))
40        return gy[input_idxs],
41
42
43def pad(x, pad_width, mode, **keywords):
44    """Pad an input variable.
45
46    Args:
47        x (:class:`~chainer.Variable` or :ref:`ndarray`):
48            Input data.
49        pad_width (int or array-like):
50            Number of values padded to the edges of each axis.
51        mode (str):
52            Specifies how the function fills the periphery of the array.
53            The mode is passed to :func:`numpy.pad` or :func:`cupy.pad`.
54            If it is ``'constant'``, the input is padded by a constant value
55            specified by ``constant_values``.
56        constant_values (int or array-like):
57            Constant values to fill the periphery in the ``'constant'`` mode.
58
59    Returns:
60        ~chainer.Variable: Output variable.
61
62    """
63    return Pad(pad_width, mode, **keywords).apply((x,))[0]
64