1# Copyright 2018 John Reese
2# Licensed under the MIT license
3
4import asyncio
5import operator
6from unittest import TestCase
7
8import aioitertools as ait
9from .helpers import async_test
10
11slist = ["A", "B", "C"]
12srange = range(1, 4)
13
14
15class ItertoolsTest(TestCase):
16    @async_test
17    async def test_accumulate_range_default(self):
18        it = ait.accumulate(srange)
19        for k in [1, 3, 6]:
20            self.assertEqual(await ait.next(it), k)
21        with self.assertRaises(StopAsyncIteration):
22            await ait.next(it)
23
24    @async_test
25    async def test_accumulate_range_function(self):
26        it = ait.accumulate(srange, func=operator.mul)
27        for k in [1, 2, 6]:
28            self.assertEqual(await ait.next(it), k)
29        with self.assertRaises(StopAsyncIteration):
30            await ait.next(it)
31
32    @async_test
33    async def test_accumulate_range_coroutine(self):
34        async def mul(a, b):
35            return a * b
36
37        it = ait.accumulate(srange, func=mul)
38        for k in [1, 2, 6]:
39            self.assertEqual(await ait.next(it), k)
40        with self.assertRaises(StopAsyncIteration):
41            await ait.next(it)
42
43    @async_test
44    async def test_accumulate_gen_function(self):
45        async def gen():
46            yield 1
47            yield 2
48            yield 4
49
50        it = ait.accumulate(gen(), func=operator.mul)
51        for k in [1, 2, 8]:
52            self.assertEqual(await ait.next(it), k)
53        with self.assertRaises(StopAsyncIteration):
54            await ait.next(it)
55
56    @async_test
57    async def test_accumulate_gen_coroutine(self):
58        async def mul(a, b):
59            return a * b
60
61        async def gen():
62            yield 1
63            yield 2
64            yield 4
65
66        it = ait.accumulate(gen(), func=mul)
67        for k in [1, 2, 8]:
68            self.assertEqual(await ait.next(it), k)
69        with self.assertRaises(StopAsyncIteration):
70            await ait.next(it)
71
72    @async_test
73    async def test_accumulate_empty(self):
74        values = []
75        async for value in ait.accumulate([]):
76            values.append(value)
77
78        self.assertEqual(values, [])
79
80    @async_test
81    async def test_chain_lists(self):
82        it = ait.chain(slist, srange)
83        for k in ["A", "B", "C", 1, 2, 3]:
84            self.assertEqual(await ait.next(it), k)
85        with self.assertRaises(StopAsyncIteration):
86            await ait.next(it)
87
88    @async_test
89    async def test_chain_list_gens(self):
90        async def gen():
91            for k in range(2, 9, 2):
92                yield k
93
94        it = ait.chain(slist, gen())
95        for k in ["A", "B", "C", 2, 4, 6, 8]:
96            self.assertEqual(await ait.next(it), k)
97        with self.assertRaises(StopAsyncIteration):
98            await ait.next(it)
99
100    @async_test
101    async def test_chain_from_iterable(self):
102        async def gen():
103            for k in range(2, 9, 2):
104                yield k
105
106        it = ait.chain.from_iterable([slist, gen()])
107        for k in ["A", "B", "C", 2, 4, 6, 8]:
108            self.assertEqual(await ait.next(it), k)
109        with self.assertRaises(StopAsyncIteration):
110            await ait.next(it)
111
112    @async_test
113    async def test_chain_from_iterable_parameter_expansion_gen(self):
114        async def gen():
115            for k in range(2, 9, 2):
116                yield k
117
118        async def parameters_gen():
119            yield slist
120            yield gen()
121
122        it = ait.chain.from_iterable(parameters_gen())
123        for k in ["A", "B", "C", 2, 4, 6, 8]:
124            self.assertEqual(await ait.next(it), k)
125        with self.assertRaises(StopAsyncIteration):
126            await ait.next(it)
127
128    @async_test
129    async def test_combinations(self):
130        it = ait.combinations(range(4), 3)
131        for k in [(0, 1, 2), (0, 1, 3), (0, 2, 3), (1, 2, 3)]:
132            self.assertEqual(await ait.next(it), k)
133        with self.assertRaises(StopAsyncIteration):
134            await ait.next(it)
135
136    @async_test
137    async def test_combinations_with_replacement(self):
138        it = ait.combinations_with_replacement(slist, 2)
139        for k in [
140            ("A", "A"),
141            ("A", "B"),
142            ("A", "C"),
143            ("B", "B"),
144            ("B", "C"),
145            ("C", "C"),
146        ]:
147            self.assertEqual(await ait.next(it), k)
148        with self.assertRaises(StopAsyncIteration):
149            await ait.next(it)
150
151    @async_test
152    async def test_compress_list(self):
153        data = range(10)
154        selectors = [0, 1, 1, 0, 0, 0, 1, 0, 1, 0]
155
156        it = ait.compress(data, selectors)
157        for k in [1, 2, 6, 8]:
158            self.assertEqual(await ait.next(it), k)
159        with self.assertRaises(StopAsyncIteration):
160            await ait.next(it)
161
162    @async_test
163    async def test_compress_gen(self):
164        data = "abcdefghijkl"
165        selectors = ait.cycle([1, 0, 0])
166
167        it = ait.compress(data, selectors)
168        for k in ["a", "d", "g", "j"]:
169            self.assertEqual(await ait.next(it), k)
170        with self.assertRaises(StopAsyncIteration):
171            await ait.next(it)
172
173    @async_test
174    async def test_count_bare(self):
175        it = ait.count()
176        for k in [0, 1, 2, 3]:
177            self.assertEqual(await ait.next(it), k)
178
179    @async_test
180    async def test_count_start(self):
181        it = ait.count(42)
182        for k in [42, 43, 44, 45]:
183            self.assertEqual(await ait.next(it), k)
184
185    @async_test
186    async def test_count_start_step(self):
187        it = ait.count(42, 3)
188        for k in [42, 45, 48, 51]:
189            self.assertEqual(await ait.next(it), k)
190
191    @async_test
192    async def test_count_negative(self):
193        it = ait.count(step=-2)
194        for k in [0, -2, -4, -6]:
195            self.assertEqual(await ait.next(it), k)
196
197    @async_test
198    async def test_cycle_list(self):
199        it = ait.cycle(slist)
200        for k in ["A", "B", "C", "A", "B", "C", "A", "B"]:
201            self.assertEqual(await ait.next(it), k)
202
203    @async_test
204    async def test_cycle_gen(self):
205        async def gen():
206            yield 1
207            yield 2
208            yield 42
209
210        it = ait.cycle(gen())
211        for k in [1, 2, 42, 1, 2, 42, 1, 2]:
212            self.assertEqual(await ait.next(it), k)
213
214    @async_test
215    async def test_dropwhile_empty(self):
216        def pred(x):
217            return x < 2
218
219        result = await ait.list(ait.dropwhile(pred, []))
220        self.assertEqual(result, [])
221
222    @async_test
223    async def test_dropwhile_function_list(self):
224        def pred(x):
225            return x < 2
226
227        it = ait.dropwhile(pred, srange)
228        for k in [2, 3]:
229            self.assertEqual(await ait.next(it), k)
230        with self.assertRaises(StopAsyncIteration):
231            await ait.next(it)
232
233    @async_test
234    async def test_dropwhile_function_gen(self):
235        def pred(x):
236            return x < 2
237
238        async def gen():
239            yield 1
240            yield 2
241            yield 42
242
243        it = ait.dropwhile(pred, gen())
244        for k in [2, 42]:
245            self.assertEqual(await ait.next(it), k)
246        with self.assertRaises(StopAsyncIteration):
247            await ait.next(it)
248
249    @async_test
250    async def test_dropwhile_coroutine_list(self):
251        async def pred(x):
252            return x < 2
253
254        it = ait.dropwhile(pred, srange)
255        for k in [2, 3]:
256            self.assertEqual(await ait.next(it), k)
257        with self.assertRaises(StopAsyncIteration):
258            await ait.next(it)
259
260    @async_test
261    async def test_dropwhile_coroutine_gen(self):
262        async def pred(x):
263            return x < 2
264
265        async def gen():
266            yield 1
267            yield 2
268            yield 42
269
270        it = ait.dropwhile(pred, gen())
271        for k in [2, 42]:
272            self.assertEqual(await ait.next(it), k)
273        with self.assertRaises(StopAsyncIteration):
274            await ait.next(it)
275
276    @async_test
277    async def test_filterfalse_function_list(self):
278        def pred(x):
279            return x % 2 == 0
280
281        it = ait.filterfalse(pred, srange)
282        for k in [1, 3]:
283            self.assertEqual(await ait.next(it), k)
284        with self.assertRaises(StopAsyncIteration):
285            await ait.next(it)
286
287    @async_test
288    async def test_filterfalse_coroutine_list(self):
289        async def pred(x):
290            return x % 2 == 0
291
292        it = ait.filterfalse(pred, srange)
293        for k in [1, 3]:
294            self.assertEqual(await ait.next(it), k)
295        with self.assertRaises(StopAsyncIteration):
296            await ait.next(it)
297
298    @async_test
299    async def test_groupby_list(self):
300        data = "aaabba"
301
302        it = ait.groupby(data)
303        for k in [("a", ["a", "a", "a"]), ("b", ["b", "b"]), ("a", ["a"])]:
304            self.assertEqual(await ait.next(it), k)
305        with self.assertRaises(StopAsyncIteration):
306            await ait.next(it)
307
308    @async_test
309    async def test_groupby_list_key(self):
310        data = "aAabBA"
311
312        it = ait.groupby(data, key=str.lower)
313        for k in [("a", ["a", "A", "a"]), ("b", ["b", "B"]), ("a", ["A"])]:
314            self.assertEqual(await ait.next(it), k)
315        with self.assertRaises(StopAsyncIteration):
316            await ait.next(it)
317
318    @async_test
319    async def test_groupby_gen(self):
320        async def gen():
321            for c in "aaabba":
322                yield c
323
324        it = ait.groupby(gen())
325        for k in [("a", ["a", "a", "a"]), ("b", ["b", "b"]), ("a", ["a"])]:
326            self.assertEqual(await ait.next(it), k)
327        with self.assertRaises(StopAsyncIteration):
328            await ait.next(it)
329
330    @async_test
331    async def test_groupby_gen_key(self):
332        async def gen():
333            for c in "aAabBA":
334                yield c
335
336        it = ait.groupby(gen(), key=str.lower)
337        for k in [("a", ["a", "A", "a"]), ("b", ["b", "B"]), ("a", ["A"])]:
338            self.assertEqual(await ait.next(it), k)
339        with self.assertRaises(StopAsyncIteration):
340            await ait.next(it)
341
342    @async_test
343    async def test_groupby_empty(self):
344        async def gen():
345            for _ in range(0):
346                yield  # Force generator with no actual iteration
347
348        async for _ in ait.groupby(gen()):
349            self.fail("No iteration should have happened")
350
351    @async_test
352    async def test_islice_bad_range(self):
353        with self.assertRaisesRegex(ValueError, "must pass stop index"):
354            async for _ in ait.islice([1, 2]):
355                pass
356
357        with self.assertRaisesRegex(ValueError, "too many arguments"):
358            async for _ in ait.islice([1, 2], 1, 2, 3, 4):
359                pass
360
361    @async_test
362    async def test_islice_stop_zero(self):
363        values = []
364        async for value in ait.islice(range(5), 0):
365            values.append(value)
366        self.assertEqual(values, [])
367
368    @async_test
369    async def test_islice_range_stop(self):
370        it = ait.islice(srange, 2)
371        for k in [1, 2]:
372            self.assertEqual(await ait.next(it), k)
373        with self.assertRaises(StopAsyncIteration):
374            await ait.next(it)
375
376    @async_test
377    async def test_islice_range_start_step(self):
378        it = ait.islice(srange, 0, None, 2)
379        for k in [1, 3]:
380            self.assertEqual(await ait.next(it), k)
381        with self.assertRaises(StopAsyncIteration):
382            await ait.next(it)
383
384    @async_test
385    async def test_islice_range_start_stop(self):
386        it = ait.islice(srange, 1, 3)
387        for k in [2, 3]:
388            self.assertEqual(await ait.next(it), k)
389        with self.assertRaises(StopAsyncIteration):
390            await ait.next(it)
391
392    @async_test
393    async def test_islice_range_start_stop_step(self):
394        it = ait.islice(srange, 1, 3, 2)
395        for k in [2]:
396            self.assertEqual(await ait.next(it), k)
397        with self.assertRaises(StopAsyncIteration):
398            await ait.next(it)
399
400    @async_test
401    async def test_islice_gen_stop(self):
402        async def gen():
403            yield 1
404            yield 2
405            yield 3
406            yield 4
407
408        gen_it = gen()
409        it = ait.islice(gen_it, 2)
410        for k in [1, 2]:
411            self.assertEqual(await ait.next(it), k)
412        with self.assertRaises(StopAsyncIteration):
413            await ait.next(it)
414        assert await ait.list(gen_it) == [3, 4]
415
416    @async_test
417    async def test_islice_gen_start_step(self):
418        async def gen():
419            yield 1
420            yield 2
421            yield 3
422            yield 4
423
424        it = ait.islice(gen(), 1, None, 2)
425        for k in [2, 4]:
426            self.assertEqual(await ait.next(it), k)
427        with self.assertRaises(StopAsyncIteration):
428            await ait.next(it)
429
430    @async_test
431    async def test_islice_gen_start_stop(self):
432        async def gen():
433            yield 1
434            yield 2
435            yield 3
436            yield 4
437
438        it = ait.islice(gen(), 1, 3)
439        for k in [2, 3]:
440            self.assertEqual(await ait.next(it), k)
441        with self.assertRaises(StopAsyncIteration):
442            await ait.next(it)
443
444    @async_test
445    async def test_islice_gen_start_stop_step(self):
446        async def gen():
447            yield 1
448            yield 2
449            yield 3
450            yield 4
451
452        gen_it = gen()
453        it = ait.islice(gen_it, 1, 3, 2)
454        for k in [2]:
455            self.assertEqual(await ait.next(it), k)
456        with self.assertRaises(StopAsyncIteration):
457            await ait.next(it)
458        assert await ait.list(gen_it) == [4]
459
460    @async_test
461    async def test_permutations_list(self):
462        it = ait.permutations(srange, r=2)
463        for k in [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1), (3, 2)]:
464            self.assertEqual(await ait.next(it), k)
465        with self.assertRaises(StopAsyncIteration):
466            await ait.next(it)
467
468    @async_test
469    async def test_permutations_gen(self):
470        async def gen():
471            yield 1
472            yield 2
473            yield 3
474
475        it = ait.permutations(gen(), r=2)
476        for k in [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1), (3, 2)]:
477            self.assertEqual(await ait.next(it), k)
478        with self.assertRaises(StopAsyncIteration):
479            await ait.next(it)
480
481    @async_test
482    async def test_product_list(self):
483        it = ait.product([1, 2], [6, 7])
484        for k in [(1, 6), (1, 7), (2, 6), (2, 7)]:
485            self.assertEqual(await ait.next(it), k)
486        with self.assertRaises(StopAsyncIteration):
487            await ait.next(it)
488
489    @async_test
490    async def test_product_gen(self):
491        async def gen(x):
492            yield x
493            yield x + 1
494
495        it = ait.product(gen(1), gen(6))
496        for k in [(1, 6), (1, 7), (2, 6), (2, 7)]:
497            self.assertEqual(await ait.next(it), k)
498        with self.assertRaises(StopAsyncIteration):
499            await ait.next(it)
500
501    @async_test
502    async def test_repeat(self):
503        it = ait.repeat(42)
504        for k in [42] * 10:
505            self.assertEqual(await ait.next(it), k)
506
507    @async_test
508    async def test_repeat_limit(self):
509        it = ait.repeat(42, 5)
510        for k in [42] * 5:
511            self.assertEqual(await ait.next(it), k)
512        with self.assertRaises(StopAsyncIteration):
513            await ait.next(it)
514
515    @async_test
516    async def test_starmap_function_list(self):
517        data = [slist[:2], slist[1:], slist]
518
519        def concat(*args):
520            return "".join(args)
521
522        it = ait.starmap(concat, data)
523        for k in ["AB", "BC", "ABC"]:
524            self.assertEqual(await ait.next(it), k)
525        with self.assertRaises(StopAsyncIteration):
526            await ait.next(it)
527
528    @async_test
529    async def test_starmap_function_gen(self):
530        def gen():
531            yield slist[:2]
532            yield slist[1:]
533            yield slist
534
535        def concat(*args):
536            return "".join(args)
537
538        it = ait.starmap(concat, gen())
539        for k in ["AB", "BC", "ABC"]:
540            self.assertEqual(await ait.next(it), k)
541        with self.assertRaises(StopAsyncIteration):
542            await ait.next(it)
543
544    @async_test
545    async def test_starmap_coroutine_list(self):
546        data = [slist[:2], slist[1:], slist]
547
548        async def concat(*args):
549            return "".join(args)
550
551        it = ait.starmap(concat, data)
552        for k in ["AB", "BC", "ABC"]:
553            self.assertEqual(await ait.next(it), k)
554        with self.assertRaises(StopAsyncIteration):
555            await ait.next(it)
556
557    @async_test
558    async def test_starmap_coroutine_gen(self):
559        async def gen():
560            yield slist[:2]
561            yield slist[1:]
562            yield slist
563
564        async def concat(*args):
565            return "".join(args)
566
567        it = ait.starmap(concat, gen())
568        for k in ["AB", "BC", "ABC"]:
569            self.assertEqual(await ait.next(it), k)
570        with self.assertRaises(StopAsyncIteration):
571            await ait.next(it)
572
573    @async_test
574    async def test_takewhile_empty(self):
575        def pred(x):
576            return x < 3
577
578        values = await ait.list(ait.takewhile(pred, []))
579        self.assertEqual(values, [])
580
581    @async_test
582    async def test_takewhile_function_list(self):
583        def pred(x):
584            return x < 3
585
586        it = ait.takewhile(pred, srange)
587        for k in [1, 2]:
588            self.assertEqual(await ait.next(it), k)
589        with self.assertRaises(StopAsyncIteration):
590            await ait.next(it)
591
592    @async_test
593    async def test_takewhile_function_gen(self):
594        async def gen():
595            yield 1
596            yield 2
597            yield 3
598
599        def pred(x):
600            return x < 3
601
602        it = ait.takewhile(pred, gen())
603        for k in [1, 2]:
604            self.assertEqual(await ait.next(it), k)
605        with self.assertRaises(StopAsyncIteration):
606            await ait.next(it)
607
608    @async_test
609    async def test_takewhile_coroutine_list(self):
610        async def pred(x):
611            return x < 3
612
613        it = ait.takewhile(pred, srange)
614        for k in [1, 2]:
615            self.assertEqual(await ait.next(it), k)
616        with self.assertRaises(StopAsyncIteration):
617            await ait.next(it)
618
619    @async_test
620    async def test_takewhile_coroutine_gen(self):
621        def gen():
622            yield 1
623            yield 2
624            yield 3
625
626        async def pred(x):
627            return x < 3
628
629        it = ait.takewhile(pred, gen())
630        for k in [1, 2]:
631            self.assertEqual(await ait.next(it), k)
632        with self.assertRaises(StopAsyncIteration):
633            await ait.next(it)
634
635    @async_test
636    async def test_tee_list_two(self):
637        it1, it2 = ait.tee(slist * 2)
638
639        for k in slist * 2:
640            a, b = await asyncio.gather(ait.next(it1), ait.next(it2))
641            self.assertEqual(a, b)
642            self.assertEqual(a, k)
643            self.assertEqual(b, k)
644        for it in [it1, it2]:
645            with self.assertRaises(StopAsyncIteration):
646                await ait.next(it)
647
648    @async_test
649    async def test_tee_list_six(self):
650        itrs = ait.tee(slist * 2, n=6)
651
652        for k in slist * 2:
653            values = await asyncio.gather(*[ait.next(it) for it in itrs])
654            for value in values:
655                self.assertEqual(value, k)
656        for it in itrs:
657            with self.assertRaises(StopAsyncIteration):
658                await ait.next(it)
659
660    @async_test
661    async def test_tee_gen_two(self):
662        async def gen():
663            yield 1
664            yield 4
665            yield 9
666            yield 16
667
668        it1, it2 = ait.tee(gen())
669
670        for k in [1, 4, 9, 16]:
671            a, b = await asyncio.gather(ait.next(it1), ait.next(it2))
672            self.assertEqual(a, b)
673            self.assertEqual(a, k)
674            self.assertEqual(b, k)
675        for it in [it1, it2]:
676            with self.assertRaises(StopAsyncIteration):
677                await ait.next(it)
678
679    @async_test
680    async def test_tee_gen_six(self):
681        async def gen():
682            yield 1
683            yield 4
684            yield 9
685            yield 16
686
687        itrs = ait.tee(gen(), n=6)
688
689        for k in [1, 4, 9, 16]:
690            values = await asyncio.gather(*[ait.next(it) for it in itrs])
691            for value in values:
692                self.assertEqual(value, k)
693        for it in itrs:
694            with self.assertRaises(StopAsyncIteration):
695                await ait.next(it)
696
697    @async_test
698    async def test_tee_propagate_exception(self):
699        class MyError(Exception):
700            pass
701
702        async def gen():
703            yield 1
704            yield 2
705            raise MyError
706
707        async def consumer(it):
708            result = 0
709            async for item in it:
710                result += item
711            return result
712
713        it1, it2 = ait.tee(gen())
714
715        values = await asyncio.gather(
716            consumer(it1),
717            consumer(it2),
718            return_exceptions=True,
719        )
720
721        for value in values:
722            self.assertIsInstance(value, MyError)
723
724    @async_test
725    async def test_zip_longest_range(self):
726        a = range(3)
727        b = range(5)
728
729        it = ait.zip_longest(a, b)
730
731        for k in [(0, 0), (1, 1), (2, 2), (None, 3), (None, 4)]:
732            self.assertEqual(await ait.next(it), k)
733        with self.assertRaises(StopAsyncIteration):
734            await ait.next(it)
735
736    @async_test
737    async def test_zip_longest_fillvalue(self):
738        async def gen():
739            yield 1
740            yield 4
741            yield 9
742            yield 16
743
744        a = gen()
745        b = range(5)
746
747        it = ait.zip_longest(a, b, fillvalue=42)
748
749        for k in [(1, 0), (4, 1), (9, 2), (16, 3), (42, 4)]:
750            self.assertEqual(await ait.next(it), k)
751        with self.assertRaises(StopAsyncIteration):
752            await ait.next(it)
753
754    @async_test
755    async def test_zip_longest_exception(self):
756        async def gen():
757            yield 1
758            yield 2
759            raise Exception("fake error")
760
761        a = gen()
762        b = ait.repeat(5)
763
764        it = ait.zip_longest(a, b)
765
766        for k in [(1, 5), (2, 5)]:
767            self.assertEqual(await ait.next(it), k)
768        with self.assertRaisesRegex(Exception, "fake error"):
769            await ait.next(it)
770