1import numpy
2import six
3
4import chainer
5from chainer import backend
6from chainer.backends import intel64
7from chainer import function_node
8from chainer.utils import type_check
9import chainerx
10
11
12class Concat(function_node.FunctionNode):
13
14    """Concatenate multiple tensors towards specified axis."""
15
16    # concat along the channel dimension by default
17    def __init__(self, axis=1):
18        if not isinstance(axis, six.integer_types):
19            raise TypeError('axis must be int')
20
21        self.axis = axis
22
23    def check_type_forward(self, in_types):
24        type_check.expect(in_types.size() > 0)
25        type_check.expect(in_types[0].ndim >
26                          type_check.make_variable(self.axis, 'axis'))
27
28        type_check.expect(
29            -in_types[0].ndim <= self.axis,
30            self.axis < in_types[0].ndim
31        )
32        ndim = type_check.eval(in_types[0].ndim)
33        axis = self.axis % ndim
34        for i in six.moves.range(1, type_check.eval(in_types.size())):
35            type_check.expect(
36                in_types[0].dtype == in_types[i].dtype,
37                in_types[0].ndim == in_types[i].ndim,
38            )
39            for d in six.moves.range(0, ndim):
40                if d == axis:
41                    continue
42                type_check.expect(in_types[0].shape[d] == in_types[i].shape[d])
43
44    def forward(self, xs):
45        if (intel64.should_use_ideep('>=auto')
46                and intel64.inputs_all_ready(xs, (4,))):
47            # iDeep implementation
48            return self._forward_ideep(xs)
49
50        # Generic implementation
51        xp = backend.get_array_module(*xs)
52        return xp.concatenate(xs, self.axis),
53
54    def forward_chainerx(self, xs):
55        return chainerx.concatenate(xs, self.axis),
56
57    def _forward_ideep(self, xs):
58        xs_mdarray = intel64.ideep.mdarrayVector()
59        for x in xs:
60            xs_mdarray.push_back(intel64.ideep.array(x))
61        ndim = xs[0].ndim
62        axis = self.axis % ndim
63        return intel64.ideep.concat.Forward(xs_mdarray, axis),
64
65    def backward(self, indexes, grad_outputs):
66        if len(self.inputs) == 1:
67            return grad_outputs
68
69        sizes = numpy.array(
70            [v.shape[self.axis] for v in self.inputs[:-1]]
71        ).cumsum()
72        gx, = grad_outputs
73        return chainer.functions.split_axis(gx, sizes, self.axis)
74
75
76def concat(xs, axis=1):
77    """Concatenates given variables along an axis.
78
79    Args:
80        xs (tuple of :class:`~chainer.Variable` or :ref:`ndarray`):
81            Input variables to be concatenated. The variables must have the \
82            same shape, except in the dimension corresponding to axis.
83        axis (int): The axis along which the arrays will be joined. Default \
84            is 1.
85
86    Returns:
87        ~chainer.Variable: The concatenated variable.
88
89    .. admonition:: Example
90
91        >>> x = np.arange(0, 12).reshape(3, 4)
92        >>> x
93        array([[ 0,  1,  2,  3],
94               [ 4,  5,  6,  7],
95               [ 8,  9, 10, 11]])
96        >>> y = np.arange(0, 3).reshape(3, 1)
97        >>> y
98        array([[0],
99               [1],
100               [2]])
101        >>> z = F.concat((x, y), axis=1)
102        >>> z.array
103        array([[ 0,  1,  2,  3,  0],
104               [ 4,  5,  6,  7,  1],
105               [ 8,  9, 10, 11,  2]])
106
107    """
108    y, = Concat(axis).apply(xs)
109    return y
110