1import itertools 2import weakref 3from collections import Counter 4 5import numpy as np 6import pytest 7 8from opt_einsum import (contract, contract_expression, contract_path, get_symbol, helpers, shared_intermediates) 9from opt_einsum.backends import to_cupy, to_torch 10from opt_einsum.contract import _einsum 11from opt_einsum.parser import parse_einsum_input 12from opt_einsum.sharing import (count_cached_ops, currently_sharing, get_sharing_cache) 13 14try: 15 import cupy 16 cupy_if_found = 'cupy' 17except ImportError: 18 cupy_if_found = pytest.param('cupy', marks=[pytest.mark.skip(reason="CuPy not installed.")]) 19 20try: 21 import torch 22 torch_if_found = 'torch' 23except ImportError: 24 torch_if_found = pytest.param('torch', marks=[pytest.mark.skip(reason="PyTorch not installed.")]) 25 26backends = ['numpy', torch_if_found, cupy_if_found] 27equations = [ 28 'ab,bc->ca', 29 'abc,bcd,dea', 30 'abc,def->fedcba', 31 'abc,bcd,df->fa', 32 # test 'prefer einsum' ops 33 'ijk,ikj', 34 'i,j->ij', 35 'ijk,k->ij', 36 'AB,BC->CA', 37] 38to_backend = { 39 'numpy': lambda x: x, 40 'torch': to_torch, 41 'cupy': to_cupy, 42} 43 44 45@pytest.mark.parametrize('eq', equations) 46@pytest.mark.parametrize('backend', backends) 47def test_sharing_value(eq, backend): 48 views = helpers.build_views(eq) 49 shapes = [v.shape for v in views] 50 expr = contract_expression(eq, *shapes) 51 52 expected = expr(*views, backend=backend) 53 with shared_intermediates(): 54 actual = expr(*views, backend=backend) 55 56 assert (actual == expected).all() 57 58 59@pytest.mark.parametrize('backend', backends) 60def test_complete_sharing(backend): 61 eq = 'ab,bc,cd->' 62 views = helpers.build_views(eq) 63 expr = contract_expression(eq, *(v.shape for v in views)) 64 65 print('-' * 40) 66 print('Without sharing:') 67 with shared_intermediates() as cache: 68 expr(*views, backend=backend) 69 expected = count_cached_ops(cache) 70 71 print('-' * 40) 72 print('With sharing:') 73 with shared_intermediates() as cache: 74 expr(*views, backend=backend) 75 expr(*views, backend=backend) 76 actual = count_cached_ops(cache) 77 78 print('-' * 40) 79 print('Without sharing: {} expressions'.format(expected)) 80 print('With sharing: {} expressions'.format(actual)) 81 assert actual == expected 82 83 84@pytest.mark.parametrize('backend', backends) 85def test_sharing_reused_cache(backend): 86 eq = 'ab,bc,cd->' 87 views = helpers.build_views(eq) 88 expr = contract_expression(eq, *(v.shape for v in views)) 89 90 print('-' * 40) 91 print('Without sharing:') 92 with shared_intermediates() as cache: 93 expr(*views, backend=backend) 94 expected = count_cached_ops(cache) 95 96 print('-' * 40) 97 print('With sharing:') 98 with shared_intermediates() as cache: 99 expr(*views, backend=backend) 100 with shared_intermediates(cache): 101 expr(*views, backend=backend) 102 actual = count_cached_ops(cache) 103 104 print('-' * 40) 105 print('Without sharing: {} expressions'.format(expected)) 106 print('With sharing: {} expressions'.format(actual)) 107 assert actual == expected 108 109 110@pytest.mark.parametrize('backend', backends) 111def test_no_sharing_separate_cache(backend): 112 eq = 'ab,bc,cd->' 113 views = helpers.build_views(eq) 114 expr = contract_expression(eq, *(v.shape for v in views)) 115 116 print('-' * 40) 117 print('Without sharing:') 118 with shared_intermediates() as cache: 119 expr(*views, backend=backend) 120 expected = count_cached_ops(cache) 121 expected.update(count_cached_ops(cache)) # we expect double 122 123 print('-' * 40) 124 print('With sharing:') 125 with shared_intermediates() as cache1: 126 expr(*views, backend=backend) 127 actual = count_cached_ops(cache1) 128 with shared_intermediates() as cache2: 129 expr(*views, backend=backend) 130 actual.update(count_cached_ops(cache2)) 131 132 print('-' * 40) 133 print('Without sharing: {} expressions'.format(expected)) 134 print('With sharing: {} expressions'.format(actual)) 135 assert actual == expected 136 137 138@pytest.mark.parametrize('backend', backends) 139def test_sharing_nesting(backend): 140 eqs = ['ab,bc,cd->a', 'ab,bc,cd->b', 'ab,bc,cd->c', 'ab,bc,cd->c'] 141 views = helpers.build_views(eqs[0]) 142 shapes = [v.shape for v in views] 143 refs = weakref.WeakValueDictionary() 144 145 def method1(views): 146 with shared_intermediates(): 147 w = contract_expression(eqs[0], *shapes)(*views, backend=backend) 148 x = contract_expression(eqs[2], *shapes)(*views, backend=backend) 149 result = contract_expression('a,b->', w.shape, x.shape)(w, x, backend=backend) 150 refs['w'] = w 151 refs['x'] = x 152 del w, x 153 assert 'w' in refs 154 assert 'x' in refs 155 assert 'w' not in refs, 'cache leakage' 156 assert 'x' not in refs, 'cache leakage' 157 return result 158 159 def method2(views): 160 with shared_intermediates(): 161 y = contract_expression(eqs[2], *shapes)(*views, backend=backend) 162 z = contract_expression(eqs[3], *shapes)(*views, backend=backend) 163 refs['y'] = y 164 refs['z'] = z 165 result = contract_expression('c,d->', y.shape, z.shape)(y, z, backend=backend) 166 result = result + method1(views) # nest method1 in method2 167 del y, z 168 assert 'y' in refs 169 assert 'z' in refs 170 assert 'y' not in refs 171 assert 'z' not in refs 172 173 method1(views) 174 method2(views) 175 176 177@pytest.mark.parametrize('eq', equations) 178@pytest.mark.parametrize('backend', backends) 179def test_sharing_modulo_commutativity(eq, backend): 180 ops = helpers.build_views(eq) 181 ops = [to_backend[backend](x) for x in ops] 182 inputs, output, _ = parse_einsum_input([eq] + ops) 183 inputs = inputs.split(',') 184 185 print('-' * 40) 186 print('Without sharing:') 187 with shared_intermediates() as cache: 188 _einsum(eq, *ops, backend=backend) 189 expected = count_cached_ops(cache) 190 191 print('-' * 40) 192 print('With sharing:') 193 with shared_intermediates() as cache: 194 for permuted in itertools.permutations(zip(inputs, ops)): 195 permuted_inputs = [p[0] for p in permuted] 196 permuted_ops = [p[1] for p in permuted] 197 permuted_eq = '{}->{}'.format(','.join(permuted_inputs), output) 198 _einsum(permuted_eq, *permuted_ops, backend=backend) 199 actual = count_cached_ops(cache) 200 201 print('-' * 40) 202 print('Without sharing: {} expressions'.format(expected)) 203 print('With sharing: {} expressions'.format(actual)) 204 assert actual == expected 205 206 207@pytest.mark.parametrize('backend', backends) 208def test_partial_sharing(backend): 209 eq = 'ab,bc,de->' 210 x, y, z1 = helpers.build_views(eq) 211 z2 = 2.0 * z1 - 1.0 212 expr = contract_expression(eq, x.shape, y.shape, z1.shape) 213 214 print('-' * 40) 215 print('Without sharing:') 216 num_exprs_nosharing = Counter() 217 with shared_intermediates() as cache: 218 expr(x, y, z1, backend=backend) 219 num_exprs_nosharing.update(count_cached_ops(cache)) 220 with shared_intermediates() as cache: 221 expr(x, y, z2, backend=backend) 222 num_exprs_nosharing.update(count_cached_ops(cache)) 223 224 print('-' * 40) 225 print('With sharing:') 226 with shared_intermediates() as cache: 227 expr(x, y, z1, backend=backend) 228 expr(x, y, z2, backend=backend) 229 num_exprs_sharing = count_cached_ops(cache) 230 231 print('-' * 40) 232 print('Without sharing: {} expressions'.format(num_exprs_nosharing)) 233 print('With sharing: {} expressions'.format(num_exprs_sharing)) 234 assert num_exprs_nosharing['einsum'] > num_exprs_sharing['einsum'] 235 236 237@pytest.mark.parametrize('backend', backends) 238def test_sharing_with_constants(backend): 239 inputs = 'ij,jk,kl' 240 outputs = 'ijkl' 241 equations = ['{}->{}'.format(inputs, output) for output in outputs] 242 shapes = (2, 3), (3, 4), (4, 5) 243 constants = {0, 2} 244 ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)] 245 var = np.random.rand(*shapes[1]) 246 247 expected = [contract_expression(eq, *shapes)(ops[0], var, ops[2]) for eq in equations] 248 249 with shared_intermediates(): 250 actual = [contract_expression(eq, *ops, constants=constants)(var) for eq in equations] 251 252 for dim, expected_dim, actual_dim in zip(outputs, expected, actual): 253 assert np.allclose(expected_dim, actual_dim), 'error at {}'.format(dim) 254 255 256@pytest.mark.parametrize('size', [3, 4, 5]) 257@pytest.mark.parametrize('backend', backends) 258def test_chain(size, backend): 259 xs = [np.random.rand(2, 2) for _ in range(size)] 260 shapes = [x.shape for x in xs] 261 alphabet = ''.join(get_symbol(i) for i in range(size + 1)) 262 names = [alphabet[i:i + 2] for i in range(size)] 263 inputs = ','.join(names) 264 265 with shared_intermediates(): 266 print(inputs) 267 for i in range(size + 1): 268 target = alphabet[i] 269 eq = '{}->{}'.format(inputs, target) 270 path_info = contract_path(eq, *xs) 271 print(path_info[1]) 272 expr = contract_expression(eq, *shapes) 273 expr(*xs, backend=backend) 274 print('-' * 40) 275 276 277@pytest.mark.parametrize('size', [3, 4, 5, 10]) 278@pytest.mark.parametrize('backend', backends) 279def test_chain_2(size, backend): 280 xs = [np.random.rand(2, 2) for _ in range(size)] 281 shapes = [x.shape for x in xs] 282 alphabet = ''.join(get_symbol(i) for i in range(size + 1)) 283 names = [alphabet[i:i + 2] for i in range(size)] 284 inputs = ','.join(names) 285 286 with shared_intermediates(): 287 print(inputs) 288 for i in range(size): 289 target = alphabet[i:i + 2] 290 eq = '{}->{}'.format(inputs, target) 291 path_info = contract_path(eq, *xs) 292 print(path_info[1]) 293 expr = contract_expression(eq, *shapes) 294 expr(*xs, backend=backend) 295 print('-' * 40) 296 297 298def _compute_cost(cache): 299 counts = count_cached_ops(cache) 300 return counts['einsum'] + counts['tensordot'] 301 302 303@pytest.mark.parametrize('backend', backends) 304def test_chain_2_growth(backend): 305 sizes = list(range(1, 21)) 306 costs = [] 307 for size in sizes: 308 xs = [np.random.rand(2, 2) for _ in range(size)] 309 alphabet = ''.join(get_symbol(i) for i in range(size + 1)) 310 names = [alphabet[i:i + 2] for i in range(size)] 311 inputs = ','.join(names) 312 313 with shared_intermediates() as cache: 314 for i in range(size): 315 target = alphabet[i:i + 2] 316 eq = '{}->{}'.format(inputs, target) 317 expr = contract_expression(eq, *(x.shape for x in xs)) 318 expr(*xs, backend=backend) 319 costs.append(_compute_cost(cache)) 320 321 print('sizes = {}'.format(repr(sizes))) 322 print('costs = {}'.format(repr(costs))) 323 for size, cost in zip(sizes, costs): 324 print('{}\t{}'.format(size, cost)) 325 326 327@pytest.mark.parametrize('size', [3, 4, 5]) 328@pytest.mark.parametrize('backend', backends) 329def test_chain_sharing(size, backend): 330 xs = [np.random.rand(2, 2) for _ in range(size)] 331 alphabet = ''.join(get_symbol(i) for i in range(size + 1)) 332 names = [alphabet[i:i + 2] for i in range(size)] 333 inputs = ','.join(names) 334 335 num_exprs_nosharing = 0 336 for i in range(size + 1): 337 with shared_intermediates() as cache: 338 target = alphabet[i] 339 eq = '{}->{}'.format(inputs, target) 340 expr = contract_expression(eq, *(x.shape for x in xs)) 341 expr(*xs, backend=backend) 342 num_exprs_nosharing += _compute_cost(cache) 343 344 with shared_intermediates() as cache: 345 print(inputs) 346 for i in range(size + 1): 347 target = alphabet[i] 348 eq = '{}->{}'.format(inputs, target) 349 path_info = contract_path(eq, *xs) 350 print(path_info[1]) 351 expr = contract_expression(eq, *(x.shape for x in xs)) 352 expr(*xs, backend=backend) 353 num_exprs_sharing = _compute_cost(cache) 354 355 print('-' * 40) 356 print('Without sharing: {} expressions'.format(num_exprs_nosharing)) 357 print('With sharing: {} expressions'.format(num_exprs_sharing)) 358 assert num_exprs_nosharing > num_exprs_sharing 359 360 361def test_multithreaded_sharing(): 362 from multiprocessing.pool import ThreadPool 363 364 def fn(): 365 X, Y, Z = helpers.build_views('ab,bc,cd') 366 367 with shared_intermediates(): 368 contract('ab,bc,cd->a', X, Y, Z) 369 contract('ab,bc,cd->b', X, Y, Z) 370 371 return len(get_sharing_cache()) 372 373 expected = fn() 374 pool = ThreadPool(8) 375 fs = [pool.apply_async(fn) for _ in range(16)] 376 assert not currently_sharing() 377 assert [f.get() for f in fs] == [expected] * 16 378 pool.close() 379