1import datetime 2import functools 3import operator 4import pickle 5 6import pytest 7from tlz import curry 8 9from dask import get 10from dask.highlevelgraph import HighLevelGraph 11from dask.optimization import SubgraphCallable 12from dask.utils import ( 13 Dispatch, 14 M, 15 SerializableLock, 16 _deprecated, 17 asciitable, 18 derived_from, 19 ensure_dict, 20 extra_titles, 21 format_bytes, 22 funcname, 23 getargspec, 24 has_keyword, 25 ignoring, 26 is_arraylike, 27 itemgetter, 28 iter_chunks, 29 memory_repr, 30 methodcaller, 31 ndeepmap, 32 noop_context, 33 parse_bytes, 34 parse_timedelta, 35 partial_by_order, 36 random_state_data, 37 skip_doctest, 38 stringify, 39 stringify_collection_keys, 40 takes_multiple_arguments, 41 typename, 42) 43from dask.utils_test import inc 44 45 46def test_getargspec(): 47 def func(x, y): 48 pass 49 50 assert getargspec(func).args == ["x", "y"] 51 52 func2 = functools.partial(func, 2) 53 # this is a bit of a lie, but maybe close enough 54 assert getargspec(func2).args == ["x", "y"] 55 56 def wrapper(*args, **kwargs): 57 pass 58 59 wrapper.__wrapped__ = func 60 assert getargspec(wrapper).args == ["x", "y"] 61 62 class MyType: 63 def __init__(self, x, y): 64 pass 65 66 assert getargspec(MyType).args == ["self", "x", "y"] 67 68 69def test_takes_multiple_arguments(): 70 assert takes_multiple_arguments(map) 71 assert not takes_multiple_arguments(sum) 72 73 def multi(a, b, c): 74 return a, b, c 75 76 class Singular: 77 def __init__(self, a): 78 pass 79 80 class Multi: 81 def __init__(self, a, b): 82 pass 83 84 assert takes_multiple_arguments(multi) 85 assert not takes_multiple_arguments(Singular) 86 assert takes_multiple_arguments(Multi) 87 88 def f(): 89 pass 90 91 assert not takes_multiple_arguments(f) 92 93 def vararg(*args): 94 pass 95 96 assert takes_multiple_arguments(vararg) 97 assert not takes_multiple_arguments(vararg, varargs=False) 98 99 100def test_dispatch(): 101 foo = Dispatch() 102 foo.register(int, lambda a: a + 1) 103 foo.register(float, lambda a: a - 1) 104 foo.register(tuple, lambda a: tuple(foo(i) for i in a)) 105 106 def f(a): 107 """My Docstring""" 108 return a 109 110 foo.register(object, f) 111 112 class Bar: 113 pass 114 115 b = Bar() 116 assert foo(1) == 2 117 assert foo.dispatch(int)(1) == 2 118 assert foo(1.0) == 0.0 119 assert foo(b) == b 120 assert foo((1, 2.0, b)) == (2, 1.0, b) 121 122 assert foo.__doc__ == f.__doc__ 123 124 125def test_dispatch_kwargs(): 126 foo = Dispatch() 127 foo.register(int, lambda a, b=10: a + b) 128 129 assert foo(1, b=20) == 21 130 131 132def test_dispatch_variadic_on_first_argument(): 133 foo = Dispatch() 134 foo.register(int, lambda a, b: a + b) 135 foo.register(float, lambda a, b: a - b) 136 137 assert foo(1, 2) == 3 138 assert foo(1.0, 2.0) == -1 139 140 141def test_dispatch_lazy(): 142 # this tests the recursive component of dispatch 143 foo = Dispatch() 144 foo.register(int, lambda a: a) 145 146 import decimal 147 148 # keep it outside lazy dec for test 149 def foo_dec(a): 150 return a + 1 151 152 @foo.register_lazy("decimal") 153 def register_decimal(): 154 import decimal 155 156 foo.register(decimal.Decimal, foo_dec) 157 158 # This test needs to be *before* any other calls 159 assert foo.dispatch(decimal.Decimal) == foo_dec 160 assert foo(decimal.Decimal(1)) == decimal.Decimal(2) 161 assert foo(1) == 1 162 163 164def test_dispatch_lazy_walks_mro(): 165 """Check that subclasses of classes with lazily registered handlers still 166 use their parent class's handler by default""" 167 import decimal 168 169 class Lazy(decimal.Decimal): 170 pass 171 172 class Eager(Lazy): 173 pass 174 175 foo = Dispatch() 176 177 @foo.register(Eager) 178 def eager_handler(x): 179 return "eager" 180 181 def lazy_handler(a): 182 return "lazy" 183 184 @foo.register_lazy("decimal") 185 def register_decimal(): 186 foo.register(decimal.Decimal, lazy_handler) 187 188 assert foo.dispatch(Lazy) == lazy_handler 189 assert foo(Lazy(1)) == "lazy" 190 assert foo.dispatch(decimal.Decimal) == lazy_handler 191 assert foo(decimal.Decimal(1)) == "lazy" 192 assert foo.dispatch(Eager) == eager_handler 193 assert foo(Eager(1)) == "eager" 194 195 196def test_random_state_data(): 197 np = pytest.importorskip("numpy") 198 seed = 37 199 state = np.random.RandomState(seed) 200 n = 10000 201 202 # Use an integer 203 states = random_state_data(n, seed) 204 assert len(states) == n 205 206 # Use RandomState object 207 states2 = random_state_data(n, state) 208 for s1, s2 in zip(states, states2): 209 assert s1.shape == (624,) 210 assert (s1 == s2).all() 211 212 # Consistent ordering 213 states = random_state_data(10, 1234) 214 states2 = random_state_data(20, 1234)[:10] 215 216 for s1, s2 in zip(states, states2): 217 assert (s1 == s2).all() 218 219 220def test_memory_repr(): 221 for power, mem_repr in enumerate(["1.0 bytes", "1.0 KB", "1.0 MB", "1.0 GB"]): 222 assert memory_repr(1024 ** power) == mem_repr 223 224 225def test_method_caller(): 226 a = [1, 2, 3, 3, 3] 227 f = methodcaller("count") 228 assert f(a, 3) == a.count(3) 229 assert methodcaller("count") is f 230 assert M.count is f 231 assert pickle.loads(pickle.dumps(f)) is f 232 assert "count" in dir(M) 233 234 assert "count" in str(methodcaller("count")) 235 assert "count" in repr(methodcaller("count")) 236 237 238def test_skip_doctest(): 239 example = """>>> xxx 240>>> 241>>> # comment 242>>> xxx""" 243 244 res = skip_doctest(example) 245 assert ( 246 res 247 == """>>> xxx # doctest: +SKIP 248>>> 249>>> # comment 250>>> xxx # doctest: +SKIP""" 251 ) 252 253 assert skip_doctest(None) == "" 254 255 example = """ 256>>> 1 + 2 # doctest: +ELLIPSES 2573""" 258 259 expected = """ 260>>> 1 + 2 # doctest: +ELLIPSES, +SKIP 2613""" 262 res = skip_doctest(example) 263 assert res == expected 264 265 266def test_extra_titles(): 267 example = """ 268 269 Notes 270 ----- 271 hello 272 273 Foo 274 --- 275 276 Notes 277 ----- 278 bar 279 """ 280 281 expected = """ 282 283 Notes 284 ----- 285 hello 286 287 Foo 288 --- 289 290 Extra Notes 291 ----------- 292 bar 293 """ 294 295 assert extra_titles(example) == expected 296 297 298def test_asciitable(): 299 res = asciitable( 300 ["fruit", "color"], 301 [("apple", "red"), ("banana", "yellow"), ("tomato", "red"), ("pear", "green")], 302 ) 303 assert res == ( 304 "+--------+--------+\n" 305 "| fruit | color |\n" 306 "+--------+--------+\n" 307 "| apple | red |\n" 308 "| banana | yellow |\n" 309 "| tomato | red |\n" 310 "| pear | green |\n" 311 "+--------+--------+" 312 ) 313 314 315def test_SerializableLock(): 316 a = SerializableLock() 317 b = SerializableLock() 318 with a: 319 pass 320 321 with a: 322 with b: 323 pass 324 325 with a: 326 assert not a.acquire(False) 327 328 a2 = pickle.loads(pickle.dumps(a)) 329 a3 = pickle.loads(pickle.dumps(a)) 330 a4 = pickle.loads(pickle.dumps(a2)) 331 332 for x in [a, a2, a3, a4]: 333 for y in [a, a2, a3, a4]: 334 with x: 335 assert not y.acquire(False) 336 337 b2 = pickle.loads(pickle.dumps(b)) 338 b3 = pickle.loads(pickle.dumps(b2)) 339 340 for x in [a, a2, a3, a4]: 341 for y in [b, b2, b3]: 342 with x: 343 with y: 344 pass 345 with y: 346 with x: 347 pass 348 349 350def test_SerializableLock_name_collision(): 351 a = SerializableLock("a") 352 b = SerializableLock("b") 353 c = SerializableLock("a") 354 d = SerializableLock() 355 356 assert a.lock is not b.lock 357 assert a.lock is c.lock 358 assert d.lock not in (a.lock, b.lock, c.lock) 359 360 361def test_SerializableLock_locked(): 362 a = SerializableLock("a") 363 assert not a.locked() 364 with a: 365 assert a.locked() 366 assert not a.locked() 367 368 369def test_SerializableLock_acquire_blocking(): 370 a = SerializableLock("a") 371 assert a.acquire(blocking=True) 372 assert not a.acquire(blocking=False) 373 a.release() 374 375 376def test_funcname(): 377 def foo(a, b, c): 378 pass 379 380 assert funcname(foo) == "foo" 381 assert funcname(functools.partial(foo, a=1)) == "foo" 382 assert funcname(M.sum) == "sum" 383 assert funcname(lambda: 1) == "lambda" 384 385 class Foo: 386 pass 387 388 assert funcname(Foo) == "Foo" 389 assert "Foo" in funcname(Foo()) 390 391 392def test_funcname_long(): 393 def a_long_function_name_11111111111111111111111111111111111111111111111(): 394 pass 395 396 result = funcname( 397 a_long_function_name_11111111111111111111111111111111111111111111111 398 ) 399 assert "a_long_function_name" in result 400 assert len(result) < 60 401 402 403def test_funcname_toolz(): 404 @curry 405 def foo(a, b, c): 406 pass 407 408 assert funcname(foo) == "foo" 409 assert funcname(foo(1)) == "foo" 410 411 412def test_funcname_multipledispatch(): 413 md = pytest.importorskip("multipledispatch") 414 415 @md.dispatch(int, int, int) 416 def foo(a, b, c): 417 pass 418 419 assert funcname(foo) == "foo" 420 assert funcname(functools.partial(foo, a=1)) == "foo" 421 422 423def test_funcname_numpy_vectorize(): 424 np = pytest.importorskip("numpy") 425 426 vfunc = np.vectorize(int) 427 assert funcname(vfunc) == "vectorize_int" 428 429 # Regression test for https://github.com/pydata/xarray/issues/3303 430 # Partial functions don't have a __name__ attribute 431 func = functools.partial(np.add, out=None) 432 vfunc = np.vectorize(func) 433 assert funcname(vfunc) == "vectorize_add" 434 435 436def test_ndeepmap(): 437 L = 1 438 assert ndeepmap(0, inc, L) == 2 439 440 L = [1] 441 assert ndeepmap(0, inc, L) == 2 442 443 L = [1, 2, 3] 444 assert ndeepmap(1, inc, L) == [2, 3, 4] 445 446 L = [[1, 2], [3, 4]] 447 assert ndeepmap(2, inc, L) == [[2, 3], [4, 5]] 448 449 L = [[[1, 2], [3, 4, 5]], [[6], []]] 450 assert ndeepmap(3, inc, L) == [[[2, 3], [4, 5, 6]], [[7], []]] 451 452 453def test_ensure_dict(): 454 d = {"x": 1} 455 assert ensure_dict(d) is d 456 457 class mydict(dict): 458 pass 459 460 d2 = ensure_dict(d, copy=True) 461 d3 = ensure_dict(HighLevelGraph.from_collections("x", d)) 462 d4 = ensure_dict(mydict(d)) 463 464 for di in (d2, d3, d4): 465 assert type(di) is dict 466 assert di is not d 467 assert di == d 468 469 470def test_itemgetter(): 471 data = [1, 2, 3] 472 g = itemgetter(1) 473 assert g(data) == 2 474 g2 = pickle.loads(pickle.dumps(g)) 475 assert g2(data) == 2 476 assert g2.index == 1 477 478 assert itemgetter(1) == itemgetter(1) 479 assert itemgetter(1) != itemgetter(2) 480 assert itemgetter(1) != 123 481 482 483def test_partial_by_order(): 484 assert partial_by_order(5, function=operator.add, other=[(1, 20)]) == 25 485 486 487def test_has_keyword(): 488 def foo(a, b, c=None): 489 pass 490 491 assert has_keyword(foo, "a") 492 assert has_keyword(foo, "b") 493 assert has_keyword(foo, "c") 494 495 bar = functools.partial(foo, a=1) 496 assert has_keyword(bar, "b") 497 assert has_keyword(bar, "c") 498 499 500def test_derived_from(): 501 class Foo: 502 def f(a, b): 503 """A super docstring 504 505 An explanation 506 507 Parameters 508 ---------- 509 a: int 510 an explanation of a 511 b: float 512 an explanation of b 513 """ 514 515 class Bar: 516 @derived_from(Foo) 517 def f(a, c): 518 pass 519 520 class Zap: 521 @derived_from(Foo) 522 def f(a, c): 523 "extra docstring" 524 pass 525 526 assert Bar.f.__doc__.strip().startswith("A super docstring") 527 assert "Foo.f" in Bar.f.__doc__ 528 assert any("inconsistencies" in line for line in Bar.f.__doc__.split("\n")[:7]) 529 530 [b_arg] = [line for line in Bar.f.__doc__.split("\n") if "b:" in line] 531 assert "not supported" in b_arg.lower() 532 assert "dask" in b_arg.lower() 533 534 assert " extra docstring\n\n" in Zap.f.__doc__ 535 536 537def test_derived_from_func(): 538 import builtins 539 540 @derived_from(builtins) 541 def sum(): 542 "extra docstring" 543 pass 544 545 assert "extra docstring\n\n" in sum.__doc__ 546 assert "Return the sum of" in sum.__doc__ 547 assert "This docstring was copied from builtins.sum" in sum.__doc__ 548 549 550def test_derived_from_dask_dataframe(): 551 dd = pytest.importorskip("dask.dataframe") 552 553 assert "inconsistencies" in dd.DataFrame.dropna.__doc__ 554 555 [axis_arg] = [ 556 line for line in dd.DataFrame.dropna.__doc__.split("\n") if "axis :" in line 557 ] 558 assert "not supported" in axis_arg.lower() 559 assert "dask" in axis_arg.lower() 560 561 562def test_parse_bytes(): 563 assert parse_bytes("100") == 100 564 assert parse_bytes("100 MB") == 100000000 565 assert parse_bytes("100M") == 100000000 566 assert parse_bytes("5kB") == 5000 567 assert parse_bytes("5.4 kB") == 5400 568 assert parse_bytes("1kiB") == 1024 569 assert parse_bytes("1Mi") == 2 ** 20 570 assert parse_bytes("1e6") == 1000000 571 assert parse_bytes("1e6 kB") == 1000000000 572 assert parse_bytes("MB") == 1000000 573 assert parse_bytes(123) == 123 574 assert parse_bytes(".5GB") == 500000000 575 576 577def test_parse_timedelta(): 578 for text, value in [ 579 ("1s", 1), 580 ("100ms", 0.1), 581 ("5S", 5), 582 ("5.5s", 5.5), 583 ("5.5 s", 5.5), 584 ("1 second", 1), 585 ("3.3 seconds", 3.3), 586 ("3.3 milliseconds", 0.0033), 587 ("3500 us", 0.0035), 588 ("1 ns", 1e-9), 589 ("2m", 120), 590 ("2 minutes", 120), 591 (None, None), 592 (3, 3), 593 (datetime.timedelta(seconds=2), 2), 594 (datetime.timedelta(milliseconds=100), 0.1), 595 ]: 596 result = parse_timedelta(text) 597 assert result == value or abs(result - value) < 1e-14 598 599 assert parse_timedelta("1ms", default="seconds") == 0.001 600 assert parse_timedelta("1", default="seconds") == 1 601 assert parse_timedelta("1", default="ms") == 0.001 602 assert parse_timedelta(1, default="ms") == 0.001 603 604 605def test_is_arraylike(): 606 np = pytest.importorskip("numpy") 607 608 assert is_arraylike(0) is False 609 assert is_arraylike(()) is False 610 assert is_arraylike(0) is False 611 assert is_arraylike([]) is False 612 assert is_arraylike([0]) is False 613 614 assert is_arraylike(np.empty(())) is True 615 assert is_arraylike(np.empty((0,))) is True 616 assert is_arraylike(np.empty((0, 0))) is True 617 618 619def test_iter_chunks(): 620 sizes = [14, 8, 5, 9, 7, 9, 1, 19, 8, 19] 621 assert list(iter_chunks(sizes, 19)) == [ 622 [14], 623 [8, 5], 624 [9, 7], 625 [9, 1], 626 [19], 627 [8], 628 [19], 629 ] 630 assert list(iter_chunks(sizes, 28)) == [[14, 8, 5], [9, 7, 9, 1], [19, 8], [19]] 631 assert list(iter_chunks(sizes, 67)) == [[14, 8, 5, 9, 7, 9, 1], [19, 8, 19]] 632 633 634def test_stringify(): 635 obj = "Hello" 636 assert stringify(obj) is obj 637 obj = b"Hello" 638 assert stringify(obj) is obj 639 dsk = {"x": 1} 640 641 assert stringify(dsk) == str(dsk) 642 assert stringify(dsk, exclusive=()) == dsk 643 644 dsk = {("x", 1): (inc, 1)} 645 assert stringify(dsk) == str({("x", 1): (inc, 1)}) 646 assert stringify(dsk, exclusive=()) == {("x", 1): (inc, 1)} 647 648 dsk = {("x", 1): (inc, 1), ("x", 2): (inc, ("x", 1))} 649 assert stringify(dsk, exclusive=dsk) == { 650 ("x", 1): (inc, 1), 651 ("x", 2): (inc, str(("x", 1))), 652 } 653 654 dsks = [ 655 {"x": 1}, 656 {("x", 1): (inc, 1), ("x", 2): (inc, ("x", 1))}, 657 {("x", 1): (sum, [1, 2, 3]), ("x", 2): (sum, [("x", 1), ("x", 1)])}, 658 ] 659 for dsk in dsks: 660 sdsk = {stringify(k): stringify(v, exclusive=dsk) for k, v in dsk.items()} 661 keys = list(dsk) 662 skeys = [str(k) for k in keys] 663 assert all(isinstance(k, str) for k in sdsk) 664 assert get(dsk, keys) == get(sdsk, skeys) 665 666 dsk = {("y", 1): (SubgraphCallable({"x": ("y", 1)}, "x", (("y", 1),)), (("z", 1),))} 667 dsk = stringify(dsk, exclusive=set(dsk) | {("z", 1)}) 668 assert dsk[("y", 1)][0].dsk["x"] == "('y', 1)" 669 assert dsk[("y", 1)][1][0] == "('z', 1)" 670 671 672def test_stringify_collection_keys(): 673 obj = "Hello" 674 assert stringify_collection_keys(obj) is obj 675 676 obj = [("a", 0), (b"a", 0), (1, 1)] 677 res = stringify_collection_keys(obj) 678 assert res[0] == str(obj[0]) 679 assert res[1] == str(obj[1]) 680 assert res[2] == obj[2] 681 682 683@pytest.mark.parametrize( 684 "n,expect", 685 [ 686 (0, "0 B"), 687 (920, "920 B"), 688 (930, "0.91 kiB"), 689 (921.23 * 2 ** 10, "921.23 kiB"), 690 (931.23 * 2 ** 10, "0.91 MiB"), 691 (921.23 * 2 ** 20, "921.23 MiB"), 692 (931.23 * 2 ** 20, "0.91 GiB"), 693 (921.23 * 2 ** 30, "921.23 GiB"), 694 (931.23 * 2 ** 30, "0.91 TiB"), 695 (921.23 * 2 ** 40, "921.23 TiB"), 696 (931.23 * 2 ** 40, "0.91 PiB"), 697 (2 ** 60, "1024.00 PiB"), 698 ], 699) 700def test_format_bytes(n, expect): 701 assert format_bytes(int(n)) == expect 702 703 704def test_deprecated(): 705 @_deprecated() 706 def foo(): 707 return "bar" 708 709 with pytest.warns(FutureWarning) as record: 710 assert foo() == "bar" 711 712 assert len(record) == 1 713 msg = str(record[0].message) 714 assert "foo is deprecated" in msg 715 assert "removed in a future release" in msg 716 717 718def test_deprecated_version(): 719 @_deprecated(version="1.2.3") 720 def foo(): 721 return "bar" 722 723 with pytest.warns(FutureWarning, match="deprecated in version 1.2.3"): 724 assert foo() == "bar" 725 726 727def test_deprecated_after_version(): 728 @_deprecated(after_version="1.2.3") 729 def foo(): 730 return "bar" 731 732 with pytest.warns(FutureWarning, match="deprecated after version 1.2.3"): 733 assert foo() == "bar" 734 735 736def test_deprecated_category(): 737 @_deprecated(category=DeprecationWarning) 738 def foo(): 739 return "bar" 740 741 with pytest.warns(DeprecationWarning): 742 assert foo() == "bar" 743 744 745def test_deprecated_message(): 746 @_deprecated(message="woohoo") 747 def foo(): 748 return "bar" 749 750 with pytest.warns(FutureWarning) as record: 751 assert foo() == "bar" 752 753 assert len(record) == 1 754 assert str(record[0].message) == "woohoo" 755 756 757def test_ignoring_deprecated(): 758 with pytest.warns(FutureWarning, match="contextlib.suppress"): 759 with ignoring(ValueError): 760 pass 761 762 763def test_noop_context_deprecated(): 764 with pytest.warns(FutureWarning, match="contextlib.nullcontext"): 765 with noop_context(): 766 pass 767 768 769def test_typename(): 770 assert typename(HighLevelGraph) == "dask.highlevelgraph.HighLevelGraph" 771 assert typename(HighLevelGraph, short=True) == "dask.HighLevelGraph" 772 773 774class MyType: 775 pass 776 777 778def test_typename_on_instances(): 779 instance = MyType() 780 assert typename(instance) == typename(MyType) 781