1""" 2Tets a series of opt_einsum contraction paths to ensure the results are the same for different paths 3""" 4 5import numpy as np 6import pytest 7 8from opt_einsum import contract, contract_expression 9 10 11def test_contract_expression_checks(): 12 # check optimize needed 13 with pytest.raises(ValueError): 14 contract_expression("ab,bc->ac", (2, 3), (3, 4), optimize=False) 15 16 # check sizes are still checked 17 with pytest.raises(ValueError): 18 contract_expression("ab,bc->ac", (2, 3), (3, 4), (42, 42)) 19 20 # check if out given 21 out = np.empty((2, 4)) 22 with pytest.raises(ValueError): 23 contract_expression("ab,bc->ac", (2, 3), (3, 4), out=out) 24 25 # check still get errors when wrong ranks supplied to expression 26 expr = contract_expression("ab,bc->ac", (2, 3), (3, 4)) 27 28 # too few arguments 29 with pytest.raises(ValueError) as err: 30 expr(np.random.rand(2, 3)) 31 assert "`ContractExpression` takes exactly 2" in str(err.value) 32 33 # too many arguments 34 with pytest.raises(ValueError) as err: 35 expr(np.random.rand(2, 3), np.random.rand(2, 3), np.random.rand(2, 3)) 36 assert "`ContractExpression` takes exactly 2" in str(err.value) 37 38 # wrong shapes 39 with pytest.raises(ValueError) as err: 40 expr(np.random.rand(2, 3, 4), np.random.rand(3, 4)) 41 assert "Internal error while evaluating `ContractExpression`" in str(err.value) 42 with pytest.raises(ValueError) as err: 43 expr(np.random.rand(2, 4), np.random.rand(3, 4, 5)) 44 assert "Internal error while evaluating `ContractExpression`" in str(err.value) 45 with pytest.raises(ValueError) as err: 46 expr(np.random.rand(2, 3), np.random.rand(3, 4), out=np.random.rand(2, 4, 6)) 47 assert "Internal error while evaluating `ContractExpression`" in str(err.value) 48 49 # should only be able to specify out 50 with pytest.raises(ValueError) as err: 51 expr(np.random.rand(2, 3), np.random.rand(3, 4), order='F') 52 assert "only valid keyword arguments to a `ContractExpression`" in str(err.value) 53 54 55def test_broadcasting_contraction(): 56 57 a = np.random.rand(1, 5, 4) 58 b = np.random.rand(4, 6) 59 c = np.random.rand(5, 6) 60 d = np.random.rand(10) 61 62 ein_scalar = contract('ijk,kl,jl', a, b, c, optimize=False) 63 opt_scalar = contract('ijk,kl,jl', a, b, c, optimize=True) 64 assert np.allclose(ein_scalar, opt_scalar) 65 66 result = ein_scalar * d 67 68 ein = contract('ijk,kl,jl,i->i', a, b, c, d, optimize=False) 69 opt = contract('ijk,kl,jl,i->i', a, b, c, d, optimize=True) 70 71 assert np.allclose(ein, result) 72 assert np.allclose(opt, result) 73 74 75def test_broadcasting_contraction2(): 76 77 a = np.random.rand(1, 1, 5, 4) 78 b = np.random.rand(4, 6) 79 c = np.random.rand(5, 6) 80 d = np.random.rand(7, 7) 81 82 ein_scalar = contract('abjk,kl,jl', a, b, c, optimize=False) 83 opt_scalar = contract('abjk,kl,jl', a, b, c, optimize=True) 84 assert np.allclose(ein_scalar, opt_scalar) 85 86 result = ein_scalar * d 87 88 ein = contract('abjk,kl,jl,ab->ab', a, b, c, d, optimize=False) 89 opt = contract('abjk,kl,jl,ab->ab', a, b, c, d, optimize=True) 90 91 assert np.allclose(ein, result) 92 assert np.allclose(opt, result) 93 94 95def test_broadcasting_contraction3(): 96 97 a = np.random.rand(1, 5, 4) 98 b = np.random.rand(4, 1, 6) 99 c = np.random.rand(5, 6) 100 d = np.random.rand(7, 7) 101 102 ein = contract('ajk,kbl,jl,ab->ab', a, b, c, d, optimize=False) 103 opt = contract('ajk,kbl,jl,ab->ab', a, b, c, d, optimize=True) 104 105 assert np.allclose(ein, opt) 106 107 108def test_broadcasting_contraction4(): 109 110 a = np.arange(64).reshape(2, 4, 8) 111 ein = contract('obk,ijk->ioj', a, a, optimize=False) 112 opt = contract('obk,ijk->ioj', a, a, optimize=True) 113 114 assert np.allclose(ein, opt) 115 116 117def test_can_blas_on_healed_broadcast_dimensions(): 118 119 expr = contract_expression("ab,bc,bd->acd", (5, 4), (1, 5), (4, 20)) 120 # first contraction involves broadcasting 121 assert expr.contraction_list[0][2] == 'bc,ab->bca' 122 assert expr.contraction_list[0][-1] is False 123 # but then is healed GEMM is usable 124 assert expr.contraction_list[1][2] == 'bca,bd->acd' 125 assert expr.contraction_list[1][-1] == 'GEMM' 126