1# -*- coding: utf-8 -*-
2
3"""
4The MIT License (MIT)
5
6Copyright (c) 2015-present Rapptz
7
8Permission is hereby granted, free of charge, to any person obtaining a
9copy of this software and associated documentation files (the "Software"),
10to deal in the Software without restriction, including without limitation
11the rights to use, copy, modify, merge, publish, distribute, sublicense,
12and/or sell copies of the Software, and to permit persons to whom the
13Software is furnished to do so, subject to the following conditions:
14
15The above copyright notice and this permission notice shall be included in
16all copies or substantial portions of the Software.
17
18THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
19OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
23FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
24DEALINGS IN THE SOFTWARE.
25"""
26
27import asyncio
28import datetime
29
30from .errors import NoMoreItems
31from .utils import time_snowflake, maybe_coroutine
32from .object import Object
33from .audit_logs import AuditLogEntry
34
35OLDEST_OBJECT = Object(id=0)
36
37class _AsyncIterator:
38    __slots__ = ()
39
40    def get(self, **attrs):
41        def predicate(elem):
42            for attr, val in attrs.items():
43                nested = attr.split('__')
44                obj = elem
45                for attribute in nested:
46                    obj = getattr(obj, attribute)
47
48                if obj != val:
49                    return False
50            return True
51
52        return self.find(predicate)
53
54    async def find(self, predicate):
55        while True:
56            try:
57                elem = await self.next()
58            except NoMoreItems:
59                return None
60
61            ret = await maybe_coroutine(predicate, elem)
62            if ret:
63                return elem
64
65    def chunk(self, max_size):
66        if max_size <= 0:
67            raise ValueError('async iterator chunk sizes must be greater than 0.')
68        return _ChunkedAsyncIterator(self, max_size)
69
70    def map(self, func):
71        return _MappedAsyncIterator(self, func)
72
73    def filter(self, predicate):
74        return _FilteredAsyncIterator(self, predicate)
75
76    async def flatten(self):
77        ret = []
78        while True:
79            try:
80                item = await self.next()
81            except NoMoreItems:
82                return ret
83            else:
84                ret.append(item)
85
86    def __aiter__(self):
87        return self
88
89    async def __anext__(self):
90        try:
91            msg = await self.next()
92        except NoMoreItems:
93            raise StopAsyncIteration()
94        else:
95            return msg
96
97def _identity(x):
98    return x
99
100class _ChunkedAsyncIterator(_AsyncIterator):
101    def __init__(self, iterator, max_size):
102        self.iterator = iterator
103        self.max_size = max_size
104
105    async def next(self):
106        ret = []
107        n = 0
108        while n < self.max_size:
109            try:
110                item = await self.iterator.next()
111            except NoMoreItems:
112                if ret:
113                    return ret
114                raise
115            else:
116                ret.append(item)
117                n += 1
118        return ret
119
120class _MappedAsyncIterator(_AsyncIterator):
121    def __init__(self, iterator, func):
122        self.iterator = iterator
123        self.func = func
124
125    async def next(self):
126        # this raises NoMoreItems and will propagate appropriately
127        item = await self.iterator.next()
128        return await maybe_coroutine(self.func, item)
129
130class _FilteredAsyncIterator(_AsyncIterator):
131    def __init__(self, iterator, predicate):
132        self.iterator = iterator
133
134        if predicate is None:
135            predicate = _identity
136
137        self.predicate = predicate
138
139    async def next(self):
140        getter = self.iterator.next
141        pred = self.predicate
142        while True:
143            # propagate NoMoreItems similar to _MappedAsyncIterator
144            item = await getter()
145            ret = await maybe_coroutine(pred, item)
146            if ret:
147                return item
148
149class ReactionIterator(_AsyncIterator):
150    def __init__(self, message, emoji, limit=100, after=None):
151        self.message = message
152        self.limit = limit
153        self.after = after
154        state = message._state
155        self.getter = state.http.get_reaction_users
156        self.state = state
157        self.emoji = emoji
158        self.guild = message.guild
159        self.channel_id = message.channel.id
160        self.users = asyncio.Queue()
161
162    async def next(self):
163        if self.users.empty():
164            await self.fill_users()
165
166        try:
167            return self.users.get_nowait()
168        except asyncio.QueueEmpty:
169            raise NoMoreItems()
170
171    async def fill_users(self):
172        # this is a hack because >circular imports<
173        from .user import User
174
175        if self.limit > 0:
176            retrieve = self.limit if self.limit <= 100 else 100
177
178            after = self.after.id if self.after else None
179            data = await self.getter(self.channel_id, self.message.id, self.emoji, retrieve, after=after)
180
181            if data:
182                self.limit -= retrieve
183                self.after = Object(id=int(data[-1]['id']))
184
185            if self.guild is None or isinstance(self.guild, Object):
186                for element in reversed(data):
187                    await self.users.put(User(state=self.state, data=element))
188            else:
189                for element in reversed(data):
190                    member_id = int(element['id'])
191                    member = self.guild.get_member(member_id)
192                    if member is not None:
193                        await self.users.put(member)
194                    else:
195                        await self.users.put(User(state=self.state, data=element))
196
197class HistoryIterator(_AsyncIterator):
198    """Iterator for receiving a channel's message history.
199
200    The messages endpoint has two behaviours we care about here:
201    If ``before`` is specified, the messages endpoint returns the `limit`
202    newest messages before ``before``, sorted with newest first. For filling over
203    100 messages, update the ``before`` parameter to the oldest message received.
204    Messages will be returned in order by time.
205    If ``after`` is specified, it returns the ``limit`` oldest messages after
206    ``after``, sorted with newest first. For filling over 100 messages, update the
207    ``after`` parameter to the newest message received. If messages are not
208    reversed, they will be out of order (99-0, 199-100, so on)
209
210    A note that if both ``before`` and ``after`` are specified, ``before`` is ignored by the
211    messages endpoint.
212
213    Parameters
214    -----------
215    messageable: :class:`abc.Messageable`
216        Messageable class to retrieve message history from.
217    limit: :class:`int`
218        Maximum number of messages to retrieve
219    before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]]
220        Message before which all messages must be.
221    after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]]
222        Message after which all messages must be.
223    around: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]]
224        Message around which all messages must be. Limit max 101. Note that if
225        limit is an even number, this will return at most limit+1 messages.
226    oldest_first: Optional[:class:`bool`]
227        If set to ``True``, return messages in oldest->newest order. Defaults to
228        ``True`` if `after` is specified, otherwise ``False``.
229    """
230
231    def __init__(self, messageable, limit,
232                 before=None, after=None, around=None, oldest_first=None):
233
234        if isinstance(before, datetime.datetime):
235            before = Object(id=time_snowflake(before, high=False))
236        if isinstance(after, datetime.datetime):
237            after = Object(id=time_snowflake(after, high=True))
238        if isinstance(around, datetime.datetime):
239            around = Object(id=time_snowflake(around))
240
241        if oldest_first is None:
242            self.reverse = after is not None
243        else:
244            self.reverse = oldest_first
245
246        self.messageable = messageable
247        self.limit = limit
248        self.before = before
249        self.after = after or OLDEST_OBJECT
250        self.around = around
251
252        self._filter = None  # message dict -> bool
253
254        self.state = self.messageable._state
255        self.logs_from = self.state.http.logs_from
256        self.messages = asyncio.Queue()
257
258        if self.around:
259            if self.limit is None:
260                raise ValueError('history does not support around with limit=None')
261            if self.limit > 101:
262                raise ValueError("history max limit 101 when specifying around parameter")
263            elif self.limit == 101:
264                self.limit = 100  # Thanks discord
265
266            self._retrieve_messages = self._retrieve_messages_around_strategy
267            if self.before and self.after:
268                self._filter = lambda m: self.after.id < int(m['id']) < self.before.id
269            elif self.before:
270                self._filter = lambda m: int(m['id']) < self.before.id
271            elif self.after:
272                self._filter = lambda m: self.after.id < int(m['id'])
273        else:
274            if self.reverse:
275                self._retrieve_messages = self._retrieve_messages_after_strategy
276                if (self.before):
277                    self._filter = lambda m: int(m['id']) < self.before.id
278            else:
279                self._retrieve_messages = self._retrieve_messages_before_strategy
280                if (self.after and self.after != OLDEST_OBJECT):
281                    self._filter = lambda m: int(m['id']) > self.after.id
282
283    async def next(self):
284        if self.messages.empty():
285            await self.fill_messages()
286
287        try:
288            return self.messages.get_nowait()
289        except asyncio.QueueEmpty:
290            raise NoMoreItems()
291
292    def _get_retrieve(self):
293        l = self.limit
294        if l is None or l > 100:
295            r = 100
296        else:
297            r = l
298        self.retrieve = r
299        return r > 0
300
301    async def flatten(self):
302        # this is similar to fill_messages except it uses a list instead
303        # of a queue to place the messages in.
304        result = []
305        channel = await self.messageable._get_channel()
306        self.channel = channel
307        while self._get_retrieve():
308            data = await self._retrieve_messages(self.retrieve)
309            if len(data) < 100:
310                self.limit = 0 # terminate the infinite loop
311
312            if self.reverse:
313                data = reversed(data)
314            if self._filter:
315                data = filter(self._filter, data)
316
317            for element in data:
318                result.append(self.state.create_message(channel=channel, data=element))
319        return result
320
321    async def fill_messages(self):
322        if not hasattr(self, 'channel'):
323            # do the required set up
324            channel = await self.messageable._get_channel()
325            self.channel = channel
326
327        if self._get_retrieve():
328            data = await self._retrieve_messages(self.retrieve)
329            if len(data) < 100:
330                self.limit = 0 # terminate the infinite loop
331
332            if self.reverse:
333                data = reversed(data)
334            if self._filter:
335                data = filter(self._filter, data)
336
337            channel = self.channel
338            for element in data:
339                await self.messages.put(self.state.create_message(channel=channel, data=element))
340
341    async def _retrieve_messages(self, retrieve):
342        """Retrieve messages and update next parameters."""
343        pass
344
345    async def _retrieve_messages_before_strategy(self, retrieve):
346        """Retrieve messages using before parameter."""
347        before = self.before.id if self.before else None
348        data = await self.logs_from(self.channel.id, retrieve, before=before)
349        if len(data):
350            if self.limit is not None:
351                self.limit -= retrieve
352            self.before = Object(id=int(data[-1]['id']))
353        return data
354
355    async def _retrieve_messages_after_strategy(self, retrieve):
356        """Retrieve messages using after parameter."""
357        after = self.after.id if self.after else None
358        data = await self.logs_from(self.channel.id, retrieve, after=after)
359        if len(data):
360            if self.limit is not None:
361                self.limit -= retrieve
362            self.after = Object(id=int(data[0]['id']))
363        return data
364
365    async def _retrieve_messages_around_strategy(self, retrieve):
366        """Retrieve messages using around parameter."""
367        if self.around:
368            around = self.around.id if self.around else None
369            data = await self.logs_from(self.channel.id, retrieve, around=around)
370            self.around = None
371            return data
372        return []
373
374class AuditLogIterator(_AsyncIterator):
375    def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None):
376        if isinstance(before, datetime.datetime):
377            before = Object(id=time_snowflake(before, high=False))
378        if isinstance(after, datetime.datetime):
379            after = Object(id=time_snowflake(after, high=True))
380
381
382        if oldest_first is None:
383            self.reverse = after is not None
384        else:
385            self.reverse = oldest_first
386
387        self.guild = guild
388        self.loop = guild._state.loop
389        self.request = guild._state.http.get_audit_logs
390        self.limit = limit
391        self.before = before
392        self.user_id = user_id
393        self.action_type = action_type
394        self.after = OLDEST_OBJECT
395        self._users = {}
396        self._state = guild._state
397
398
399        self._filter = None  # entry dict -> bool
400
401        self.entries = asyncio.Queue()
402
403
404        if self.reverse:
405            self._strategy = self._after_strategy
406            if self.before:
407                self._filter = lambda m: int(m['id']) < self.before.id
408        else:
409            self._strategy = self._before_strategy
410            if self.after and self.after != OLDEST_OBJECT:
411                self._filter = lambda m: int(m['id']) > self.after.id
412
413    async def _before_strategy(self, retrieve):
414        before = self.before.id if self.before else None
415        data = await self.request(self.guild.id, limit=retrieve, user_id=self.user_id,
416                                  action_type=self.action_type, before=before)
417
418        entries = data.get('audit_log_entries', [])
419        if len(data) and entries:
420            if self.limit is not None:
421                self.limit -= retrieve
422            self.before = Object(id=int(entries[-1]['id']))
423        return data.get('users', []), entries
424
425    async def _after_strategy(self, retrieve):
426        after = self.after.id if self.after else None
427        data = await self.request(self.guild.id, limit=retrieve, user_id=self.user_id,
428                                  action_type=self.action_type, after=after)
429        entries = data.get('audit_log_entries', [])
430        if len(data) and entries:
431            if self.limit is not None:
432                self.limit -= retrieve
433            self.after = Object(id=int(entries[0]['id']))
434        return data.get('users', []), entries
435
436    async def next(self):
437        if self.entries.empty():
438            await self._fill()
439
440        try:
441            return self.entries.get_nowait()
442        except asyncio.QueueEmpty:
443            raise NoMoreItems()
444
445    def _get_retrieve(self):
446        l = self.limit
447        if l is None or l > 100:
448            r = 100
449        else:
450            r = l
451        self.retrieve = r
452        return r > 0
453
454    async def _fill(self):
455        from .user import User
456
457        if self._get_retrieve():
458            users, data = await self._strategy(self.retrieve)
459            if len(data) < 100:
460                self.limit = 0 # terminate the infinite loop
461
462            if self.reverse:
463                data = reversed(data)
464            if self._filter:
465                data = filter(self._filter, data)
466
467            for user in users:
468                u = User(data=user, state=self._state)
469                self._users[u.id] = u
470
471            for element in data:
472                # TODO: remove this if statement later
473                if element['action_type'] is None:
474                    continue
475
476                await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild))
477
478
479class GuildIterator(_AsyncIterator):
480    """Iterator for receiving the client's guilds.
481
482    The guilds endpoint has the same two behaviours as described
483    in :class:`HistoryIterator`:
484    If ``before`` is specified, the guilds endpoint returns the ``limit``
485    newest guilds before ``before``, sorted with newest first. For filling over
486    100 guilds, update the ``before`` parameter to the oldest guild received.
487    Guilds will be returned in order by time.
488    If `after` is specified, it returns the ``limit`` oldest guilds after ``after``,
489    sorted with newest first. For filling over 100 guilds, update the ``after``
490    parameter to the newest guild received, If guilds are not reversed, they
491    will be out of order (99-0, 199-100, so on)
492
493    Not that if both ``before`` and ``after`` are specified, ``before`` is ignored by the
494    guilds endpoint.
495
496    Parameters
497    -----------
498    bot: :class:`discord.Client`
499        The client to retrieve the guilds from.
500    limit: :class:`int`
501        Maximum number of guilds to retrieve.
502    before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]]
503        Object before which all guilds must be.
504    after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]]
505        Object after which all guilds must be.
506    """
507    def __init__(self, bot, limit, before=None, after=None):
508
509        if isinstance(before, datetime.datetime):
510            before = Object(id=time_snowflake(before, high=False))
511        if isinstance(after, datetime.datetime):
512            after = Object(id=time_snowflake(after, high=True))
513
514        self.bot = bot
515        self.limit = limit
516        self.before = before
517        self.after = after
518
519        self._filter = None
520
521        self.state = self.bot._connection
522        self.get_guilds = self.bot.http.get_guilds
523        self.guilds = asyncio.Queue()
524
525        if self.before and self.after:
526            self._retrieve_guilds = self._retrieve_guilds_before_strategy
527            self._filter = lambda m: int(m['id']) > self.after.id
528        elif self.after:
529            self._retrieve_guilds = self._retrieve_guilds_after_strategy
530        else:
531            self._retrieve_guilds = self._retrieve_guilds_before_strategy
532
533    async def next(self):
534        if self.guilds.empty():
535            await self.fill_guilds()
536
537        try:
538            return self.guilds.get_nowait()
539        except asyncio.QueueEmpty:
540            raise NoMoreItems()
541
542    def _get_retrieve(self):
543        l = self.limit
544        if l is None or l > 100:
545            r = 100
546        else:
547            r = l
548        self.retrieve = r
549        return r > 0
550
551    def create_guild(self, data):
552        from .guild import Guild
553        return Guild(state=self.state, data=data)
554
555    async def flatten(self):
556        result = []
557        while self._get_retrieve():
558            data = await self._retrieve_guilds(self.retrieve)
559            if len(data) < 100:
560                self.limit = 0
561
562            if self._filter:
563                data = filter(self._filter, data)
564
565            for element in data:
566                result.append(self.create_guild(element))
567        return result
568
569    async def fill_guilds(self):
570        if self._get_retrieve():
571            data = await self._retrieve_guilds(self.retrieve)
572            if self.limit is None or len(data) < 100:
573                self.limit = 0
574
575            if self._filter:
576                data = filter(self._filter, data)
577
578            for element in data:
579                await self.guilds.put(self.create_guild(element))
580
581    async def _retrieve_guilds(self, retrieve):
582        """Retrieve guilds and update next parameters."""
583        pass
584
585    async def _retrieve_guilds_before_strategy(self, retrieve):
586        """Retrieve guilds using before parameter."""
587        before = self.before.id if self.before else None
588        data = await self.get_guilds(retrieve, before=before)
589        if len(data):
590            if self.limit is not None:
591                self.limit -= retrieve
592            self.before = Object(id=int(data[-1]['id']))
593        return data
594
595    async def _retrieve_guilds_after_strategy(self, retrieve):
596        """Retrieve guilds using after parameter."""
597        after = self.after.id if self.after else None
598        data = await self.get_guilds(retrieve, after=after)
599        if len(data):
600            if self.limit is not None:
601                self.limit -= retrieve
602            self.after = Object(id=int(data[0]['id']))
603        return data
604
605class MemberIterator(_AsyncIterator):
606    def __init__(self, guild, limit=1000, after=None):
607
608        if isinstance(after, datetime.datetime):
609            after = Object(id=time_snowflake(after, high=True))
610
611        self.guild = guild
612        self.limit = limit
613        self.after = after or OLDEST_OBJECT
614
615        self.state = self.guild._state
616        self.get_members = self.state.http.get_members
617        self.members = asyncio.Queue()
618
619    async def next(self):
620        if self.members.empty():
621            await self.fill_members()
622
623        try:
624            return self.members.get_nowait()
625        except asyncio.QueueEmpty:
626            raise NoMoreItems()
627
628    def _get_retrieve(self):
629        l = self.limit
630        if l is None or l > 1000:
631            r = 1000
632        else:
633            r = l
634        self.retrieve = r
635        return r > 0
636
637    async def fill_members(self):
638        if self._get_retrieve():
639            after = self.after.id if self.after else None
640            data = await self.get_members(self.guild.id, self.retrieve, after)
641            if not data:
642                # no data, terminate
643                return
644
645            if len(data) < 1000:
646                self.limit = 0 # terminate loop
647
648            self.after = Object(id=int(data[-1]['user']['id']))
649
650            for element in reversed(data):
651                await self.members.put(self.create_member(element))
652
653    def create_member(self, data):
654        from .member import Member
655        return Member(data=data, guild=self.guild, state=self.state)
656