1import warnings 2 3import chainer 4from chainer import backend 5from chainer.backends import cuda 6from chainer import function_node 7from chainer import utils 8from chainer.utils import argument 9from chainer.utils import type_check 10 11 12def _enumerate_axes(subscripts): 13 if '@' in subscripts: 14 left_sub, right_sub = subscripts.split('@') 15 for i, s in enumerate(left_sub): 16 yield i, s 17 yield slice(len(left_sub), -len(right_sub) or None), '@' 18 for i, s in enumerate(right_sub): 19 yield i - len(right_sub), s 20 else: 21 for i, s in enumerate(subscripts): 22 yield i, s 23 24 25def _einsum(xp, dtype, in_subscripts, out_subscript, *inputs, **kwargs): 26 check_undefined_ellipsis_sum, = argument.parse_kwargs( 27 kwargs, ('check_undefined_ellipsis_sum', False)) 28 sum_ellipsis = '@' in in_subscripts and '@' not in out_subscript 29 if sum_ellipsis: 30 # einsum does not usually allow summing over '...' 31 subscripts = '{}->...{}'.format( 32 in_subscripts.replace('@', '...'), 33 out_subscript 34 ) 35 else: 36 subscripts = '{}->{}'.format( 37 in_subscripts, 38 out_subscript 39 ).replace('@', '...') 40 41 # Use optimize option whenever it is critical in speed. 42 # Otherwise avoid bugs in numpy>=1.12,<1.15. 43 einsum_kwargs = {} 44 if len(inputs) >= 3: 45 einsum_kwargs['optimize'] = True 46 try: 47 y = xp.einsum(subscripts, *inputs, **einsum_kwargs) 48 except TypeError: 49 warnings.warn( 50 '{xp}.einsum does not support optimize option. ' 51 'Use newer version of {xp} to speed up.' 52 .format(xp=xp.__name__), 53 chainer.warnings.PerformanceWarning, 54 ) 55 y = xp.einsum(subscripts, *inputs) 56 57 if sum_ellipsis: 58 sum_ndim = y.ndim - len(out_subscript) 59 if check_undefined_ellipsis_sum and sum_ndim > 0: 60 raise ValueError( 61 'einsum should not support summing over Ellipsis, ' 62 'while NumPy 1.14 sometimes accidentally supports it. ' 63 'This feature is no longer supported by Chainer. ' 64 'See also NumPy issues #10926, #9984.', 65 ) 66 y = xp.sum(y, axis=tuple(range(sum_ndim))) 67 68 return utils.force_array(y, dtype) 69 70 71class EinSum(function_node.FunctionNode): 72 73 def __init__(self, in_subs, out_sub): 74 self.in_subs = in_subs 75 self.out_sub = out_sub 76 77 def check_type_forward(self, in_types): 78 for i, in_type in enumerate(in_types): 79 type_check._argname((in_type,), ('x{}'.format(i),)) 80 type_check.expect(in_type.dtype.kind == 'f') 81 82 in_subs = self.in_subs.split(',') 83 type_check.expect(in_types.size() == len(in_subs)) 84 85 shape_dict = {} 86 for in_sub, in_type in zip(in_subs, in_types): 87 for axis, char in _enumerate_axes(in_sub): 88 shape = in_type.shape[axis] 89 if char in shape_dict: 90 type_check.expect(shape_dict[char] == shape) 91 else: 92 shape_dict[char] = shape 93 94 def forward(self, inputs): 95 n_args = len(inputs) 96 # TODO(kataoka): Do not retain inputs if n_args == 1 97 self.retain_inputs(tuple(range(n_args))) 98 99 xp = backend.get_array_module(inputs[0]) 100 dtype = xp.result_type(*[x.dtype for x in inputs]) 101 y = _einsum(xp, dtype, self.in_subs, self.out_sub, *inputs, 102 check_undefined_ellipsis_sum=True) 103 return y, 104 105 def backward(self, indices, grad_outputs): 106 inputs = self.get_retained_inputs() 107 g, = grad_outputs 108 109 fwd_in_subs = self.in_subs.split(',') 110 fwd_out_sub = self.out_sub 111 return tuple( 112 DiagEinSum( 113 in_subs=','.join([ 114 (fwd_out_sub if j == i else s) 115 for j, s in enumerate(fwd_in_subs) 116 ]), 117 out_sub=fwd_in_subs[i], 118 out_shape=inputs[i].shape, 119 ).apply(tuple( 120 (g if j == i else x) 121 for j, x in enumerate(inputs) 122 ))[0] 123 for i in indices 124 ) 125 126 127class DiagEinSum(EinSum): 128 129 def __init__(self, in_subs, out_sub, out_shape): 130 self.in_subs = in_subs 131 self.out_sub = out_sub 132 self.out_shape = out_shape 133 134 def forward(self, inputs): 135 n_args = len(inputs) 136 # TODO(kataoka): Do not retain inputs if n_args == 1 137 self.retain_inputs(tuple(range(n_args))) 138 139 xp = backend.get_array_module(inputs[0]) 140 dtype = xp.result_type(*[x.dtype for x in inputs]) 141 142 out_set = set(self.out_sub) 143 144 # '@' is a single char, ',' is excluded. 145 io_set = out_set.intersection(set(self.in_subs)) 146 147 if len(io_set) == len(self.out_sub): 148 y = _einsum(xp, dtype, self.in_subs, self.out_sub, *inputs) 149 else: 150 direct_sub = [] 151 inverse_sub = [] 152 expander = [] 153 for c in sorted(out_set): 154 if c in io_set: 155 direct_sub.append(c) 156 expander.append(slice(None)) 157 else: 158 expander.append(None) 159 inverse_sub.append(c) 160 161 y = xp.zeros(self.out_shape, dtype) 162 diag_y = _einsum( 163 xp, dtype, self.out_sub, ''.join(inverse_sub), y) 164 if diag_y.base is not y: 165 raise ValueError('Update CuPy to close CuPy Issue #1199') 166 # Make the view writeable as numpy PR #5410 for numpy<1.10. 167 if xp is not cuda.cupy: # no setflags in cupy 168 diag_y.setflags(write=True) 169 diag_y[...] = _einsum( 170 xp, dtype, self.in_subs, ''.join(direct_sub), *inputs 171 )[tuple(expander)] 172 return y, 173 174 175def einsum(*operands): 176 """Einstein summation 177 178 This function supports two formats of inputs: 179 180 - ``einsum(subscripts, op0, op1, ...)`` 181 - ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])`` 182 183 See also :func:`numpy.einsum` 184 185 .. admonition:: Example 186 187 The following example computes a batched application of a bilinear 188 function with weight ``w``. 189 190 >>> x1 = np.arange(12).reshape(3, 4).astype(np.float32) 191 >>> x2 = np.arange(15).reshape(3, 5).astype(np.float32) 192 >>> w = np.arange(120).reshape(4, 5, 6).astype(np.float32) 193 >>> y = F.einsum('ij,ik,jkl->il', x1, x2, w) 194 >>> y.shape 195 (3, 6) 196 197 The batch axes can be denoted by ``...``. If the string of output 198 subscripts is omitted, the summation is taken over the subscript 199 alphabets with two (or more) occurrences. 200 201 >>> np.allclose(y.array, F.einsum('...j,...k,jkl', x1, x2, w).array) 202 True 203 204 In the other format: 205 206 >>> y = F.einsum(x1, [0, 1], x2, [0, 2], w, [1, 2, 3], [0, 3]) 207 >>> y.shape 208 (3, 6) 209 >>> y = F.einsum(x1, [Ellipsis, 1], x2, [Ellipsis, 2], w, [1, 2, 3]) 210 >>> y.shape 211 (3, 6) 212 213 """ 214 input_subscripts, output_subscript, ioperands = \ 215 _parse_einsum_input(operands) 216 return EinSum( 217 in_subs=input_subscripts, 218 out_sub=output_subscript, 219 ).apply(ioperands)[0] 220 221 222# #################### cupy.linalg.einsum #################### 223# From cupy PR #873 224 225einsum_symbols = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' 226einsum_symbols_set = set(einsum_symbols) 227 228 229def _parse_einsum_input(operands): 230 """Parses einsum operands. 231 232 This function is based on `numpy.core.einsumfunc._parse_einsum_input` 233 function in NumPy 1.14. 234 235 Returns 236 ------- 237 input_strings : str 238 Parsed input strings 239 output_string : str 240 Parsed output string 241 operands : list of array_like 242 The operands to use in the numpy contraction 243 244 Examples 245 -------- 246 The operand list is simplified to reduce printing: 247 248 >>> a = np.random.rand(4, 4) 249 >>> b = np.random.rand(4, 4, 4) 250 >>> _parse_einsum_input(('...a,...a->...', a, b)) 251 ('@a,@a', '@', [a, b]) 252 253 >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0])) 254 ('@a,@a', '@', [a, b]) 255 """ 256 257 if not operands: 258 raise ValueError('No input operands') 259 260 if isinstance(operands[0], str): 261 subscripts = operands[0].replace(' ', '') 262 operands = operands[1:] 263 264 # Ensure all characters are valid 265 for s in subscripts: 266 if s in '.,->': 267 continue 268 if s not in einsum_symbols: 269 raise ValueError('Character %s is not a valid symbol.' % s) 270 271 # Check for proper "->" 272 if ('-' in subscripts) or ('>' in subscripts): 273 if any(( 274 subscripts.count('-') > 1, 275 subscripts.count('>') > 1, 276 subscripts.count('->') != 1, 277 )): 278 raise ValueError('Subscripts can only contain one \'->\'.') 279 280 # Parse "..." 281 subscripts = subscripts.replace('...', '@') 282 if '.' in subscripts: 283 raise ValueError('Invalid Ellipses.') 284 285 else: 286 tmp_operands = list(operands) 287 operand_list = [] 288 subscript_list = [] 289 for p in range(len(operands) // 2): 290 operand_list.append(tmp_operands.pop(0)) 291 subscript_list.append(tmp_operands.pop(0)) 292 293 output_list = tmp_operands[-1] if len(tmp_operands) else None 294 operands = operand_list 295 subscripts = '' 296 last = len(subscript_list) - 1 297 for num, sub in enumerate(subscript_list): 298 for s in sub: 299 if s is Ellipsis: 300 subscripts += '@' 301 elif isinstance(s, int): 302 subscripts += einsum_symbols[s] 303 else: 304 raise TypeError('For this input type lists must contain ' 305 'either int or Ellipsis') 306 if num != last: 307 subscripts += ',' 308 309 if output_list is not None: 310 subscripts += '->' 311 for s in output_list: 312 if s is Ellipsis: 313 subscripts += '@' 314 elif isinstance(s, int): 315 subscripts += einsum_symbols[s] 316 else: 317 raise TypeError('For this input type lists must contain ' 318 'either int or Ellipsis') 319 320 # Build output string if does not exist 321 if '->' in subscripts: 322 input_subscripts, output_subscript = subscripts.split('->') 323 324 # Make sure output subscripts are in the input 325 for char in output_subscript: 326 if char not in input_subscripts: 327 raise ValueError( 328 'Output character %s did not appear in the input' 329 % ('...' if char == '@' else char)) 330 331 else: 332 input_subscripts = subscripts 333 # Build output subscripts 334 tmp_subscripts = subscripts.replace(',', '') 335 output_subscript = '' 336 for s in sorted(set(tmp_subscripts)): 337 if s == '@' or tmp_subscripts.count(s) == 1: 338 output_subscript += s 339 340 # Make sure number operands is equivalent to the number of terms 341 if len(input_subscripts.split(',')) != len(operands): 342 raise ValueError('Number of einsum subscripts must be equal to the ' 343 'number of operands.') 344 345 return input_subscripts, output_subscript, operands 346