1""" 2Tests the accuracy of the opt_einsum paths in addition to unit tests for 3the various path helper functions. 4""" 5 6import itertools 7import sys 8 9import numpy as np 10import pytest 11 12import opt_einsum as oe 13 14explicit_path_tests = { 15 'GEMM1': ([set('abd'), set('ac'), set('bdc')], set(''), { 16 'a': 1, 17 'b': 2, 18 'c': 3, 19 'd': 4 20 }), 21 'Inner1': ([set('abcd'), set('abc'), set('bc')], set(''), { 22 'a': 5, 23 'b': 2, 24 'c': 3, 25 'd': 4 26 }), 27} 28 29# note that these tests have no unique solution due to the chosen dimensions 30path_edge_tests = [ 31 ['greedy', 'eb,cb,fb->cef', ((0, 2), (0, 1))], 32 ['branch-all', 'eb,cb,fb->cef', ((0, 2), (0, 1))], 33 ['branch-2', 'eb,cb,fb->cef', ((0, 2), (0, 1))], 34 ['optimal', 'eb,cb,fb->cef', ((0, 2), (0, 1))], 35 ['dp', 'eb,cb,fb->cef', ((1, 2), (0, 1))], 36 ['greedy', 'dd,fb,be,cdb->cef', ((0, 3), (0, 1), (0, 1))], 37 ['branch-all', 'dd,fb,be,cdb->cef', ((0, 3), (0, 1), (0, 1))], 38 ['branch-2', 'dd,fb,be,cdb->cef', ((0, 3), (0, 1), (0, 1))], 39 ['optimal', 'dd,fb,be,cdb->cef', ((0, 3), (0, 1), (0, 1))], 40 ['optimal', 'dd,fb,be,cdb->cef', ((0, 3), (0, 1), (0, 1))], 41 ['dp', 'dd,fb,be,cdb->cef', ((0, 3), (0, 2), (0, 1))], 42 ['greedy', 'bca,cdb,dbf,afc->', ((1, 2), (0, 2), (0, 1))], 43 ['branch-all', 'bca,cdb,dbf,afc->', ((1, 2), (0, 2), (0, 1))], 44 ['branch-2', 'bca,cdb,dbf,afc->', ((1, 2), (0, 2), (0, 1))], 45 ['optimal', 'bca,cdb,dbf,afc->', ((1, 2), (0, 2), (0, 1))], 46 ['dp', 'bca,cdb,dbf,afc->', ((1, 2), (1, 2), (0, 1))], 47 ['greedy', 'dcc,fce,ea,dbf->ab', ((1, 2), (0, 1), (0, 1))], 48 ['branch-all', 'dcc,fce,ea,dbf->ab', ((1, 2), (0, 2), (0, 1))], 49 ['branch-2', 'dcc,fce,ea,dbf->ab', ((1, 2), (0, 2), (0, 1))], 50 ['optimal', 'dcc,fce,ea,dbf->ab', ((1, 2), (0, 2), (0, 1))], 51 ['dp', 'dcc,fce,ea,dbf->ab', ((1, 2), (0, 2), (0, 1))], 52] 53 54 55def check_path(test_output, benchmark, bypass=False): 56 if not isinstance(test_output, list): 57 return False 58 59 if len(test_output) != len(benchmark): 60 return False 61 62 ret = True 63 for pos in range(len(test_output)): 64 ret &= isinstance(test_output[pos], tuple) 65 ret &= test_output[pos] == benchmark[pos] 66 return ret 67 68 69def assert_contract_order(func, test_data, max_size, benchmark): 70 71 test_output = func(test_data[0], test_data[1], test_data[2], max_size) 72 assert check_path(test_output, benchmark) 73 74 75def test_size_by_dict(): 76 77 sizes_dict = {} 78 for ind, val in zip('abcdez', [2, 5, 9, 11, 13, 0]): 79 sizes_dict[ind] = val 80 81 path_func = oe.helpers.compute_size_by_dict 82 83 assert 1 == path_func('', sizes_dict) 84 assert 2 == path_func('a', sizes_dict) 85 assert 5 == path_func('b', sizes_dict) 86 87 assert 0 == path_func('z', sizes_dict) 88 assert 0 == path_func('az', sizes_dict) 89 assert 0 == path_func('zbc', sizes_dict) 90 91 assert 104 == path_func('aaae', sizes_dict) 92 assert 12870 == path_func('abcde', sizes_dict) 93 94 95def test_flop_cost(): 96 97 size_dict = {v: 10 for v in "abcdef"} 98 99 # Loop over an array 100 assert 10 == oe.helpers.flop_count("a", False, 1, size_dict) 101 102 # Hadamard product (*) 103 assert 10 == oe.helpers.flop_count("a", False, 2, size_dict) 104 assert 100 == oe.helpers.flop_count("ab", False, 2, size_dict) 105 106 # Inner product (+, *) 107 assert 20 == oe.helpers.flop_count("a", True, 2, size_dict) 108 assert 200 == oe.helpers.flop_count("ab", True, 2, size_dict) 109 110 # Inner product x3 (+, *, *) 111 assert 30 == oe.helpers.flop_count("a", True, 3, size_dict) 112 113 # GEMM 114 assert 2000 == oe.helpers.flop_count("abc", True, 2, size_dict) 115 116 117def test_bad_path_option(): 118 with pytest.raises(KeyError): 119 oe.contract("a,b,c", [1], [2], [3], optimize='optimall') 120 121 122def test_explicit_path(): 123 x = oe.contract("a,b,c", [1], [2], [3], optimize=[(1, 2), (0, 1)]) 124 assert x.item() == 6 125 126 127def test_path_optimal(): 128 129 test_func = oe.paths.optimal 130 131 test_data = explicit_path_tests['GEMM1'] 132 assert_contract_order(test_func, test_data, 5000, [(0, 2), (0, 1)]) 133 assert_contract_order(test_func, test_data, 0, [(0, 1, 2)]) 134 135 136def test_path_greedy(): 137 138 test_func = oe.paths.greedy 139 140 test_data = explicit_path_tests['GEMM1'] 141 assert_contract_order(test_func, test_data, 5000, [(0, 2), (0, 1)]) 142 assert_contract_order(test_func, test_data, 0, [(0, 1, 2)]) 143 144 145def test_memory_paths(): 146 147 expression = "abc,bdef,fghj,cem,mhk,ljk->adgl" 148 149 views = oe.helpers.build_views(expression) 150 151 # Test tiny memory limit 152 path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=5) 153 assert check_path(path_ret[0], [(0, 1, 2, 3, 4, 5)]) 154 155 path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=5) 156 assert check_path(path_ret[0], [(0, 1, 2, 3, 4, 5)]) 157 158 # Check the possibilities, greedy is capped 159 path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=-1) 160 assert check_path(path_ret[0], [(0, 3), (0, 4), (0, 2), (0, 2), (0, 1)]) 161 162 path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=-1) 163 assert check_path(path_ret[0], [(0, 3), (0, 4), (0, 2), (0, 2), (0, 1)]) 164 165 166@pytest.mark.parametrize("alg,expression,order", path_edge_tests) 167def test_path_edge_cases(alg, expression, order): 168 views = oe.helpers.build_views(expression) 169 170 # Test tiny memory limit 171 path_ret = oe.contract_path(expression, *views, optimize=alg) 172 assert check_path(path_ret[0], order) 173 174 175def test_optimal_edge_cases(): 176 177 # Edge test5 178 expression = 'a,ac,ab,ad,cd,bd,bc->' 179 edge_test4 = oe.helpers.build_views(expression, dimension_dict={"a": 20, "b": 20, "c": 20, "d": 20}) 180 path, path_str = oe.contract_path(expression, *edge_test4, optimize='greedy', memory_limit='max_input') 181 assert check_path(path, [(0, 1), (0, 1, 2, 3, 4, 5)]) 182 183 path, path_str = oe.contract_path(expression, *edge_test4, optimize='optimal', memory_limit='max_input') 184 assert check_path(path, [(0, 1), (0, 1, 2, 3, 4, 5)]) 185 186 187def test_greedy_edge_cases(): 188 189 expression = "abc,cfd,dbe,efa" 190 dim_dict = {k: 20 for k in expression.replace(",", "")} 191 tensors = oe.helpers.build_views(expression, dimension_dict=dim_dict) 192 193 path, path_str = oe.contract_path(expression, *tensors, optimize='greedy', memory_limit='max_input') 194 assert check_path(path, [(0, 1, 2, 3)]) 195 196 path, path_str = oe.contract_path(expression, *tensors, optimize='greedy', memory_limit=-1) 197 assert check_path(path, [(0, 1), (0, 2), (0, 1)]) 198 199 200def test_dp_edge_cases_dimension_1(): 201 eq = 'nlp,nlq,pl->n' 202 shapes = [(1, 1, 1), (1, 1, 1), (1, 1)] 203 info = oe.contract_path(eq, *shapes, shapes=True, optimize='dp')[1] 204 assert max(info.scale_list) == 3 205 206 207def test_dp_edge_cases_all_singlet_indices(): 208 eq = 'a,bcd,efg->' 209 shapes = [(2, ), (2, 2, 2), (2, 2, 2)] 210 info = oe.contract_path(eq, *shapes, shapes=True, optimize='dp')[1] 211 assert max(info.scale_list) == 3 212 213 214def test_custom_dp_can_optimize_for_outer_products(): 215 eq = "a,b,abc->c" 216 217 da, db, dc = 2, 2, 3 218 shapes = [(da, ), (db, ), (da, db, dc)] 219 220 opt1 = oe.DynamicProgramming(search_outer=False) 221 opt2 = oe.DynamicProgramming(search_outer=True) 222 223 info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1] 224 info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1] 225 226 assert info2.opt_cost < info1.opt_cost 227 228 229def test_custom_dp_can_optimize_for_size(): 230 eq, shapes = oe.helpers.rand_equation(10, 4, seed=43) 231 232 opt1 = oe.DynamicProgramming(minimize='flops') 233 opt2 = oe.DynamicProgramming(minimize='size') 234 235 info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1] 236 info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1] 237 238 assert (info1.opt_cost < info2.opt_cost) 239 assert (info1.largest_intermediate > info2.largest_intermediate) 240 241 242def test_custom_dp_can_set_cost_cap(): 243 eq, shapes = oe.helpers.rand_equation(5, 3, seed=42) 244 opt1 = oe.DynamicProgramming(cost_cap=True) 245 opt2 = oe.DynamicProgramming(cost_cap=False) 246 opt3 = oe.DynamicProgramming(cost_cap=100) 247 info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1] 248 info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1] 249 info3 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt3)[1] 250 assert info1.opt_cost == info2.opt_cost == info3.opt_cost 251 252 253@pytest.mark.parametrize("optimize", ['greedy', 'branch-2', 'branch-all', 'optimal', 'dp']) 254def test_can_optimize_outer_products(optimize): 255 a, b, c = [np.random.randn(10, 10) for _ in range(3)] 256 d = np.random.randn(10, 2) 257 assert oe.contract_path("ab,cd,ef,fg", a, b, c, d, optimize=optimize)[0] == [(2, 3), (0, 2), (0, 1)] 258 259 260@pytest.mark.parametrize('num_symbols', [2, 3, 26, 26 + 26, 256 - 140, 300]) 261def test_large_path(num_symbols): 262 symbols = ''.join(oe.get_symbol(i) for i in range(num_symbols)) 263 dimension_dict = dict(zip(symbols, itertools.cycle([2, 3, 4]))) 264 expression = ','.join(symbols[t:t + 2] for t in range(num_symbols - 1)) 265 tensors = oe.helpers.build_views(expression, dimension_dict=dimension_dict) 266 267 # Check that path construction does not crash 268 oe.contract_path(expression, *tensors, optimize='greedy') 269 270 271def test_custom_random_greedy(): 272 eq, shapes = oe.helpers.rand_equation(10, 4, seed=42) 273 views = list(map(np.ones, shapes)) 274 275 with pytest.raises(ValueError): 276 oe.RandomGreedy(minimize='something') 277 278 optimizer = oe.RandomGreedy(max_repeats=10, minimize='flops') 279 path, path_info = oe.contract_path(eq, *views, optimize=optimizer) 280 281 assert len(optimizer.costs) == 10 282 assert len(optimizer.sizes) == 10 283 284 assert path == optimizer.path 285 assert optimizer.best['flops'] == min(optimizer.costs) 286 assert path_info.largest_intermediate == optimizer.best['size'] 287 assert path_info.opt_cost == optimizer.best['flops'] 288 289 # check can change settings and run again 290 optimizer.temperature = 0.0 291 optimizer.max_repeats = 6 292 path, path_info = oe.contract_path(eq, *views, optimize=optimizer) 293 294 assert len(optimizer.costs) == 16 295 assert len(optimizer.sizes) == 16 296 297 assert path == optimizer.path 298 assert optimizer.best['size'] == min(optimizer.sizes) 299 assert path_info.largest_intermediate == optimizer.best['size'] 300 assert path_info.opt_cost == optimizer.best['flops'] 301 302 # check error if we try and reuse the optimizer on a different expression 303 eq, shapes = oe.helpers.rand_equation(10, 4, seed=41) 304 views = list(map(np.ones, shapes)) 305 with pytest.raises(ValueError): 306 path, path_info = oe.contract_path(eq, *views, optimize=optimizer) 307 308 309def test_custom_branchbound(): 310 eq, shapes = oe.helpers.rand_equation(8, 4, seed=42) 311 views = list(map(np.ones, shapes)) 312 optimizer = oe.BranchBound(nbranch=2, cutoff_flops_factor=10, minimize='size') 313 314 path, path_info = oe.contract_path(eq, *views, optimize=optimizer) 315 316 assert path == optimizer.path 317 assert path_info.largest_intermediate == optimizer.best['size'] 318 assert path_info.opt_cost == optimizer.best['flops'] 319 320 # tweak settings and run again 321 optimizer.nbranch = 3 322 optimizer.cutoff_flops_factor = 4 323 path, path_info = oe.contract_path(eq, *views, optimize=optimizer) 324 325 assert path == optimizer.path 326 assert path_info.largest_intermediate == optimizer.best['size'] 327 assert path_info.opt_cost == optimizer.best['flops'] 328 329 # check error if we try and reuse the optimizer on a different expression 330 eq, shapes = oe.helpers.rand_equation(8, 4, seed=41) 331 views = list(map(np.ones, shapes)) 332 with pytest.raises(ValueError): 333 path, path_info = oe.contract_path(eq, *views, optimize=optimizer) 334 335 336@pytest.mark.skipif(sys.version_info < (3, 2), reason="requires python3.2 or higher") 337def test_parallel_random_greedy(): 338 from concurrent.futures import ProcessPoolExecutor 339 pool = ProcessPoolExecutor(2) 340 341 eq, shapes = oe.helpers.rand_equation(10, 4, seed=42) 342 views = list(map(np.ones, shapes)) 343 344 optimizer = oe.RandomGreedy(max_repeats=10, parallel=pool) 345 path, path_info = oe.contract_path(eq, *views, optimize=optimizer) 346 347 assert len(optimizer.costs) == 10 348 assert len(optimizer.sizes) == 10 349 350 assert path == optimizer.path 351 assert optimizer.parallel is pool 352 assert optimizer._executor is pool 353 assert optimizer.best['flops'] == min(optimizer.costs) 354 assert path_info.largest_intermediate == optimizer.best['size'] 355 assert path_info.opt_cost == optimizer.best['flops'] 356 357 # now switch to max time algorithm 358 optimizer.max_repeats = int(1e6) 359 optimizer.max_time = 0.2 360 optimizer.parallel = 2 361 362 path, path_info = oe.contract_path(eq, *views, optimize=optimizer) 363 364 assert len(optimizer.costs) > 10 365 assert len(optimizer.sizes) > 10 366 367 assert path == optimizer.path 368 assert optimizer.best['flops'] == min(optimizer.costs) 369 assert path_info.largest_intermediate == optimizer.best['size'] 370 assert path_info.opt_cost == optimizer.best['flops'] 371 372 optimizer.parallel = True 373 assert optimizer._executor is not None 374 assert optimizer._executor is not pool 375 376 are_done = [f.running() or f.done() for f in optimizer._futures] 377 assert all(are_done) 378 379 380def test_custom_path_optimizer(): 381 class NaiveOptimizer(oe.paths.PathOptimizer): 382 def __call__(self, inputs, output, size_dict, memory_limit=None): 383 self.was_used = True 384 return [(0, 1)] * (len(inputs) - 1) 385 386 eq, shapes = oe.helpers.rand_equation(5, 3, seed=42, d_max=3) 387 views = list(map(np.ones, shapes)) 388 389 exp = oe.contract(eq, *views, optimize=False) 390 391 optimizer = NaiveOptimizer() 392 out = oe.contract(eq, *views, optimize=optimizer) 393 assert exp == out 394 assert optimizer.was_used 395 396 397def test_custom_random_optimizer(): 398 class NaiveRandomOptimizer(oe.path_random.RandomOptimizer): 399 @staticmethod 400 def random_path(r, n, inputs, output, size_dict): 401 """Picks a completely random contraction order. 402 """ 403 np.random.seed(r) 404 ssa_path = [] 405 remaining = set(range(n)) 406 while len(remaining) > 1: 407 i, j = np.random.choice(list(remaining), size=2, replace=False) 408 remaining.add(n + len(ssa_path)) 409 remaining.remove(i) 410 remaining.remove(j) 411 ssa_path.append((i, j)) 412 cost, size = oe.path_random.ssa_path_compute_cost(ssa_path, inputs, output, size_dict) 413 return ssa_path, cost, size 414 415 def setup(self, inputs, output, size_dict): 416 self.was_used = True 417 n = len(inputs) 418 trial_fn = self.random_path 419 trial_args = (n, inputs, output, size_dict) 420 return trial_fn, trial_args 421 422 eq, shapes = oe.helpers.rand_equation(5, 3, seed=42, d_max=3) 423 views = list(map(np.ones, shapes)) 424 425 exp = oe.contract(eq, *views, optimize=False) 426 427 optimizer = NaiveRandomOptimizer(max_repeats=16) 428 out = oe.contract(eq, *views, optimize=optimizer) 429 assert exp == out 430 assert optimizer.was_used 431 432 assert len(optimizer.costs) == 16 433 434 435def test_optimizer_registration(): 436 def custom_optimizer(inputs, output, size_dict, memory_limit): 437 return [(0, 1)] * (len(inputs) - 1) 438 439 with pytest.raises(KeyError): 440 oe.paths.register_path_fn('optimal', custom_optimizer) 441 442 oe.paths.register_path_fn('custom', custom_optimizer) 443 assert 'custom' in oe.paths._PATH_OPTIONS 444 445 eq = 'ab,bc,cd' 446 shapes = [(2, 3), (3, 4), (4, 5)] 447 path, path_info = oe.contract_path(eq, *shapes, shapes=True, optimize='custom') 448 assert path == [(0, 1), (0, 1)] 449 del oe.paths._PATH_OPTIONS['custom'] 450