1import six
2
3from chainer import backend
4from chainer import function_node
5from chainer import utils
6from chainer.utils import type_check
7
8
9class Repeat(function_node.FunctionNode):
10
11    """Repeat elements of an array."""
12
13    def __init__(self, repeats, axis=None):
14        if isinstance(repeats, six.integer_types):
15            self.repeats = (repeats,)
16        elif isinstance(repeats, tuple) and all(
17                isinstance(x, six.integer_types) for x in repeats):
18            # Although it is not explicitly documented, NumPy/CuPy allows
19            # specifying bool or tuple of bools as `repeats`.
20            # Thus we just check type against `six.integer_types`, without
21            # excluding `bool`.
22            self.repeats = repeats
23        else:
24            raise TypeError('repeats must be int or tuple of ints')
25
26        if not all(x >= 0 for x in self.repeats):
27            raise ValueError('all elements in repeats must be zero or larger')
28
29        if axis is not None and (
30                not isinstance(axis, six.integer_types) or
31                isinstance(axis, bool)):
32            # `axis` cannot be bool, in contrast to `repeats`.
33            raise TypeError('axis must be int or None')
34        self.axis = axis
35
36    def check_type_forward(self, in_types):
37        type_check._argname(in_types, ('x',))
38
39    def forward(self, inputs):
40        self.retain_inputs((0,))
41        x, = inputs
42        xp = backend.get_array_module(x)
43        repeats = self.repeats
44
45        # Workaround for bug in NumPy 1.9 that specifying one element list to
46        # `repeats` fails to broadcast.
47        if len(repeats) == 1:
48            repeats = repeats[0]
49
50        return xp.repeat(x, repeats, self.axis),
51
52    def backward(self, indexes, grad_outputs):
53        x, = self.get_retained_inputs()
54        return RepeatGrad(self.repeats, self.axis, x.shape, x.dtype).apply(
55            grad_outputs)
56
57
58class RepeatGrad(function_node.FunctionNode):
59
60    def __init__(self, repeats, axis, in_shape, in_dtype):
61        self.repeats = repeats
62        self.axis = axis
63        if axis is not None and axis < 0:
64            self.axis += len(in_shape)
65
66        self.in_shape = in_shape
67        self.in_dtype = in_dtype
68
69    def forward(self, inputs):
70        gy, = inputs
71        xp = backend.get_array_module(gy)
72        repeats = self.repeats
73        axis = self.axis
74        shape = list(self.in_shape)
75        dtype = self.in_dtype
76
77        if len(gy) == 0:
78            gx = xp.zeros(shape, dtype)
79            return gx,
80
81        if len(repeats) == 1:
82            repeats = int(repeats[0])
83            if axis is None:
84                gx = gy.reshape(-1, repeats).sum(axis=1).reshape(shape)
85            else:
86                shape[axis:axis + 1] = [-1, repeats]
87                gx = gy.reshape(shape).sum(axis=axis + 1)
88            return gx,
89
90        if axis is None:
91            pos = 0
92            gx = xp.zeros(utils.size_of_shape(shape), dtype)
93            for (i, r) in enumerate(repeats):
94                gx[i] = xp.sum(gy[pos:pos + r])
95                pos += r
96            gx = gx.reshape(shape)
97        else:
98            gx = xp.zeros(shape, dtype)
99            pos = 0
100            src = [slice(None)] * axis + [None]
101            dst = [slice(None)] * axis + [None]
102            for (i, r) in enumerate(repeats):
103                src[-1] = slice(pos, pos + r)
104                dst[-1] = slice(i, i + 1)
105                gx[tuple(dst)] = gy[tuple(src)].sum(axis=axis, keepdims=True)
106                pos += r
107        return gx,
108
109    def backward(self, indexes, grad_outputs):
110        return Repeat(self.repeats, self.axis).apply(grad_outputs)
111
112
113def repeat(x, repeats, axis=None):
114    """Construct an array by repeating a given array.
115
116    Args:
117        x (:class:`~chainer.Variable` or :ref:`ndarray`):
118            Input variable.
119        repeats (:class:`int` or :class:`tuple` of :class:`int` s):
120            The number of times which each element of ``x`` is repeated.
121        axis (:class:`int`):
122            The axis along which to repeat values.
123
124    Returns:
125        ~chainer.Variable: The repeated output Variable.
126
127    .. admonition:: Example
128
129        >>> x = np.array([0, 1, 2])
130        >>> x.shape
131        (3,)
132        >>> y = F.repeat(x, 2)
133        >>> y.shape
134        (6,)
135        >>> y.array
136        array([0, 0, 1, 1, 2, 2])
137        >>> x = np.array([[1,2], [3,4]])
138        >>> x.shape
139        (2, 2)
140        >>> y = F.repeat(x, 3, axis=1)
141        >>> y.shape
142        (2, 6)
143        >>> y.array
144        array([[1, 1, 1, 2, 2, 2],
145               [3, 3, 3, 4, 4, 4]])
146        >>> y = F.repeat(x, (1, 2), axis=0)
147        >>> y.shape
148        (3, 2)
149        >>> y.array
150        array([[1, 2],
151               [3, 4],
152               [3, 4]])
153
154    """
155    return Repeat(repeats, axis).apply((x,))[0]
156