1"""
2Tests the accuracy of the opt_einsum paths in addition to unit tests for
3the various path helper functions.
4"""
5
6import itertools
7import sys
8
9import numpy as np
10import pytest
11
12import opt_einsum as oe
13
14explicit_path_tests = {
15    'GEMM1': ([set('abd'), set('ac'), set('bdc')], set(''), {
16        'a': 1,
17        'b': 2,
18        'c': 3,
19        'd': 4
20    }),
21    'Inner1': ([set('abcd'), set('abc'), set('bc')], set(''), {
22        'a': 5,
23        'b': 2,
24        'c': 3,
25        'd': 4
26    }),
27}
28
29# note that these tests have no unique solution due to the chosen dimensions
30path_edge_tests = [
31    ['greedy', 'eb,cb,fb->cef', ((0, 2), (0, 1))],
32    ['branch-all', 'eb,cb,fb->cef', ((0, 2), (0, 1))],
33    ['branch-2', 'eb,cb,fb->cef', ((0, 2), (0, 1))],
34    ['optimal', 'eb,cb,fb->cef', ((0, 2), (0, 1))],
35    ['dp', 'eb,cb,fb->cef', ((1, 2), (0, 1))],
36    ['greedy', 'dd,fb,be,cdb->cef', ((0, 3), (0, 1), (0, 1))],
37    ['branch-all', 'dd,fb,be,cdb->cef', ((0, 3), (0, 1), (0, 1))],
38    ['branch-2', 'dd,fb,be,cdb->cef', ((0, 3), (0, 1), (0, 1))],
39    ['optimal', 'dd,fb,be,cdb->cef', ((0, 3), (0, 1), (0, 1))],
40    ['optimal', 'dd,fb,be,cdb->cef', ((0, 3), (0, 1), (0, 1))],
41    ['dp', 'dd,fb,be,cdb->cef', ((0, 3), (0, 2), (0, 1))],
42    ['greedy', 'bca,cdb,dbf,afc->', ((1, 2), (0, 2), (0, 1))],
43    ['branch-all', 'bca,cdb,dbf,afc->', ((1, 2), (0, 2), (0, 1))],
44    ['branch-2', 'bca,cdb,dbf,afc->', ((1, 2), (0, 2), (0, 1))],
45    ['optimal', 'bca,cdb,dbf,afc->', ((1, 2), (0, 2), (0, 1))],
46    ['dp', 'bca,cdb,dbf,afc->', ((1, 2), (1, 2), (0, 1))],
47    ['greedy', 'dcc,fce,ea,dbf->ab', ((1, 2), (0, 1), (0, 1))],
48    ['branch-all', 'dcc,fce,ea,dbf->ab', ((1, 2), (0, 2), (0, 1))],
49    ['branch-2', 'dcc,fce,ea,dbf->ab', ((1, 2), (0, 2), (0, 1))],
50    ['optimal', 'dcc,fce,ea,dbf->ab', ((1, 2), (0, 2), (0, 1))],
51    ['dp', 'dcc,fce,ea,dbf->ab', ((1, 2), (0, 2), (0, 1))],
52]
53
54
55def check_path(test_output, benchmark, bypass=False):
56    if not isinstance(test_output, list):
57        return False
58
59    if len(test_output) != len(benchmark):
60        return False
61
62    ret = True
63    for pos in range(len(test_output)):
64        ret &= isinstance(test_output[pos], tuple)
65        ret &= test_output[pos] == benchmark[pos]
66    return ret
67
68
69def assert_contract_order(func, test_data, max_size, benchmark):
70
71    test_output = func(test_data[0], test_data[1], test_data[2], max_size)
72    assert check_path(test_output, benchmark)
73
74
75def test_size_by_dict():
76
77    sizes_dict = {}
78    for ind, val in zip('abcdez', [2, 5, 9, 11, 13, 0]):
79        sizes_dict[ind] = val
80
81    path_func = oe.helpers.compute_size_by_dict
82
83    assert 1 == path_func('', sizes_dict)
84    assert 2 == path_func('a', sizes_dict)
85    assert 5 == path_func('b', sizes_dict)
86
87    assert 0 == path_func('z', sizes_dict)
88    assert 0 == path_func('az', sizes_dict)
89    assert 0 == path_func('zbc', sizes_dict)
90
91    assert 104 == path_func('aaae', sizes_dict)
92    assert 12870 == path_func('abcde', sizes_dict)
93
94
95def test_flop_cost():
96
97    size_dict = {v: 10 for v in "abcdef"}
98
99    # Loop over an array
100    assert 10 == oe.helpers.flop_count("a", False, 1, size_dict)
101
102    # Hadamard product (*)
103    assert 10 == oe.helpers.flop_count("a", False, 2, size_dict)
104    assert 100 == oe.helpers.flop_count("ab", False, 2, size_dict)
105
106    # Inner product (+, *)
107    assert 20 == oe.helpers.flop_count("a", True, 2, size_dict)
108    assert 200 == oe.helpers.flop_count("ab", True, 2, size_dict)
109
110    # Inner product x3 (+, *, *)
111    assert 30 == oe.helpers.flop_count("a", True, 3, size_dict)
112
113    # GEMM
114    assert 2000 == oe.helpers.flop_count("abc", True, 2, size_dict)
115
116
117def test_bad_path_option():
118    with pytest.raises(KeyError):
119        oe.contract("a,b,c", [1], [2], [3], optimize='optimall')
120
121
122def test_explicit_path():
123    x = oe.contract("a,b,c", [1], [2], [3], optimize=[(1, 2), (0, 1)])
124    assert x.item() == 6
125
126
127def test_path_optimal():
128
129    test_func = oe.paths.optimal
130
131    test_data = explicit_path_tests['GEMM1']
132    assert_contract_order(test_func, test_data, 5000, [(0, 2), (0, 1)])
133    assert_contract_order(test_func, test_data, 0, [(0, 1, 2)])
134
135
136def test_path_greedy():
137
138    test_func = oe.paths.greedy
139
140    test_data = explicit_path_tests['GEMM1']
141    assert_contract_order(test_func, test_data, 5000, [(0, 2), (0, 1)])
142    assert_contract_order(test_func, test_data, 0, [(0, 1, 2)])
143
144
145def test_memory_paths():
146
147    expression = "abc,bdef,fghj,cem,mhk,ljk->adgl"
148
149    views = oe.helpers.build_views(expression)
150
151    # Test tiny memory limit
152    path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=5)
153    assert check_path(path_ret[0], [(0, 1, 2, 3, 4, 5)])
154
155    path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=5)
156    assert check_path(path_ret[0], [(0, 1, 2, 3, 4, 5)])
157
158    # Check the possibilities, greedy is capped
159    path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=-1)
160    assert check_path(path_ret[0], [(0, 3), (0, 4), (0, 2), (0, 2), (0, 1)])
161
162    path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=-1)
163    assert check_path(path_ret[0], [(0, 3), (0, 4), (0, 2), (0, 2), (0, 1)])
164
165
166@pytest.mark.parametrize("alg,expression,order", path_edge_tests)
167def test_path_edge_cases(alg, expression, order):
168    views = oe.helpers.build_views(expression)
169
170    # Test tiny memory limit
171    path_ret = oe.contract_path(expression, *views, optimize=alg)
172    assert check_path(path_ret[0], order)
173
174
175def test_optimal_edge_cases():
176
177    # Edge test5
178    expression = 'a,ac,ab,ad,cd,bd,bc->'
179    edge_test4 = oe.helpers.build_views(expression, dimension_dict={"a": 20, "b": 20, "c": 20, "d": 20})
180    path, path_str = oe.contract_path(expression, *edge_test4, optimize='greedy', memory_limit='max_input')
181    assert check_path(path, [(0, 1), (0, 1, 2, 3, 4, 5)])
182
183    path, path_str = oe.contract_path(expression, *edge_test4, optimize='optimal', memory_limit='max_input')
184    assert check_path(path, [(0, 1), (0, 1, 2, 3, 4, 5)])
185
186
187def test_greedy_edge_cases():
188
189    expression = "abc,cfd,dbe,efa"
190    dim_dict = {k: 20 for k in expression.replace(",", "")}
191    tensors = oe.helpers.build_views(expression, dimension_dict=dim_dict)
192
193    path, path_str = oe.contract_path(expression, *tensors, optimize='greedy', memory_limit='max_input')
194    assert check_path(path, [(0, 1, 2, 3)])
195
196    path, path_str = oe.contract_path(expression, *tensors, optimize='greedy', memory_limit=-1)
197    assert check_path(path, [(0, 1), (0, 2), (0, 1)])
198
199
200def test_dp_edge_cases_dimension_1():
201    eq = 'nlp,nlq,pl->n'
202    shapes = [(1, 1, 1), (1, 1, 1), (1, 1)]
203    info = oe.contract_path(eq, *shapes, shapes=True, optimize='dp')[1]
204    assert max(info.scale_list) == 3
205
206
207def test_dp_edge_cases_all_singlet_indices():
208    eq = 'a,bcd,efg->'
209    shapes = [(2, ), (2, 2, 2), (2, 2, 2)]
210    info = oe.contract_path(eq, *shapes, shapes=True, optimize='dp')[1]
211    assert max(info.scale_list) == 3
212
213
214def test_custom_dp_can_optimize_for_outer_products():
215    eq = "a,b,abc->c"
216
217    da, db, dc = 2, 2, 3
218    shapes = [(da, ), (db, ), (da, db, dc)]
219
220    opt1 = oe.DynamicProgramming(search_outer=False)
221    opt2 = oe.DynamicProgramming(search_outer=True)
222
223    info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1]
224    info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1]
225
226    assert info2.opt_cost < info1.opt_cost
227
228
229def test_custom_dp_can_optimize_for_size():
230    eq, shapes = oe.helpers.rand_equation(10, 4, seed=43)
231
232    opt1 = oe.DynamicProgramming(minimize='flops')
233    opt2 = oe.DynamicProgramming(minimize='size')
234
235    info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1]
236    info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1]
237
238    assert (info1.opt_cost < info2.opt_cost)
239    assert (info1.largest_intermediate > info2.largest_intermediate)
240
241
242def test_custom_dp_can_set_cost_cap():
243    eq, shapes = oe.helpers.rand_equation(5, 3, seed=42)
244    opt1 = oe.DynamicProgramming(cost_cap=True)
245    opt2 = oe.DynamicProgramming(cost_cap=False)
246    opt3 = oe.DynamicProgramming(cost_cap=100)
247    info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1]
248    info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1]
249    info3 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt3)[1]
250    assert info1.opt_cost == info2.opt_cost == info3.opt_cost
251
252
253@pytest.mark.parametrize("optimize", ['greedy', 'branch-2', 'branch-all', 'optimal', 'dp'])
254def test_can_optimize_outer_products(optimize):
255    a, b, c = [np.random.randn(10, 10) for _ in range(3)]
256    d = np.random.randn(10, 2)
257    assert oe.contract_path("ab,cd,ef,fg", a, b, c, d, optimize=optimize)[0] == [(2, 3), (0, 2), (0, 1)]
258
259
260@pytest.mark.parametrize('num_symbols', [2, 3, 26, 26 + 26, 256 - 140, 300])
261def test_large_path(num_symbols):
262    symbols = ''.join(oe.get_symbol(i) for i in range(num_symbols))
263    dimension_dict = dict(zip(symbols, itertools.cycle([2, 3, 4])))
264    expression = ','.join(symbols[t:t + 2] for t in range(num_symbols - 1))
265    tensors = oe.helpers.build_views(expression, dimension_dict=dimension_dict)
266
267    # Check that path construction does not crash
268    oe.contract_path(expression, *tensors, optimize='greedy')
269
270
271def test_custom_random_greedy():
272    eq, shapes = oe.helpers.rand_equation(10, 4, seed=42)
273    views = list(map(np.ones, shapes))
274
275    with pytest.raises(ValueError):
276        oe.RandomGreedy(minimize='something')
277
278    optimizer = oe.RandomGreedy(max_repeats=10, minimize='flops')
279    path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
280
281    assert len(optimizer.costs) == 10
282    assert len(optimizer.sizes) == 10
283
284    assert path == optimizer.path
285    assert optimizer.best['flops'] == min(optimizer.costs)
286    assert path_info.largest_intermediate == optimizer.best['size']
287    assert path_info.opt_cost == optimizer.best['flops']
288
289    # check can change settings and run again
290    optimizer.temperature = 0.0
291    optimizer.max_repeats = 6
292    path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
293
294    assert len(optimizer.costs) == 16
295    assert len(optimizer.sizes) == 16
296
297    assert path == optimizer.path
298    assert optimizer.best['size'] == min(optimizer.sizes)
299    assert path_info.largest_intermediate == optimizer.best['size']
300    assert path_info.opt_cost == optimizer.best['flops']
301
302    # check error if we try and reuse the optimizer on a different expression
303    eq, shapes = oe.helpers.rand_equation(10, 4, seed=41)
304    views = list(map(np.ones, shapes))
305    with pytest.raises(ValueError):
306        path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
307
308
309def test_custom_branchbound():
310    eq, shapes = oe.helpers.rand_equation(8, 4, seed=42)
311    views = list(map(np.ones, shapes))
312    optimizer = oe.BranchBound(nbranch=2, cutoff_flops_factor=10, minimize='size')
313
314    path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
315
316    assert path == optimizer.path
317    assert path_info.largest_intermediate == optimizer.best['size']
318    assert path_info.opt_cost == optimizer.best['flops']
319
320    # tweak settings and run again
321    optimizer.nbranch = 3
322    optimizer.cutoff_flops_factor = 4
323    path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
324
325    assert path == optimizer.path
326    assert path_info.largest_intermediate == optimizer.best['size']
327    assert path_info.opt_cost == optimizer.best['flops']
328
329    # check error if we try and reuse the optimizer on a different expression
330    eq, shapes = oe.helpers.rand_equation(8, 4, seed=41)
331    views = list(map(np.ones, shapes))
332    with pytest.raises(ValueError):
333        path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
334
335
336@pytest.mark.skipif(sys.version_info < (3, 2), reason="requires python3.2 or higher")
337def test_parallel_random_greedy():
338    from concurrent.futures import ProcessPoolExecutor
339    pool = ProcessPoolExecutor(2)
340
341    eq, shapes = oe.helpers.rand_equation(10, 4, seed=42)
342    views = list(map(np.ones, shapes))
343
344    optimizer = oe.RandomGreedy(max_repeats=10, parallel=pool)
345    path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
346
347    assert len(optimizer.costs) == 10
348    assert len(optimizer.sizes) == 10
349
350    assert path == optimizer.path
351    assert optimizer.parallel is pool
352    assert optimizer._executor is pool
353    assert optimizer.best['flops'] == min(optimizer.costs)
354    assert path_info.largest_intermediate == optimizer.best['size']
355    assert path_info.opt_cost == optimizer.best['flops']
356
357    # now switch to max time algorithm
358    optimizer.max_repeats = int(1e6)
359    optimizer.max_time = 0.2
360    optimizer.parallel = 2
361
362    path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
363
364    assert len(optimizer.costs) > 10
365    assert len(optimizer.sizes) > 10
366
367    assert path == optimizer.path
368    assert optimizer.best['flops'] == min(optimizer.costs)
369    assert path_info.largest_intermediate == optimizer.best['size']
370    assert path_info.opt_cost == optimizer.best['flops']
371
372    optimizer.parallel = True
373    assert optimizer._executor is not None
374    assert optimizer._executor is not pool
375
376    are_done = [f.running() or f.done() for f in optimizer._futures]
377    assert all(are_done)
378
379
380def test_custom_path_optimizer():
381    class NaiveOptimizer(oe.paths.PathOptimizer):
382        def __call__(self, inputs, output, size_dict, memory_limit=None):
383            self.was_used = True
384            return [(0, 1)] * (len(inputs) - 1)
385
386    eq, shapes = oe.helpers.rand_equation(5, 3, seed=42, d_max=3)
387    views = list(map(np.ones, shapes))
388
389    exp = oe.contract(eq, *views, optimize=False)
390
391    optimizer = NaiveOptimizer()
392    out = oe.contract(eq, *views, optimize=optimizer)
393    assert exp == out
394    assert optimizer.was_used
395
396
397def test_custom_random_optimizer():
398    class NaiveRandomOptimizer(oe.path_random.RandomOptimizer):
399        @staticmethod
400        def random_path(r, n, inputs, output, size_dict):
401            """Picks a completely random contraction order.
402            """
403            np.random.seed(r)
404            ssa_path = []
405            remaining = set(range(n))
406            while len(remaining) > 1:
407                i, j = np.random.choice(list(remaining), size=2, replace=False)
408                remaining.add(n + len(ssa_path))
409                remaining.remove(i)
410                remaining.remove(j)
411                ssa_path.append((i, j))
412            cost, size = oe.path_random.ssa_path_compute_cost(ssa_path, inputs, output, size_dict)
413            return ssa_path, cost, size
414
415        def setup(self, inputs, output, size_dict):
416            self.was_used = True
417            n = len(inputs)
418            trial_fn = self.random_path
419            trial_args = (n, inputs, output, size_dict)
420            return trial_fn, trial_args
421
422    eq, shapes = oe.helpers.rand_equation(5, 3, seed=42, d_max=3)
423    views = list(map(np.ones, shapes))
424
425    exp = oe.contract(eq, *views, optimize=False)
426
427    optimizer = NaiveRandomOptimizer(max_repeats=16)
428    out = oe.contract(eq, *views, optimize=optimizer)
429    assert exp == out
430    assert optimizer.was_used
431
432    assert len(optimizer.costs) == 16
433
434
435def test_optimizer_registration():
436    def custom_optimizer(inputs, output, size_dict, memory_limit):
437        return [(0, 1)] * (len(inputs) - 1)
438
439    with pytest.raises(KeyError):
440        oe.paths.register_path_fn('optimal', custom_optimizer)
441
442    oe.paths.register_path_fn('custom', custom_optimizer)
443    assert 'custom' in oe.paths._PATH_OPTIONS
444
445    eq = 'ab,bc,cd'
446    shapes = [(2, 3), (3, 4), (4, 5)]
447    path, path_info = oe.contract_path(eq, *shapes, shapes=True, optimize='custom')
448    assert path == [(0, 1), (0, 1)]
449    del oe.paths._PATH_OPTIONS['custom']
450