1import numpy as np
2import pytest
3
4from opt_einsum import (backends, contract, contract_expression, helpers, sharing)
5from opt_einsum.contract import Shaped, infer_backend, parse_backend
6
7try:
8    import cupy
9    found_cupy = True
10except ImportError:
11    found_cupy = False
12
13try:
14    import tensorflow as tf
15    # needed so tensorflow doesn't allocate all gpu mem
16    _TF_CONFIG = tf.ConfigProto()
17    _TF_CONFIG.gpu_options.allow_growth = True
18    found_tensorflow = True
19except ImportError:
20    found_tensorflow = False
21
22try:
23    import os
24    os.environ['MKL_THREADING_LAYER'] = 'GNU'
25    import theano
26    found_theano = True
27except ImportError:
28    found_theano = False
29
30try:
31    import torch
32    found_torch = True
33except ImportError:
34    found_torch = False
35
36try:
37    import jax
38    found_jax = True
39except ImportError:
40    found_jax = False
41
42try:
43    import autograd
44    found_autograd = True
45except ImportError:
46    found_autograd = False
47
48tests = [
49    'ab,bc->ca',
50    'abc,bcd,dea',
51    'abc,def->fedcba',
52    'abc,bcd,df->fa',
53    # test 'prefer einsum' ops
54    'ijk,ikj',
55    'i,j->ij',
56    'ijk,k->ij',
57    'AB,BC->CA',
58]
59
60
61@pytest.mark.skipif(not found_tensorflow, reason="Tensorflow not installed.")
62@pytest.mark.parametrize("string", tests)
63def test_tensorflow(string):
64    views = helpers.build_views(string)
65    ein = contract(string, *views, optimize=False, use_blas=False)
66    opt = np.empty_like(ein)
67
68    shps = [v.shape for v in views]
69    expr = contract_expression(string, *shps, optimize=True)
70
71    sess = tf.Session(config=_TF_CONFIG)
72    with sess.as_default():
73        expr(*views, backend='tensorflow', out=opt)
74    sess.close()
75
76    assert np.allclose(ein, opt)
77
78    # test non-conversion mode
79    tensorflow_views = [backends.to_tensorflow(view) for view in views]
80    expr(*tensorflow_views)
81
82
83@pytest.mark.skipif(not found_tensorflow, reason="Tensorflow not installed.")
84@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
85def test_tensorflow_with_constants(constants):
86    eq = 'ij,jk,kl->li'
87    shapes = (2, 3), (3, 4), (4, 5)
88    non_const, = {0, 1, 2} - constants
89    ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
90    var = np.random.rand(*shapes[non_const])
91    res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3)))
92
93    expr = contract_expression(eq, *ops, constants=constants)
94
95    # check tensorflow
96    with tf.Session(config=_TF_CONFIG).as_default():
97        res_got = expr(var, backend='tensorflow')
98    assert all(array is None or infer_backend(array) == 'tensorflow'
99               for array in expr._evaluated_constants['tensorflow'])
100    assert np.allclose(res_exp, res_got)
101
102    # check can call with numpy still
103    res_got2 = expr(var, backend='numpy')
104    assert np.allclose(res_exp, res_got2)
105
106    # check tensorflow call returns tensorflow still
107    res_got3 = expr(backends.to_tensorflow(var))
108    assert isinstance(res_got3, tf.Tensor)
109
110
111@pytest.mark.skipif(not found_tensorflow, reason="Tensorflow not installed.")
112@pytest.mark.parametrize("string", tests)
113def test_tensorflow_with_sharing(string):
114    views = helpers.build_views(string)
115    ein = contract(string, *views, optimize=False, use_blas=False)
116
117    shps = [v.shape for v in views]
118    expr = contract_expression(string, *shps, optimize=True)
119
120    sess = tf.Session(config=_TF_CONFIG)
121
122    with sess.as_default(), sharing.shared_intermediates() as cache:
123        tfl1 = expr(*views, backend='tensorflow')
124        assert sharing.get_sharing_cache() is cache
125        cache_sz = len(cache)
126        assert cache_sz > 0
127        tfl2 = expr(*views, backend='tensorflow')
128        assert len(cache) == cache_sz
129
130    assert all(isinstance(t, tf.Tensor) for t in cache.values())
131
132    assert np.allclose(ein, tfl1)
133    assert np.allclose(ein, tfl2)
134
135
136@pytest.mark.skipif(not found_theano, reason="Theano not installed.")
137@pytest.mark.parametrize("string", tests)
138def test_theano(string):
139    views = helpers.build_views(string)
140    ein = contract(string, *views, optimize=False, use_blas=False)
141    shps = [v.shape for v in views]
142
143    expr = contract_expression(string, *shps, optimize=True)
144
145    opt = expr(*views, backend='theano')
146    assert np.allclose(ein, opt)
147
148    # test non-conversion mode
149    theano_views = [backends.to_theano(view) for view in views]
150    theano_opt = expr(*theano_views)
151    assert isinstance(theano_opt, theano.tensor.TensorVariable)
152
153
154@pytest.mark.skipif(not found_theano, reason="theano not installed.")
155@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
156def test_theano_with_constants(constants):
157    eq = 'ij,jk,kl->li'
158    shapes = (2, 3), (3, 4), (4, 5)
159    non_const, = {0, 1, 2} - constants
160    ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
161    var = np.random.rand(*shapes[non_const])
162    res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3)))
163
164    expr = contract_expression(eq, *ops, constants=constants)
165
166    # check theano
167    res_got = expr(var, backend='theano')
168    assert all(array is None or infer_backend(array) == 'theano' for array in expr._evaluated_constants['theano'])
169    assert np.allclose(res_exp, res_got)
170
171    # check can call with numpy still
172    res_got2 = expr(var, backend='numpy')
173    assert np.allclose(res_exp, res_got2)
174
175    # check theano call returns theano still
176    res_got3 = expr(backends.to_theano(var))
177    assert isinstance(res_got3, theano.tensor.TensorVariable)
178
179
180@pytest.mark.skipif(not found_theano, reason="Theano not installed.")
181@pytest.mark.parametrize("string", tests)
182def test_theano_with_sharing(string):
183    views = helpers.build_views(string)
184    ein = contract(string, *views, optimize=False, use_blas=False)
185
186    shps = [v.shape for v in views]
187    expr = contract_expression(string, *shps, optimize=True)
188
189    with sharing.shared_intermediates() as cache:
190        thn1 = expr(*views, backend='theano')
191        assert sharing.get_sharing_cache() is cache
192        cache_sz = len(cache)
193        assert cache_sz > 0
194        thn2 = expr(*views, backend='theano')
195        assert len(cache) == cache_sz
196
197    assert all(isinstance(t, theano.tensor.TensorVariable) for t in cache.values())
198
199    assert np.allclose(ein, thn1)
200    assert np.allclose(ein, thn2)
201
202
203@pytest.mark.skipif(not found_cupy, reason="Cupy not installed.")
204@pytest.mark.parametrize("string", tests)
205def test_cupy(string):  # pragma: no cover
206    views = helpers.build_views(string)
207    ein = contract(string, *views, optimize=False, use_blas=False)
208    shps = [v.shape for v in views]
209
210    expr = contract_expression(string, *shps, optimize=True)
211
212    opt = expr(*views, backend='cupy')
213    assert np.allclose(ein, opt)
214
215    # test non-conversion mode
216    cupy_views = [backends.to_cupy(view) for view in views]
217    cupy_opt = expr(*cupy_views)
218    assert isinstance(cupy_opt, cupy.ndarray)
219    assert np.allclose(ein, cupy.asnumpy(cupy_opt))
220
221
222@pytest.mark.skipif(not found_cupy, reason="Cupy not installed.")
223@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
224def test_cupy_with_constants(constants):  # pragma: no cover
225    eq = 'ij,jk,kl->li'
226    shapes = (2, 3), (3, 4), (4, 5)
227    non_const, = {0, 1, 2} - constants
228    ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
229    var = np.random.rand(*shapes[non_const])
230    res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3)))
231
232    expr = contract_expression(eq, *ops, constants=constants)
233
234    # check cupy
235    res_got = expr(var, backend='cupy')
236    # check cupy versions of constants exist
237    assert all(array is None or infer_backend(array) == 'cupy' for array in expr._evaluated_constants['cupy'])
238    assert np.allclose(res_exp, res_got)
239
240    # check can call with numpy still
241    res_got2 = expr(var, backend='numpy')
242    assert np.allclose(res_exp, res_got2)
243
244    # check cupy call returns cupy still
245    res_got3 = expr(cupy.asarray(var))
246    assert isinstance(res_got3, cupy.ndarray)
247    assert np.allclose(res_exp, res_got3.get())
248
249
250@pytest.mark.skipif(not found_jax, reason="jax not installed.")
251@pytest.mark.parametrize("string", tests)
252def test_jax(string):  # pragma: no cover
253    views = helpers.build_views(string)
254    ein = contract(string, *views, optimize=False, use_blas=False)
255    shps = [v.shape for v in views]
256
257    expr = contract_expression(string, *shps, optimize=True)
258
259    opt = expr(*views, backend='jax')
260    assert np.allclose(ein, opt)
261    assert isinstance(opt, np.ndarray)
262
263
264@pytest.mark.skipif(not found_jax, reason="jax not installed.")
265@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
266def test_jax_with_constants(constants):  # pragma: no cover
267    eq = 'ij,jk,kl->li'
268    shapes = (2, 3), (3, 4), (4, 5)
269    non_const, = {0, 1, 2} - constants
270    ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
271    var = np.random.rand(*shapes[non_const])
272    res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3)))
273
274    expr = contract_expression(eq, *ops, constants=constants)
275
276    # check jax
277    res_got = expr(var, backend='jax')
278    # check jax versions of constants exist
279    assert all(array is None or infer_backend(array) == 'jax' for array in expr._evaluated_constants['jax'])
280
281    assert np.allclose(res_exp, res_got)
282
283
284@pytest.mark.skipif(not found_jax, reason="jax not installed.")
285def test_jax_jit_gradient():
286    eq = 'ij,jk,kl->'
287    shapes = (2, 3), (3, 4), (4, 2)
288    views = [np.random.randn(*s) for s in shapes]
289    expr = contract_expression(eq, *shapes)
290    x0 = expr(*views)
291
292    jit_expr = jax.jit(expr)
293    x1 = jit_expr(*views).item()
294    assert x1 == pytest.approx(x0, rel=1e-5)
295
296    # jax only takes gradient w.r.t first argument
297    grad_expr = jax.jit(jax.grad(lambda views: expr(*views)))
298    view_grads = grad_expr(views)
299    assert all(v1.shape == v2.shape for v1, v2 in zip(views, view_grads))
300
301    # taking a step along the gradient should reduce our 'loss'
302    new_views = [v - 0.001 * dv for v, dv in zip(views, view_grads)]
303    x2 = jit_expr(*new_views).item()
304    assert x2 < x1
305
306
307@pytest.mark.skipif(not found_autograd, reason="autograd not installed.")
308def test_autograd_gradient():
309    eq = 'ij,jk,kl->'
310    shapes = (2, 3), (3, 4), (4, 2)
311    views = [np.random.randn(*s) for s in shapes]
312    expr = contract_expression(eq, *shapes)
313    x0 = expr(*views)
314
315    # autograd only takes gradient w.r.t first argument
316    grad_expr = autograd.grad(lambda views: expr(*views))
317    view_grads = grad_expr(views)
318    assert all(v1.shape == v2.shape for v1, v2 in zip(views, view_grads))
319
320    # taking a step along the gradient should reduce our 'loss'
321    new_views = [v - 0.001 * dv for v, dv in zip(views, view_grads)]
322    x1 = expr(*new_views)
323    assert x1 < x0
324
325
326@pytest.mark.parametrize("string", tests)
327def test_dask(string):
328    da = pytest.importorskip("dask.array")
329
330    views = helpers.build_views(string)
331    ein = contract(string, *views, optimize=False, use_blas=False)
332    shps = [v.shape for v in views]
333    expr = contract_expression(string, *shps, optimize=True)
334
335    # test non-conversion mode
336    da_views = [da.from_array(x, chunks=(2)) for x in views]
337    da_opt = expr(*da_views)
338
339    # check type is maintained when not using numpy arrays
340    assert isinstance(da_opt, da.Array)
341
342    assert np.allclose(ein, np.array(da_opt))
343
344    # try raw contract
345    da_opt = contract(string, *da_views)
346    assert isinstance(da_opt, da.Array)
347    assert np.allclose(ein, np.array(da_opt))
348
349
350@pytest.mark.parametrize("string", tests)
351def test_sparse(string):
352    sparse = pytest.importorskip("sparse")
353
354    views = helpers.build_views(string)
355
356    # sparsify views so they don't become dense during contraction
357    for view in views:
358        np.random.seed(42)
359        mask = np.random.choice([False, True], view.shape, True, [0.05, 0.95])
360        view[mask] = 0
361
362    ein = contract(string, *views, optimize=False, use_blas=False)
363    shps = [v.shape for v in views]
364    expr = contract_expression(string, *shps, optimize=True)
365
366    # test non-conversion mode
367    sparse_views = [sparse.COO.from_numpy(x) for x in views]
368    sparse_opt = expr(*sparse_views)
369
370    # check type is maintained when not using numpy arrays
371    assert isinstance(sparse_opt, sparse.COO)
372
373    assert np.allclose(ein, sparse_opt.todense())
374
375    # try raw contract
376    sparse_opt = contract(string, *sparse_views)
377    assert isinstance(sparse_opt, sparse.COO)
378    assert np.allclose(ein, sparse_opt.todense())
379
380
381@pytest.mark.skipif(not found_torch, reason="Torch not installed.")
382@pytest.mark.parametrize("string", tests)
383def test_torch(string):
384
385    views = helpers.build_views(string)
386    ein = contract(string, *views, optimize=False, use_blas=False)
387    shps = [v.shape for v in views]
388
389    expr = contract_expression(string, *shps, optimize=True)
390
391    opt = expr(*views, backend='torch')
392    assert np.allclose(ein, opt)
393
394    # test non-conversion mode
395    torch_views = [backends.to_torch(view) for view in views]
396    torch_opt = expr(*torch_views)
397    assert isinstance(torch_opt, torch.Tensor)
398    assert np.allclose(ein, torch_opt.cpu().numpy())
399
400
401@pytest.mark.skipif(not found_torch, reason="Torch not installed.")
402@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
403def test_torch_with_constants(constants):
404    eq = 'ij,jk,kl->li'
405    shapes = (2, 3), (3, 4), (4, 5)
406    non_const, = {0, 1, 2} - constants
407    ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
408    var = np.random.rand(*shapes[non_const])
409    res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3)))
410
411    expr = contract_expression(eq, *ops, constants=constants)
412
413    # check torch
414    res_got = expr(var, backend='torch')
415    assert all(array is None or infer_backend(array) == 'torch' for array in expr._evaluated_constants['torch'])
416    assert np.allclose(res_exp, res_got)
417
418    # check can call with numpy still
419    res_got2 = expr(var, backend='numpy')
420    assert np.allclose(res_exp, res_got2)
421
422    # check torch call returns torch still
423    res_got3 = expr(backends.to_torch(var))
424    assert isinstance(res_got3, torch.Tensor)
425    res_got3 = res_got3.numpy() if res_got3.device.type == 'cpu' else res_got3.cpu().numpy()
426    assert np.allclose(res_exp, res_got3)
427
428
429def test_auto_backend_custom_array_no_tensordot():
430    x = Shaped((1, 2, 3))
431    # Shaped is an array-like object defined by opt_einsum - which has no TDOT
432    assert infer_backend(x) == 'opt_einsum'
433    assert parse_backend([x], 'auto') == 'numpy'
434
435
436@pytest.mark.parametrize("string", tests)
437def test_object_arrays_backend(string):
438    views = helpers.build_views(string)
439    ein = contract(string, *views, optimize=False, use_blas=False)
440    assert ein.dtype != object
441
442    shps = [v.shape for v in views]
443    expr = contract_expression(string, *shps, optimize=True)
444
445    obj_views = [view.astype(object) for view in views]
446
447    # try raw contract
448    obj_opt = contract(string, *obj_views, backend='object')
449    assert obj_opt.dtype == object
450    assert np.allclose(ein, obj_opt.astype(float))
451
452    # test expression
453    obj_opt = expr(*obj_views, backend='object')
454    assert obj_opt.dtype == object
455    assert np.allclose(ein, obj_opt.astype(float))
456