1# Copyright 2018 John Reese 2# Licensed under the MIT license 3 4import asyncio 5import operator 6from unittest import TestCase 7 8import aioitertools as ait 9from .helpers import async_test 10 11slist = ["A", "B", "C"] 12srange = range(1, 4) 13 14 15class ItertoolsTest(TestCase): 16 @async_test 17 async def test_accumulate_range_default(self): 18 it = ait.accumulate(srange) 19 for k in [1, 3, 6]: 20 self.assertEqual(await ait.next(it), k) 21 with self.assertRaises(StopAsyncIteration): 22 await ait.next(it) 23 24 @async_test 25 async def test_accumulate_range_function(self): 26 it = ait.accumulate(srange, func=operator.mul) 27 for k in [1, 2, 6]: 28 self.assertEqual(await ait.next(it), k) 29 with self.assertRaises(StopAsyncIteration): 30 await ait.next(it) 31 32 @async_test 33 async def test_accumulate_range_coroutine(self): 34 async def mul(a, b): 35 return a * b 36 37 it = ait.accumulate(srange, func=mul) 38 for k in [1, 2, 6]: 39 self.assertEqual(await ait.next(it), k) 40 with self.assertRaises(StopAsyncIteration): 41 await ait.next(it) 42 43 @async_test 44 async def test_accumulate_gen_function(self): 45 async def gen(): 46 yield 1 47 yield 2 48 yield 4 49 50 it = ait.accumulate(gen(), func=operator.mul) 51 for k in [1, 2, 8]: 52 self.assertEqual(await ait.next(it), k) 53 with self.assertRaises(StopAsyncIteration): 54 await ait.next(it) 55 56 @async_test 57 async def test_accumulate_gen_coroutine(self): 58 async def mul(a, b): 59 return a * b 60 61 async def gen(): 62 yield 1 63 yield 2 64 yield 4 65 66 it = ait.accumulate(gen(), func=mul) 67 for k in [1, 2, 8]: 68 self.assertEqual(await ait.next(it), k) 69 with self.assertRaises(StopAsyncIteration): 70 await ait.next(it) 71 72 @async_test 73 async def test_accumulate_empty(self): 74 values = [] 75 async for value in ait.accumulate([]): 76 values.append(value) 77 78 self.assertEqual(values, []) 79 80 @async_test 81 async def test_chain_lists(self): 82 it = ait.chain(slist, srange) 83 for k in ["A", "B", "C", 1, 2, 3]: 84 self.assertEqual(await ait.next(it), k) 85 with self.assertRaises(StopAsyncIteration): 86 await ait.next(it) 87 88 @async_test 89 async def test_chain_list_gens(self): 90 async def gen(): 91 for k in range(2, 9, 2): 92 yield k 93 94 it = ait.chain(slist, gen()) 95 for k in ["A", "B", "C", 2, 4, 6, 8]: 96 self.assertEqual(await ait.next(it), k) 97 with self.assertRaises(StopAsyncIteration): 98 await ait.next(it) 99 100 @async_test 101 async def test_chain_from_iterable(self): 102 async def gen(): 103 for k in range(2, 9, 2): 104 yield k 105 106 it = ait.chain.from_iterable([slist, gen()]) 107 for k in ["A", "B", "C", 2, 4, 6, 8]: 108 self.assertEqual(await ait.next(it), k) 109 with self.assertRaises(StopAsyncIteration): 110 await ait.next(it) 111 112 @async_test 113 async def test_chain_from_iterable_parameter_expansion_gen(self): 114 async def gen(): 115 for k in range(2, 9, 2): 116 yield k 117 118 async def parameters_gen(): 119 yield slist 120 yield gen() 121 122 it = ait.chain.from_iterable(parameters_gen()) 123 for k in ["A", "B", "C", 2, 4, 6, 8]: 124 self.assertEqual(await ait.next(it), k) 125 with self.assertRaises(StopAsyncIteration): 126 await ait.next(it) 127 128 @async_test 129 async def test_combinations(self): 130 it = ait.combinations(range(4), 3) 131 for k in [(0, 1, 2), (0, 1, 3), (0, 2, 3), (1, 2, 3)]: 132 self.assertEqual(await ait.next(it), k) 133 with self.assertRaises(StopAsyncIteration): 134 await ait.next(it) 135 136 @async_test 137 async def test_combinations_with_replacement(self): 138 it = ait.combinations_with_replacement(slist, 2) 139 for k in [ 140 ("A", "A"), 141 ("A", "B"), 142 ("A", "C"), 143 ("B", "B"), 144 ("B", "C"), 145 ("C", "C"), 146 ]: 147 self.assertEqual(await ait.next(it), k) 148 with self.assertRaises(StopAsyncIteration): 149 await ait.next(it) 150 151 @async_test 152 async def test_compress_list(self): 153 data = range(10) 154 selectors = [0, 1, 1, 0, 0, 0, 1, 0, 1, 0] 155 156 it = ait.compress(data, selectors) 157 for k in [1, 2, 6, 8]: 158 self.assertEqual(await ait.next(it), k) 159 with self.assertRaises(StopAsyncIteration): 160 await ait.next(it) 161 162 @async_test 163 async def test_compress_gen(self): 164 data = "abcdefghijkl" 165 selectors = ait.cycle([1, 0, 0]) 166 167 it = ait.compress(data, selectors) 168 for k in ["a", "d", "g", "j"]: 169 self.assertEqual(await ait.next(it), k) 170 with self.assertRaises(StopAsyncIteration): 171 await ait.next(it) 172 173 @async_test 174 async def test_count_bare(self): 175 it = ait.count() 176 for k in [0, 1, 2, 3]: 177 self.assertEqual(await ait.next(it), k) 178 179 @async_test 180 async def test_count_start(self): 181 it = ait.count(42) 182 for k in [42, 43, 44, 45]: 183 self.assertEqual(await ait.next(it), k) 184 185 @async_test 186 async def test_count_start_step(self): 187 it = ait.count(42, 3) 188 for k in [42, 45, 48, 51]: 189 self.assertEqual(await ait.next(it), k) 190 191 @async_test 192 async def test_count_negative(self): 193 it = ait.count(step=-2) 194 for k in [0, -2, -4, -6]: 195 self.assertEqual(await ait.next(it), k) 196 197 @async_test 198 async def test_cycle_list(self): 199 it = ait.cycle(slist) 200 for k in ["A", "B", "C", "A", "B", "C", "A", "B"]: 201 self.assertEqual(await ait.next(it), k) 202 203 @async_test 204 async def test_cycle_gen(self): 205 async def gen(): 206 yield 1 207 yield 2 208 yield 42 209 210 it = ait.cycle(gen()) 211 for k in [1, 2, 42, 1, 2, 42, 1, 2]: 212 self.assertEqual(await ait.next(it), k) 213 214 @async_test 215 async def test_dropwhile_empty(self): 216 def pred(x): 217 return x < 2 218 219 result = await ait.list(ait.dropwhile(pred, [])) 220 self.assertEqual(result, []) 221 222 @async_test 223 async def test_dropwhile_function_list(self): 224 def pred(x): 225 return x < 2 226 227 it = ait.dropwhile(pred, srange) 228 for k in [2, 3]: 229 self.assertEqual(await ait.next(it), k) 230 with self.assertRaises(StopAsyncIteration): 231 await ait.next(it) 232 233 @async_test 234 async def test_dropwhile_function_gen(self): 235 def pred(x): 236 return x < 2 237 238 async def gen(): 239 yield 1 240 yield 2 241 yield 42 242 243 it = ait.dropwhile(pred, gen()) 244 for k in [2, 42]: 245 self.assertEqual(await ait.next(it), k) 246 with self.assertRaises(StopAsyncIteration): 247 await ait.next(it) 248 249 @async_test 250 async def test_dropwhile_coroutine_list(self): 251 async def pred(x): 252 return x < 2 253 254 it = ait.dropwhile(pred, srange) 255 for k in [2, 3]: 256 self.assertEqual(await ait.next(it), k) 257 with self.assertRaises(StopAsyncIteration): 258 await ait.next(it) 259 260 @async_test 261 async def test_dropwhile_coroutine_gen(self): 262 async def pred(x): 263 return x < 2 264 265 async def gen(): 266 yield 1 267 yield 2 268 yield 42 269 270 it = ait.dropwhile(pred, gen()) 271 for k in [2, 42]: 272 self.assertEqual(await ait.next(it), k) 273 with self.assertRaises(StopAsyncIteration): 274 await ait.next(it) 275 276 @async_test 277 async def test_filterfalse_function_list(self): 278 def pred(x): 279 return x % 2 == 0 280 281 it = ait.filterfalse(pred, srange) 282 for k in [1, 3]: 283 self.assertEqual(await ait.next(it), k) 284 with self.assertRaises(StopAsyncIteration): 285 await ait.next(it) 286 287 @async_test 288 async def test_filterfalse_coroutine_list(self): 289 async def pred(x): 290 return x % 2 == 0 291 292 it = ait.filterfalse(pred, srange) 293 for k in [1, 3]: 294 self.assertEqual(await ait.next(it), k) 295 with self.assertRaises(StopAsyncIteration): 296 await ait.next(it) 297 298 @async_test 299 async def test_groupby_list(self): 300 data = "aaabba" 301 302 it = ait.groupby(data) 303 for k in [("a", ["a", "a", "a"]), ("b", ["b", "b"]), ("a", ["a"])]: 304 self.assertEqual(await ait.next(it), k) 305 with self.assertRaises(StopAsyncIteration): 306 await ait.next(it) 307 308 @async_test 309 async def test_groupby_list_key(self): 310 data = "aAabBA" 311 312 it = ait.groupby(data, key=str.lower) 313 for k in [("a", ["a", "A", "a"]), ("b", ["b", "B"]), ("a", ["A"])]: 314 self.assertEqual(await ait.next(it), k) 315 with self.assertRaises(StopAsyncIteration): 316 await ait.next(it) 317 318 @async_test 319 async def test_groupby_gen(self): 320 async def gen(): 321 for c in "aaabba": 322 yield c 323 324 it = ait.groupby(gen()) 325 for k in [("a", ["a", "a", "a"]), ("b", ["b", "b"]), ("a", ["a"])]: 326 self.assertEqual(await ait.next(it), k) 327 with self.assertRaises(StopAsyncIteration): 328 await ait.next(it) 329 330 @async_test 331 async def test_groupby_gen_key(self): 332 async def gen(): 333 for c in "aAabBA": 334 yield c 335 336 it = ait.groupby(gen(), key=str.lower) 337 for k in [("a", ["a", "A", "a"]), ("b", ["b", "B"]), ("a", ["A"])]: 338 self.assertEqual(await ait.next(it), k) 339 with self.assertRaises(StopAsyncIteration): 340 await ait.next(it) 341 342 @async_test 343 async def test_groupby_empty(self): 344 async def gen(): 345 for _ in range(0): 346 yield # Force generator with no actual iteration 347 348 async for _ in ait.groupby(gen()): 349 self.fail("No iteration should have happened") 350 351 @async_test 352 async def test_islice_bad_range(self): 353 with self.assertRaisesRegex(ValueError, "must pass stop index"): 354 async for _ in ait.islice([1, 2]): 355 pass 356 357 with self.assertRaisesRegex(ValueError, "too many arguments"): 358 async for _ in ait.islice([1, 2], 1, 2, 3, 4): 359 pass 360 361 @async_test 362 async def test_islice_stop_zero(self): 363 values = [] 364 async for value in ait.islice(range(5), 0): 365 values.append(value) 366 self.assertEqual(values, []) 367 368 @async_test 369 async def test_islice_range_stop(self): 370 it = ait.islice(srange, 2) 371 for k in [1, 2]: 372 self.assertEqual(await ait.next(it), k) 373 with self.assertRaises(StopAsyncIteration): 374 await ait.next(it) 375 376 @async_test 377 async def test_islice_range_start_step(self): 378 it = ait.islice(srange, 0, None, 2) 379 for k in [1, 3]: 380 self.assertEqual(await ait.next(it), k) 381 with self.assertRaises(StopAsyncIteration): 382 await ait.next(it) 383 384 @async_test 385 async def test_islice_range_start_stop(self): 386 it = ait.islice(srange, 1, 3) 387 for k in [2, 3]: 388 self.assertEqual(await ait.next(it), k) 389 with self.assertRaises(StopAsyncIteration): 390 await ait.next(it) 391 392 @async_test 393 async def test_islice_range_start_stop_step(self): 394 it = ait.islice(srange, 1, 3, 2) 395 for k in [2]: 396 self.assertEqual(await ait.next(it), k) 397 with self.assertRaises(StopAsyncIteration): 398 await ait.next(it) 399 400 @async_test 401 async def test_islice_gen_stop(self): 402 async def gen(): 403 yield 1 404 yield 2 405 yield 3 406 yield 4 407 408 gen_it = gen() 409 it = ait.islice(gen_it, 2) 410 for k in [1, 2]: 411 self.assertEqual(await ait.next(it), k) 412 with self.assertRaises(StopAsyncIteration): 413 await ait.next(it) 414 assert await ait.list(gen_it) == [3, 4] 415 416 @async_test 417 async def test_islice_gen_start_step(self): 418 async def gen(): 419 yield 1 420 yield 2 421 yield 3 422 yield 4 423 424 it = ait.islice(gen(), 1, None, 2) 425 for k in [2, 4]: 426 self.assertEqual(await ait.next(it), k) 427 with self.assertRaises(StopAsyncIteration): 428 await ait.next(it) 429 430 @async_test 431 async def test_islice_gen_start_stop(self): 432 async def gen(): 433 yield 1 434 yield 2 435 yield 3 436 yield 4 437 438 it = ait.islice(gen(), 1, 3) 439 for k in [2, 3]: 440 self.assertEqual(await ait.next(it), k) 441 with self.assertRaises(StopAsyncIteration): 442 await ait.next(it) 443 444 @async_test 445 async def test_islice_gen_start_stop_step(self): 446 async def gen(): 447 yield 1 448 yield 2 449 yield 3 450 yield 4 451 452 gen_it = gen() 453 it = ait.islice(gen_it, 1, 3, 2) 454 for k in [2]: 455 self.assertEqual(await ait.next(it), k) 456 with self.assertRaises(StopAsyncIteration): 457 await ait.next(it) 458 assert await ait.list(gen_it) == [4] 459 460 @async_test 461 async def test_permutations_list(self): 462 it = ait.permutations(srange, r=2) 463 for k in [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1), (3, 2)]: 464 self.assertEqual(await ait.next(it), k) 465 with self.assertRaises(StopAsyncIteration): 466 await ait.next(it) 467 468 @async_test 469 async def test_permutations_gen(self): 470 async def gen(): 471 yield 1 472 yield 2 473 yield 3 474 475 it = ait.permutations(gen(), r=2) 476 for k in [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1), (3, 2)]: 477 self.assertEqual(await ait.next(it), k) 478 with self.assertRaises(StopAsyncIteration): 479 await ait.next(it) 480 481 @async_test 482 async def test_product_list(self): 483 it = ait.product([1, 2], [6, 7]) 484 for k in [(1, 6), (1, 7), (2, 6), (2, 7)]: 485 self.assertEqual(await ait.next(it), k) 486 with self.assertRaises(StopAsyncIteration): 487 await ait.next(it) 488 489 @async_test 490 async def test_product_gen(self): 491 async def gen(x): 492 yield x 493 yield x + 1 494 495 it = ait.product(gen(1), gen(6)) 496 for k in [(1, 6), (1, 7), (2, 6), (2, 7)]: 497 self.assertEqual(await ait.next(it), k) 498 with self.assertRaises(StopAsyncIteration): 499 await ait.next(it) 500 501 @async_test 502 async def test_repeat(self): 503 it = ait.repeat(42) 504 for k in [42] * 10: 505 self.assertEqual(await ait.next(it), k) 506 507 @async_test 508 async def test_repeat_limit(self): 509 it = ait.repeat(42, 5) 510 for k in [42] * 5: 511 self.assertEqual(await ait.next(it), k) 512 with self.assertRaises(StopAsyncIteration): 513 await ait.next(it) 514 515 @async_test 516 async def test_starmap_function_list(self): 517 data = [slist[:2], slist[1:], slist] 518 519 def concat(*args): 520 return "".join(args) 521 522 it = ait.starmap(concat, data) 523 for k in ["AB", "BC", "ABC"]: 524 self.assertEqual(await ait.next(it), k) 525 with self.assertRaises(StopAsyncIteration): 526 await ait.next(it) 527 528 @async_test 529 async def test_starmap_function_gen(self): 530 def gen(): 531 yield slist[:2] 532 yield slist[1:] 533 yield slist 534 535 def concat(*args): 536 return "".join(args) 537 538 it = ait.starmap(concat, gen()) 539 for k in ["AB", "BC", "ABC"]: 540 self.assertEqual(await ait.next(it), k) 541 with self.assertRaises(StopAsyncIteration): 542 await ait.next(it) 543 544 @async_test 545 async def test_starmap_coroutine_list(self): 546 data = [slist[:2], slist[1:], slist] 547 548 async def concat(*args): 549 return "".join(args) 550 551 it = ait.starmap(concat, data) 552 for k in ["AB", "BC", "ABC"]: 553 self.assertEqual(await ait.next(it), k) 554 with self.assertRaises(StopAsyncIteration): 555 await ait.next(it) 556 557 @async_test 558 async def test_starmap_coroutine_gen(self): 559 async def gen(): 560 yield slist[:2] 561 yield slist[1:] 562 yield slist 563 564 async def concat(*args): 565 return "".join(args) 566 567 it = ait.starmap(concat, gen()) 568 for k in ["AB", "BC", "ABC"]: 569 self.assertEqual(await ait.next(it), k) 570 with self.assertRaises(StopAsyncIteration): 571 await ait.next(it) 572 573 @async_test 574 async def test_takewhile_empty(self): 575 def pred(x): 576 return x < 3 577 578 values = await ait.list(ait.takewhile(pred, [])) 579 self.assertEqual(values, []) 580 581 @async_test 582 async def test_takewhile_function_list(self): 583 def pred(x): 584 return x < 3 585 586 it = ait.takewhile(pred, srange) 587 for k in [1, 2]: 588 self.assertEqual(await ait.next(it), k) 589 with self.assertRaises(StopAsyncIteration): 590 await ait.next(it) 591 592 @async_test 593 async def test_takewhile_function_gen(self): 594 async def gen(): 595 yield 1 596 yield 2 597 yield 3 598 599 def pred(x): 600 return x < 3 601 602 it = ait.takewhile(pred, gen()) 603 for k in [1, 2]: 604 self.assertEqual(await ait.next(it), k) 605 with self.assertRaises(StopAsyncIteration): 606 await ait.next(it) 607 608 @async_test 609 async def test_takewhile_coroutine_list(self): 610 async def pred(x): 611 return x < 3 612 613 it = ait.takewhile(pred, srange) 614 for k in [1, 2]: 615 self.assertEqual(await ait.next(it), k) 616 with self.assertRaises(StopAsyncIteration): 617 await ait.next(it) 618 619 @async_test 620 async def test_takewhile_coroutine_gen(self): 621 def gen(): 622 yield 1 623 yield 2 624 yield 3 625 626 async def pred(x): 627 return x < 3 628 629 it = ait.takewhile(pred, gen()) 630 for k in [1, 2]: 631 self.assertEqual(await ait.next(it), k) 632 with self.assertRaises(StopAsyncIteration): 633 await ait.next(it) 634 635 @async_test 636 async def test_tee_list_two(self): 637 it1, it2 = ait.tee(slist * 2) 638 639 for k in slist * 2: 640 a, b = await asyncio.gather(ait.next(it1), ait.next(it2)) 641 self.assertEqual(a, b) 642 self.assertEqual(a, k) 643 self.assertEqual(b, k) 644 for it in [it1, it2]: 645 with self.assertRaises(StopAsyncIteration): 646 await ait.next(it) 647 648 @async_test 649 async def test_tee_list_six(self): 650 itrs = ait.tee(slist * 2, n=6) 651 652 for k in slist * 2: 653 values = await asyncio.gather(*[ait.next(it) for it in itrs]) 654 for value in values: 655 self.assertEqual(value, k) 656 for it in itrs: 657 with self.assertRaises(StopAsyncIteration): 658 await ait.next(it) 659 660 @async_test 661 async def test_tee_gen_two(self): 662 async def gen(): 663 yield 1 664 yield 4 665 yield 9 666 yield 16 667 668 it1, it2 = ait.tee(gen()) 669 670 for k in [1, 4, 9, 16]: 671 a, b = await asyncio.gather(ait.next(it1), ait.next(it2)) 672 self.assertEqual(a, b) 673 self.assertEqual(a, k) 674 self.assertEqual(b, k) 675 for it in [it1, it2]: 676 with self.assertRaises(StopAsyncIteration): 677 await ait.next(it) 678 679 @async_test 680 async def test_tee_gen_six(self): 681 async def gen(): 682 yield 1 683 yield 4 684 yield 9 685 yield 16 686 687 itrs = ait.tee(gen(), n=6) 688 689 for k in [1, 4, 9, 16]: 690 values = await asyncio.gather(*[ait.next(it) for it in itrs]) 691 for value in values: 692 self.assertEqual(value, k) 693 for it in itrs: 694 with self.assertRaises(StopAsyncIteration): 695 await ait.next(it) 696 697 @async_test 698 async def test_tee_propagate_exception(self): 699 class MyError(Exception): 700 pass 701 702 async def gen(): 703 yield 1 704 yield 2 705 raise MyError 706 707 async def consumer(it): 708 result = 0 709 async for item in it: 710 result += item 711 return result 712 713 it1, it2 = ait.tee(gen()) 714 715 values = await asyncio.gather( 716 consumer(it1), 717 consumer(it2), 718 return_exceptions=True, 719 ) 720 721 for value in values: 722 self.assertIsInstance(value, MyError) 723 724 @async_test 725 async def test_zip_longest_range(self): 726 a = range(3) 727 b = range(5) 728 729 it = ait.zip_longest(a, b) 730 731 for k in [(0, 0), (1, 1), (2, 2), (None, 3), (None, 4)]: 732 self.assertEqual(await ait.next(it), k) 733 with self.assertRaises(StopAsyncIteration): 734 await ait.next(it) 735 736 @async_test 737 async def test_zip_longest_fillvalue(self): 738 async def gen(): 739 yield 1 740 yield 4 741 yield 9 742 yield 16 743 744 a = gen() 745 b = range(5) 746 747 it = ait.zip_longest(a, b, fillvalue=42) 748 749 for k in [(1, 0), (4, 1), (9, 2), (16, 3), (42, 4)]: 750 self.assertEqual(await ait.next(it), k) 751 with self.assertRaises(StopAsyncIteration): 752 await ait.next(it) 753 754 @async_test 755 async def test_zip_longest_exception(self): 756 async def gen(): 757 yield 1 758 yield 2 759 raise Exception("fake error") 760 761 a = gen() 762 b = ait.repeat(5) 763 764 it = ait.zip_longest(a, b) 765 766 for k in [(1, 5), (2, 5)]: 767 self.assertEqual(await ait.next(it), k) 768 with self.assertRaisesRegex(Exception, "fake error"): 769 await ait.next(it) 770