1import numpy 2import six 3 4import chainer 5from chainer import backend 6from chainer.backends import intel64 7from chainer import function_node 8from chainer.utils import type_check 9import chainerx 10 11 12class Concat(function_node.FunctionNode): 13 14 """Concatenate multiple tensors towards specified axis.""" 15 16 # concat along the channel dimension by default 17 def __init__(self, axis=1): 18 if not isinstance(axis, six.integer_types): 19 raise TypeError('axis must be int') 20 21 self.axis = axis 22 23 def check_type_forward(self, in_types): 24 type_check.expect(in_types.size() > 0) 25 type_check.expect(in_types[0].ndim > 26 type_check.make_variable(self.axis, 'axis')) 27 28 type_check.expect( 29 -in_types[0].ndim <= self.axis, 30 self.axis < in_types[0].ndim 31 ) 32 ndim = type_check.eval(in_types[0].ndim) 33 axis = self.axis % ndim 34 for i in six.moves.range(1, type_check.eval(in_types.size())): 35 type_check.expect( 36 in_types[0].dtype == in_types[i].dtype, 37 in_types[0].ndim == in_types[i].ndim, 38 ) 39 for d in six.moves.range(0, ndim): 40 if d == axis: 41 continue 42 type_check.expect(in_types[0].shape[d] == in_types[i].shape[d]) 43 44 def forward(self, xs): 45 if (intel64.should_use_ideep('>=auto') 46 and intel64.inputs_all_ready(xs, (4,))): 47 # iDeep implementation 48 return self._forward_ideep(xs) 49 50 # Generic implementation 51 xp = backend.get_array_module(*xs) 52 return xp.concatenate(xs, self.axis), 53 54 def forward_chainerx(self, xs): 55 return chainerx.concatenate(xs, self.axis), 56 57 def _forward_ideep(self, xs): 58 xs_mdarray = intel64.ideep.mdarrayVector() 59 for x in xs: 60 xs_mdarray.push_back(intel64.ideep.array(x)) 61 ndim = xs[0].ndim 62 axis = self.axis % ndim 63 return intel64.ideep.concat.Forward(xs_mdarray, axis), 64 65 def backward(self, indexes, grad_outputs): 66 if len(self.inputs) == 1: 67 return grad_outputs 68 69 sizes = numpy.array( 70 [v.shape[self.axis] for v in self.inputs[:-1]] 71 ).cumsum() 72 gx, = grad_outputs 73 return chainer.functions.split_axis(gx, sizes, self.axis) 74 75 76def concat(xs, axis=1): 77 """Concatenates given variables along an axis. 78 79 Args: 80 xs (tuple of :class:`~chainer.Variable` or :ref:`ndarray`): 81 Input variables to be concatenated. The variables must have the \ 82 same shape, except in the dimension corresponding to axis. 83 axis (int): The axis along which the arrays will be joined. Default \ 84 is 1. 85 86 Returns: 87 ~chainer.Variable: The concatenated variable. 88 89 .. admonition:: Example 90 91 >>> x = np.arange(0, 12).reshape(3, 4) 92 >>> x 93 array([[ 0, 1, 2, 3], 94 [ 4, 5, 6, 7], 95 [ 8, 9, 10, 11]]) 96 >>> y = np.arange(0, 3).reshape(3, 1) 97 >>> y 98 array([[0], 99 [1], 100 [2]]) 101 >>> z = F.concat((x, y), axis=1) 102 >>> z.array 103 array([[ 0, 1, 2, 3, 0], 104 [ 4, 5, 6, 7, 1], 105 [ 8, 9, 10, 11, 2]]) 106 107 """ 108 y, = Concat(axis).apply(xs) 109 return y 110