1import numpy
2import six
3
4import chainer
5from chainer import backend
6from chainer import function_node
7from chainer.utils import type_check
8
9
10class Vstack(function_node.FunctionNode):
11
12    """Concatenate multiple tensors vertically (row wise)."""
13
14    def check_type_forward(self, in_types):
15        type_check.expect(in_types.size() > 0)
16
17        ndim = type_check.eval(in_types[0].ndim)
18        for i in six.moves.range(1, type_check.eval(in_types.size())):
19            type_check.expect(
20                in_types[0].dtype == in_types[i].dtype,
21                in_types[0].ndim == in_types[i].ndim,
22            )
23            if ndim <= 1:
24                type_check.expect(in_types[0].shape == in_types[i].shape)
25                continue
26            for d in six.moves.range(1, ndim):
27                type_check.expect(in_types[0].shape[d] == in_types[i].shape[d])
28
29    def forward(self, xs):
30        xp = backend.get_array_module(*xs)
31        return xp.vstack(xs),
32
33    def backward(self, indexes, grad_outputs):
34        gy, = grad_outputs
35        ndim = len(self.inputs[0].shape)
36        if len(self.inputs) == 1:
37            if ndim <= 1:
38                return gy.reshape(self.inputs[0].shape),
39            return gy,
40
41        if ndim <= 1:
42            gxs = chainer.functions.split_axis(gy, len(self.inputs), 0)
43            return [gx.reshape(self.inputs[0].shape) for gx in gxs]
44
45        sizes = numpy.array([x.shape[0] for x in self.inputs[:-1]]).cumsum()
46        return chainer.functions.split_axis(gy, sizes, 0)
47
48
49def vstack(xs):
50    """Concatenate variables vertically (row wise).
51
52    Args:
53        xs (list of :class:`~chainer.Variable` or :ref:`ndarray`):
54            Input variables to be concatenated. The variables must have the
55            same ``ndim``. When the variables have the second axis (i.e.
56            :math:`ndim \\geq 2`), the variables must have the same shape
57            along all but the first axis. When the variables do not have the
58            second axis(i.e. :math:`ndim < 2`), the variables must have the
59            same shape.
60
61    Returns:
62        ~chainer.Variable:
63            Output variable. When the input variables have the second axis
64            (i.e. :math:`ndim \\geq 2`), the shapes of inputs and output are
65            the same along all but the first axis. The length of first axis
66            is the sum of the lengths of inputs' first axis.
67            When the variables do not have the second axis (i.e.
68            :math:`ndim < 2`), the shape of output is ``(2, N)`` (``N`` is the
69            size of the input variable).
70
71    .. admonition:: Example
72
73        >>> x1 = np.array((1, 2, 3))
74        >>> x1.shape
75        (3,)
76        >>> x2 = np.array((2, 3, 4))
77        >>> x2.shape
78        (3,)
79        >>> y = F.vstack((x1, x2))
80        >>> y.shape
81        (2, 3)
82        >>> y.array
83        array([[1, 2, 3],
84               [2, 3, 4]])
85        >>> x1 = np.arange(0, 12).reshape(3, 4)
86        >>> x1.shape
87        (3, 4)
88        >>> x1
89        array([[ 0,  1,  2,  3],
90               [ 4,  5,  6,  7],
91               [ 8,  9, 10, 11]])
92        >>> x2 = np.arange(12, 20).reshape(2, 4)
93        >>> x2.shape
94        (2, 4)
95        >>> x2
96        array([[12, 13, 14, 15],
97               [16, 17, 18, 19]])
98        >>> y = F.vstack([x1, x2])
99        >>> y.shape
100        (5, 4)
101        >>> y.array
102        array([[ 0,  1,  2,  3],
103               [ 4,  5,  6,  7],
104               [ 8,  9, 10, 11],
105               [12, 13, 14, 15],
106               [16, 17, 18, 19]])
107
108    """
109
110    return Vstack().apply((xs))[0]
111