1import numpy as np 2import pytest 3 4from opt_einsum import (backends, contract, contract_expression, helpers, sharing) 5from opt_einsum.contract import Shaped, infer_backend, parse_backend 6 7try: 8 import cupy 9 found_cupy = True 10except ImportError: 11 found_cupy = False 12 13try: 14 import tensorflow as tf 15 # needed so tensorflow doesn't allocate all gpu mem 16 _TF_CONFIG = tf.ConfigProto() 17 _TF_CONFIG.gpu_options.allow_growth = True 18 found_tensorflow = True 19except ImportError: 20 found_tensorflow = False 21 22try: 23 import os 24 os.environ['MKL_THREADING_LAYER'] = 'GNU' 25 import theano 26 found_theano = True 27except ImportError: 28 found_theano = False 29 30try: 31 import torch 32 found_torch = True 33except ImportError: 34 found_torch = False 35 36try: 37 import jax 38 found_jax = True 39except ImportError: 40 found_jax = False 41 42try: 43 import autograd 44 found_autograd = True 45except ImportError: 46 found_autograd = False 47 48tests = [ 49 'ab,bc->ca', 50 'abc,bcd,dea', 51 'abc,def->fedcba', 52 'abc,bcd,df->fa', 53 # test 'prefer einsum' ops 54 'ijk,ikj', 55 'i,j->ij', 56 'ijk,k->ij', 57 'AB,BC->CA', 58] 59 60 61@pytest.mark.skipif(not found_tensorflow, reason="Tensorflow not installed.") 62@pytest.mark.parametrize("string", tests) 63def test_tensorflow(string): 64 views = helpers.build_views(string) 65 ein = contract(string, *views, optimize=False, use_blas=False) 66 opt = np.empty_like(ein) 67 68 shps = [v.shape for v in views] 69 expr = contract_expression(string, *shps, optimize=True) 70 71 sess = tf.Session(config=_TF_CONFIG) 72 with sess.as_default(): 73 expr(*views, backend='tensorflow', out=opt) 74 sess.close() 75 76 assert np.allclose(ein, opt) 77 78 # test non-conversion mode 79 tensorflow_views = [backends.to_tensorflow(view) for view in views] 80 expr(*tensorflow_views) 81 82 83@pytest.mark.skipif(not found_tensorflow, reason="Tensorflow not installed.") 84@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}]) 85def test_tensorflow_with_constants(constants): 86 eq = 'ij,jk,kl->li' 87 shapes = (2, 3), (3, 4), (4, 5) 88 non_const, = {0, 1, 2} - constants 89 ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)] 90 var = np.random.rand(*shapes[non_const]) 91 res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3))) 92 93 expr = contract_expression(eq, *ops, constants=constants) 94 95 # check tensorflow 96 with tf.Session(config=_TF_CONFIG).as_default(): 97 res_got = expr(var, backend='tensorflow') 98 assert all(array is None or infer_backend(array) == 'tensorflow' 99 for array in expr._evaluated_constants['tensorflow']) 100 assert np.allclose(res_exp, res_got) 101 102 # check can call with numpy still 103 res_got2 = expr(var, backend='numpy') 104 assert np.allclose(res_exp, res_got2) 105 106 # check tensorflow call returns tensorflow still 107 res_got3 = expr(backends.to_tensorflow(var)) 108 assert isinstance(res_got3, tf.Tensor) 109 110 111@pytest.mark.skipif(not found_tensorflow, reason="Tensorflow not installed.") 112@pytest.mark.parametrize("string", tests) 113def test_tensorflow_with_sharing(string): 114 views = helpers.build_views(string) 115 ein = contract(string, *views, optimize=False, use_blas=False) 116 117 shps = [v.shape for v in views] 118 expr = contract_expression(string, *shps, optimize=True) 119 120 sess = tf.Session(config=_TF_CONFIG) 121 122 with sess.as_default(), sharing.shared_intermediates() as cache: 123 tfl1 = expr(*views, backend='tensorflow') 124 assert sharing.get_sharing_cache() is cache 125 cache_sz = len(cache) 126 assert cache_sz > 0 127 tfl2 = expr(*views, backend='tensorflow') 128 assert len(cache) == cache_sz 129 130 assert all(isinstance(t, tf.Tensor) for t in cache.values()) 131 132 assert np.allclose(ein, tfl1) 133 assert np.allclose(ein, tfl2) 134 135 136@pytest.mark.skipif(not found_theano, reason="Theano not installed.") 137@pytest.mark.parametrize("string", tests) 138def test_theano(string): 139 views = helpers.build_views(string) 140 ein = contract(string, *views, optimize=False, use_blas=False) 141 shps = [v.shape for v in views] 142 143 expr = contract_expression(string, *shps, optimize=True) 144 145 opt = expr(*views, backend='theano') 146 assert np.allclose(ein, opt) 147 148 # test non-conversion mode 149 theano_views = [backends.to_theano(view) for view in views] 150 theano_opt = expr(*theano_views) 151 assert isinstance(theano_opt, theano.tensor.TensorVariable) 152 153 154@pytest.mark.skipif(not found_theano, reason="theano not installed.") 155@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}]) 156def test_theano_with_constants(constants): 157 eq = 'ij,jk,kl->li' 158 shapes = (2, 3), (3, 4), (4, 5) 159 non_const, = {0, 1, 2} - constants 160 ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)] 161 var = np.random.rand(*shapes[non_const]) 162 res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3))) 163 164 expr = contract_expression(eq, *ops, constants=constants) 165 166 # check theano 167 res_got = expr(var, backend='theano') 168 assert all(array is None or infer_backend(array) == 'theano' for array in expr._evaluated_constants['theano']) 169 assert np.allclose(res_exp, res_got) 170 171 # check can call with numpy still 172 res_got2 = expr(var, backend='numpy') 173 assert np.allclose(res_exp, res_got2) 174 175 # check theano call returns theano still 176 res_got3 = expr(backends.to_theano(var)) 177 assert isinstance(res_got3, theano.tensor.TensorVariable) 178 179 180@pytest.mark.skipif(not found_theano, reason="Theano not installed.") 181@pytest.mark.parametrize("string", tests) 182def test_theano_with_sharing(string): 183 views = helpers.build_views(string) 184 ein = contract(string, *views, optimize=False, use_blas=False) 185 186 shps = [v.shape for v in views] 187 expr = contract_expression(string, *shps, optimize=True) 188 189 with sharing.shared_intermediates() as cache: 190 thn1 = expr(*views, backend='theano') 191 assert sharing.get_sharing_cache() is cache 192 cache_sz = len(cache) 193 assert cache_sz > 0 194 thn2 = expr(*views, backend='theano') 195 assert len(cache) == cache_sz 196 197 assert all(isinstance(t, theano.tensor.TensorVariable) for t in cache.values()) 198 199 assert np.allclose(ein, thn1) 200 assert np.allclose(ein, thn2) 201 202 203@pytest.mark.skipif(not found_cupy, reason="Cupy not installed.") 204@pytest.mark.parametrize("string", tests) 205def test_cupy(string): # pragma: no cover 206 views = helpers.build_views(string) 207 ein = contract(string, *views, optimize=False, use_blas=False) 208 shps = [v.shape for v in views] 209 210 expr = contract_expression(string, *shps, optimize=True) 211 212 opt = expr(*views, backend='cupy') 213 assert np.allclose(ein, opt) 214 215 # test non-conversion mode 216 cupy_views = [backends.to_cupy(view) for view in views] 217 cupy_opt = expr(*cupy_views) 218 assert isinstance(cupy_opt, cupy.ndarray) 219 assert np.allclose(ein, cupy.asnumpy(cupy_opt)) 220 221 222@pytest.mark.skipif(not found_cupy, reason="Cupy not installed.") 223@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}]) 224def test_cupy_with_constants(constants): # pragma: no cover 225 eq = 'ij,jk,kl->li' 226 shapes = (2, 3), (3, 4), (4, 5) 227 non_const, = {0, 1, 2} - constants 228 ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)] 229 var = np.random.rand(*shapes[non_const]) 230 res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3))) 231 232 expr = contract_expression(eq, *ops, constants=constants) 233 234 # check cupy 235 res_got = expr(var, backend='cupy') 236 # check cupy versions of constants exist 237 assert all(array is None or infer_backend(array) == 'cupy' for array in expr._evaluated_constants['cupy']) 238 assert np.allclose(res_exp, res_got) 239 240 # check can call with numpy still 241 res_got2 = expr(var, backend='numpy') 242 assert np.allclose(res_exp, res_got2) 243 244 # check cupy call returns cupy still 245 res_got3 = expr(cupy.asarray(var)) 246 assert isinstance(res_got3, cupy.ndarray) 247 assert np.allclose(res_exp, res_got3.get()) 248 249 250@pytest.mark.skipif(not found_jax, reason="jax not installed.") 251@pytest.mark.parametrize("string", tests) 252def test_jax(string): # pragma: no cover 253 views = helpers.build_views(string) 254 ein = contract(string, *views, optimize=False, use_blas=False) 255 shps = [v.shape for v in views] 256 257 expr = contract_expression(string, *shps, optimize=True) 258 259 opt = expr(*views, backend='jax') 260 assert np.allclose(ein, opt) 261 assert isinstance(opt, np.ndarray) 262 263 264@pytest.mark.skipif(not found_jax, reason="jax not installed.") 265@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}]) 266def test_jax_with_constants(constants): # pragma: no cover 267 eq = 'ij,jk,kl->li' 268 shapes = (2, 3), (3, 4), (4, 5) 269 non_const, = {0, 1, 2} - constants 270 ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)] 271 var = np.random.rand(*shapes[non_const]) 272 res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3))) 273 274 expr = contract_expression(eq, *ops, constants=constants) 275 276 # check jax 277 res_got = expr(var, backend='jax') 278 # check jax versions of constants exist 279 assert all(array is None or infer_backend(array) == 'jax' for array in expr._evaluated_constants['jax']) 280 281 assert np.allclose(res_exp, res_got) 282 283 284@pytest.mark.skipif(not found_jax, reason="jax not installed.") 285def test_jax_jit_gradient(): 286 eq = 'ij,jk,kl->' 287 shapes = (2, 3), (3, 4), (4, 2) 288 views = [np.random.randn(*s) for s in shapes] 289 expr = contract_expression(eq, *shapes) 290 x0 = expr(*views) 291 292 jit_expr = jax.jit(expr) 293 x1 = jit_expr(*views).item() 294 assert x1 == pytest.approx(x0, rel=1e-5) 295 296 # jax only takes gradient w.r.t first argument 297 grad_expr = jax.jit(jax.grad(lambda views: expr(*views))) 298 view_grads = grad_expr(views) 299 assert all(v1.shape == v2.shape for v1, v2 in zip(views, view_grads)) 300 301 # taking a step along the gradient should reduce our 'loss' 302 new_views = [v - 0.001 * dv for v, dv in zip(views, view_grads)] 303 x2 = jit_expr(*new_views).item() 304 assert x2 < x1 305 306 307@pytest.mark.skipif(not found_autograd, reason="autograd not installed.") 308def test_autograd_gradient(): 309 eq = 'ij,jk,kl->' 310 shapes = (2, 3), (3, 4), (4, 2) 311 views = [np.random.randn(*s) for s in shapes] 312 expr = contract_expression(eq, *shapes) 313 x0 = expr(*views) 314 315 # autograd only takes gradient w.r.t first argument 316 grad_expr = autograd.grad(lambda views: expr(*views)) 317 view_grads = grad_expr(views) 318 assert all(v1.shape == v2.shape for v1, v2 in zip(views, view_grads)) 319 320 # taking a step along the gradient should reduce our 'loss' 321 new_views = [v - 0.001 * dv for v, dv in zip(views, view_grads)] 322 x1 = expr(*new_views) 323 assert x1 < x0 324 325 326@pytest.mark.parametrize("string", tests) 327def test_dask(string): 328 da = pytest.importorskip("dask.array") 329 330 views = helpers.build_views(string) 331 ein = contract(string, *views, optimize=False, use_blas=False) 332 shps = [v.shape for v in views] 333 expr = contract_expression(string, *shps, optimize=True) 334 335 # test non-conversion mode 336 da_views = [da.from_array(x, chunks=(2)) for x in views] 337 da_opt = expr(*da_views) 338 339 # check type is maintained when not using numpy arrays 340 assert isinstance(da_opt, da.Array) 341 342 assert np.allclose(ein, np.array(da_opt)) 343 344 # try raw contract 345 da_opt = contract(string, *da_views) 346 assert isinstance(da_opt, da.Array) 347 assert np.allclose(ein, np.array(da_opt)) 348 349 350@pytest.mark.parametrize("string", tests) 351def test_sparse(string): 352 sparse = pytest.importorskip("sparse") 353 354 views = helpers.build_views(string) 355 356 # sparsify views so they don't become dense during contraction 357 for view in views: 358 np.random.seed(42) 359 mask = np.random.choice([False, True], view.shape, True, [0.05, 0.95]) 360 view[mask] = 0 361 362 ein = contract(string, *views, optimize=False, use_blas=False) 363 shps = [v.shape for v in views] 364 expr = contract_expression(string, *shps, optimize=True) 365 366 # test non-conversion mode 367 sparse_views = [sparse.COO.from_numpy(x) for x in views] 368 sparse_opt = expr(*sparse_views) 369 370 # check type is maintained when not using numpy arrays 371 assert isinstance(sparse_opt, sparse.COO) 372 373 assert np.allclose(ein, sparse_opt.todense()) 374 375 # try raw contract 376 sparse_opt = contract(string, *sparse_views) 377 assert isinstance(sparse_opt, sparse.COO) 378 assert np.allclose(ein, sparse_opt.todense()) 379 380 381@pytest.mark.skipif(not found_torch, reason="Torch not installed.") 382@pytest.mark.parametrize("string", tests) 383def test_torch(string): 384 385 views = helpers.build_views(string) 386 ein = contract(string, *views, optimize=False, use_blas=False) 387 shps = [v.shape for v in views] 388 389 expr = contract_expression(string, *shps, optimize=True) 390 391 opt = expr(*views, backend='torch') 392 assert np.allclose(ein, opt) 393 394 # test non-conversion mode 395 torch_views = [backends.to_torch(view) for view in views] 396 torch_opt = expr(*torch_views) 397 assert isinstance(torch_opt, torch.Tensor) 398 assert np.allclose(ein, torch_opt.cpu().numpy()) 399 400 401@pytest.mark.skipif(not found_torch, reason="Torch not installed.") 402@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}]) 403def test_torch_with_constants(constants): 404 eq = 'ij,jk,kl->li' 405 shapes = (2, 3), (3, 4), (4, 5) 406 non_const, = {0, 1, 2} - constants 407 ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)] 408 var = np.random.rand(*shapes[non_const]) 409 res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3))) 410 411 expr = contract_expression(eq, *ops, constants=constants) 412 413 # check torch 414 res_got = expr(var, backend='torch') 415 assert all(array is None or infer_backend(array) == 'torch' for array in expr._evaluated_constants['torch']) 416 assert np.allclose(res_exp, res_got) 417 418 # check can call with numpy still 419 res_got2 = expr(var, backend='numpy') 420 assert np.allclose(res_exp, res_got2) 421 422 # check torch call returns torch still 423 res_got3 = expr(backends.to_torch(var)) 424 assert isinstance(res_got3, torch.Tensor) 425 res_got3 = res_got3.numpy() if res_got3.device.type == 'cpu' else res_got3.cpu().numpy() 426 assert np.allclose(res_exp, res_got3) 427 428 429def test_auto_backend_custom_array_no_tensordot(): 430 x = Shaped((1, 2, 3)) 431 # Shaped is an array-like object defined by opt_einsum - which has no TDOT 432 assert infer_backend(x) == 'opt_einsum' 433 assert parse_backend([x], 'auto') == 'numpy' 434 435 436@pytest.mark.parametrize("string", tests) 437def test_object_arrays_backend(string): 438 views = helpers.build_views(string) 439 ein = contract(string, *views, optimize=False, use_blas=False) 440 assert ein.dtype != object 441 442 shps = [v.shape for v in views] 443 expr = contract_expression(string, *shps, optimize=True) 444 445 obj_views = [view.astype(object) for view in views] 446 447 # try raw contract 448 obj_opt = contract(string, *obj_views, backend='object') 449 assert obj_opt.dtype == object 450 assert np.allclose(ein, obj_opt.astype(float)) 451 452 # test expression 453 obj_opt = expr(*obj_views, backend='object') 454 assert obj_opt.dtype == object 455 assert np.allclose(ein, obj_opt.astype(float)) 456