1import pickle 2import types 3from collections import namedtuple 4from functools import partial 5from operator import add, setitem 6from random import random 7 8import cloudpickle 9import pytest 10from tlz import merge 11 12import dask 13import dask.bag as db 14from dask import compute 15from dask.delayed import Delayed, delayed, to_task_dask 16from dask.highlevelgraph import HighLevelGraph 17from dask.utils_test import inc 18 19try: 20 from operator import matmul 21except ImportError: 22 matmul = None 23 24 25class Tuple: 26 __dask_scheduler__ = staticmethod(dask.threaded.get) 27 28 def __init__(self, dsk, keys): 29 self._dask = dsk 30 self._keys = keys 31 32 def __dask_tokenize__(self): 33 return self._keys 34 35 def __dask_graph__(self): 36 return self._dask 37 38 def __dask_keys__(self): 39 return self._keys 40 41 def __dask_postcompute__(self): 42 return tuple, () 43 44 45@pytest.mark.filterwarnings("ignore:The dask.delayed:UserWarning") 46def test_to_task_dask(): 47 a = delayed(1, name="a") 48 b = delayed(2, name="b") 49 task, dask = to_task_dask([a, b, 3]) 50 assert task == ["a", "b", 3] 51 52 task, dask = to_task_dask((a, b, 3)) 53 assert task == (tuple, ["a", "b", 3]) 54 assert dict(dask) == merge(a.dask, b.dask) 55 56 task, dask = to_task_dask({a: 1, b: 2}) 57 assert task == (dict, [["b", 2], ["a", 1]]) or task == (dict, [["a", 1], ["b", 2]]) 58 assert dict(dask) == merge(a.dask, b.dask) 59 60 f = namedtuple("f", ["x", "y"]) 61 x = f(1, 2) 62 task, dask = to_task_dask(x) 63 assert task == x 64 assert dict(dask) == {} 65 66 task, dask = to_task_dask(slice(a, b, 3)) 67 assert task == (slice, "a", "b", 3) 68 assert dict(dask) == merge(a.dask, b.dask) 69 70 # Issue https://github.com/dask/dask/issues/2107 71 class MyClass(dict): 72 pass 73 74 task, dask = to_task_dask(MyClass()) 75 assert type(task) is MyClass 76 assert dict(dask) == {} 77 78 # Custom dask objects 79 x = Tuple({"a": 1, "b": 2, "c": (add, "a", "b")}, ["a", "b", "c"]) 80 task, dask = to_task_dask(x) 81 assert task in dask 82 f = dask.pop(task) 83 assert f == (tuple, ["a", "b", "c"]) 84 assert dask == x._dask 85 86 87def test_delayed(): 88 add2 = delayed(add) 89 assert add2(1, 2).compute() == 3 90 assert (add2(1, 2) + 3).compute() == 6 91 assert add2(add2(1, 2), 3).compute() == 6 92 93 a = delayed(1) 94 assert a.compute() == 1 95 assert 1 in a.dask.values() 96 b = add2(add2(a, 2), 3) 97 assert a.key in b.dask 98 99 100def test_delayed_with_dataclass(): 101 dataclasses = pytest.importorskip("dataclasses") 102 103 # Avoid @dataclass decorator as Python < 3.7 fail to interpret the type hints 104 ADataClass = dataclasses.make_dataclass( 105 "ADataClass", [("a", int), ("b", int, dataclasses.field(init=False))] 106 ) 107 108 literal = dask.delayed(3) 109 with_class = dask.delayed({"a": ADataClass(a=literal)}) 110 111 def return_nested(obj): 112 return obj["a"].a 113 114 final = delayed(return_nested)(with_class) 115 116 assert final.compute() == 3 117 118 119def test_operators(): 120 a = delayed([1, 2, 3]) 121 assert a[0].compute() == 1 122 assert (a + a).compute() == [1, 2, 3, 1, 2, 3] 123 b = delayed(2) 124 assert a[:b].compute() == [1, 2] 125 126 a = delayed(10) 127 assert (a + 1).compute() == 11 128 assert (1 + a).compute() == 11 129 assert (a >> 1).compute() == 5 130 assert (a > 2).compute() 131 assert (a ** 2).compute() == 100 132 133 if matmul: 134 135 class dummy: 136 def __matmul__(self, other): 137 return 4 138 139 c = delayed(dummy()) # noqa 140 d = delayed(dummy()) # noqa 141 142 assert (eval("c @ d")).compute() == 4 143 144 145def test_methods(): 146 a = delayed("a b c d e") 147 assert a.split(" ").compute() == ["a", "b", "c", "d", "e"] 148 assert a.upper().replace("B", "A").split().count("A").compute() == 2 149 assert a.split(" ", pure=True).key == a.split(" ", pure=True).key 150 o = a.split(" ", dask_key_name="test") 151 assert o.key == "test" 152 153 154def test_attributes(): 155 a = delayed(2 + 1j) 156 assert a.real._key == a.real._key 157 assert a.real.compute() == 2 158 assert a.imag.compute() == 1 159 assert (a.real + a.imag).compute() == 3 160 161 162def test_method_getattr_call_same_task(): 163 a = delayed([1, 2, 3]) 164 o = a.index(1) 165 # Don't getattr the method, then call in separate task 166 assert getattr not in {v[0] for v in o.__dask_graph__().values()} 167 168 169def test_np_dtype_of_delayed(): 170 # This used to result in a segfault due to recursion, see 171 # https://github.com/dask/dask/pull/4374#issuecomment-454381465 172 np = pytest.importorskip("numpy") 173 x = delayed(1) 174 with pytest.raises(TypeError): 175 np.dtype(x) 176 assert delayed(np.array([1], dtype="f8")).dtype.compute() == np.dtype("f8") 177 178 179def test_delayed_visualise_warn(): 180 # Raise a warning when user calls visualise() 181 # instead of visualize() 182 def inc(x): 183 return x + 1 184 185 z = dask.delayed(inc)(1) 186 z.compute() 187 188 with pytest.warns( 189 UserWarning, match="dask.delayed objects have no `visualise` method" 190 ): 191 z.visualise(file_name="desk_graph.svg") 192 193 # with no args 194 with pytest.warns( 195 UserWarning, match="dask.delayed objects have no `visualise` method" 196 ): 197 z.visualise() 198 199 200def test_delayed_errors(): 201 a = delayed([1, 2, 3]) 202 # Immutable 203 pytest.raises(TypeError, lambda: setattr(a, "foo", 1)) 204 pytest.raises(TypeError, lambda: setitem(a, 1, 0)) 205 # Can't iterate, or check if contains 206 pytest.raises(TypeError, lambda: 1 in a) 207 pytest.raises(TypeError, lambda: list(a)) 208 # No dynamic generation of magic/hidden methods 209 pytest.raises(AttributeError, lambda: a._hidden()) 210 # Truth of delayed forbidden 211 pytest.raises(TypeError, lambda: bool(a)) 212 213 214def test_common_subexpressions(): 215 a = delayed([1, 2, 3]) 216 res = a[0] + a[0] 217 assert a[0].key in res.dask 218 assert a.key in res.dask 219 assert len(res.dask) == 3 220 221 222def test_delayed_optimize(): 223 x = Delayed("b", {"a": 1, "b": (inc, "a"), "c": (inc, "b")}) 224 (x2,) = dask.optimize(x) 225 # Delayed's __dask_optimize__ culls out 'c' 226 assert sorted(x2.dask.keys()) == ["a", "b"] 227 228 229def test_lists(): 230 a = delayed(1) 231 b = delayed(2) 232 c = delayed(sum)([a, b]) 233 assert c.compute() == 3 234 235 236def test_literates(): 237 a = delayed(1) 238 b = a + 1 239 lit = (a, b, 3) 240 assert delayed(lit).compute() == (1, 2, 3) 241 lit = [a, b, 3] 242 assert delayed(lit).compute() == [1, 2, 3] 243 lit = {a, b, 3} 244 assert delayed(lit).compute() == {1, 2, 3} 245 lit = {a: "a", b: "b", 3: "c"} 246 assert delayed(lit).compute() == {1: "a", 2: "b", 3: "c"} 247 assert delayed(lit)[a].compute() == "a" 248 lit = {"a": a, "b": b, "c": 3} 249 assert delayed(lit).compute() == {"a": 1, "b": 2, "c": 3} 250 assert delayed(lit)["a"].compute() == 1 251 252 253def test_literates_keys(): 254 a = delayed(1) 255 b = a + 1 256 lit = (a, b, 3) 257 assert delayed(lit).key != delayed(lit).key 258 assert delayed(lit, pure=True).key == delayed(lit, pure=True).key 259 260 261def test_lists_are_concrete(): 262 a = delayed(1) 263 b = delayed(2) 264 c = delayed(max)([[a, 10], [b, 20]], key=lambda x: x[0])[1] 265 266 assert c.compute() == 20 267 268 269def test_iterators(): 270 a = delayed(1) 271 b = delayed(2) 272 c = delayed(sum)(iter([a, b])) 273 274 assert c.compute() == 3 275 276 def f(seq): 277 return sum(seq) 278 279 c = delayed(f)(iter([a, b])) 280 assert c.compute() == 3 281 282 283def test_traverse_false(): 284 # Create a list with a dask value, and test that it's not computed 285 def fail(*args): 286 raise ValueError("shouldn't have computed") 287 288 a = delayed(fail)() 289 290 # list 291 x = [a, 1, 2, 3] 292 res = delayed(x, traverse=False).compute() 293 assert len(res) == 4 294 assert res[0] is a 295 assert res[1:] == x[1:] 296 297 # tuple that looks like a task 298 x = (fail, a, (fail, a)) 299 res = delayed(x, traverse=False).compute() 300 assert isinstance(res, tuple) 301 assert res[0] == fail 302 assert res[1] is a 303 304 # list containing task-like-things 305 x = [1, (fail, a), a] 306 res = delayed(x, traverse=False).compute() 307 assert isinstance(res, list) 308 assert res[0] == 1 309 assert res[1][0] == fail and res[1][1] is a 310 assert res[2] is a 311 312 # traverse=False still hits top level 313 b = delayed(1) 314 x = delayed(b, traverse=False) 315 assert x.compute() == 1 316 317 318def test_pure(): 319 v1 = delayed(add, pure=True)(1, 2) 320 v2 = delayed(add, pure=True)(1, 2) 321 assert v1.key == v2.key 322 323 myrand = delayed(random) 324 assert myrand().key != myrand().key 325 326 327def test_pure_global_setting(): 328 # delayed functions 329 func = delayed(add) 330 331 with dask.config.set(delayed_pure=True): 332 assert func(1, 2).key == func(1, 2).key 333 334 with dask.config.set(delayed_pure=False): 335 assert func(1, 2).key != func(1, 2).key 336 337 func = delayed(add, pure=True) 338 with dask.config.set(delayed_pure=False): 339 assert func(1, 2).key == func(1, 2).key 340 341 # delayed objects 342 assert delayed(1).key != delayed(1).key 343 with dask.config.set(delayed_pure=True): 344 assert delayed(1).key == delayed(1).key 345 346 with dask.config.set(delayed_pure=False): 347 assert delayed(1, pure=True).key == delayed(1, pure=True).key 348 349 # delayed methods 350 data = delayed([1, 2, 3]) 351 assert data.index(1).key != data.index(1).key 352 353 with dask.config.set(delayed_pure=True): 354 assert data.index(1).key == data.index(1).key 355 assert data.index(1, pure=False).key != data.index(1, pure=False).key 356 357 with dask.config.set(delayed_pure=False): 358 assert data.index(1, pure=True).key == data.index(1, pure=True).key 359 360 # magic methods always pure 361 with dask.config.set(delayed_pure=False): 362 assert data.index.key == data.index.key 363 element = data[0] 364 assert (element + element).key == (element + element).key 365 366 367def test_nout(): 368 func = delayed(lambda x: (x, -x), nout=2, pure=True) 369 x = func(1) 370 assert len(x) == 2 371 a, b = x 372 assert compute(a, b) == (1, -1) 373 assert a._length is None 374 assert b._length is None 375 pytest.raises(TypeError, lambda: len(a)) 376 pytest.raises(TypeError, lambda: list(a)) 377 378 pytest.raises(ValueError, lambda: delayed(add, nout=-1)) 379 pytest.raises(ValueError, lambda: delayed(add, nout=True)) 380 381 func = delayed(add, nout=None) 382 a = func(1) 383 assert a._length is None 384 pytest.raises(TypeError, lambda: list(a)) 385 pytest.raises(TypeError, lambda: len(a)) 386 387 func = delayed(lambda x: (x,), nout=1, pure=True) 388 x = func(1) 389 assert len(x) == 1 390 (a,) = x 391 assert a.compute() == 1 392 assert a._length is None 393 pytest.raises(TypeError, lambda: len(a)) 394 395 func = delayed(lambda x: tuple(), nout=0, pure=True) 396 x = func(1) 397 assert len(x) == 0 398 assert x.compute() == tuple() 399 400 401@pytest.mark.parametrize( 402 "x", 403 [[1, 2], (1, 2), (add, 1, 2), [], ()], 404) 405def test_nout_with_tasks(x): 406 length = len(x) 407 d = delayed(x, nout=length) 408 assert len(d) == len(list(d)) == length 409 assert d.compute() == x 410 411 412def test_kwargs(): 413 def mysum(a, b, c=(), **kwargs): 414 return a + b + sum(c) + sum(kwargs.values()) 415 416 dmysum = delayed(mysum) 417 ten = dmysum(1, 2, c=[delayed(3), 0], four=dmysum(2, 2)) 418 assert ten.compute() == 10 419 dmysum = delayed(mysum, pure=True) 420 c = [delayed(3), 0] 421 ten = dmysum(1, 2, c=c, four=dmysum(2, 2)) 422 assert ten.compute() == 10 423 assert dmysum(1, 2, c=c, four=dmysum(2, 2)).key == ten.key 424 assert dmysum(1, 2, c=c, four=dmysum(2, 3)).key != ten.key 425 assert dmysum(1, 2, c=c, four=4).key != ten.key 426 assert dmysum(1, 2, c=c, four=4).key != dmysum(2, 2, c=c, four=4).key 427 428 429def test_custom_delayed(): 430 x = Tuple({"a": 1, "b": 2, "c": (add, "a", "b")}, ["a", "b", "c"]) 431 x2 = delayed(add, pure=True)(x, (4, 5, 6)) 432 n = delayed(len, pure=True)(x) 433 assert delayed(len, pure=True)(x).key == n.key 434 assert x2.compute() == (1, 2, 3, 4, 5, 6) 435 assert compute(n, x2, x) == (3, (1, 2, 3, 4, 5, 6), (1, 2, 3)) 436 437 438@pytest.mark.filterwarnings("ignore:The dask.delayed:UserWarning") 439def test_array_delayed(): 440 np = pytest.importorskip("numpy") 441 da = pytest.importorskip("dask.array") 442 443 arr = np.arange(100).reshape((10, 10)) 444 darr = da.from_array(arr, chunks=(5, 5)) 445 val = delayed(sum)([arr, darr, 1]) 446 assert isinstance(val, Delayed) 447 assert np.allclose(val.compute(), arr + arr + 1) 448 assert val.sum().compute() == (arr + arr + 1).sum() 449 assert val[0, 0].compute() == (arr + arr + 1)[0, 0] 450 451 task, dsk = to_task_dask(darr) 452 assert not darr.dask.keys() - dsk.keys() 453 diff = dsk.keys() - darr.dask.keys() 454 assert len(diff) == 1 455 456 delayed_arr = delayed(darr) 457 assert (delayed_arr.compute() == arr).all() 458 459 460def test_array_bag_delayed(): 461 da = pytest.importorskip("dask.array") 462 np = pytest.importorskip("numpy") 463 464 arr1 = np.arange(100).reshape((10, 10)) 465 arr2 = arr1.dot(arr1.T) 466 darr1 = da.from_array(arr1, chunks=(5, 5)) 467 darr2 = da.from_array(arr2, chunks=(5, 5)) 468 b = db.from_sequence([1, 2, 3]) 469 seq = [arr1, arr2, darr1, darr2, b] 470 out = delayed(sum)([i.sum() for i in seq]) 471 assert out.compute() == 2 * arr1.sum() + 2 * arr2.sum() + sum([1, 2, 3]) 472 473 474def test_delayed_picklable(): 475 # Delayed 476 x = delayed(divmod, nout=2, pure=True)(1, 2) 477 y = pickle.loads(pickle.dumps(x)) 478 assert x.dask == y.dask 479 assert x._key == y._key 480 assert x._length == y._length 481 # DelayedLeaf 482 x = delayed(1j + 2) 483 y = pickle.loads(pickle.dumps(x)) 484 assert x.dask == y.dask 485 assert x._key == y._key 486 assert x._nout == y._nout 487 assert x._pure == y._pure 488 # DelayedAttr 489 x = x.real 490 y = pickle.loads(pickle.dumps(x)) 491 assert x._obj._key == y._obj._key 492 assert x._obj.dask == y._obj.dask 493 assert x._attr == y._attr 494 assert x._key == y._key 495 496 497def test_delayed_compute_forward_kwargs(): 498 x = delayed(1) + 2 499 x.compute(bogus_keyword=10) 500 501 502def test_delayed_method_descriptor(): 503 delayed(bytes.decode)(b"") # does not err 504 505 506def test_delayed_callable(): 507 f = delayed(add, pure=True) 508 v = f(1, 2) 509 assert v.dask == {v.key: (add, 1, 2)} 510 511 assert f.dask == {f.key: add} 512 assert f.compute() == add 513 514 515def test_delayed_name_on_call(): 516 f = delayed(add, pure=True) 517 assert f(1, 2, dask_key_name="foo")._key == "foo" 518 519 520def test_callable_obj(): 521 class Foo: 522 def __init__(self, a): 523 self.a = a 524 525 def __call__(self): 526 return 2 527 528 foo = Foo(1) 529 f = delayed(foo) 530 assert f.compute() is foo 531 assert f.a.compute() == 1 532 assert f().compute() == 2 533 534 535def identity(x): 536 return x 537 538 539def test_name_consistent_across_instances(): 540 func = delayed(identity, pure=True) 541 542 data = {"x": 1, "y": 25, "z": [1, 2, 3]} 543 assert func(data)._key == "identity-02129ed1acaffa7039deee80c5da547c" 544 545 data = {"x": 1, 1: "x"} 546 assert func(data)._key == func(data)._key 547 assert func(1)._key == "identity-ca2fae46a3b938016331acac1908ae45" 548 549 550def test_sensitive_to_partials(): 551 assert ( 552 delayed(partial(add, 10), pure=True)(2)._key 553 != delayed(partial(add, 20), pure=True)(2)._key 554 ) 555 556 557def test_delayed_name(): 558 assert delayed(1)._key.startswith("int-") 559 assert delayed(1, pure=True)._key.startswith("int-") 560 assert delayed(1, name="X")._key == "X" 561 562 def myfunc(x): 563 return x + 1 564 565 assert delayed(myfunc)(1).key.startswith("myfunc") 566 567 568def test_finalize_name(): 569 da = pytest.importorskip("dask.array") 570 571 x = da.ones(10, chunks=5) 572 v = delayed([x]) 573 assert set(x.dask).issubset(v.dask) 574 575 def key(s): 576 if isinstance(s, tuple): 577 s = s[0] 578 # Ignore _ in 'ones_like' 579 return s.split("-")[0].replace("_", "") 580 581 assert all(key(k).isalpha() for k in v.dask) 582 583 584def test_keys_from_array(): 585 da = pytest.importorskip("dask.array") 586 from dask.array.utils import _check_dsk 587 588 X = da.ones((10, 10), chunks=5).to_delayed().flatten() 589 xs = [delayed(inc)(x) for x in X] 590 591 _check_dsk(xs[0].dask) 592 593 594# Mostly copied from https://github.com/pytoolz/toolz/pull/220 595def test_delayed_decorator_on_method(): 596 class A: 597 BASE = 10 598 599 def __init__(self, base): 600 self.BASE = base 601 602 @delayed 603 def addmethod(self, x, y): 604 return self.BASE + x + y 605 606 @classmethod 607 @delayed 608 def addclass(cls, x, y): 609 return cls.BASE + x + y 610 611 @staticmethod 612 @delayed 613 def addstatic(x, y): 614 return x + y 615 616 a = A(100) 617 assert a.addmethod(3, 4).compute() == 107 618 assert A.addmethod(a, 3, 4).compute() == 107 619 620 assert a.addclass(3, 4).compute() == 17 621 assert A.addclass(3, 4).compute() == 17 622 623 assert a.addstatic(3, 4).compute() == 7 624 assert A.addstatic(3, 4).compute() == 7 625 626 # We want the decorated methods to be actual methods for instance methods 627 # and class methods since their first arguments are the object and the 628 # class respectively. Or in other words, the first argument is generated by 629 # the runtime based on the object/class before the dot. 630 assert isinstance(a.addmethod, types.MethodType) 631 assert isinstance(A.addclass, types.MethodType) 632 633 # For static methods (and regular functions), the decorated methods should 634 # be Delayed objects. 635 assert isinstance(A.addstatic, Delayed) 636 637 638def test_attribute_of_attribute(): 639 x = delayed(123) 640 assert isinstance(x.a, Delayed) 641 assert isinstance(x.a.b, Delayed) 642 assert isinstance(x.a.b.c, Delayed) 643 644 645def test_check_meta_flag(): 646 dd = pytest.importorskip("dask.dataframe") 647 from pandas import Series 648 649 a = Series(["a", "b", "a"], dtype="category") 650 b = Series(["a", "c", "a"], dtype="category") 651 da = delayed(lambda x: x)(a) 652 db = delayed(lambda x: x)(b) 653 654 c = dd.from_delayed([da, db], verify_meta=False) 655 dd.utils.assert_eq(c, c) 656 657 658def modlevel_eager(x): 659 return x + 1 660 661 662@delayed 663def modlevel_delayed1(x): 664 return x + 1 665 666 667@delayed(pure=False) 668def modlevel_delayed2(x): 669 return x + 1 670 671 672@pytest.mark.parametrize( 673 "f", 674 [ 675 delayed(modlevel_eager), 676 pytest.param(modlevel_delayed1, marks=pytest.mark.xfail(reason="#3369")), 677 pytest.param(modlevel_delayed2, marks=pytest.mark.xfail(reason="#3369")), 678 ], 679) 680def test_pickle(f): 681 d = f(2) 682 d = pickle.loads(pickle.dumps(d, protocol=pickle.HIGHEST_PROTOCOL)) 683 assert d.compute() == 3 684 685 686@pytest.mark.parametrize( 687 "f", [delayed(modlevel_eager), modlevel_delayed1, modlevel_delayed2] 688) 689def test_cloudpickle(f): 690 d = f(2) 691 d = cloudpickle.loads(cloudpickle.dumps(d, protocol=pickle.HIGHEST_PROTOCOL)) 692 assert d.compute() == 3 693 694 695def test_dask_layers(): 696 d1 = delayed(1) 697 assert d1.dask.layers.keys() == {d1.key} 698 assert d1.dask.dependencies == {d1.key: set()} 699 assert d1.__dask_layers__() == (d1.key,) 700 d2 = modlevel_delayed1(d1) 701 assert d2.dask.layers.keys() == {d1.key, d2.key} 702 assert d2.dask.dependencies == {d1.key: set(), d2.key: {d1.key}} 703 assert d2.__dask_layers__() == (d2.key,) 704 705 706def test_dask_layers_to_delayed(): 707 # da.Array.to_delayed squashes the dask graph and causes the layer name not to 708 # match the key 709 da = pytest.importorskip("dask.array") 710 d = da.ones(1).to_delayed()[0] 711 name = d.key[0] 712 assert d.key[1:] == (0,) 713 assert d.dask.layers.keys() == {"delayed-" + name} 714 assert d.dask.dependencies == {"delayed-" + name: set()} 715 assert d.__dask_layers__() == ("delayed-" + name,) 716 717 718def test_annotations_survive_optimization(): 719 with dask.annotate(foo="bar"): 720 graph = HighLevelGraph.from_collections( 721 "b", 722 {"a": 1, "b": (inc, "a"), "c": (inc, "b")}, 723 [], 724 ) 725 d = Delayed("b", graph) 726 727 assert type(d.dask) is HighLevelGraph 728 assert len(d.dask.layers) == 1 729 assert len(d.dask.layers["b"]) == 3 730 assert d.dask.layers["b"].annotations == {"foo": "bar"} 731 732 # Ensure optimizing a Delayed object returns a HighLevelGraph 733 # and doesn't loose annotations 734 (d_opt,) = dask.optimize(d) 735 assert type(d_opt.dask) is HighLevelGraph 736 assert len(d_opt.dask.layers) == 1 737 assert len(d_opt.dask.layers["b"]) == 2 # c is culled 738 assert d_opt.dask.layers["b"].annotations == {"foo": "bar"} 739