1# -*- coding: utf-8 -*- 2 3""" 4The MIT License (MIT) 5 6Copyright (c) 2015-present Rapptz 7 8Permission is hereby granted, free of charge, to any person obtaining a 9copy of this software and associated documentation files (the "Software"), 10to deal in the Software without restriction, including without limitation 11the rights to use, copy, modify, merge, publish, distribute, sublicense, 12and/or sell copies of the Software, and to permit persons to whom the 13Software is furnished to do so, subject to the following conditions: 14 15The above copyright notice and this permission notice shall be included in 16all copies or substantial portions of the Software. 17 18THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 19OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 23FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 24DEALINGS IN THE SOFTWARE. 25""" 26 27import asyncio 28import datetime 29 30from .errors import NoMoreItems 31from .utils import time_snowflake, maybe_coroutine 32from .object import Object 33from .audit_logs import AuditLogEntry 34 35OLDEST_OBJECT = Object(id=0) 36 37class _AsyncIterator: 38 __slots__ = () 39 40 def get(self, **attrs): 41 def predicate(elem): 42 for attr, val in attrs.items(): 43 nested = attr.split('__') 44 obj = elem 45 for attribute in nested: 46 obj = getattr(obj, attribute) 47 48 if obj != val: 49 return False 50 return True 51 52 return self.find(predicate) 53 54 async def find(self, predicate): 55 while True: 56 try: 57 elem = await self.next() 58 except NoMoreItems: 59 return None 60 61 ret = await maybe_coroutine(predicate, elem) 62 if ret: 63 return elem 64 65 def chunk(self, max_size): 66 if max_size <= 0: 67 raise ValueError('async iterator chunk sizes must be greater than 0.') 68 return _ChunkedAsyncIterator(self, max_size) 69 70 def map(self, func): 71 return _MappedAsyncIterator(self, func) 72 73 def filter(self, predicate): 74 return _FilteredAsyncIterator(self, predicate) 75 76 async def flatten(self): 77 ret = [] 78 while True: 79 try: 80 item = await self.next() 81 except NoMoreItems: 82 return ret 83 else: 84 ret.append(item) 85 86 def __aiter__(self): 87 return self 88 89 async def __anext__(self): 90 try: 91 msg = await self.next() 92 except NoMoreItems: 93 raise StopAsyncIteration() 94 else: 95 return msg 96 97def _identity(x): 98 return x 99 100class _ChunkedAsyncIterator(_AsyncIterator): 101 def __init__(self, iterator, max_size): 102 self.iterator = iterator 103 self.max_size = max_size 104 105 async def next(self): 106 ret = [] 107 n = 0 108 while n < self.max_size: 109 try: 110 item = await self.iterator.next() 111 except NoMoreItems: 112 if ret: 113 return ret 114 raise 115 else: 116 ret.append(item) 117 n += 1 118 return ret 119 120class _MappedAsyncIterator(_AsyncIterator): 121 def __init__(self, iterator, func): 122 self.iterator = iterator 123 self.func = func 124 125 async def next(self): 126 # this raises NoMoreItems and will propagate appropriately 127 item = await self.iterator.next() 128 return await maybe_coroutine(self.func, item) 129 130class _FilteredAsyncIterator(_AsyncIterator): 131 def __init__(self, iterator, predicate): 132 self.iterator = iterator 133 134 if predicate is None: 135 predicate = _identity 136 137 self.predicate = predicate 138 139 async def next(self): 140 getter = self.iterator.next 141 pred = self.predicate 142 while True: 143 # propagate NoMoreItems similar to _MappedAsyncIterator 144 item = await getter() 145 ret = await maybe_coroutine(pred, item) 146 if ret: 147 return item 148 149class ReactionIterator(_AsyncIterator): 150 def __init__(self, message, emoji, limit=100, after=None): 151 self.message = message 152 self.limit = limit 153 self.after = after 154 state = message._state 155 self.getter = state.http.get_reaction_users 156 self.state = state 157 self.emoji = emoji 158 self.guild = message.guild 159 self.channel_id = message.channel.id 160 self.users = asyncio.Queue() 161 162 async def next(self): 163 if self.users.empty(): 164 await self.fill_users() 165 166 try: 167 return self.users.get_nowait() 168 except asyncio.QueueEmpty: 169 raise NoMoreItems() 170 171 async def fill_users(self): 172 # this is a hack because >circular imports< 173 from .user import User 174 175 if self.limit > 0: 176 retrieve = self.limit if self.limit <= 100 else 100 177 178 after = self.after.id if self.after else None 179 data = await self.getter(self.channel_id, self.message.id, self.emoji, retrieve, after=after) 180 181 if data: 182 self.limit -= retrieve 183 self.after = Object(id=int(data[-1]['id'])) 184 185 if self.guild is None or isinstance(self.guild, Object): 186 for element in reversed(data): 187 await self.users.put(User(state=self.state, data=element)) 188 else: 189 for element in reversed(data): 190 member_id = int(element['id']) 191 member = self.guild.get_member(member_id) 192 if member is not None: 193 await self.users.put(member) 194 else: 195 await self.users.put(User(state=self.state, data=element)) 196 197class HistoryIterator(_AsyncIterator): 198 """Iterator for receiving a channel's message history. 199 200 The messages endpoint has two behaviours we care about here: 201 If ``before`` is specified, the messages endpoint returns the `limit` 202 newest messages before ``before``, sorted with newest first. For filling over 203 100 messages, update the ``before`` parameter to the oldest message received. 204 Messages will be returned in order by time. 205 If ``after`` is specified, it returns the ``limit`` oldest messages after 206 ``after``, sorted with newest first. For filling over 100 messages, update the 207 ``after`` parameter to the newest message received. If messages are not 208 reversed, they will be out of order (99-0, 199-100, so on) 209 210 A note that if both ``before`` and ``after`` are specified, ``before`` is ignored by the 211 messages endpoint. 212 213 Parameters 214 ----------- 215 messageable: :class:`abc.Messageable` 216 Messageable class to retrieve message history from. 217 limit: :class:`int` 218 Maximum number of messages to retrieve 219 before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] 220 Message before which all messages must be. 221 after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] 222 Message after which all messages must be. 223 around: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] 224 Message around which all messages must be. Limit max 101. Note that if 225 limit is an even number, this will return at most limit+1 messages. 226 oldest_first: Optional[:class:`bool`] 227 If set to ``True``, return messages in oldest->newest order. Defaults to 228 ``True`` if `after` is specified, otherwise ``False``. 229 """ 230 231 def __init__(self, messageable, limit, 232 before=None, after=None, around=None, oldest_first=None): 233 234 if isinstance(before, datetime.datetime): 235 before = Object(id=time_snowflake(before, high=False)) 236 if isinstance(after, datetime.datetime): 237 after = Object(id=time_snowflake(after, high=True)) 238 if isinstance(around, datetime.datetime): 239 around = Object(id=time_snowflake(around)) 240 241 if oldest_first is None: 242 self.reverse = after is not None 243 else: 244 self.reverse = oldest_first 245 246 self.messageable = messageable 247 self.limit = limit 248 self.before = before 249 self.after = after or OLDEST_OBJECT 250 self.around = around 251 252 self._filter = None # message dict -> bool 253 254 self.state = self.messageable._state 255 self.logs_from = self.state.http.logs_from 256 self.messages = asyncio.Queue() 257 258 if self.around: 259 if self.limit is None: 260 raise ValueError('history does not support around with limit=None') 261 if self.limit > 101: 262 raise ValueError("history max limit 101 when specifying around parameter") 263 elif self.limit == 101: 264 self.limit = 100 # Thanks discord 265 266 self._retrieve_messages = self._retrieve_messages_around_strategy 267 if self.before and self.after: 268 self._filter = lambda m: self.after.id < int(m['id']) < self.before.id 269 elif self.before: 270 self._filter = lambda m: int(m['id']) < self.before.id 271 elif self.after: 272 self._filter = lambda m: self.after.id < int(m['id']) 273 else: 274 if self.reverse: 275 self._retrieve_messages = self._retrieve_messages_after_strategy 276 if (self.before): 277 self._filter = lambda m: int(m['id']) < self.before.id 278 else: 279 self._retrieve_messages = self._retrieve_messages_before_strategy 280 if (self.after and self.after != OLDEST_OBJECT): 281 self._filter = lambda m: int(m['id']) > self.after.id 282 283 async def next(self): 284 if self.messages.empty(): 285 await self.fill_messages() 286 287 try: 288 return self.messages.get_nowait() 289 except asyncio.QueueEmpty: 290 raise NoMoreItems() 291 292 def _get_retrieve(self): 293 l = self.limit 294 if l is None or l > 100: 295 r = 100 296 else: 297 r = l 298 self.retrieve = r 299 return r > 0 300 301 async def flatten(self): 302 # this is similar to fill_messages except it uses a list instead 303 # of a queue to place the messages in. 304 result = [] 305 channel = await self.messageable._get_channel() 306 self.channel = channel 307 while self._get_retrieve(): 308 data = await self._retrieve_messages(self.retrieve) 309 if len(data) < 100: 310 self.limit = 0 # terminate the infinite loop 311 312 if self.reverse: 313 data = reversed(data) 314 if self._filter: 315 data = filter(self._filter, data) 316 317 for element in data: 318 result.append(self.state.create_message(channel=channel, data=element)) 319 return result 320 321 async def fill_messages(self): 322 if not hasattr(self, 'channel'): 323 # do the required set up 324 channel = await self.messageable._get_channel() 325 self.channel = channel 326 327 if self._get_retrieve(): 328 data = await self._retrieve_messages(self.retrieve) 329 if len(data) < 100: 330 self.limit = 0 # terminate the infinite loop 331 332 if self.reverse: 333 data = reversed(data) 334 if self._filter: 335 data = filter(self._filter, data) 336 337 channel = self.channel 338 for element in data: 339 await self.messages.put(self.state.create_message(channel=channel, data=element)) 340 341 async def _retrieve_messages(self, retrieve): 342 """Retrieve messages and update next parameters.""" 343 pass 344 345 async def _retrieve_messages_before_strategy(self, retrieve): 346 """Retrieve messages using before parameter.""" 347 before = self.before.id if self.before else None 348 data = await self.logs_from(self.channel.id, retrieve, before=before) 349 if len(data): 350 if self.limit is not None: 351 self.limit -= retrieve 352 self.before = Object(id=int(data[-1]['id'])) 353 return data 354 355 async def _retrieve_messages_after_strategy(self, retrieve): 356 """Retrieve messages using after parameter.""" 357 after = self.after.id if self.after else None 358 data = await self.logs_from(self.channel.id, retrieve, after=after) 359 if len(data): 360 if self.limit is not None: 361 self.limit -= retrieve 362 self.after = Object(id=int(data[0]['id'])) 363 return data 364 365 async def _retrieve_messages_around_strategy(self, retrieve): 366 """Retrieve messages using around parameter.""" 367 if self.around: 368 around = self.around.id if self.around else None 369 data = await self.logs_from(self.channel.id, retrieve, around=around) 370 self.around = None 371 return data 372 return [] 373 374class AuditLogIterator(_AsyncIterator): 375 def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None): 376 if isinstance(before, datetime.datetime): 377 before = Object(id=time_snowflake(before, high=False)) 378 if isinstance(after, datetime.datetime): 379 after = Object(id=time_snowflake(after, high=True)) 380 381 382 if oldest_first is None: 383 self.reverse = after is not None 384 else: 385 self.reverse = oldest_first 386 387 self.guild = guild 388 self.loop = guild._state.loop 389 self.request = guild._state.http.get_audit_logs 390 self.limit = limit 391 self.before = before 392 self.user_id = user_id 393 self.action_type = action_type 394 self.after = OLDEST_OBJECT 395 self._users = {} 396 self._state = guild._state 397 398 399 self._filter = None # entry dict -> bool 400 401 self.entries = asyncio.Queue() 402 403 404 if self.reverse: 405 self._strategy = self._after_strategy 406 if self.before: 407 self._filter = lambda m: int(m['id']) < self.before.id 408 else: 409 self._strategy = self._before_strategy 410 if self.after and self.after != OLDEST_OBJECT: 411 self._filter = lambda m: int(m['id']) > self.after.id 412 413 async def _before_strategy(self, retrieve): 414 before = self.before.id if self.before else None 415 data = await self.request(self.guild.id, limit=retrieve, user_id=self.user_id, 416 action_type=self.action_type, before=before) 417 418 entries = data.get('audit_log_entries', []) 419 if len(data) and entries: 420 if self.limit is not None: 421 self.limit -= retrieve 422 self.before = Object(id=int(entries[-1]['id'])) 423 return data.get('users', []), entries 424 425 async def _after_strategy(self, retrieve): 426 after = self.after.id if self.after else None 427 data = await self.request(self.guild.id, limit=retrieve, user_id=self.user_id, 428 action_type=self.action_type, after=after) 429 entries = data.get('audit_log_entries', []) 430 if len(data) and entries: 431 if self.limit is not None: 432 self.limit -= retrieve 433 self.after = Object(id=int(entries[0]['id'])) 434 return data.get('users', []), entries 435 436 async def next(self): 437 if self.entries.empty(): 438 await self._fill() 439 440 try: 441 return self.entries.get_nowait() 442 except asyncio.QueueEmpty: 443 raise NoMoreItems() 444 445 def _get_retrieve(self): 446 l = self.limit 447 if l is None or l > 100: 448 r = 100 449 else: 450 r = l 451 self.retrieve = r 452 return r > 0 453 454 async def _fill(self): 455 from .user import User 456 457 if self._get_retrieve(): 458 users, data = await self._strategy(self.retrieve) 459 if len(data) < 100: 460 self.limit = 0 # terminate the infinite loop 461 462 if self.reverse: 463 data = reversed(data) 464 if self._filter: 465 data = filter(self._filter, data) 466 467 for user in users: 468 u = User(data=user, state=self._state) 469 self._users[u.id] = u 470 471 for element in data: 472 # TODO: remove this if statement later 473 if element['action_type'] is None: 474 continue 475 476 await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild)) 477 478 479class GuildIterator(_AsyncIterator): 480 """Iterator for receiving the client's guilds. 481 482 The guilds endpoint has the same two behaviours as described 483 in :class:`HistoryIterator`: 484 If ``before`` is specified, the guilds endpoint returns the ``limit`` 485 newest guilds before ``before``, sorted with newest first. For filling over 486 100 guilds, update the ``before`` parameter to the oldest guild received. 487 Guilds will be returned in order by time. 488 If `after` is specified, it returns the ``limit`` oldest guilds after ``after``, 489 sorted with newest first. For filling over 100 guilds, update the ``after`` 490 parameter to the newest guild received, If guilds are not reversed, they 491 will be out of order (99-0, 199-100, so on) 492 493 Not that if both ``before`` and ``after`` are specified, ``before`` is ignored by the 494 guilds endpoint. 495 496 Parameters 497 ----------- 498 bot: :class:`discord.Client` 499 The client to retrieve the guilds from. 500 limit: :class:`int` 501 Maximum number of guilds to retrieve. 502 before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] 503 Object before which all guilds must be. 504 after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] 505 Object after which all guilds must be. 506 """ 507 def __init__(self, bot, limit, before=None, after=None): 508 509 if isinstance(before, datetime.datetime): 510 before = Object(id=time_snowflake(before, high=False)) 511 if isinstance(after, datetime.datetime): 512 after = Object(id=time_snowflake(after, high=True)) 513 514 self.bot = bot 515 self.limit = limit 516 self.before = before 517 self.after = after 518 519 self._filter = None 520 521 self.state = self.bot._connection 522 self.get_guilds = self.bot.http.get_guilds 523 self.guilds = asyncio.Queue() 524 525 if self.before and self.after: 526 self._retrieve_guilds = self._retrieve_guilds_before_strategy 527 self._filter = lambda m: int(m['id']) > self.after.id 528 elif self.after: 529 self._retrieve_guilds = self._retrieve_guilds_after_strategy 530 else: 531 self._retrieve_guilds = self._retrieve_guilds_before_strategy 532 533 async def next(self): 534 if self.guilds.empty(): 535 await self.fill_guilds() 536 537 try: 538 return self.guilds.get_nowait() 539 except asyncio.QueueEmpty: 540 raise NoMoreItems() 541 542 def _get_retrieve(self): 543 l = self.limit 544 if l is None or l > 100: 545 r = 100 546 else: 547 r = l 548 self.retrieve = r 549 return r > 0 550 551 def create_guild(self, data): 552 from .guild import Guild 553 return Guild(state=self.state, data=data) 554 555 async def flatten(self): 556 result = [] 557 while self._get_retrieve(): 558 data = await self._retrieve_guilds(self.retrieve) 559 if len(data) < 100: 560 self.limit = 0 561 562 if self._filter: 563 data = filter(self._filter, data) 564 565 for element in data: 566 result.append(self.create_guild(element)) 567 return result 568 569 async def fill_guilds(self): 570 if self._get_retrieve(): 571 data = await self._retrieve_guilds(self.retrieve) 572 if self.limit is None or len(data) < 100: 573 self.limit = 0 574 575 if self._filter: 576 data = filter(self._filter, data) 577 578 for element in data: 579 await self.guilds.put(self.create_guild(element)) 580 581 async def _retrieve_guilds(self, retrieve): 582 """Retrieve guilds and update next parameters.""" 583 pass 584 585 async def _retrieve_guilds_before_strategy(self, retrieve): 586 """Retrieve guilds using before parameter.""" 587 before = self.before.id if self.before else None 588 data = await self.get_guilds(retrieve, before=before) 589 if len(data): 590 if self.limit is not None: 591 self.limit -= retrieve 592 self.before = Object(id=int(data[-1]['id'])) 593 return data 594 595 async def _retrieve_guilds_after_strategy(self, retrieve): 596 """Retrieve guilds using after parameter.""" 597 after = self.after.id if self.after else None 598 data = await self.get_guilds(retrieve, after=after) 599 if len(data): 600 if self.limit is not None: 601 self.limit -= retrieve 602 self.after = Object(id=int(data[0]['id'])) 603 return data 604 605class MemberIterator(_AsyncIterator): 606 def __init__(self, guild, limit=1000, after=None): 607 608 if isinstance(after, datetime.datetime): 609 after = Object(id=time_snowflake(after, high=True)) 610 611 self.guild = guild 612 self.limit = limit 613 self.after = after or OLDEST_OBJECT 614 615 self.state = self.guild._state 616 self.get_members = self.state.http.get_members 617 self.members = asyncio.Queue() 618 619 async def next(self): 620 if self.members.empty(): 621 await self.fill_members() 622 623 try: 624 return self.members.get_nowait() 625 except asyncio.QueueEmpty: 626 raise NoMoreItems() 627 628 def _get_retrieve(self): 629 l = self.limit 630 if l is None or l > 1000: 631 r = 1000 632 else: 633 r = l 634 self.retrieve = r 635 return r > 0 636 637 async def fill_members(self): 638 if self._get_retrieve(): 639 after = self.after.id if self.after else None 640 data = await self.get_members(self.guild.id, self.retrieve, after) 641 if not data: 642 # no data, terminate 643 return 644 645 if len(data) < 1000: 646 self.limit = 0 # terminate loop 647 648 self.after = Object(id=int(data[-1]['user']['id'])) 649 650 for element in reversed(data): 651 await self.members.put(self.create_member(element)) 652 653 def create_member(self, data): 654 from .member import Member 655 return Member(data=data, guild=self.guild, state=self.state) 656