1import numpy 2import six 3 4import chainer 5from chainer import backend 6from chainer import function_node 7from chainer.utils import type_check 8 9 10class Vstack(function_node.FunctionNode): 11 12 """Concatenate multiple tensors vertically (row wise).""" 13 14 def check_type_forward(self, in_types): 15 type_check.expect(in_types.size() > 0) 16 17 ndim = type_check.eval(in_types[0].ndim) 18 for i in six.moves.range(1, type_check.eval(in_types.size())): 19 type_check.expect( 20 in_types[0].dtype == in_types[i].dtype, 21 in_types[0].ndim == in_types[i].ndim, 22 ) 23 if ndim <= 1: 24 type_check.expect(in_types[0].shape == in_types[i].shape) 25 continue 26 for d in six.moves.range(1, ndim): 27 type_check.expect(in_types[0].shape[d] == in_types[i].shape[d]) 28 29 def forward(self, xs): 30 xp = backend.get_array_module(*xs) 31 return xp.vstack(xs), 32 33 def backward(self, indexes, grad_outputs): 34 gy, = grad_outputs 35 ndim = len(self.inputs[0].shape) 36 if len(self.inputs) == 1: 37 if ndim <= 1: 38 return gy.reshape(self.inputs[0].shape), 39 return gy, 40 41 if ndim <= 1: 42 gxs = chainer.functions.split_axis(gy, len(self.inputs), 0) 43 return [gx.reshape(self.inputs[0].shape) for gx in gxs] 44 45 sizes = numpy.array([x.shape[0] for x in self.inputs[:-1]]).cumsum() 46 return chainer.functions.split_axis(gy, sizes, 0) 47 48 49def vstack(xs): 50 """Concatenate variables vertically (row wise). 51 52 Args: 53 xs (list of :class:`~chainer.Variable` or :ref:`ndarray`): 54 Input variables to be concatenated. The variables must have the 55 same ``ndim``. When the variables have the second axis (i.e. 56 :math:`ndim \\geq 2`), the variables must have the same shape 57 along all but the first axis. When the variables do not have the 58 second axis(i.e. :math:`ndim < 2`), the variables must have the 59 same shape. 60 61 Returns: 62 ~chainer.Variable: 63 Output variable. When the input variables have the second axis 64 (i.e. :math:`ndim \\geq 2`), the shapes of inputs and output are 65 the same along all but the first axis. The length of first axis 66 is the sum of the lengths of inputs' first axis. 67 When the variables do not have the second axis (i.e. 68 :math:`ndim < 2`), the shape of output is ``(2, N)`` (``N`` is the 69 size of the input variable). 70 71 .. admonition:: Example 72 73 >>> x1 = np.array((1, 2, 3)) 74 >>> x1.shape 75 (3,) 76 >>> x2 = np.array((2, 3, 4)) 77 >>> x2.shape 78 (3,) 79 >>> y = F.vstack((x1, x2)) 80 >>> y.shape 81 (2, 3) 82 >>> y.array 83 array([[1, 2, 3], 84 [2, 3, 4]]) 85 >>> x1 = np.arange(0, 12).reshape(3, 4) 86 >>> x1.shape 87 (3, 4) 88 >>> x1 89 array([[ 0, 1, 2, 3], 90 [ 4, 5, 6, 7], 91 [ 8, 9, 10, 11]]) 92 >>> x2 = np.arange(12, 20).reshape(2, 4) 93 >>> x2.shape 94 (2, 4) 95 >>> x2 96 array([[12, 13, 14, 15], 97 [16, 17, 18, 19]]) 98 >>> y = F.vstack([x1, x2]) 99 >>> y.shape 100 (5, 4) 101 >>> y.array 102 array([[ 0, 1, 2, 3], 103 [ 4, 5, 6, 7], 104 [ 8, 9, 10, 11], 105 [12, 13, 14, 15], 106 [16, 17, 18, 19]]) 107 108 """ 109 110 return Vstack().apply((xs))[0] 111