1"""
2Tets a series of opt_einsum contraction paths to ensure the results are the same for different paths
3"""
4
5import numpy as np
6import pytest
7
8from opt_einsum import contract, contract_expression
9
10
11def test_contract_expression_checks():
12    # check optimize needed
13    with pytest.raises(ValueError):
14        contract_expression("ab,bc->ac", (2, 3), (3, 4), optimize=False)
15
16    # check sizes are still checked
17    with pytest.raises(ValueError):
18        contract_expression("ab,bc->ac", (2, 3), (3, 4), (42, 42))
19
20    # check if out given
21    out = np.empty((2, 4))
22    with pytest.raises(ValueError):
23        contract_expression("ab,bc->ac", (2, 3), (3, 4), out=out)
24
25    # check still get errors when wrong ranks supplied to expression
26    expr = contract_expression("ab,bc->ac", (2, 3), (3, 4))
27
28    # too few arguments
29    with pytest.raises(ValueError) as err:
30        expr(np.random.rand(2, 3))
31    assert "`ContractExpression` takes exactly 2" in str(err.value)
32
33    # too many arguments
34    with pytest.raises(ValueError) as err:
35        expr(np.random.rand(2, 3), np.random.rand(2, 3), np.random.rand(2, 3))
36    assert "`ContractExpression` takes exactly 2" in str(err.value)
37
38    # wrong shapes
39    with pytest.raises(ValueError) as err:
40        expr(np.random.rand(2, 3, 4), np.random.rand(3, 4))
41    assert "Internal error while evaluating `ContractExpression`" in str(err.value)
42    with pytest.raises(ValueError) as err:
43        expr(np.random.rand(2, 4), np.random.rand(3, 4, 5))
44    assert "Internal error while evaluating `ContractExpression`" in str(err.value)
45    with pytest.raises(ValueError) as err:
46        expr(np.random.rand(2, 3), np.random.rand(3, 4), out=np.random.rand(2, 4, 6))
47    assert "Internal error while evaluating `ContractExpression`" in str(err.value)
48
49    # should only be able to specify out
50    with pytest.raises(ValueError) as err:
51        expr(np.random.rand(2, 3), np.random.rand(3, 4), order='F')
52    assert "only valid keyword arguments to a `ContractExpression`" in str(err.value)
53
54
55def test_broadcasting_contraction():
56
57    a = np.random.rand(1, 5, 4)
58    b = np.random.rand(4, 6)
59    c = np.random.rand(5, 6)
60    d = np.random.rand(10)
61
62    ein_scalar = contract('ijk,kl,jl', a, b, c, optimize=False)
63    opt_scalar = contract('ijk,kl,jl', a, b, c, optimize=True)
64    assert np.allclose(ein_scalar, opt_scalar)
65
66    result = ein_scalar * d
67
68    ein = contract('ijk,kl,jl,i->i', a, b, c, d, optimize=False)
69    opt = contract('ijk,kl,jl,i->i', a, b, c, d, optimize=True)
70
71    assert np.allclose(ein, result)
72    assert np.allclose(opt, result)
73
74
75def test_broadcasting_contraction2():
76
77    a = np.random.rand(1, 1, 5, 4)
78    b = np.random.rand(4, 6)
79    c = np.random.rand(5, 6)
80    d = np.random.rand(7, 7)
81
82    ein_scalar = contract('abjk,kl,jl', a, b, c, optimize=False)
83    opt_scalar = contract('abjk,kl,jl', a, b, c, optimize=True)
84    assert np.allclose(ein_scalar, opt_scalar)
85
86    result = ein_scalar * d
87
88    ein = contract('abjk,kl,jl,ab->ab', a, b, c, d, optimize=False)
89    opt = contract('abjk,kl,jl,ab->ab', a, b, c, d, optimize=True)
90
91    assert np.allclose(ein, result)
92    assert np.allclose(opt, result)
93
94
95def test_broadcasting_contraction3():
96
97    a = np.random.rand(1, 5, 4)
98    b = np.random.rand(4, 1, 6)
99    c = np.random.rand(5, 6)
100    d = np.random.rand(7, 7)
101
102    ein = contract('ajk,kbl,jl,ab->ab', a, b, c, d, optimize=False)
103    opt = contract('ajk,kbl,jl,ab->ab', a, b, c, d, optimize=True)
104
105    assert np.allclose(ein, opt)
106
107
108def test_broadcasting_contraction4():
109
110    a = np.arange(64).reshape(2, 4, 8)
111    ein = contract('obk,ijk->ioj', a, a, optimize=False)
112    opt = contract('obk,ijk->ioj', a, a, optimize=True)
113
114    assert np.allclose(ein, opt)
115
116
117def test_can_blas_on_healed_broadcast_dimensions():
118
119    expr = contract_expression("ab,bc,bd->acd", (5, 4), (1, 5), (4, 20))
120    # first contraction involves broadcasting
121    assert expr.contraction_list[0][2] == 'bc,ab->bca'
122    assert expr.contraction_list[0][-1] is False
123    # but then is healed GEMM is usable
124    assert expr.contraction_list[1][2] == 'bca,bd->acd'
125    assert expr.contraction_list[1][-1] == 'GEMM'
126