1import datetime
2import functools
3import operator
4import pickle
5
6import pytest
7from tlz import curry
8
9from dask import get
10from dask.highlevelgraph import HighLevelGraph
11from dask.optimization import SubgraphCallable
12from dask.utils import (
13    Dispatch,
14    M,
15    SerializableLock,
16    _deprecated,
17    asciitable,
18    derived_from,
19    ensure_dict,
20    extra_titles,
21    format_bytes,
22    funcname,
23    getargspec,
24    has_keyword,
25    ignoring,
26    is_arraylike,
27    itemgetter,
28    iter_chunks,
29    memory_repr,
30    methodcaller,
31    ndeepmap,
32    noop_context,
33    parse_bytes,
34    parse_timedelta,
35    partial_by_order,
36    random_state_data,
37    skip_doctest,
38    stringify,
39    stringify_collection_keys,
40    takes_multiple_arguments,
41    typename,
42)
43from dask.utils_test import inc
44
45
46def test_getargspec():
47    def func(x, y):
48        pass
49
50    assert getargspec(func).args == ["x", "y"]
51
52    func2 = functools.partial(func, 2)
53    # this is a bit of a lie, but maybe close enough
54    assert getargspec(func2).args == ["x", "y"]
55
56    def wrapper(*args, **kwargs):
57        pass
58
59    wrapper.__wrapped__ = func
60    assert getargspec(wrapper).args == ["x", "y"]
61
62    class MyType:
63        def __init__(self, x, y):
64            pass
65
66    assert getargspec(MyType).args == ["self", "x", "y"]
67
68
69def test_takes_multiple_arguments():
70    assert takes_multiple_arguments(map)
71    assert not takes_multiple_arguments(sum)
72
73    def multi(a, b, c):
74        return a, b, c
75
76    class Singular:
77        def __init__(self, a):
78            pass
79
80    class Multi:
81        def __init__(self, a, b):
82            pass
83
84    assert takes_multiple_arguments(multi)
85    assert not takes_multiple_arguments(Singular)
86    assert takes_multiple_arguments(Multi)
87
88    def f():
89        pass
90
91    assert not takes_multiple_arguments(f)
92
93    def vararg(*args):
94        pass
95
96    assert takes_multiple_arguments(vararg)
97    assert not takes_multiple_arguments(vararg, varargs=False)
98
99
100def test_dispatch():
101    foo = Dispatch()
102    foo.register(int, lambda a: a + 1)
103    foo.register(float, lambda a: a - 1)
104    foo.register(tuple, lambda a: tuple(foo(i) for i in a))
105
106    def f(a):
107        """My Docstring"""
108        return a
109
110    foo.register(object, f)
111
112    class Bar:
113        pass
114
115    b = Bar()
116    assert foo(1) == 2
117    assert foo.dispatch(int)(1) == 2
118    assert foo(1.0) == 0.0
119    assert foo(b) == b
120    assert foo((1, 2.0, b)) == (2, 1.0, b)
121
122    assert foo.__doc__ == f.__doc__
123
124
125def test_dispatch_kwargs():
126    foo = Dispatch()
127    foo.register(int, lambda a, b=10: a + b)
128
129    assert foo(1, b=20) == 21
130
131
132def test_dispatch_variadic_on_first_argument():
133    foo = Dispatch()
134    foo.register(int, lambda a, b: a + b)
135    foo.register(float, lambda a, b: a - b)
136
137    assert foo(1, 2) == 3
138    assert foo(1.0, 2.0) == -1
139
140
141def test_dispatch_lazy():
142    # this tests the recursive component of dispatch
143    foo = Dispatch()
144    foo.register(int, lambda a: a)
145
146    import decimal
147
148    # keep it outside lazy dec for test
149    def foo_dec(a):
150        return a + 1
151
152    @foo.register_lazy("decimal")
153    def register_decimal():
154        import decimal
155
156        foo.register(decimal.Decimal, foo_dec)
157
158    # This test needs to be *before* any other calls
159    assert foo.dispatch(decimal.Decimal) == foo_dec
160    assert foo(decimal.Decimal(1)) == decimal.Decimal(2)
161    assert foo(1) == 1
162
163
164def test_dispatch_lazy_walks_mro():
165    """Check that subclasses of classes with lazily registered handlers still
166    use their parent class's handler by default"""
167    import decimal
168
169    class Lazy(decimal.Decimal):
170        pass
171
172    class Eager(Lazy):
173        pass
174
175    foo = Dispatch()
176
177    @foo.register(Eager)
178    def eager_handler(x):
179        return "eager"
180
181    def lazy_handler(a):
182        return "lazy"
183
184    @foo.register_lazy("decimal")
185    def register_decimal():
186        foo.register(decimal.Decimal, lazy_handler)
187
188    assert foo.dispatch(Lazy) == lazy_handler
189    assert foo(Lazy(1)) == "lazy"
190    assert foo.dispatch(decimal.Decimal) == lazy_handler
191    assert foo(decimal.Decimal(1)) == "lazy"
192    assert foo.dispatch(Eager) == eager_handler
193    assert foo(Eager(1)) == "eager"
194
195
196def test_random_state_data():
197    np = pytest.importorskip("numpy")
198    seed = 37
199    state = np.random.RandomState(seed)
200    n = 10000
201
202    # Use an integer
203    states = random_state_data(n, seed)
204    assert len(states) == n
205
206    # Use RandomState object
207    states2 = random_state_data(n, state)
208    for s1, s2 in zip(states, states2):
209        assert s1.shape == (624,)
210        assert (s1 == s2).all()
211
212    # Consistent ordering
213    states = random_state_data(10, 1234)
214    states2 = random_state_data(20, 1234)[:10]
215
216    for s1, s2 in zip(states, states2):
217        assert (s1 == s2).all()
218
219
220def test_memory_repr():
221    for power, mem_repr in enumerate(["1.0 bytes", "1.0 KB", "1.0 MB", "1.0 GB"]):
222        assert memory_repr(1024 ** power) == mem_repr
223
224
225def test_method_caller():
226    a = [1, 2, 3, 3, 3]
227    f = methodcaller("count")
228    assert f(a, 3) == a.count(3)
229    assert methodcaller("count") is f
230    assert M.count is f
231    assert pickle.loads(pickle.dumps(f)) is f
232    assert "count" in dir(M)
233
234    assert "count" in str(methodcaller("count"))
235    assert "count" in repr(methodcaller("count"))
236
237
238def test_skip_doctest():
239    example = """>>> xxx
240>>>
241>>> # comment
242>>> xxx"""
243
244    res = skip_doctest(example)
245    assert (
246        res
247        == """>>> xxx  # doctest: +SKIP
248>>>
249>>> # comment
250>>> xxx  # doctest: +SKIP"""
251    )
252
253    assert skip_doctest(None) == ""
254
255    example = """
256>>> 1 + 2  # doctest: +ELLIPSES
2573"""
258
259    expected = """
260>>> 1 + 2  # doctest: +ELLIPSES, +SKIP
2613"""
262    res = skip_doctest(example)
263    assert res == expected
264
265
266def test_extra_titles():
267    example = """
268
269    Notes
270    -----
271    hello
272
273    Foo
274    ---
275
276    Notes
277    -----
278    bar
279    """
280
281    expected = """
282
283    Notes
284    -----
285    hello
286
287    Foo
288    ---
289
290    Extra Notes
291    -----------
292    bar
293    """
294
295    assert extra_titles(example) == expected
296
297
298def test_asciitable():
299    res = asciitable(
300        ["fruit", "color"],
301        [("apple", "red"), ("banana", "yellow"), ("tomato", "red"), ("pear", "green")],
302    )
303    assert res == (
304        "+--------+--------+\n"
305        "| fruit  | color  |\n"
306        "+--------+--------+\n"
307        "| apple  | red    |\n"
308        "| banana | yellow |\n"
309        "| tomato | red    |\n"
310        "| pear   | green  |\n"
311        "+--------+--------+"
312    )
313
314
315def test_SerializableLock():
316    a = SerializableLock()
317    b = SerializableLock()
318    with a:
319        pass
320
321    with a:
322        with b:
323            pass
324
325    with a:
326        assert not a.acquire(False)
327
328    a2 = pickle.loads(pickle.dumps(a))
329    a3 = pickle.loads(pickle.dumps(a))
330    a4 = pickle.loads(pickle.dumps(a2))
331
332    for x in [a, a2, a3, a4]:
333        for y in [a, a2, a3, a4]:
334            with x:
335                assert not y.acquire(False)
336
337    b2 = pickle.loads(pickle.dumps(b))
338    b3 = pickle.loads(pickle.dumps(b2))
339
340    for x in [a, a2, a3, a4]:
341        for y in [b, b2, b3]:
342            with x:
343                with y:
344                    pass
345            with y:
346                with x:
347                    pass
348
349
350def test_SerializableLock_name_collision():
351    a = SerializableLock("a")
352    b = SerializableLock("b")
353    c = SerializableLock("a")
354    d = SerializableLock()
355
356    assert a.lock is not b.lock
357    assert a.lock is c.lock
358    assert d.lock not in (a.lock, b.lock, c.lock)
359
360
361def test_SerializableLock_locked():
362    a = SerializableLock("a")
363    assert not a.locked()
364    with a:
365        assert a.locked()
366    assert not a.locked()
367
368
369def test_SerializableLock_acquire_blocking():
370    a = SerializableLock("a")
371    assert a.acquire(blocking=True)
372    assert not a.acquire(blocking=False)
373    a.release()
374
375
376def test_funcname():
377    def foo(a, b, c):
378        pass
379
380    assert funcname(foo) == "foo"
381    assert funcname(functools.partial(foo, a=1)) == "foo"
382    assert funcname(M.sum) == "sum"
383    assert funcname(lambda: 1) == "lambda"
384
385    class Foo:
386        pass
387
388    assert funcname(Foo) == "Foo"
389    assert "Foo" in funcname(Foo())
390
391
392def test_funcname_long():
393    def a_long_function_name_11111111111111111111111111111111111111111111111():
394        pass
395
396    result = funcname(
397        a_long_function_name_11111111111111111111111111111111111111111111111
398    )
399    assert "a_long_function_name" in result
400    assert len(result) < 60
401
402
403def test_funcname_toolz():
404    @curry
405    def foo(a, b, c):
406        pass
407
408    assert funcname(foo) == "foo"
409    assert funcname(foo(1)) == "foo"
410
411
412def test_funcname_multipledispatch():
413    md = pytest.importorskip("multipledispatch")
414
415    @md.dispatch(int, int, int)
416    def foo(a, b, c):
417        pass
418
419    assert funcname(foo) == "foo"
420    assert funcname(functools.partial(foo, a=1)) == "foo"
421
422
423def test_funcname_numpy_vectorize():
424    np = pytest.importorskip("numpy")
425
426    vfunc = np.vectorize(int)
427    assert funcname(vfunc) == "vectorize_int"
428
429    # Regression test for https://github.com/pydata/xarray/issues/3303
430    # Partial functions don't have a __name__ attribute
431    func = functools.partial(np.add, out=None)
432    vfunc = np.vectorize(func)
433    assert funcname(vfunc) == "vectorize_add"
434
435
436def test_ndeepmap():
437    L = 1
438    assert ndeepmap(0, inc, L) == 2
439
440    L = [1]
441    assert ndeepmap(0, inc, L) == 2
442
443    L = [1, 2, 3]
444    assert ndeepmap(1, inc, L) == [2, 3, 4]
445
446    L = [[1, 2], [3, 4]]
447    assert ndeepmap(2, inc, L) == [[2, 3], [4, 5]]
448
449    L = [[[1, 2], [3, 4, 5]], [[6], []]]
450    assert ndeepmap(3, inc, L) == [[[2, 3], [4, 5, 6]], [[7], []]]
451
452
453def test_ensure_dict():
454    d = {"x": 1}
455    assert ensure_dict(d) is d
456
457    class mydict(dict):
458        pass
459
460    d2 = ensure_dict(d, copy=True)
461    d3 = ensure_dict(HighLevelGraph.from_collections("x", d))
462    d4 = ensure_dict(mydict(d))
463
464    for di in (d2, d3, d4):
465        assert type(di) is dict
466        assert di is not d
467        assert di == d
468
469
470def test_itemgetter():
471    data = [1, 2, 3]
472    g = itemgetter(1)
473    assert g(data) == 2
474    g2 = pickle.loads(pickle.dumps(g))
475    assert g2(data) == 2
476    assert g2.index == 1
477
478    assert itemgetter(1) == itemgetter(1)
479    assert itemgetter(1) != itemgetter(2)
480    assert itemgetter(1) != 123
481
482
483def test_partial_by_order():
484    assert partial_by_order(5, function=operator.add, other=[(1, 20)]) == 25
485
486
487def test_has_keyword():
488    def foo(a, b, c=None):
489        pass
490
491    assert has_keyword(foo, "a")
492    assert has_keyword(foo, "b")
493    assert has_keyword(foo, "c")
494
495    bar = functools.partial(foo, a=1)
496    assert has_keyword(bar, "b")
497    assert has_keyword(bar, "c")
498
499
500def test_derived_from():
501    class Foo:
502        def f(a, b):
503            """A super docstring
504
505            An explanation
506
507            Parameters
508            ----------
509            a: int
510                an explanation of a
511            b: float
512                an explanation of b
513            """
514
515    class Bar:
516        @derived_from(Foo)
517        def f(a, c):
518            pass
519
520    class Zap:
521        @derived_from(Foo)
522        def f(a, c):
523            "extra docstring"
524            pass
525
526    assert Bar.f.__doc__.strip().startswith("A super docstring")
527    assert "Foo.f" in Bar.f.__doc__
528    assert any("inconsistencies" in line for line in Bar.f.__doc__.split("\n")[:7])
529
530    [b_arg] = [line for line in Bar.f.__doc__.split("\n") if "b:" in line]
531    assert "not supported" in b_arg.lower()
532    assert "dask" in b_arg.lower()
533
534    assert "  extra docstring\n\n" in Zap.f.__doc__
535
536
537def test_derived_from_func():
538    import builtins
539
540    @derived_from(builtins)
541    def sum():
542        "extra docstring"
543        pass
544
545    assert "extra docstring\n\n" in sum.__doc__
546    assert "Return the sum of" in sum.__doc__
547    assert "This docstring was copied from builtins.sum" in sum.__doc__
548
549
550def test_derived_from_dask_dataframe():
551    dd = pytest.importorskip("dask.dataframe")
552
553    assert "inconsistencies" in dd.DataFrame.dropna.__doc__
554
555    [axis_arg] = [
556        line for line in dd.DataFrame.dropna.__doc__.split("\n") if "axis :" in line
557    ]
558    assert "not supported" in axis_arg.lower()
559    assert "dask" in axis_arg.lower()
560
561
562def test_parse_bytes():
563    assert parse_bytes("100") == 100
564    assert parse_bytes("100 MB") == 100000000
565    assert parse_bytes("100M") == 100000000
566    assert parse_bytes("5kB") == 5000
567    assert parse_bytes("5.4 kB") == 5400
568    assert parse_bytes("1kiB") == 1024
569    assert parse_bytes("1Mi") == 2 ** 20
570    assert parse_bytes("1e6") == 1000000
571    assert parse_bytes("1e6 kB") == 1000000000
572    assert parse_bytes("MB") == 1000000
573    assert parse_bytes(123) == 123
574    assert parse_bytes(".5GB") == 500000000
575
576
577def test_parse_timedelta():
578    for text, value in [
579        ("1s", 1),
580        ("100ms", 0.1),
581        ("5S", 5),
582        ("5.5s", 5.5),
583        ("5.5 s", 5.5),
584        ("1 second", 1),
585        ("3.3 seconds", 3.3),
586        ("3.3 milliseconds", 0.0033),
587        ("3500 us", 0.0035),
588        ("1 ns", 1e-9),
589        ("2m", 120),
590        ("2 minutes", 120),
591        (None, None),
592        (3, 3),
593        (datetime.timedelta(seconds=2), 2),
594        (datetime.timedelta(milliseconds=100), 0.1),
595    ]:
596        result = parse_timedelta(text)
597        assert result == value or abs(result - value) < 1e-14
598
599    assert parse_timedelta("1ms", default="seconds") == 0.001
600    assert parse_timedelta("1", default="seconds") == 1
601    assert parse_timedelta("1", default="ms") == 0.001
602    assert parse_timedelta(1, default="ms") == 0.001
603
604
605def test_is_arraylike():
606    np = pytest.importorskip("numpy")
607
608    assert is_arraylike(0) is False
609    assert is_arraylike(()) is False
610    assert is_arraylike(0) is False
611    assert is_arraylike([]) is False
612    assert is_arraylike([0]) is False
613
614    assert is_arraylike(np.empty(())) is True
615    assert is_arraylike(np.empty((0,))) is True
616    assert is_arraylike(np.empty((0, 0))) is True
617
618
619def test_iter_chunks():
620    sizes = [14, 8, 5, 9, 7, 9, 1, 19, 8, 19]
621    assert list(iter_chunks(sizes, 19)) == [
622        [14],
623        [8, 5],
624        [9, 7],
625        [9, 1],
626        [19],
627        [8],
628        [19],
629    ]
630    assert list(iter_chunks(sizes, 28)) == [[14, 8, 5], [9, 7, 9, 1], [19, 8], [19]]
631    assert list(iter_chunks(sizes, 67)) == [[14, 8, 5, 9, 7, 9, 1], [19, 8, 19]]
632
633
634def test_stringify():
635    obj = "Hello"
636    assert stringify(obj) is obj
637    obj = b"Hello"
638    assert stringify(obj) is obj
639    dsk = {"x": 1}
640
641    assert stringify(dsk) == str(dsk)
642    assert stringify(dsk, exclusive=()) == dsk
643
644    dsk = {("x", 1): (inc, 1)}
645    assert stringify(dsk) == str({("x", 1): (inc, 1)})
646    assert stringify(dsk, exclusive=()) == {("x", 1): (inc, 1)}
647
648    dsk = {("x", 1): (inc, 1), ("x", 2): (inc, ("x", 1))}
649    assert stringify(dsk, exclusive=dsk) == {
650        ("x", 1): (inc, 1),
651        ("x", 2): (inc, str(("x", 1))),
652    }
653
654    dsks = [
655        {"x": 1},
656        {("x", 1): (inc, 1), ("x", 2): (inc, ("x", 1))},
657        {("x", 1): (sum, [1, 2, 3]), ("x", 2): (sum, [("x", 1), ("x", 1)])},
658    ]
659    for dsk in dsks:
660        sdsk = {stringify(k): stringify(v, exclusive=dsk) for k, v in dsk.items()}
661        keys = list(dsk)
662        skeys = [str(k) for k in keys]
663        assert all(isinstance(k, str) for k in sdsk)
664        assert get(dsk, keys) == get(sdsk, skeys)
665
666    dsk = {("y", 1): (SubgraphCallable({"x": ("y", 1)}, "x", (("y", 1),)), (("z", 1),))}
667    dsk = stringify(dsk, exclusive=set(dsk) | {("z", 1)})
668    assert dsk[("y", 1)][0].dsk["x"] == "('y', 1)"
669    assert dsk[("y", 1)][1][0] == "('z', 1)"
670
671
672def test_stringify_collection_keys():
673    obj = "Hello"
674    assert stringify_collection_keys(obj) is obj
675
676    obj = [("a", 0), (b"a", 0), (1, 1)]
677    res = stringify_collection_keys(obj)
678    assert res[0] == str(obj[0])
679    assert res[1] == str(obj[1])
680    assert res[2] == obj[2]
681
682
683@pytest.mark.parametrize(
684    "n,expect",
685    [
686        (0, "0 B"),
687        (920, "920 B"),
688        (930, "0.91 kiB"),
689        (921.23 * 2 ** 10, "921.23 kiB"),
690        (931.23 * 2 ** 10, "0.91 MiB"),
691        (921.23 * 2 ** 20, "921.23 MiB"),
692        (931.23 * 2 ** 20, "0.91 GiB"),
693        (921.23 * 2 ** 30, "921.23 GiB"),
694        (931.23 * 2 ** 30, "0.91 TiB"),
695        (921.23 * 2 ** 40, "921.23 TiB"),
696        (931.23 * 2 ** 40, "0.91 PiB"),
697        (2 ** 60, "1024.00 PiB"),
698    ],
699)
700def test_format_bytes(n, expect):
701    assert format_bytes(int(n)) == expect
702
703
704def test_deprecated():
705    @_deprecated()
706    def foo():
707        return "bar"
708
709    with pytest.warns(FutureWarning) as record:
710        assert foo() == "bar"
711
712    assert len(record) == 1
713    msg = str(record[0].message)
714    assert "foo is deprecated" in msg
715    assert "removed in a future release" in msg
716
717
718def test_deprecated_version():
719    @_deprecated(version="1.2.3")
720    def foo():
721        return "bar"
722
723    with pytest.warns(FutureWarning, match="deprecated in version 1.2.3"):
724        assert foo() == "bar"
725
726
727def test_deprecated_after_version():
728    @_deprecated(after_version="1.2.3")
729    def foo():
730        return "bar"
731
732    with pytest.warns(FutureWarning, match="deprecated after version 1.2.3"):
733        assert foo() == "bar"
734
735
736def test_deprecated_category():
737    @_deprecated(category=DeprecationWarning)
738    def foo():
739        return "bar"
740
741    with pytest.warns(DeprecationWarning):
742        assert foo() == "bar"
743
744
745def test_deprecated_message():
746    @_deprecated(message="woohoo")
747    def foo():
748        return "bar"
749
750    with pytest.warns(FutureWarning) as record:
751        assert foo() == "bar"
752
753    assert len(record) == 1
754    assert str(record[0].message) == "woohoo"
755
756
757def test_ignoring_deprecated():
758    with pytest.warns(FutureWarning, match="contextlib.suppress"):
759        with ignoring(ValueError):
760            pass
761
762
763def test_noop_context_deprecated():
764    with pytest.warns(FutureWarning, match="contextlib.nullcontext"):
765        with noop_context():
766            pass
767
768
769def test_typename():
770    assert typename(HighLevelGraph) == "dask.highlevelgraph.HighLevelGraph"
771    assert typename(HighLevelGraph, short=True) == "dask.HighLevelGraph"
772
773
774class MyType:
775    pass
776
777
778def test_typename_on_instances():
779    instance = MyType()
780    assert typename(instance) == typename(MyType)
781