1import six 2 3from chainer import backend 4from chainer import function_node 5from chainer import utils 6from chainer.utils import type_check 7 8 9class Repeat(function_node.FunctionNode): 10 11 """Repeat elements of an array.""" 12 13 def __init__(self, repeats, axis=None): 14 if isinstance(repeats, six.integer_types): 15 self.repeats = (repeats,) 16 elif isinstance(repeats, tuple) and all( 17 isinstance(x, six.integer_types) for x in repeats): 18 # Although it is not explicitly documented, NumPy/CuPy allows 19 # specifying bool or tuple of bools as `repeats`. 20 # Thus we just check type against `six.integer_types`, without 21 # excluding `bool`. 22 self.repeats = repeats 23 else: 24 raise TypeError('repeats must be int or tuple of ints') 25 26 if not all(x >= 0 for x in self.repeats): 27 raise ValueError('all elements in repeats must be zero or larger') 28 29 if axis is not None and ( 30 not isinstance(axis, six.integer_types) or 31 isinstance(axis, bool)): 32 # `axis` cannot be bool, in contrast to `repeats`. 33 raise TypeError('axis must be int or None') 34 self.axis = axis 35 36 def check_type_forward(self, in_types): 37 type_check._argname(in_types, ('x',)) 38 39 def forward(self, inputs): 40 self.retain_inputs((0,)) 41 x, = inputs 42 xp = backend.get_array_module(x) 43 repeats = self.repeats 44 45 # Workaround for bug in NumPy 1.9 that specifying one element list to 46 # `repeats` fails to broadcast. 47 if len(repeats) == 1: 48 repeats = repeats[0] 49 50 return xp.repeat(x, repeats, self.axis), 51 52 def backward(self, indexes, grad_outputs): 53 x, = self.get_retained_inputs() 54 return RepeatGrad(self.repeats, self.axis, x.shape, x.dtype).apply( 55 grad_outputs) 56 57 58class RepeatGrad(function_node.FunctionNode): 59 60 def __init__(self, repeats, axis, in_shape, in_dtype): 61 self.repeats = repeats 62 self.axis = axis 63 if axis is not None and axis < 0: 64 self.axis += len(in_shape) 65 66 self.in_shape = in_shape 67 self.in_dtype = in_dtype 68 69 def forward(self, inputs): 70 gy, = inputs 71 xp = backend.get_array_module(gy) 72 repeats = self.repeats 73 axis = self.axis 74 shape = list(self.in_shape) 75 dtype = self.in_dtype 76 77 if len(gy) == 0: 78 gx = xp.zeros(shape, dtype) 79 return gx, 80 81 if len(repeats) == 1: 82 repeats = int(repeats[0]) 83 if axis is None: 84 gx = gy.reshape(-1, repeats).sum(axis=1).reshape(shape) 85 else: 86 shape[axis:axis + 1] = [-1, repeats] 87 gx = gy.reshape(shape).sum(axis=axis + 1) 88 return gx, 89 90 if axis is None: 91 pos = 0 92 gx = xp.zeros(utils.size_of_shape(shape), dtype) 93 for (i, r) in enumerate(repeats): 94 gx[i] = xp.sum(gy[pos:pos + r]) 95 pos += r 96 gx = gx.reshape(shape) 97 else: 98 gx = xp.zeros(shape, dtype) 99 pos = 0 100 src = [slice(None)] * axis + [None] 101 dst = [slice(None)] * axis + [None] 102 for (i, r) in enumerate(repeats): 103 src[-1] = slice(pos, pos + r) 104 dst[-1] = slice(i, i + 1) 105 gx[tuple(dst)] = gy[tuple(src)].sum(axis=axis, keepdims=True) 106 pos += r 107 return gx, 108 109 def backward(self, indexes, grad_outputs): 110 return Repeat(self.repeats, self.axis).apply(grad_outputs) 111 112 113def repeat(x, repeats, axis=None): 114 """Construct an array by repeating a given array. 115 116 Args: 117 x (:class:`~chainer.Variable` or :ref:`ndarray`): 118 Input variable. 119 repeats (:class:`int` or :class:`tuple` of :class:`int` s): 120 The number of times which each element of ``x`` is repeated. 121 axis (:class:`int`): 122 The axis along which to repeat values. 123 124 Returns: 125 ~chainer.Variable: The repeated output Variable. 126 127 .. admonition:: Example 128 129 >>> x = np.array([0, 1, 2]) 130 >>> x.shape 131 (3,) 132 >>> y = F.repeat(x, 2) 133 >>> y.shape 134 (6,) 135 >>> y.array 136 array([0, 0, 1, 1, 2, 2]) 137 >>> x = np.array([[1,2], [3,4]]) 138 >>> x.shape 139 (2, 2) 140 >>> y = F.repeat(x, 3, axis=1) 141 >>> y.shape 142 (2, 6) 143 >>> y.array 144 array([[1, 1, 1, 2, 2, 2], 145 [3, 3, 3, 4, 4, 4]]) 146 >>> y = F.repeat(x, (1, 2), axis=0) 147 >>> y.shape 148 (3, 2) 149 >>> y.array 150 array([[1, 2], 151 [3, 4], 152 [3, 4]]) 153 154 """ 155 return Repeat(repeats, axis).apply((x,))[0] 156