1""" 2Tests the input parsing for opt_einsum. Duplicates the np.einsum input tests. 3""" 4 5import numpy as np 6import pytest 7 8from opt_einsum import contract, contract_path 9 10 11def build_views(string): 12 chars = 'abcdefghij' 13 sizes = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4]) 14 sizes = {c: s for c, s in zip(chars, sizes)} 15 16 views = [] 17 18 string = string.replace('...', 'ij') 19 20 terms = string.split('->')[0].split(',') 21 for term in terms: 22 dims = [sizes[x] for x in term] 23 views.append(np.random.rand(*dims)) 24 return views 25 26 27def test_type_errors(): 28 # subscripts must be a string 29 with pytest.raises(TypeError): 30 contract(0, 0) 31 32 # out parameter must be an array 33 with pytest.raises(TypeError): 34 contract("", 0, out='test') 35 36 # order parameter must be a valid order 37 # changed in Numpy 1.19, see https://github.com/numpy/numpy/commit/35b0a051c19265f5643f6011ee11e31d30c8bc4c 38 with pytest.raises((TypeError, ValueError)): 39 contract("", 0, order='W') 40 41 # casting parameter must be a valid casting 42 with pytest.raises(ValueError): 43 contract("", 0, casting='blah') 44 45 # dtype parameter must be a valid dtype 46 with pytest.raises(TypeError): 47 contract("", 0, dtype='bad_data_type') 48 49 # other keyword arguments are rejected 50 with pytest.raises(TypeError): 51 contract("", 0, bad_arg=0) 52 53 # issue 4528 revealed a segfault with this call 54 with pytest.raises(TypeError): 55 contract(*(None, ) * 63) 56 57 # Cannot have two -> 58 with pytest.raises(ValueError): 59 contract("->,->", 0, 5) 60 61 # Undefined symbol lhs 62 with pytest.raises(ValueError): 63 contract("&,a->", 0, 5) 64 65 # Undefined symbol rhs 66 with pytest.raises(ValueError): 67 contract("a,a->&", 0, 5) 68 69 with pytest.raises(ValueError): 70 contract("a,a->&", 0, 5) 71 72 # Catch ellipsis errors 73 string = '...a->...a' 74 views = build_views(string) 75 76 # Subscript list must contain Ellipsis or (hashable && comparable) object 77 with pytest.raises(TypeError): 78 contract(views[0], [Ellipsis, 0], [Ellipsis, ['a']]) 79 80 with pytest.raises(TypeError): 81 contract(views[0], [Ellipsis, dict()], [Ellipsis, 'a']) 82 83 84def test_value_errors(): 85 with pytest.raises(ValueError): 86 contract("") 87 88 # subscripts must be a string 89 with pytest.raises(TypeError): 90 contract(0, 0) 91 92 # invalid subscript character 93 with pytest.raises(ValueError): 94 contract("i%...", [0, 0]) 95 with pytest.raises(ValueError): 96 contract("...j$", [0, 0]) 97 with pytest.raises(ValueError): 98 contract("i->&", [0, 0]) 99 100 with pytest.raises(ValueError): 101 contract("") 102 # number of operands must match count in subscripts string 103 with pytest.raises(ValueError): 104 contract("", 0, 0) 105 with pytest.raises(ValueError): 106 contract(",", 0, [0], [0]) 107 with pytest.raises(ValueError): 108 contract(",", [0]) 109 110 # can't have more subscripts than dimensions in the operand 111 with pytest.raises(ValueError): 112 contract("i", 0) 113 with pytest.raises(ValueError): 114 contract("ij", [0, 0]) 115 with pytest.raises(ValueError): 116 contract("...i", 0) 117 with pytest.raises(ValueError): 118 contract("i...j", [0, 0]) 119 with pytest.raises(ValueError): 120 contract("i...", 0) 121 with pytest.raises(ValueError): 122 contract("ij...", [0, 0]) 123 124 # invalid ellipsis 125 with pytest.raises(ValueError): 126 contract("i..", [0, 0]) 127 with pytest.raises(ValueError): 128 contract(".i...", [0, 0]) 129 with pytest.raises(ValueError): 130 contract("j->..j", [0, 0]) 131 with pytest.raises(ValueError): 132 contract("j->.j...", [0, 0]) 133 134 # invalid subscript character 135 with pytest.raises(ValueError): 136 contract("i%...", [0, 0]) 137 with pytest.raises(ValueError): 138 contract("...j$", [0, 0]) 139 with pytest.raises(ValueError): 140 contract("i->&", [0, 0]) 141 142 # output subscripts must appear in input 143 with pytest.raises(ValueError): 144 contract("i->ij", [0, 0]) 145 146 # output subscripts may only be specified once 147 with pytest.raises(ValueError): 148 contract("ij->jij", [[0, 0], [0, 0]]) 149 150 # dimensions much match when being collapsed 151 with pytest.raises(ValueError): 152 contract("ii", np.arange(6).reshape(2, 3)) 153 with pytest.raises(ValueError): 154 contract("ii->i", np.arange(6).reshape(2, 3)) 155 156 # broadcasting to new dimensions must be enabled explicitly 157 with pytest.raises(ValueError): 158 contract("i", np.arange(6).reshape(2, 3)) 159 with pytest.raises(ValueError): 160 contract("i->i", [[0, 1], [0, 1]], out=np.arange(4).reshape(2, 2)) 161 162 163def test_contract_inputs(): 164 165 with pytest.raises(TypeError): 166 contract_path("i->i", [[0, 1], [0, 1]], bad_kwarg=True) 167 168 with pytest.raises(ValueError): 169 contract_path("i->i", [[0, 1], [0, 1]], memory_limit=-1) 170 171 172@pytest.mark.parametrize( 173 "string", 174 [ 175 # Ellipse 176 '...a->...', 177 'a...->...', 178 'a...a->...a', 179 '...,...', 180 'a,b', 181 '...a,...b', 182 ]) 183def test_compare(string): 184 views = build_views(string) 185 186 ein = contract(string, *views, optimize=False) 187 opt = contract(string, *views) 188 assert np.allclose(ein, opt) 189 190 opt = contract(string, *views, optimize='optimal') 191 assert np.allclose(ein, opt) 192 193 194def test_ellipse_input1(): 195 string = '...a->...' 196 views = build_views(string) 197 198 ein = contract(string, *views, optimize=False) 199 opt = contract(views[0], [Ellipsis, 0], [Ellipsis]) 200 assert np.allclose(ein, opt) 201 202 203def test_ellipse_input2(): 204 string = '...a' 205 views = build_views(string) 206 207 ein = contract(string, *views, optimize=False) 208 opt = contract(views[0], [Ellipsis, 0]) 209 assert np.allclose(ein, opt) 210 211 212def test_ellipse_input3(): 213 string = '...a->...a' 214 views = build_views(string) 215 216 ein = contract(string, *views, optimize=False) 217 opt = contract(views[0], [Ellipsis, 0], [Ellipsis, 0]) 218 assert np.allclose(ein, opt) 219 220 221def test_ellipse_input4(): 222 string = '...b,...a->...' 223 views = build_views(string) 224 225 ein = contract(string, *views, optimize=False) 226 opt = contract(views[0], [Ellipsis, 1], views[1], [Ellipsis, 0], [Ellipsis]) 227 assert np.allclose(ein, opt) 228 229 230def test_singleton_dimension_broadcast(): 231 # singleton dimensions broadcast (gh-10343) 232 p = np.ones((10, 2)) 233 q = np.ones((1, 2)) 234 235 ein = contract('ij,ij->j', p, q, optimize=False) 236 opt = contract('ij,ij->j', p, q, optimize=True) 237 assert np.allclose(ein, opt) 238 assert np.allclose(opt, [10., 10.]) 239 240 p = np.ones((1, 5)) 241 q = np.ones((5, 5)) 242 243 for optimize in (True, False): 244 res1 = contract("...ij,...jk->...ik", p, p, optimize=optimize), 245 res2 = contract("...ij,...jk->...ik", p, q, optimize=optimize) 246 assert np.allclose(res1, res2) 247 assert np.allclose(res2, np.full((1, 5), 5)) 248 249 250def test_large_int_input_format(): 251 string = 'ab,bc,cd' 252 x, y, z = build_views(string) 253 string_output = contract(string, x, y, z) 254 int_output = contract(x, (1000, 1001), y, (1001, 1002), z, (1002, 1003)) 255 assert np.allclose(string_output, int_output) 256 for i in range(10): 257 transpose_output = contract(x, (i + 1, i)) 258 assert np.allclose(transpose_output, x.T) 259 260 261def test_hashable_object_input_format(): 262 string = 'ab,bc,cd' 263 x, y, z = build_views(string) 264 string_output = contract(string, x, y, z) 265 hash_output1 = contract(x, ('left', 'bond1'), y, ('bond1', 'bond2'), z, ('bond2', 'right')) 266 hash_output2 = contract(x, ('left', 'bond1'), y, ('bond1', 'bond2'), z, ('bond2', 'right'), ('left', 'right')) 267 assert np.allclose(string_output, hash_output1) 268 assert np.allclose(hash_output1, hash_output2) 269 for i in range(1, 10): 270 transpose_output = contract(x, ('b' * i, 'a' * i)) 271 assert np.allclose(transpose_output, x.T) 272