1import pickle
2import types
3from collections import namedtuple
4from functools import partial
5from operator import add, setitem
6from random import random
7
8import cloudpickle
9import pytest
10from tlz import merge
11
12import dask
13import dask.bag as db
14from dask import compute
15from dask.delayed import Delayed, delayed, to_task_dask
16from dask.highlevelgraph import HighLevelGraph
17from dask.utils_test import inc
18
19try:
20    from operator import matmul
21except ImportError:
22    matmul = None
23
24
25class Tuple:
26    __dask_scheduler__ = staticmethod(dask.threaded.get)
27
28    def __init__(self, dsk, keys):
29        self._dask = dsk
30        self._keys = keys
31
32    def __dask_tokenize__(self):
33        return self._keys
34
35    def __dask_graph__(self):
36        return self._dask
37
38    def __dask_keys__(self):
39        return self._keys
40
41    def __dask_postcompute__(self):
42        return tuple, ()
43
44
45@pytest.mark.filterwarnings("ignore:The dask.delayed:UserWarning")
46def test_to_task_dask():
47    a = delayed(1, name="a")
48    b = delayed(2, name="b")
49    task, dask = to_task_dask([a, b, 3])
50    assert task == ["a", "b", 3]
51
52    task, dask = to_task_dask((a, b, 3))
53    assert task == (tuple, ["a", "b", 3])
54    assert dict(dask) == merge(a.dask, b.dask)
55
56    task, dask = to_task_dask({a: 1, b: 2})
57    assert task == (dict, [["b", 2], ["a", 1]]) or task == (dict, [["a", 1], ["b", 2]])
58    assert dict(dask) == merge(a.dask, b.dask)
59
60    f = namedtuple("f", ["x", "y"])
61    x = f(1, 2)
62    task, dask = to_task_dask(x)
63    assert task == x
64    assert dict(dask) == {}
65
66    task, dask = to_task_dask(slice(a, b, 3))
67    assert task == (slice, "a", "b", 3)
68    assert dict(dask) == merge(a.dask, b.dask)
69
70    # Issue https://github.com/dask/dask/issues/2107
71    class MyClass(dict):
72        pass
73
74    task, dask = to_task_dask(MyClass())
75    assert type(task) is MyClass
76    assert dict(dask) == {}
77
78    # Custom dask objects
79    x = Tuple({"a": 1, "b": 2, "c": (add, "a", "b")}, ["a", "b", "c"])
80    task, dask = to_task_dask(x)
81    assert task in dask
82    f = dask.pop(task)
83    assert f == (tuple, ["a", "b", "c"])
84    assert dask == x._dask
85
86
87def test_delayed():
88    add2 = delayed(add)
89    assert add2(1, 2).compute() == 3
90    assert (add2(1, 2) + 3).compute() == 6
91    assert add2(add2(1, 2), 3).compute() == 6
92
93    a = delayed(1)
94    assert a.compute() == 1
95    assert 1 in a.dask.values()
96    b = add2(add2(a, 2), 3)
97    assert a.key in b.dask
98
99
100def test_delayed_with_dataclass():
101    dataclasses = pytest.importorskip("dataclasses")
102
103    # Avoid @dataclass decorator as Python < 3.7 fail to interpret the type hints
104    ADataClass = dataclasses.make_dataclass(
105        "ADataClass", [("a", int), ("b", int, dataclasses.field(init=False))]
106    )
107
108    literal = dask.delayed(3)
109    with_class = dask.delayed({"a": ADataClass(a=literal)})
110
111    def return_nested(obj):
112        return obj["a"].a
113
114    final = delayed(return_nested)(with_class)
115
116    assert final.compute() == 3
117
118
119def test_operators():
120    a = delayed([1, 2, 3])
121    assert a[0].compute() == 1
122    assert (a + a).compute() == [1, 2, 3, 1, 2, 3]
123    b = delayed(2)
124    assert a[:b].compute() == [1, 2]
125
126    a = delayed(10)
127    assert (a + 1).compute() == 11
128    assert (1 + a).compute() == 11
129    assert (a >> 1).compute() == 5
130    assert (a > 2).compute()
131    assert (a ** 2).compute() == 100
132
133    if matmul:
134
135        class dummy:
136            def __matmul__(self, other):
137                return 4
138
139        c = delayed(dummy())  # noqa
140        d = delayed(dummy())  # noqa
141
142        assert (eval("c @ d")).compute() == 4
143
144
145def test_methods():
146    a = delayed("a b c d e")
147    assert a.split(" ").compute() == ["a", "b", "c", "d", "e"]
148    assert a.upper().replace("B", "A").split().count("A").compute() == 2
149    assert a.split(" ", pure=True).key == a.split(" ", pure=True).key
150    o = a.split(" ", dask_key_name="test")
151    assert o.key == "test"
152
153
154def test_attributes():
155    a = delayed(2 + 1j)
156    assert a.real._key == a.real._key
157    assert a.real.compute() == 2
158    assert a.imag.compute() == 1
159    assert (a.real + a.imag).compute() == 3
160
161
162def test_method_getattr_call_same_task():
163    a = delayed([1, 2, 3])
164    o = a.index(1)
165    # Don't getattr the method, then call in separate task
166    assert getattr not in {v[0] for v in o.__dask_graph__().values()}
167
168
169def test_np_dtype_of_delayed():
170    # This used to result in a segfault due to recursion, see
171    # https://github.com/dask/dask/pull/4374#issuecomment-454381465
172    np = pytest.importorskip("numpy")
173    x = delayed(1)
174    with pytest.raises(TypeError):
175        np.dtype(x)
176    assert delayed(np.array([1], dtype="f8")).dtype.compute() == np.dtype("f8")
177
178
179def test_delayed_visualise_warn():
180    # Raise a warning when user calls visualise()
181    # instead of visualize()
182    def inc(x):
183        return x + 1
184
185    z = dask.delayed(inc)(1)
186    z.compute()
187
188    with pytest.warns(
189        UserWarning, match="dask.delayed objects have no `visualise` method"
190    ):
191        z.visualise(file_name="desk_graph.svg")
192
193    # with no args
194    with pytest.warns(
195        UserWarning, match="dask.delayed objects have no `visualise` method"
196    ):
197        z.visualise()
198
199
200def test_delayed_errors():
201    a = delayed([1, 2, 3])
202    # Immutable
203    pytest.raises(TypeError, lambda: setattr(a, "foo", 1))
204    pytest.raises(TypeError, lambda: setitem(a, 1, 0))
205    # Can't iterate, or check if contains
206    pytest.raises(TypeError, lambda: 1 in a)
207    pytest.raises(TypeError, lambda: list(a))
208    # No dynamic generation of magic/hidden methods
209    pytest.raises(AttributeError, lambda: a._hidden())
210    # Truth of delayed forbidden
211    pytest.raises(TypeError, lambda: bool(a))
212
213
214def test_common_subexpressions():
215    a = delayed([1, 2, 3])
216    res = a[0] + a[0]
217    assert a[0].key in res.dask
218    assert a.key in res.dask
219    assert len(res.dask) == 3
220
221
222def test_delayed_optimize():
223    x = Delayed("b", {"a": 1, "b": (inc, "a"), "c": (inc, "b")})
224    (x2,) = dask.optimize(x)
225    # Delayed's __dask_optimize__ culls out 'c'
226    assert sorted(x2.dask.keys()) == ["a", "b"]
227
228
229def test_lists():
230    a = delayed(1)
231    b = delayed(2)
232    c = delayed(sum)([a, b])
233    assert c.compute() == 3
234
235
236def test_literates():
237    a = delayed(1)
238    b = a + 1
239    lit = (a, b, 3)
240    assert delayed(lit).compute() == (1, 2, 3)
241    lit = [a, b, 3]
242    assert delayed(lit).compute() == [1, 2, 3]
243    lit = {a, b, 3}
244    assert delayed(lit).compute() == {1, 2, 3}
245    lit = {a: "a", b: "b", 3: "c"}
246    assert delayed(lit).compute() == {1: "a", 2: "b", 3: "c"}
247    assert delayed(lit)[a].compute() == "a"
248    lit = {"a": a, "b": b, "c": 3}
249    assert delayed(lit).compute() == {"a": 1, "b": 2, "c": 3}
250    assert delayed(lit)["a"].compute() == 1
251
252
253def test_literates_keys():
254    a = delayed(1)
255    b = a + 1
256    lit = (a, b, 3)
257    assert delayed(lit).key != delayed(lit).key
258    assert delayed(lit, pure=True).key == delayed(lit, pure=True).key
259
260
261def test_lists_are_concrete():
262    a = delayed(1)
263    b = delayed(2)
264    c = delayed(max)([[a, 10], [b, 20]], key=lambda x: x[0])[1]
265
266    assert c.compute() == 20
267
268
269def test_iterators():
270    a = delayed(1)
271    b = delayed(2)
272    c = delayed(sum)(iter([a, b]))
273
274    assert c.compute() == 3
275
276    def f(seq):
277        return sum(seq)
278
279    c = delayed(f)(iter([a, b]))
280    assert c.compute() == 3
281
282
283def test_traverse_false():
284    # Create a list with a dask value, and test that it's not computed
285    def fail(*args):
286        raise ValueError("shouldn't have computed")
287
288    a = delayed(fail)()
289
290    # list
291    x = [a, 1, 2, 3]
292    res = delayed(x, traverse=False).compute()
293    assert len(res) == 4
294    assert res[0] is a
295    assert res[1:] == x[1:]
296
297    # tuple that looks like a task
298    x = (fail, a, (fail, a))
299    res = delayed(x, traverse=False).compute()
300    assert isinstance(res, tuple)
301    assert res[0] == fail
302    assert res[1] is a
303
304    # list containing task-like-things
305    x = [1, (fail, a), a]
306    res = delayed(x, traverse=False).compute()
307    assert isinstance(res, list)
308    assert res[0] == 1
309    assert res[1][0] == fail and res[1][1] is a
310    assert res[2] is a
311
312    # traverse=False still hits top level
313    b = delayed(1)
314    x = delayed(b, traverse=False)
315    assert x.compute() == 1
316
317
318def test_pure():
319    v1 = delayed(add, pure=True)(1, 2)
320    v2 = delayed(add, pure=True)(1, 2)
321    assert v1.key == v2.key
322
323    myrand = delayed(random)
324    assert myrand().key != myrand().key
325
326
327def test_pure_global_setting():
328    # delayed functions
329    func = delayed(add)
330
331    with dask.config.set(delayed_pure=True):
332        assert func(1, 2).key == func(1, 2).key
333
334    with dask.config.set(delayed_pure=False):
335        assert func(1, 2).key != func(1, 2).key
336
337    func = delayed(add, pure=True)
338    with dask.config.set(delayed_pure=False):
339        assert func(1, 2).key == func(1, 2).key
340
341    # delayed objects
342    assert delayed(1).key != delayed(1).key
343    with dask.config.set(delayed_pure=True):
344        assert delayed(1).key == delayed(1).key
345
346    with dask.config.set(delayed_pure=False):
347        assert delayed(1, pure=True).key == delayed(1, pure=True).key
348
349    # delayed methods
350    data = delayed([1, 2, 3])
351    assert data.index(1).key != data.index(1).key
352
353    with dask.config.set(delayed_pure=True):
354        assert data.index(1).key == data.index(1).key
355        assert data.index(1, pure=False).key != data.index(1, pure=False).key
356
357    with dask.config.set(delayed_pure=False):
358        assert data.index(1, pure=True).key == data.index(1, pure=True).key
359
360    # magic methods always pure
361    with dask.config.set(delayed_pure=False):
362        assert data.index.key == data.index.key
363        element = data[0]
364        assert (element + element).key == (element + element).key
365
366
367def test_nout():
368    func = delayed(lambda x: (x, -x), nout=2, pure=True)
369    x = func(1)
370    assert len(x) == 2
371    a, b = x
372    assert compute(a, b) == (1, -1)
373    assert a._length is None
374    assert b._length is None
375    pytest.raises(TypeError, lambda: len(a))
376    pytest.raises(TypeError, lambda: list(a))
377
378    pytest.raises(ValueError, lambda: delayed(add, nout=-1))
379    pytest.raises(ValueError, lambda: delayed(add, nout=True))
380
381    func = delayed(add, nout=None)
382    a = func(1)
383    assert a._length is None
384    pytest.raises(TypeError, lambda: list(a))
385    pytest.raises(TypeError, lambda: len(a))
386
387    func = delayed(lambda x: (x,), nout=1, pure=True)
388    x = func(1)
389    assert len(x) == 1
390    (a,) = x
391    assert a.compute() == 1
392    assert a._length is None
393    pytest.raises(TypeError, lambda: len(a))
394
395    func = delayed(lambda x: tuple(), nout=0, pure=True)
396    x = func(1)
397    assert len(x) == 0
398    assert x.compute() == tuple()
399
400
401@pytest.mark.parametrize(
402    "x",
403    [[1, 2], (1, 2), (add, 1, 2), [], ()],
404)
405def test_nout_with_tasks(x):
406    length = len(x)
407    d = delayed(x, nout=length)
408    assert len(d) == len(list(d)) == length
409    assert d.compute() == x
410
411
412def test_kwargs():
413    def mysum(a, b, c=(), **kwargs):
414        return a + b + sum(c) + sum(kwargs.values())
415
416    dmysum = delayed(mysum)
417    ten = dmysum(1, 2, c=[delayed(3), 0], four=dmysum(2, 2))
418    assert ten.compute() == 10
419    dmysum = delayed(mysum, pure=True)
420    c = [delayed(3), 0]
421    ten = dmysum(1, 2, c=c, four=dmysum(2, 2))
422    assert ten.compute() == 10
423    assert dmysum(1, 2, c=c, four=dmysum(2, 2)).key == ten.key
424    assert dmysum(1, 2, c=c, four=dmysum(2, 3)).key != ten.key
425    assert dmysum(1, 2, c=c, four=4).key != ten.key
426    assert dmysum(1, 2, c=c, four=4).key != dmysum(2, 2, c=c, four=4).key
427
428
429def test_custom_delayed():
430    x = Tuple({"a": 1, "b": 2, "c": (add, "a", "b")}, ["a", "b", "c"])
431    x2 = delayed(add, pure=True)(x, (4, 5, 6))
432    n = delayed(len, pure=True)(x)
433    assert delayed(len, pure=True)(x).key == n.key
434    assert x2.compute() == (1, 2, 3, 4, 5, 6)
435    assert compute(n, x2, x) == (3, (1, 2, 3, 4, 5, 6), (1, 2, 3))
436
437
438@pytest.mark.filterwarnings("ignore:The dask.delayed:UserWarning")
439def test_array_delayed():
440    np = pytest.importorskip("numpy")
441    da = pytest.importorskip("dask.array")
442
443    arr = np.arange(100).reshape((10, 10))
444    darr = da.from_array(arr, chunks=(5, 5))
445    val = delayed(sum)([arr, darr, 1])
446    assert isinstance(val, Delayed)
447    assert np.allclose(val.compute(), arr + arr + 1)
448    assert val.sum().compute() == (arr + arr + 1).sum()
449    assert val[0, 0].compute() == (arr + arr + 1)[0, 0]
450
451    task, dsk = to_task_dask(darr)
452    assert not darr.dask.keys() - dsk.keys()
453    diff = dsk.keys() - darr.dask.keys()
454    assert len(diff) == 1
455
456    delayed_arr = delayed(darr)
457    assert (delayed_arr.compute() == arr).all()
458
459
460def test_array_bag_delayed():
461    da = pytest.importorskip("dask.array")
462    np = pytest.importorskip("numpy")
463
464    arr1 = np.arange(100).reshape((10, 10))
465    arr2 = arr1.dot(arr1.T)
466    darr1 = da.from_array(arr1, chunks=(5, 5))
467    darr2 = da.from_array(arr2, chunks=(5, 5))
468    b = db.from_sequence([1, 2, 3])
469    seq = [arr1, arr2, darr1, darr2, b]
470    out = delayed(sum)([i.sum() for i in seq])
471    assert out.compute() == 2 * arr1.sum() + 2 * arr2.sum() + sum([1, 2, 3])
472
473
474def test_delayed_picklable():
475    # Delayed
476    x = delayed(divmod, nout=2, pure=True)(1, 2)
477    y = pickle.loads(pickle.dumps(x))
478    assert x.dask == y.dask
479    assert x._key == y._key
480    assert x._length == y._length
481    # DelayedLeaf
482    x = delayed(1j + 2)
483    y = pickle.loads(pickle.dumps(x))
484    assert x.dask == y.dask
485    assert x._key == y._key
486    assert x._nout == y._nout
487    assert x._pure == y._pure
488    # DelayedAttr
489    x = x.real
490    y = pickle.loads(pickle.dumps(x))
491    assert x._obj._key == y._obj._key
492    assert x._obj.dask == y._obj.dask
493    assert x._attr == y._attr
494    assert x._key == y._key
495
496
497def test_delayed_compute_forward_kwargs():
498    x = delayed(1) + 2
499    x.compute(bogus_keyword=10)
500
501
502def test_delayed_method_descriptor():
503    delayed(bytes.decode)(b"")  # does not err
504
505
506def test_delayed_callable():
507    f = delayed(add, pure=True)
508    v = f(1, 2)
509    assert v.dask == {v.key: (add, 1, 2)}
510
511    assert f.dask == {f.key: add}
512    assert f.compute() == add
513
514
515def test_delayed_name_on_call():
516    f = delayed(add, pure=True)
517    assert f(1, 2, dask_key_name="foo")._key == "foo"
518
519
520def test_callable_obj():
521    class Foo:
522        def __init__(self, a):
523            self.a = a
524
525        def __call__(self):
526            return 2
527
528    foo = Foo(1)
529    f = delayed(foo)
530    assert f.compute() is foo
531    assert f.a.compute() == 1
532    assert f().compute() == 2
533
534
535def identity(x):
536    return x
537
538
539def test_name_consistent_across_instances():
540    func = delayed(identity, pure=True)
541
542    data = {"x": 1, "y": 25, "z": [1, 2, 3]}
543    assert func(data)._key == "identity-02129ed1acaffa7039deee80c5da547c"
544
545    data = {"x": 1, 1: "x"}
546    assert func(data)._key == func(data)._key
547    assert func(1)._key == "identity-ca2fae46a3b938016331acac1908ae45"
548
549
550def test_sensitive_to_partials():
551    assert (
552        delayed(partial(add, 10), pure=True)(2)._key
553        != delayed(partial(add, 20), pure=True)(2)._key
554    )
555
556
557def test_delayed_name():
558    assert delayed(1)._key.startswith("int-")
559    assert delayed(1, pure=True)._key.startswith("int-")
560    assert delayed(1, name="X")._key == "X"
561
562    def myfunc(x):
563        return x + 1
564
565    assert delayed(myfunc)(1).key.startswith("myfunc")
566
567
568def test_finalize_name():
569    da = pytest.importorskip("dask.array")
570
571    x = da.ones(10, chunks=5)
572    v = delayed([x])
573    assert set(x.dask).issubset(v.dask)
574
575    def key(s):
576        if isinstance(s, tuple):
577            s = s[0]
578        # Ignore _ in 'ones_like'
579        return s.split("-")[0].replace("_", "")
580
581    assert all(key(k).isalpha() for k in v.dask)
582
583
584def test_keys_from_array():
585    da = pytest.importorskip("dask.array")
586    from dask.array.utils import _check_dsk
587
588    X = da.ones((10, 10), chunks=5).to_delayed().flatten()
589    xs = [delayed(inc)(x) for x in X]
590
591    _check_dsk(xs[0].dask)
592
593
594# Mostly copied from https://github.com/pytoolz/toolz/pull/220
595def test_delayed_decorator_on_method():
596    class A:
597        BASE = 10
598
599        def __init__(self, base):
600            self.BASE = base
601
602        @delayed
603        def addmethod(self, x, y):
604            return self.BASE + x + y
605
606        @classmethod
607        @delayed
608        def addclass(cls, x, y):
609            return cls.BASE + x + y
610
611        @staticmethod
612        @delayed
613        def addstatic(x, y):
614            return x + y
615
616    a = A(100)
617    assert a.addmethod(3, 4).compute() == 107
618    assert A.addmethod(a, 3, 4).compute() == 107
619
620    assert a.addclass(3, 4).compute() == 17
621    assert A.addclass(3, 4).compute() == 17
622
623    assert a.addstatic(3, 4).compute() == 7
624    assert A.addstatic(3, 4).compute() == 7
625
626    # We want the decorated methods to be actual methods for instance methods
627    # and class methods since their first arguments are the object and the
628    # class respectively. Or in other words, the first argument is generated by
629    # the runtime based on the object/class before the dot.
630    assert isinstance(a.addmethod, types.MethodType)
631    assert isinstance(A.addclass, types.MethodType)
632
633    # For static methods (and regular functions), the decorated methods should
634    # be Delayed objects.
635    assert isinstance(A.addstatic, Delayed)
636
637
638def test_attribute_of_attribute():
639    x = delayed(123)
640    assert isinstance(x.a, Delayed)
641    assert isinstance(x.a.b, Delayed)
642    assert isinstance(x.a.b.c, Delayed)
643
644
645def test_check_meta_flag():
646    dd = pytest.importorskip("dask.dataframe")
647    from pandas import Series
648
649    a = Series(["a", "b", "a"], dtype="category")
650    b = Series(["a", "c", "a"], dtype="category")
651    da = delayed(lambda x: x)(a)
652    db = delayed(lambda x: x)(b)
653
654    c = dd.from_delayed([da, db], verify_meta=False)
655    dd.utils.assert_eq(c, c)
656
657
658def modlevel_eager(x):
659    return x + 1
660
661
662@delayed
663def modlevel_delayed1(x):
664    return x + 1
665
666
667@delayed(pure=False)
668def modlevel_delayed2(x):
669    return x + 1
670
671
672@pytest.mark.parametrize(
673    "f",
674    [
675        delayed(modlevel_eager),
676        pytest.param(modlevel_delayed1, marks=pytest.mark.xfail(reason="#3369")),
677        pytest.param(modlevel_delayed2, marks=pytest.mark.xfail(reason="#3369")),
678    ],
679)
680def test_pickle(f):
681    d = f(2)
682    d = pickle.loads(pickle.dumps(d, protocol=pickle.HIGHEST_PROTOCOL))
683    assert d.compute() == 3
684
685
686@pytest.mark.parametrize(
687    "f", [delayed(modlevel_eager), modlevel_delayed1, modlevel_delayed2]
688)
689def test_cloudpickle(f):
690    d = f(2)
691    d = cloudpickle.loads(cloudpickle.dumps(d, protocol=pickle.HIGHEST_PROTOCOL))
692    assert d.compute() == 3
693
694
695def test_dask_layers():
696    d1 = delayed(1)
697    assert d1.dask.layers.keys() == {d1.key}
698    assert d1.dask.dependencies == {d1.key: set()}
699    assert d1.__dask_layers__() == (d1.key,)
700    d2 = modlevel_delayed1(d1)
701    assert d2.dask.layers.keys() == {d1.key, d2.key}
702    assert d2.dask.dependencies == {d1.key: set(), d2.key: {d1.key}}
703    assert d2.__dask_layers__() == (d2.key,)
704
705
706def test_dask_layers_to_delayed():
707    # da.Array.to_delayed squashes the dask graph and causes the layer name not to
708    # match the key
709    da = pytest.importorskip("dask.array")
710    d = da.ones(1).to_delayed()[0]
711    name = d.key[0]
712    assert d.key[1:] == (0,)
713    assert d.dask.layers.keys() == {"delayed-" + name}
714    assert d.dask.dependencies == {"delayed-" + name: set()}
715    assert d.__dask_layers__() == ("delayed-" + name,)
716
717
718def test_annotations_survive_optimization():
719    with dask.annotate(foo="bar"):
720        graph = HighLevelGraph.from_collections(
721            "b",
722            {"a": 1, "b": (inc, "a"), "c": (inc, "b")},
723            [],
724        )
725        d = Delayed("b", graph)
726
727    assert type(d.dask) is HighLevelGraph
728    assert len(d.dask.layers) == 1
729    assert len(d.dask.layers["b"]) == 3
730    assert d.dask.layers["b"].annotations == {"foo": "bar"}
731
732    # Ensure optimizing a Delayed object returns a HighLevelGraph
733    # and doesn't loose annotations
734    (d_opt,) = dask.optimize(d)
735    assert type(d_opt.dask) is HighLevelGraph
736    assert len(d_opt.dask.layers) == 1
737    assert len(d_opt.dask.layers["b"]) == 2  # c is culled
738    assert d_opt.dask.layers["b"].annotations == {"foo": "bar"}
739