1import itertools
2import weakref
3from collections import Counter
4
5import numpy as np
6import pytest
7
8from opt_einsum import (contract, contract_expression, contract_path, get_symbol, helpers, shared_intermediates)
9from opt_einsum.backends import to_cupy, to_torch
10from opt_einsum.contract import _einsum
11from opt_einsum.parser import parse_einsum_input
12from opt_einsum.sharing import (count_cached_ops, currently_sharing, get_sharing_cache)
13
14try:
15    import cupy
16    cupy_if_found = 'cupy'
17except ImportError:
18    cupy_if_found = pytest.param('cupy', marks=[pytest.mark.skip(reason="CuPy not installed.")])
19
20try:
21    import torch
22    torch_if_found = 'torch'
23except ImportError:
24    torch_if_found = pytest.param('torch', marks=[pytest.mark.skip(reason="PyTorch not installed.")])
25
26backends = ['numpy', torch_if_found, cupy_if_found]
27equations = [
28    'ab,bc->ca',
29    'abc,bcd,dea',
30    'abc,def->fedcba',
31    'abc,bcd,df->fa',
32    # test 'prefer einsum' ops
33    'ijk,ikj',
34    'i,j->ij',
35    'ijk,k->ij',
36    'AB,BC->CA',
37]
38to_backend = {
39    'numpy': lambda x: x,
40    'torch': to_torch,
41    'cupy': to_cupy,
42}
43
44
45@pytest.mark.parametrize('eq', equations)
46@pytest.mark.parametrize('backend', backends)
47def test_sharing_value(eq, backend):
48    views = helpers.build_views(eq)
49    shapes = [v.shape for v in views]
50    expr = contract_expression(eq, *shapes)
51
52    expected = expr(*views, backend=backend)
53    with shared_intermediates():
54        actual = expr(*views, backend=backend)
55
56    assert (actual == expected).all()
57
58
59@pytest.mark.parametrize('backend', backends)
60def test_complete_sharing(backend):
61    eq = 'ab,bc,cd->'
62    views = helpers.build_views(eq)
63    expr = contract_expression(eq, *(v.shape for v in views))
64
65    print('-' * 40)
66    print('Without sharing:')
67    with shared_intermediates() as cache:
68        expr(*views, backend=backend)
69        expected = count_cached_ops(cache)
70
71    print('-' * 40)
72    print('With sharing:')
73    with shared_intermediates() as cache:
74        expr(*views, backend=backend)
75        expr(*views, backend=backend)
76        actual = count_cached_ops(cache)
77
78    print('-' * 40)
79    print('Without sharing: {} expressions'.format(expected))
80    print('With sharing: {} expressions'.format(actual))
81    assert actual == expected
82
83
84@pytest.mark.parametrize('backend', backends)
85def test_sharing_reused_cache(backend):
86    eq = 'ab,bc,cd->'
87    views = helpers.build_views(eq)
88    expr = contract_expression(eq, *(v.shape for v in views))
89
90    print('-' * 40)
91    print('Without sharing:')
92    with shared_intermediates() as cache:
93        expr(*views, backend=backend)
94        expected = count_cached_ops(cache)
95
96    print('-' * 40)
97    print('With sharing:')
98    with shared_intermediates() as cache:
99        expr(*views, backend=backend)
100    with shared_intermediates(cache):
101        expr(*views, backend=backend)
102        actual = count_cached_ops(cache)
103
104    print('-' * 40)
105    print('Without sharing: {} expressions'.format(expected))
106    print('With sharing: {} expressions'.format(actual))
107    assert actual == expected
108
109
110@pytest.mark.parametrize('backend', backends)
111def test_no_sharing_separate_cache(backend):
112    eq = 'ab,bc,cd->'
113    views = helpers.build_views(eq)
114    expr = contract_expression(eq, *(v.shape for v in views))
115
116    print('-' * 40)
117    print('Without sharing:')
118    with shared_intermediates() as cache:
119        expr(*views, backend=backend)
120        expected = count_cached_ops(cache)
121        expected.update(count_cached_ops(cache))  # we expect double
122
123    print('-' * 40)
124    print('With sharing:')
125    with shared_intermediates() as cache1:
126        expr(*views, backend=backend)
127        actual = count_cached_ops(cache1)
128    with shared_intermediates() as cache2:
129        expr(*views, backend=backend)
130        actual.update(count_cached_ops(cache2))
131
132    print('-' * 40)
133    print('Without sharing: {} expressions'.format(expected))
134    print('With sharing: {} expressions'.format(actual))
135    assert actual == expected
136
137
138@pytest.mark.parametrize('backend', backends)
139def test_sharing_nesting(backend):
140    eqs = ['ab,bc,cd->a', 'ab,bc,cd->b', 'ab,bc,cd->c', 'ab,bc,cd->c']
141    views = helpers.build_views(eqs[0])
142    shapes = [v.shape for v in views]
143    refs = weakref.WeakValueDictionary()
144
145    def method1(views):
146        with shared_intermediates():
147            w = contract_expression(eqs[0], *shapes)(*views, backend=backend)
148            x = contract_expression(eqs[2], *shapes)(*views, backend=backend)
149            result = contract_expression('a,b->', w.shape, x.shape)(w, x, backend=backend)
150            refs['w'] = w
151            refs['x'] = x
152            del w, x
153            assert 'w' in refs
154            assert 'x' in refs
155        assert 'w' not in refs, 'cache leakage'
156        assert 'x' not in refs, 'cache leakage'
157        return result
158
159    def method2(views):
160        with shared_intermediates():
161            y = contract_expression(eqs[2], *shapes)(*views, backend=backend)
162            z = contract_expression(eqs[3], *shapes)(*views, backend=backend)
163            refs['y'] = y
164            refs['z'] = z
165            result = contract_expression('c,d->', y.shape, z.shape)(y, z, backend=backend)
166            result = result + method1(views)  # nest method1 in method2
167            del y, z
168            assert 'y' in refs
169            assert 'z' in refs
170        assert 'y' not in refs
171        assert 'z' not in refs
172
173    method1(views)
174    method2(views)
175
176
177@pytest.mark.parametrize('eq', equations)
178@pytest.mark.parametrize('backend', backends)
179def test_sharing_modulo_commutativity(eq, backend):
180    ops = helpers.build_views(eq)
181    ops = [to_backend[backend](x) for x in ops]
182    inputs, output, _ = parse_einsum_input([eq] + ops)
183    inputs = inputs.split(',')
184
185    print('-' * 40)
186    print('Without sharing:')
187    with shared_intermediates() as cache:
188        _einsum(eq, *ops, backend=backend)
189        expected = count_cached_ops(cache)
190
191    print('-' * 40)
192    print('With sharing:')
193    with shared_intermediates() as cache:
194        for permuted in itertools.permutations(zip(inputs, ops)):
195            permuted_inputs = [p[0] for p in permuted]
196            permuted_ops = [p[1] for p in permuted]
197            permuted_eq = '{}->{}'.format(','.join(permuted_inputs), output)
198            _einsum(permuted_eq, *permuted_ops, backend=backend)
199        actual = count_cached_ops(cache)
200
201    print('-' * 40)
202    print('Without sharing: {} expressions'.format(expected))
203    print('With sharing: {} expressions'.format(actual))
204    assert actual == expected
205
206
207@pytest.mark.parametrize('backend', backends)
208def test_partial_sharing(backend):
209    eq = 'ab,bc,de->'
210    x, y, z1 = helpers.build_views(eq)
211    z2 = 2.0 * z1 - 1.0
212    expr = contract_expression(eq, x.shape, y.shape, z1.shape)
213
214    print('-' * 40)
215    print('Without sharing:')
216    num_exprs_nosharing = Counter()
217    with shared_intermediates() as cache:
218        expr(x, y, z1, backend=backend)
219        num_exprs_nosharing.update(count_cached_ops(cache))
220    with shared_intermediates() as cache:
221        expr(x, y, z2, backend=backend)
222        num_exprs_nosharing.update(count_cached_ops(cache))
223
224    print('-' * 40)
225    print('With sharing:')
226    with shared_intermediates() as cache:
227        expr(x, y, z1, backend=backend)
228        expr(x, y, z2, backend=backend)
229        num_exprs_sharing = count_cached_ops(cache)
230
231    print('-' * 40)
232    print('Without sharing: {} expressions'.format(num_exprs_nosharing))
233    print('With sharing: {} expressions'.format(num_exprs_sharing))
234    assert num_exprs_nosharing['einsum'] > num_exprs_sharing['einsum']
235
236
237@pytest.mark.parametrize('backend', backends)
238def test_sharing_with_constants(backend):
239    inputs = 'ij,jk,kl'
240    outputs = 'ijkl'
241    equations = ['{}->{}'.format(inputs, output) for output in outputs]
242    shapes = (2, 3), (3, 4), (4, 5)
243    constants = {0, 2}
244    ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
245    var = np.random.rand(*shapes[1])
246
247    expected = [contract_expression(eq, *shapes)(ops[0], var, ops[2]) for eq in equations]
248
249    with shared_intermediates():
250        actual = [contract_expression(eq, *ops, constants=constants)(var) for eq in equations]
251
252    for dim, expected_dim, actual_dim in zip(outputs, expected, actual):
253        assert np.allclose(expected_dim, actual_dim), 'error at {}'.format(dim)
254
255
256@pytest.mark.parametrize('size', [3, 4, 5])
257@pytest.mark.parametrize('backend', backends)
258def test_chain(size, backend):
259    xs = [np.random.rand(2, 2) for _ in range(size)]
260    shapes = [x.shape for x in xs]
261    alphabet = ''.join(get_symbol(i) for i in range(size + 1))
262    names = [alphabet[i:i + 2] for i in range(size)]
263    inputs = ','.join(names)
264
265    with shared_intermediates():
266        print(inputs)
267        for i in range(size + 1):
268            target = alphabet[i]
269            eq = '{}->{}'.format(inputs, target)
270            path_info = contract_path(eq, *xs)
271            print(path_info[1])
272            expr = contract_expression(eq, *shapes)
273            expr(*xs, backend=backend)
274        print('-' * 40)
275
276
277@pytest.mark.parametrize('size', [3, 4, 5, 10])
278@pytest.mark.parametrize('backend', backends)
279def test_chain_2(size, backend):
280    xs = [np.random.rand(2, 2) for _ in range(size)]
281    shapes = [x.shape for x in xs]
282    alphabet = ''.join(get_symbol(i) for i in range(size + 1))
283    names = [alphabet[i:i + 2] for i in range(size)]
284    inputs = ','.join(names)
285
286    with shared_intermediates():
287        print(inputs)
288        for i in range(size):
289            target = alphabet[i:i + 2]
290            eq = '{}->{}'.format(inputs, target)
291            path_info = contract_path(eq, *xs)
292            print(path_info[1])
293            expr = contract_expression(eq, *shapes)
294            expr(*xs, backend=backend)
295        print('-' * 40)
296
297
298def _compute_cost(cache):
299    counts = count_cached_ops(cache)
300    return counts['einsum'] + counts['tensordot']
301
302
303@pytest.mark.parametrize('backend', backends)
304def test_chain_2_growth(backend):
305    sizes = list(range(1, 21))
306    costs = []
307    for size in sizes:
308        xs = [np.random.rand(2, 2) for _ in range(size)]
309        alphabet = ''.join(get_symbol(i) for i in range(size + 1))
310        names = [alphabet[i:i + 2] for i in range(size)]
311        inputs = ','.join(names)
312
313        with shared_intermediates() as cache:
314            for i in range(size):
315                target = alphabet[i:i + 2]
316                eq = '{}->{}'.format(inputs, target)
317                expr = contract_expression(eq, *(x.shape for x in xs))
318                expr(*xs, backend=backend)
319            costs.append(_compute_cost(cache))
320
321    print('sizes = {}'.format(repr(sizes)))
322    print('costs = {}'.format(repr(costs)))
323    for size, cost in zip(sizes, costs):
324        print('{}\t{}'.format(size, cost))
325
326
327@pytest.mark.parametrize('size', [3, 4, 5])
328@pytest.mark.parametrize('backend', backends)
329def test_chain_sharing(size, backend):
330    xs = [np.random.rand(2, 2) for _ in range(size)]
331    alphabet = ''.join(get_symbol(i) for i in range(size + 1))
332    names = [alphabet[i:i + 2] for i in range(size)]
333    inputs = ','.join(names)
334
335    num_exprs_nosharing = 0
336    for i in range(size + 1):
337        with shared_intermediates() as cache:
338            target = alphabet[i]
339            eq = '{}->{}'.format(inputs, target)
340            expr = contract_expression(eq, *(x.shape for x in xs))
341            expr(*xs, backend=backend)
342            num_exprs_nosharing += _compute_cost(cache)
343
344    with shared_intermediates() as cache:
345        print(inputs)
346        for i in range(size + 1):
347            target = alphabet[i]
348            eq = '{}->{}'.format(inputs, target)
349            path_info = contract_path(eq, *xs)
350            print(path_info[1])
351            expr = contract_expression(eq, *(x.shape for x in xs))
352            expr(*xs, backend=backend)
353        num_exprs_sharing = _compute_cost(cache)
354
355    print('-' * 40)
356    print('Without sharing: {} expressions'.format(num_exprs_nosharing))
357    print('With sharing: {} expressions'.format(num_exprs_sharing))
358    assert num_exprs_nosharing > num_exprs_sharing
359
360
361def test_multithreaded_sharing():
362    from multiprocessing.pool import ThreadPool
363
364    def fn():
365        X, Y, Z = helpers.build_views('ab,bc,cd')
366
367        with shared_intermediates():
368            contract('ab,bc,cd->a', X, Y, Z)
369            contract('ab,bc,cd->b', X, Y, Z)
370
371            return len(get_sharing_cache())
372
373    expected = fn()
374    pool = ThreadPool(8)
375    fs = [pool.apply_async(fn) for _ in range(16)]
376    assert not currently_sharing()
377    assert [f.get() for f in fs] == [expected] * 16
378    pool.close()
379