1"""
2Tests the input parsing for opt_einsum. Duplicates the np.einsum input tests.
3"""
4
5import numpy as np
6import pytest
7
8from opt_einsum import contract, contract_path
9
10
11def build_views(string):
12    chars = 'abcdefghij'
13    sizes = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4])
14    sizes = {c: s for c, s in zip(chars, sizes)}
15
16    views = []
17
18    string = string.replace('...', 'ij')
19
20    terms = string.split('->')[0].split(',')
21    for term in terms:
22        dims = [sizes[x] for x in term]
23        views.append(np.random.rand(*dims))
24    return views
25
26
27def test_type_errors():
28    # subscripts must be a string
29    with pytest.raises(TypeError):
30        contract(0, 0)
31
32    # out parameter must be an array
33    with pytest.raises(TypeError):
34        contract("", 0, out='test')
35
36    # order parameter must be a valid order
37    # changed in Numpy 1.19, see https://github.com/numpy/numpy/commit/35b0a051c19265f5643f6011ee11e31d30c8bc4c
38    with pytest.raises((TypeError, ValueError)):
39        contract("", 0, order='W')
40
41    # casting parameter must be a valid casting
42    with pytest.raises(ValueError):
43        contract("", 0, casting='blah')
44
45    # dtype parameter must be a valid dtype
46    with pytest.raises(TypeError):
47        contract("", 0, dtype='bad_data_type')
48
49    # other keyword arguments are rejected
50    with pytest.raises(TypeError):
51        contract("", 0, bad_arg=0)
52
53    # issue 4528 revealed a segfault with this call
54    with pytest.raises(TypeError):
55        contract(*(None, ) * 63)
56
57    # Cannot have two ->
58    with pytest.raises(ValueError):
59        contract("->,->", 0, 5)
60
61    # Undefined symbol lhs
62    with pytest.raises(ValueError):
63        contract("&,a->", 0, 5)
64
65    # Undefined symbol rhs
66    with pytest.raises(ValueError):
67        contract("a,a->&", 0, 5)
68
69    with pytest.raises(ValueError):
70        contract("a,a->&", 0, 5)
71
72    # Catch ellipsis errors
73    string = '...a->...a'
74    views = build_views(string)
75
76    # Subscript list must contain Ellipsis or (hashable && comparable) object
77    with pytest.raises(TypeError):
78        contract(views[0], [Ellipsis, 0], [Ellipsis, ['a']])
79
80    with pytest.raises(TypeError):
81        contract(views[0], [Ellipsis, dict()], [Ellipsis, 'a'])
82
83
84def test_value_errors():
85    with pytest.raises(ValueError):
86        contract("")
87
88    # subscripts must be a string
89    with pytest.raises(TypeError):
90        contract(0, 0)
91
92    # invalid subscript character
93    with pytest.raises(ValueError):
94        contract("i%...", [0, 0])
95    with pytest.raises(ValueError):
96        contract("...j$", [0, 0])
97    with pytest.raises(ValueError):
98        contract("i->&", [0, 0])
99
100    with pytest.raises(ValueError):
101        contract("")
102    # number of operands must match count in subscripts string
103    with pytest.raises(ValueError):
104        contract("", 0, 0)
105    with pytest.raises(ValueError):
106        contract(",", 0, [0], [0])
107    with pytest.raises(ValueError):
108        contract(",", [0])
109
110    # can't have more subscripts than dimensions in the operand
111    with pytest.raises(ValueError):
112        contract("i", 0)
113    with pytest.raises(ValueError):
114        contract("ij", [0, 0])
115    with pytest.raises(ValueError):
116        contract("...i", 0)
117    with pytest.raises(ValueError):
118        contract("i...j", [0, 0])
119    with pytest.raises(ValueError):
120        contract("i...", 0)
121    with pytest.raises(ValueError):
122        contract("ij...", [0, 0])
123
124    # invalid ellipsis
125    with pytest.raises(ValueError):
126        contract("i..", [0, 0])
127    with pytest.raises(ValueError):
128        contract(".i...", [0, 0])
129    with pytest.raises(ValueError):
130        contract("j->..j", [0, 0])
131    with pytest.raises(ValueError):
132        contract("j->.j...", [0, 0])
133
134    # invalid subscript character
135    with pytest.raises(ValueError):
136        contract("i%...", [0, 0])
137    with pytest.raises(ValueError):
138        contract("...j$", [0, 0])
139    with pytest.raises(ValueError):
140        contract("i->&", [0, 0])
141
142    # output subscripts must appear in input
143    with pytest.raises(ValueError):
144        contract("i->ij", [0, 0])
145
146    # output subscripts may only be specified once
147    with pytest.raises(ValueError):
148        contract("ij->jij", [[0, 0], [0, 0]])
149
150    # dimensions much match when being collapsed
151    with pytest.raises(ValueError):
152        contract("ii", np.arange(6).reshape(2, 3))
153    with pytest.raises(ValueError):
154        contract("ii->i", np.arange(6).reshape(2, 3))
155
156    # broadcasting to new dimensions must be enabled explicitly
157    with pytest.raises(ValueError):
158        contract("i", np.arange(6).reshape(2, 3))
159    with pytest.raises(ValueError):
160        contract("i->i", [[0, 1], [0, 1]], out=np.arange(4).reshape(2, 2))
161
162
163def test_contract_inputs():
164
165    with pytest.raises(TypeError):
166        contract_path("i->i", [[0, 1], [0, 1]], bad_kwarg=True)
167
168    with pytest.raises(ValueError):
169        contract_path("i->i", [[0, 1], [0, 1]], memory_limit=-1)
170
171
172@pytest.mark.parametrize(
173    "string",
174    [
175        # Ellipse
176        '...a->...',
177        'a...->...',
178        'a...a->...a',
179        '...,...',
180        'a,b',
181        '...a,...b',
182    ])
183def test_compare(string):
184    views = build_views(string)
185
186    ein = contract(string, *views, optimize=False)
187    opt = contract(string, *views)
188    assert np.allclose(ein, opt)
189
190    opt = contract(string, *views, optimize='optimal')
191    assert np.allclose(ein, opt)
192
193
194def test_ellipse_input1():
195    string = '...a->...'
196    views = build_views(string)
197
198    ein = contract(string, *views, optimize=False)
199    opt = contract(views[0], [Ellipsis, 0], [Ellipsis])
200    assert np.allclose(ein, opt)
201
202
203def test_ellipse_input2():
204    string = '...a'
205    views = build_views(string)
206
207    ein = contract(string, *views, optimize=False)
208    opt = contract(views[0], [Ellipsis, 0])
209    assert np.allclose(ein, opt)
210
211
212def test_ellipse_input3():
213    string = '...a->...a'
214    views = build_views(string)
215
216    ein = contract(string, *views, optimize=False)
217    opt = contract(views[0], [Ellipsis, 0], [Ellipsis, 0])
218    assert np.allclose(ein, opt)
219
220
221def test_ellipse_input4():
222    string = '...b,...a->...'
223    views = build_views(string)
224
225    ein = contract(string, *views, optimize=False)
226    opt = contract(views[0], [Ellipsis, 1], views[1], [Ellipsis, 0], [Ellipsis])
227    assert np.allclose(ein, opt)
228
229
230def test_singleton_dimension_broadcast():
231    # singleton dimensions broadcast (gh-10343)
232    p = np.ones((10, 2))
233    q = np.ones((1, 2))
234
235    ein = contract('ij,ij->j', p, q, optimize=False)
236    opt = contract('ij,ij->j', p, q, optimize=True)
237    assert np.allclose(ein, opt)
238    assert np.allclose(opt, [10., 10.])
239
240    p = np.ones((1, 5))
241    q = np.ones((5, 5))
242
243    for optimize in (True, False):
244        res1 = contract("...ij,...jk->...ik", p, p, optimize=optimize),
245        res2 = contract("...ij,...jk->...ik", p, q, optimize=optimize)
246        assert np.allclose(res1, res2)
247        assert np.allclose(res2, np.full((1, 5), 5))
248
249
250def test_large_int_input_format():
251    string = 'ab,bc,cd'
252    x, y, z = build_views(string)
253    string_output = contract(string, x, y, z)
254    int_output = contract(x, (1000, 1001), y, (1001, 1002), z, (1002, 1003))
255    assert np.allclose(string_output, int_output)
256    for i in range(10):
257        transpose_output = contract(x, (i + 1, i))
258        assert np.allclose(transpose_output, x.T)
259
260
261def test_hashable_object_input_format():
262    string = 'ab,bc,cd'
263    x, y, z = build_views(string)
264    string_output = contract(string, x, y, z)
265    hash_output1 = contract(x, ('left', 'bond1'), y, ('bond1', 'bond2'), z, ('bond2', 'right'))
266    hash_output2 = contract(x, ('left', 'bond1'), y, ('bond1', 'bond2'), z, ('bond2', 'right'), ('left', 'right'))
267    assert np.allclose(string_output, hash_output1)
268    assert np.allclose(hash_output1, hash_output2)
269    for i in range(1, 10):
270        transpose_output = contract(x, ('b' * i, 'a' * i))
271        assert np.allclose(transpose_output, x.T)
272