1import six
2
3import chainer
4from chainer import backend
5from chainer import function_node
6from chainer.utils import type_check
7
8
9class Tile(function_node.FunctionNode):
10
11    """Tiling of an array."""
12
13    def __init__(self, reps):
14        if isinstance(reps, six.integer_types):
15            self.reps = (reps,)
16        elif isinstance(reps, tuple) and all(
17                isinstance(x, six.integer_types) for x in reps):
18            self.reps = reps
19        else:
20            msg = 'reps must be int or tuple of ints.\n' \
21                  'Actual: {0}'.format(type(reps))
22            raise TypeError(msg)
23
24        if not all(x >= 0 for x in self.reps):
25            raise ValueError('All elements in reps must be zero or larger')
26
27    def check_type_forward(self, in_types):
28        type_check.expect(in_types.size() == 1)
29
30    def forward(self, inputs):
31        self._in_shape = inputs[0].shape
32        xp = backend.get_array_module(*inputs)
33        return xp.tile(inputs[0], self.reps),
34
35    def backward(self, indexes, grad_outputs):
36        reps = self.reps
37        shape = tuple(self._in_shape)
38        ndim = len(shape)
39
40        # Ensure input and reps have the same length.
41        if ndim > len(reps):
42            reps = (1,) * (ndim - len(reps)) + reps
43        elif ndim < len(reps):
44            shape = (1,) * (len(reps) - ndim) + shape
45
46        gy, = grad_outputs
47
48        # Reshape so that base axis and reps axis can be distinguished.
49        new_shape = []
50        for i in range(gy.ndim):
51            new_shape.append(reps[i])
52            new_shape.append(shape[i])
53        new_shape = tuple(new_shape)
54
55        # Sum along reps axis
56        reps_axis = tuple(range(0, 2 * gy.ndim, 2))
57        gy = gy.reshape(new_shape)
58        gy = chainer.functions.sum(gy, axis=reps_axis)
59
60        if ndim < len(reps):
61            return gy.reshape(self._in_shape),
62        else:
63            return gy,
64
65
66def tile(x, reps):
67    """Construct an array by tiling a given array.
68
69
70    Args:
71        x (:class:`~chainer.Variable` or :ref:`ndarray`):
72            Input variable. Let the length of ``reps`` be ``d``. If
73            ``x.ndim < d``, ``x`` is treated as ``d``-dimensional array by
74            prepending new axes. For example, when the shape of ``x`` is
75            ``(2,)`` and tiled with 2-dim repetitions, ``x`` is treated as the
76            shape ``(1, 2)``. If ``x.ndim > d``, ``reps`` is treated as
77            ``x.ndim``-dimensional by pre-pending 1's. For example, when the
78            shape of ``x`` is ``(2, 3, 2, 3)``, the 2-dim ``reps`` of
79            ``(2, 2)`` is treated as ``(1, 1, 2, 2)``.
80        reps (:class:`int` or :class:`tuple` of :class:`int` s):
81            The number of times which ``x`` is replicated along each axis.
82
83    Returns:
84        ~chainer.Variable: The tiled output Variable.
85        Let the length of ``reps`` be ``d``, the output has the dimension of
86        ``max(d, x.ndim)``.
87
88    .. admonition:: Example
89
90        >>> x = np.array([0, 1, 2])
91        >>> x.shape
92        (3,)
93        >>> y = F.tile(x, 2)
94        >>> y.shape
95        (6,)
96        >>> y.array
97        array([0, 1, 2, 0, 1, 2])
98        >>> y = F.tile(x, (2, 2))
99        >>> y.shape
100        (2, 6)
101        >>> y.array
102        array([[0, 1, 2, 0, 1, 2],
103               [0, 1, 2, 0, 1, 2]])
104        >>> y = F.tile(x, (2, 1, 2))
105        >>> y.shape
106        (2, 1, 6)
107        >>> y.array
108        array([[[0, 1, 2, 0, 1, 2]],
109        <BLANKLINE>
110               [[0, 1, 2, 0, 1, 2]]])
111
112        >>> x = np.array([[1, 2], [3, 4]])
113        >>> x.shape
114        (2, 2)
115        >>> y = F.tile(x, 2)
116        >>> y.shape
117        (2, 4)
118        >>> y.array
119        array([[1, 2, 1, 2],
120               [3, 4, 3, 4]])
121        >>> y = F.tile(x, (2, 2))
122        >>> y.shape
123        (4, 4)
124        >>> y.array
125        array([[1, 2, 1, 2],
126               [3, 4, 3, 4],
127               [1, 2, 1, 2],
128               [3, 4, 3, 4]])
129        >>> y = F.tile(x, (2, 1, 2))
130        >>> y.shape
131        (2, 2, 4)
132        >>> y.array
133        array([[[1, 2, 1, 2],
134                [3, 4, 3, 4]],
135        <BLANKLINE>
136               [[1, 2, 1, 2],
137                [3, 4, 3, 4]]])
138
139    """
140    return Tile(reps).apply((x,))[0]
141