1import numpy
2import six
3
4from chainer import backend
5from chainer import function_node
6from chainer import utils
7from chainer.utils import collections_abc
8from chainer.utils import type_check
9
10
11def _tensordot(a, b, a_axes, b_axes, c_axes=None):
12    a_col_ndim = len(a_axes[1])
13    b_row_ndim = len(b_axes[0])
14    if a_col_ndim != b_row_ndim:
15        raise ValueError('axes count mismatch')
16    if a.ndim < a_col_ndim or b.ndim < b_row_ndim:
17        raise ValueError('dimension of input tensors must be '
18                         'greater equal to dot-axes count ({})'
19                         .format(a_col_ndim))
20    for a_axis, b_axis in zip(a_axes[1], b_axes[0]):
21        if a.shape[a_axis] != b.shape[b_axis]:
22            raise ValueError('shape mismatch')
23
24    xp = backend.get_array_module(a)
25    y = xp.tensordot(a, b, axes=(tuple(a_axes[1]), tuple(b_axes[0])))
26
27    if c_axes is not None:
28        a_row_ndim = len(a_axes[0])
29        b_col_ndim = len(b_axes[1])
30        c_row_ndim = len(c_axes[0])
31        c_col_ndim = len(c_axes[1])
32        if a_row_ndim != c_row_ndim:
33            raise ValueError('axes count mismatch')
34        if b_col_ndim != c_col_ndim:
35            raise ValueError('axes count mismatch')
36
37        trans = [None for i in six.moves.range(y.ndim)]
38        table_a = [1 if i in a_axes[0] else 0 for i in six.moves.range(a.ndim)]
39        table_a = numpy.cumsum(table_a) - 1
40        for i, c_axis in enumerate(c_axes[0]):
41            trans[c_axis] = table_a[a_axes[0][i]]
42        table_b = [1 if i in b_axes[1] else 0 for i in six.moves.range(b.ndim)]
43        table_b = numpy.cumsum(table_b) - 1
44        for i, c_axis in enumerate(c_axes[1]):
45            trans[c_axis] = table_b[b_axes[1][i]] + len(a_axes[0])
46        for i, c_axis in enumerate(trans):
47            if i != c_axis:
48                y = xp.transpose(y, trans)
49                break
50
51    return y
52
53
54class TensorDot(function_node.FunctionNode):
55
56    def __init__(self, axes=2, a_axes=None, b_axes=None, c_axes=None,
57                 dtype=None):
58        self.axes = axes
59        self.a_axes = a_axes
60        self.b_axes = b_axes
61        self.c_axes = c_axes
62        self.dtype = dtype
63
64        if isinstance(axes, collections_abc.Sequence):
65            if len(axes) != 2:
66                raise ValueError('axes must be a pair of sequence of integers '
67                                 'when it is a list or tuple.')
68        elif isinstance(axes, six.integer_types):
69            pass
70        else:
71            raise TypeError('axes must be a pair of sequence of integers or '
72                            'an integer')
73
74    def check_type_forward(self, in_types):
75        type_check._argname(in_types, ('a', 'b'))
76        a_type, b_type = in_types
77
78        type_check.expect(
79            a_type.dtype.kind == 'f',
80            b_type.dtype.kind == 'f',
81        )
82
83    def forward(self, inputs):
84        self.retain_inputs((0, 1))
85        a, b = inputs
86
87        if self.a_axes is None or self.b_axes is None:
88            a_axes = [[], []]  # 0:row axes, 1:col axes
89            b_axes = [[], []]  # 0:row axes, 1:col axes
90            axes = self.axes
91            if isinstance(axes, collections_abc.Sequence):
92                a_axes[1], b_axes[0] = axes
93                if numpy.isscalar(a_axes[1]):
94                    a_axes[1] = a_axes[1],
95                if numpy.isscalar(b_axes[0]):
96                    b_axes[0] = b_axes[0],
97            else:
98                a_axes[1] = six.moves.range(a.ndim - axes, a.ndim)
99                b_axes[0] = six.moves.range(axes)
100            a_range = six.moves.range(a.ndim)
101            a_axes[0] = [i for i in a_range if i not in a_axes[1]]
102            b_range = six.moves.range(b.ndim)
103            b_axes[1] = [i for i in b_range if i not in b_axes[0]]
104            self.a_axes = a_axes
105            self.b_axes = b_axes
106
107        c = _tensordot(a, b, self.a_axes, self.b_axes, self.c_axes)
108
109        if self.c_axes is None:
110            c_axes = [[], []]  # 0:row axes, 1:col axes
111            c_row_ndim = len(self.a_axes[0])
112            c_col_ndim = len(self.b_axes[1])
113            c_axes[0] = six.moves.range(c_row_ndim)
114            c_axes[1] = six.moves.range(c_row_ndim, c_row_ndim + c_col_ndim)
115            self.c_axes = c_axes
116
117        return utils.force_array(c, self.dtype),
118
119    def backward(self, indexes, grad_outputs):
120        a, b = self.get_retained_inputs()
121        gc, = grad_outputs
122
123        ga = None
124        if 0 in indexes:
125            ga, = TensorDot(a_axes=self.c_axes,
126                            b_axes=[self.b_axes[1], self.b_axes[0]],
127                            c_axes=self.a_axes,
128                            dtype=a.dtype).apply((gc, b))
129
130        gb = None
131        if 1 in indexes:
132            gb, = TensorDot(a_axes=[self.a_axes[1], self.a_axes[0]],
133                            b_axes=self.c_axes,
134                            c_axes=self.b_axes,
135                            dtype=b.dtype).apply((a, gc))
136
137        return ga, gb
138
139
140def tensordot(a, b, axes=2):
141    """Returns the tensor dot product of two arrays along specified axes.
142
143    This is equivalent to compute dot product along the specified axes which
144    are treated as one axis by reshaping.
145
146    Args:
147        a (:class:`~chainer.Variable` or :ref:`ndarray`): The first argument.
148        b (:class:`~chainer.Variable` or :ref:`ndarray`): The second argument.
149        axes:
150            - If it is an integer, then ``axes`` axes at the last of ``a`` and
151              the first of ``b`` are used.
152            - If it is a pair of sequences of integers, then these two
153              sequences specify the list of axes for ``a`` and ``b``. The
154              corresponding axes are paired for sum-product.
155
156    Returns:
157        ~chainer.Variable: The tensor dot product of ``a`` and ``b`` along the
158        axes specified by ``axes``.
159
160    .. admonition:: Example
161
162        >>> a = np.random.rand(5, 3, 2)
163        >>> b = np.random.rand(3, 2, 4)
164        >>> c = F.tensordot(a, b, axes=2)
165        >>> c.shape
166        (5, 4)
167
168    .. seealso:: :func:`numpy.tensordot`
169
170    """
171    return TensorDot(axes=axes).apply((a, b))[0]
172