1import six 2 3import chainer 4from chainer import backend 5from chainer import function_node 6from chainer.utils import type_check 7 8 9class Tile(function_node.FunctionNode): 10 11 """Tiling of an array.""" 12 13 def __init__(self, reps): 14 if isinstance(reps, six.integer_types): 15 self.reps = (reps,) 16 elif isinstance(reps, tuple) and all( 17 isinstance(x, six.integer_types) for x in reps): 18 self.reps = reps 19 else: 20 msg = 'reps must be int or tuple of ints.\n' \ 21 'Actual: {0}'.format(type(reps)) 22 raise TypeError(msg) 23 24 if not all(x >= 0 for x in self.reps): 25 raise ValueError('All elements in reps must be zero or larger') 26 27 def check_type_forward(self, in_types): 28 type_check.expect(in_types.size() == 1) 29 30 def forward(self, inputs): 31 self._in_shape = inputs[0].shape 32 xp = backend.get_array_module(*inputs) 33 return xp.tile(inputs[0], self.reps), 34 35 def backward(self, indexes, grad_outputs): 36 reps = self.reps 37 shape = tuple(self._in_shape) 38 ndim = len(shape) 39 40 # Ensure input and reps have the same length. 41 if ndim > len(reps): 42 reps = (1,) * (ndim - len(reps)) + reps 43 elif ndim < len(reps): 44 shape = (1,) * (len(reps) - ndim) + shape 45 46 gy, = grad_outputs 47 48 # Reshape so that base axis and reps axis can be distinguished. 49 new_shape = [] 50 for i in range(gy.ndim): 51 new_shape.append(reps[i]) 52 new_shape.append(shape[i]) 53 new_shape = tuple(new_shape) 54 55 # Sum along reps axis 56 reps_axis = tuple(range(0, 2 * gy.ndim, 2)) 57 gy = gy.reshape(new_shape) 58 gy = chainer.functions.sum(gy, axis=reps_axis) 59 60 if ndim < len(reps): 61 return gy.reshape(self._in_shape), 62 else: 63 return gy, 64 65 66def tile(x, reps): 67 """Construct an array by tiling a given array. 68 69 70 Args: 71 x (:class:`~chainer.Variable` or :ref:`ndarray`): 72 Input variable. Let the length of ``reps`` be ``d``. If 73 ``x.ndim < d``, ``x`` is treated as ``d``-dimensional array by 74 prepending new axes. For example, when the shape of ``x`` is 75 ``(2,)`` and tiled with 2-dim repetitions, ``x`` is treated as the 76 shape ``(1, 2)``. If ``x.ndim > d``, ``reps`` is treated as 77 ``x.ndim``-dimensional by pre-pending 1's. For example, when the 78 shape of ``x`` is ``(2, 3, 2, 3)``, the 2-dim ``reps`` of 79 ``(2, 2)`` is treated as ``(1, 1, 2, 2)``. 80 reps (:class:`int` or :class:`tuple` of :class:`int` s): 81 The number of times which ``x`` is replicated along each axis. 82 83 Returns: 84 ~chainer.Variable: The tiled output Variable. 85 Let the length of ``reps`` be ``d``, the output has the dimension of 86 ``max(d, x.ndim)``. 87 88 .. admonition:: Example 89 90 >>> x = np.array([0, 1, 2]) 91 >>> x.shape 92 (3,) 93 >>> y = F.tile(x, 2) 94 >>> y.shape 95 (6,) 96 >>> y.array 97 array([0, 1, 2, 0, 1, 2]) 98 >>> y = F.tile(x, (2, 2)) 99 >>> y.shape 100 (2, 6) 101 >>> y.array 102 array([[0, 1, 2, 0, 1, 2], 103 [0, 1, 2, 0, 1, 2]]) 104 >>> y = F.tile(x, (2, 1, 2)) 105 >>> y.shape 106 (2, 1, 6) 107 >>> y.array 108 array([[[0, 1, 2, 0, 1, 2]], 109 <BLANKLINE> 110 [[0, 1, 2, 0, 1, 2]]]) 111 112 >>> x = np.array([[1, 2], [3, 4]]) 113 >>> x.shape 114 (2, 2) 115 >>> y = F.tile(x, 2) 116 >>> y.shape 117 (2, 4) 118 >>> y.array 119 array([[1, 2, 1, 2], 120 [3, 4, 3, 4]]) 121 >>> y = F.tile(x, (2, 2)) 122 >>> y.shape 123 (4, 4) 124 >>> y.array 125 array([[1, 2, 1, 2], 126 [3, 4, 3, 4], 127 [1, 2, 1, 2], 128 [3, 4, 3, 4]]) 129 >>> y = F.tile(x, (2, 1, 2)) 130 >>> y.shape 131 (2, 2, 4) 132 >>> y.array 133 array([[[1, 2, 1, 2], 134 [3, 4, 3, 4]], 135 <BLANKLINE> 136 [[1, 2, 1, 2], 137 [3, 4, 3, 4]]]) 138 139 """ 140 return Tile(reps).apply((x,))[0] 141