1import warnings
2
3import chainer
4from chainer import backend
5from chainer.backends import cuda
6from chainer import function_node
7from chainer import utils
8from chainer.utils import argument
9from chainer.utils import type_check
10
11
12def _enumerate_axes(subscripts):
13    if '@' in subscripts:
14        left_sub, right_sub = subscripts.split('@')
15        for i, s in enumerate(left_sub):
16            yield i, s
17        yield slice(len(left_sub), -len(right_sub) or None), '@'
18        for i, s in enumerate(right_sub):
19            yield i - len(right_sub), s
20    else:
21        for i, s in enumerate(subscripts):
22            yield i, s
23
24
25def _einsum(xp, dtype, in_subscripts, out_subscript, *inputs, **kwargs):
26    check_undefined_ellipsis_sum, = argument.parse_kwargs(
27        kwargs, ('check_undefined_ellipsis_sum', False))
28    sum_ellipsis = '@' in in_subscripts and '@' not in out_subscript
29    if sum_ellipsis:
30        # einsum does not usually allow summing over '...'
31        subscripts = '{}->...{}'.format(
32            in_subscripts.replace('@', '...'),
33            out_subscript
34        )
35    else:
36        subscripts = '{}->{}'.format(
37            in_subscripts,
38            out_subscript
39        ).replace('@', '...')
40
41    # Use optimize option whenever it is critical in speed.
42    # Otherwise avoid bugs in numpy>=1.12,<1.15.
43    einsum_kwargs = {}
44    if len(inputs) >= 3:
45        einsum_kwargs['optimize'] = True
46    try:
47        y = xp.einsum(subscripts, *inputs, **einsum_kwargs)
48    except TypeError:
49        warnings.warn(
50            '{xp}.einsum does not support optimize option. '
51            'Use newer version of {xp} to speed up.'
52            .format(xp=xp.__name__),
53            chainer.warnings.PerformanceWarning,
54        )
55        y = xp.einsum(subscripts, *inputs)
56
57    if sum_ellipsis:
58        sum_ndim = y.ndim - len(out_subscript)
59        if check_undefined_ellipsis_sum and sum_ndim > 0:
60            raise ValueError(
61                'einsum should not support summing over Ellipsis, '
62                'while NumPy 1.14 sometimes accidentally supports it. '
63                'This feature is no longer supported by Chainer. '
64                'See also NumPy issues #10926, #9984.',
65            )
66        y = xp.sum(y, axis=tuple(range(sum_ndim)))
67
68    return utils.force_array(y, dtype)
69
70
71class EinSum(function_node.FunctionNode):
72
73    def __init__(self, in_subs, out_sub):
74        self.in_subs = in_subs
75        self.out_sub = out_sub
76
77    def check_type_forward(self, in_types):
78        for i, in_type in enumerate(in_types):
79            type_check._argname((in_type,), ('x{}'.format(i),))
80            type_check.expect(in_type.dtype.kind == 'f')
81
82        in_subs = self.in_subs.split(',')
83        type_check.expect(in_types.size() == len(in_subs))
84
85        shape_dict = {}
86        for in_sub, in_type in zip(in_subs, in_types):
87            for axis, char in _enumerate_axes(in_sub):
88                shape = in_type.shape[axis]
89                if char in shape_dict:
90                    type_check.expect(shape_dict[char] == shape)
91                else:
92                    shape_dict[char] = shape
93
94    def forward(self, inputs):
95        n_args = len(inputs)
96        # TODO(kataoka): Do not retain inputs if n_args == 1
97        self.retain_inputs(tuple(range(n_args)))
98
99        xp = backend.get_array_module(inputs[0])
100        dtype = xp.result_type(*[x.dtype for x in inputs])
101        y = _einsum(xp, dtype, self.in_subs, self.out_sub, *inputs,
102                    check_undefined_ellipsis_sum=True)
103        return y,
104
105    def backward(self, indices, grad_outputs):
106        inputs = self.get_retained_inputs()
107        g, = grad_outputs
108
109        fwd_in_subs = self.in_subs.split(',')
110        fwd_out_sub = self.out_sub
111        return tuple(
112            DiagEinSum(
113                in_subs=','.join([
114                    (fwd_out_sub if j == i else s)
115                    for j, s in enumerate(fwd_in_subs)
116                ]),
117                out_sub=fwd_in_subs[i],
118                out_shape=inputs[i].shape,
119            ).apply(tuple(
120                (g if j == i else x)
121                for j, x in enumerate(inputs)
122            ))[0]
123            for i in indices
124        )
125
126
127class DiagEinSum(EinSum):
128
129    def __init__(self, in_subs, out_sub, out_shape):
130        self.in_subs = in_subs
131        self.out_sub = out_sub
132        self.out_shape = out_shape
133
134    def forward(self, inputs):
135        n_args = len(inputs)
136        # TODO(kataoka): Do not retain inputs if n_args == 1
137        self.retain_inputs(tuple(range(n_args)))
138
139        xp = backend.get_array_module(inputs[0])
140        dtype = xp.result_type(*[x.dtype for x in inputs])
141
142        out_set = set(self.out_sub)
143
144        # '@' is a single char, ',' is excluded.
145        io_set = out_set.intersection(set(self.in_subs))
146
147        if len(io_set) == len(self.out_sub):
148            y = _einsum(xp, dtype, self.in_subs, self.out_sub, *inputs)
149        else:
150            direct_sub = []
151            inverse_sub = []
152            expander = []
153            for c in sorted(out_set):
154                if c in io_set:
155                    direct_sub.append(c)
156                    expander.append(slice(None))
157                else:
158                    expander.append(None)
159                inverse_sub.append(c)
160
161            y = xp.zeros(self.out_shape, dtype)
162            diag_y = _einsum(
163                xp, dtype, self.out_sub, ''.join(inverse_sub), y)
164            if diag_y.base is not y:
165                raise ValueError('Update CuPy to close CuPy Issue #1199')
166            # Make the view writeable as numpy PR #5410 for numpy<1.10.
167            if xp is not cuda.cupy:  # no setflags in cupy
168                diag_y.setflags(write=True)
169            diag_y[...] = _einsum(
170                xp, dtype, self.in_subs, ''.join(direct_sub), *inputs
171            )[tuple(expander)]
172        return y,
173
174
175def einsum(*operands):
176    """Einstein summation
177
178    This function supports two formats of inputs:
179
180    - ``einsum(subscripts, op0, op1, ...)``
181    - ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``
182
183    See also :func:`numpy.einsum`
184
185    .. admonition:: Example
186
187        The following example computes a batched application of a bilinear
188        function with weight ``w``.
189
190        >>> x1 = np.arange(12).reshape(3, 4).astype(np.float32)
191        >>> x2 = np.arange(15).reshape(3, 5).astype(np.float32)
192        >>> w = np.arange(120).reshape(4, 5, 6).astype(np.float32)
193        >>> y = F.einsum('ij,ik,jkl->il', x1, x2, w)
194        >>> y.shape
195        (3, 6)
196
197        The batch axes can be denoted by ``...``. If the string of output
198        subscripts is omitted, the summation is taken over the subscript
199        alphabets with two (or more) occurrences.
200
201        >>> np.allclose(y.array, F.einsum('...j,...k,jkl', x1, x2, w).array)
202        True
203
204        In the other format:
205
206        >>> y = F.einsum(x1, [0, 1], x2, [0, 2], w, [1, 2, 3], [0, 3])
207        >>> y.shape
208        (3, 6)
209        >>> y = F.einsum(x1, [Ellipsis, 1], x2, [Ellipsis, 2], w, [1, 2, 3])
210        >>> y.shape
211        (3, 6)
212
213    """
214    input_subscripts, output_subscript, ioperands = \
215        _parse_einsum_input(operands)
216    return EinSum(
217        in_subs=input_subscripts,
218        out_sub=output_subscript,
219    ).apply(ioperands)[0]
220
221
222# #################### cupy.linalg.einsum ####################
223# From cupy PR #873
224
225einsum_symbols = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
226einsum_symbols_set = set(einsum_symbols)
227
228
229def _parse_einsum_input(operands):
230    """Parses einsum operands.
231
232    This function is based on `numpy.core.einsumfunc._parse_einsum_input`
233    function in NumPy 1.14.
234
235    Returns
236    -------
237    input_strings : str
238        Parsed input strings
239    output_string : str
240        Parsed output string
241    operands : list of array_like
242        The operands to use in the numpy contraction
243
244    Examples
245    --------
246    The operand list is simplified to reduce printing:
247
248    >>> a = np.random.rand(4, 4)
249    >>> b = np.random.rand(4, 4, 4)
250    >>> _parse_einsum_input(('...a,...a->...', a, b))
251    ('@a,@a', '@', [a, b])
252
253    >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
254    ('@a,@a', '@', [a, b])
255    """
256
257    if not operands:
258        raise ValueError('No input operands')
259
260    if isinstance(operands[0], str):
261        subscripts = operands[0].replace(' ', '')
262        operands = operands[1:]
263
264        # Ensure all characters are valid
265        for s in subscripts:
266            if s in '.,->':
267                continue
268            if s not in einsum_symbols:
269                raise ValueError('Character %s is not a valid symbol.' % s)
270
271        # Check for proper "->"
272        if ('-' in subscripts) or ('>' in subscripts):
273            if any((
274                    subscripts.count('-') > 1,
275                    subscripts.count('>') > 1,
276                    subscripts.count('->') != 1,
277            )):
278                raise ValueError('Subscripts can only contain one \'->\'.')
279
280        # Parse "..."
281        subscripts = subscripts.replace('...', '@')
282        if '.' in subscripts:
283            raise ValueError('Invalid Ellipses.')
284
285    else:
286        tmp_operands = list(operands)
287        operand_list = []
288        subscript_list = []
289        for p in range(len(operands) // 2):
290            operand_list.append(tmp_operands.pop(0))
291            subscript_list.append(tmp_operands.pop(0))
292
293        output_list = tmp_operands[-1] if len(tmp_operands) else None
294        operands = operand_list
295        subscripts = ''
296        last = len(subscript_list) - 1
297        for num, sub in enumerate(subscript_list):
298            for s in sub:
299                if s is Ellipsis:
300                    subscripts += '@'
301                elif isinstance(s, int):
302                    subscripts += einsum_symbols[s]
303                else:
304                    raise TypeError('For this input type lists must contain '
305                                    'either int or Ellipsis')
306            if num != last:
307                subscripts += ','
308
309        if output_list is not None:
310            subscripts += '->'
311            for s in output_list:
312                if s is Ellipsis:
313                    subscripts += '@'
314                elif isinstance(s, int):
315                    subscripts += einsum_symbols[s]
316                else:
317                    raise TypeError('For this input type lists must contain '
318                                    'either int or Ellipsis')
319
320    # Build output string if does not exist
321    if '->' in subscripts:
322        input_subscripts, output_subscript = subscripts.split('->')
323
324        # Make sure output subscripts are in the input
325        for char in output_subscript:
326            if char not in input_subscripts:
327                raise ValueError(
328                    'Output character %s did not appear in the input'
329                    % ('...' if char == '@' else char))
330
331    else:
332        input_subscripts = subscripts
333        # Build output subscripts
334        tmp_subscripts = subscripts.replace(',', '')
335        output_subscript = ''
336        for s in sorted(set(tmp_subscripts)):
337            if s == '@' or tmp_subscripts.count(s) == 1:
338                output_subscript += s
339
340    # Make sure number operands is equivalent to the number of terms
341    if len(input_subscripts.split(',')) != len(operands):
342        raise ValueError('Number of einsum subscripts must be equal to the '
343                         'number of operands.')
344
345    return input_subscripts, output_subscript, operands
346