1import numpy 2import six 3 4from chainer import backend 5from chainer import function_node 6from chainer import utils 7from chainer.utils import collections_abc 8from chainer.utils import type_check 9 10 11def _tensordot(a, b, a_axes, b_axes, c_axes=None): 12 a_col_ndim = len(a_axes[1]) 13 b_row_ndim = len(b_axes[0]) 14 if a_col_ndim != b_row_ndim: 15 raise ValueError('axes count mismatch') 16 if a.ndim < a_col_ndim or b.ndim < b_row_ndim: 17 raise ValueError('dimension of input tensors must be ' 18 'greater equal to dot-axes count ({})' 19 .format(a_col_ndim)) 20 for a_axis, b_axis in zip(a_axes[1], b_axes[0]): 21 if a.shape[a_axis] != b.shape[b_axis]: 22 raise ValueError('shape mismatch') 23 24 xp = backend.get_array_module(a) 25 y = xp.tensordot(a, b, axes=(tuple(a_axes[1]), tuple(b_axes[0]))) 26 27 if c_axes is not None: 28 a_row_ndim = len(a_axes[0]) 29 b_col_ndim = len(b_axes[1]) 30 c_row_ndim = len(c_axes[0]) 31 c_col_ndim = len(c_axes[1]) 32 if a_row_ndim != c_row_ndim: 33 raise ValueError('axes count mismatch') 34 if b_col_ndim != c_col_ndim: 35 raise ValueError('axes count mismatch') 36 37 trans = [None for i in six.moves.range(y.ndim)] 38 table_a = [1 if i in a_axes[0] else 0 for i in six.moves.range(a.ndim)] 39 table_a = numpy.cumsum(table_a) - 1 40 for i, c_axis in enumerate(c_axes[0]): 41 trans[c_axis] = table_a[a_axes[0][i]] 42 table_b = [1 if i in b_axes[1] else 0 for i in six.moves.range(b.ndim)] 43 table_b = numpy.cumsum(table_b) - 1 44 for i, c_axis in enumerate(c_axes[1]): 45 trans[c_axis] = table_b[b_axes[1][i]] + len(a_axes[0]) 46 for i, c_axis in enumerate(trans): 47 if i != c_axis: 48 y = xp.transpose(y, trans) 49 break 50 51 return y 52 53 54class TensorDot(function_node.FunctionNode): 55 56 def __init__(self, axes=2, a_axes=None, b_axes=None, c_axes=None, 57 dtype=None): 58 self.axes = axes 59 self.a_axes = a_axes 60 self.b_axes = b_axes 61 self.c_axes = c_axes 62 self.dtype = dtype 63 64 if isinstance(axes, collections_abc.Sequence): 65 if len(axes) != 2: 66 raise ValueError('axes must be a pair of sequence of integers ' 67 'when it is a list or tuple.') 68 elif isinstance(axes, six.integer_types): 69 pass 70 else: 71 raise TypeError('axes must be a pair of sequence of integers or ' 72 'an integer') 73 74 def check_type_forward(self, in_types): 75 type_check._argname(in_types, ('a', 'b')) 76 a_type, b_type = in_types 77 78 type_check.expect( 79 a_type.dtype.kind == 'f', 80 b_type.dtype.kind == 'f', 81 ) 82 83 def forward(self, inputs): 84 self.retain_inputs((0, 1)) 85 a, b = inputs 86 87 if self.a_axes is None or self.b_axes is None: 88 a_axes = [[], []] # 0:row axes, 1:col axes 89 b_axes = [[], []] # 0:row axes, 1:col axes 90 axes = self.axes 91 if isinstance(axes, collections_abc.Sequence): 92 a_axes[1], b_axes[0] = axes 93 if numpy.isscalar(a_axes[1]): 94 a_axes[1] = a_axes[1], 95 if numpy.isscalar(b_axes[0]): 96 b_axes[0] = b_axes[0], 97 else: 98 a_axes[1] = six.moves.range(a.ndim - axes, a.ndim) 99 b_axes[0] = six.moves.range(axes) 100 a_range = six.moves.range(a.ndim) 101 a_axes[0] = [i for i in a_range if i not in a_axes[1]] 102 b_range = six.moves.range(b.ndim) 103 b_axes[1] = [i for i in b_range if i not in b_axes[0]] 104 self.a_axes = a_axes 105 self.b_axes = b_axes 106 107 c = _tensordot(a, b, self.a_axes, self.b_axes, self.c_axes) 108 109 if self.c_axes is None: 110 c_axes = [[], []] # 0:row axes, 1:col axes 111 c_row_ndim = len(self.a_axes[0]) 112 c_col_ndim = len(self.b_axes[1]) 113 c_axes[0] = six.moves.range(c_row_ndim) 114 c_axes[1] = six.moves.range(c_row_ndim, c_row_ndim + c_col_ndim) 115 self.c_axes = c_axes 116 117 return utils.force_array(c, self.dtype), 118 119 def backward(self, indexes, grad_outputs): 120 a, b = self.get_retained_inputs() 121 gc, = grad_outputs 122 123 ga = None 124 if 0 in indexes: 125 ga, = TensorDot(a_axes=self.c_axes, 126 b_axes=[self.b_axes[1], self.b_axes[0]], 127 c_axes=self.a_axes, 128 dtype=a.dtype).apply((gc, b)) 129 130 gb = None 131 if 1 in indexes: 132 gb, = TensorDot(a_axes=[self.a_axes[1], self.a_axes[0]], 133 b_axes=self.c_axes, 134 c_axes=self.b_axes, 135 dtype=b.dtype).apply((a, gc)) 136 137 return ga, gb 138 139 140def tensordot(a, b, axes=2): 141 """Returns the tensor dot product of two arrays along specified axes. 142 143 This is equivalent to compute dot product along the specified axes which 144 are treated as one axis by reshaping. 145 146 Args: 147 a (:class:`~chainer.Variable` or :ref:`ndarray`): The first argument. 148 b (:class:`~chainer.Variable` or :ref:`ndarray`): The second argument. 149 axes: 150 - If it is an integer, then ``axes`` axes at the last of ``a`` and 151 the first of ``b`` are used. 152 - If it is a pair of sequences of integers, then these two 153 sequences specify the list of axes for ``a`` and ``b``. The 154 corresponding axes are paired for sum-product. 155 156 Returns: 157 ~chainer.Variable: The tensor dot product of ``a`` and ``b`` along the 158 axes specified by ``axes``. 159 160 .. admonition:: Example 161 162 >>> a = np.random.rand(5, 3, 2) 163 >>> b = np.random.rand(3, 2, 4) 164 >>> c = F.tensordot(a, b, axes=2) 165 >>> c.shape 166 (5, 4) 167 168 .. seealso:: :func:`numpy.tensordot` 169 170 """ 171 return TensorDot(axes=axes).apply((a, b))[0] 172