1import asyncio 2import collections 3import types 4import warnings 5import sys 6 7from .connection import create_connection, _PUBSUB_COMMANDS 8from .log import logger 9from .util import parse_url, CloseEvent 10from .errors import PoolClosedError 11from .abc import AbcPool 12from .locks import Lock 13 14 15async def create_pool(address, *, db=None, password=None, ssl=None, 16 encoding=None, minsize=1, maxsize=10, 17 parser=None, loop=None, create_connection_timeout=None, 18 pool_cls=None, connection_cls=None): 19 # FIXME: rewrite docstring 20 """Creates Redis Pool. 21 22 By default it creates pool of Redis instances, but it is 23 also possible to create pool of plain connections by passing 24 ``lambda conn: conn`` as commands_factory. 25 26 *commands_factory* parameter is deprecated since v0.2.9 27 28 All arguments are the same as for create_connection. 29 30 Returns RedisPool instance or a pool_cls if it is given. 31 """ 32 if pool_cls: 33 assert issubclass(pool_cls, AbcPool),\ 34 "pool_class does not meet the AbcPool contract" 35 cls = pool_cls 36 else: 37 cls = ConnectionsPool 38 if isinstance(address, str): 39 address, options = parse_url(address) 40 db = options.setdefault('db', db) 41 password = options.setdefault('password', password) 42 encoding = options.setdefault('encoding', encoding) 43 create_connection_timeout = options.setdefault( 44 'timeout', create_connection_timeout) 45 if 'ssl' in options: 46 assert options['ssl'] or (not options['ssl'] and not ssl), ( 47 "Conflicting ssl options are set", options['ssl'], ssl) 48 ssl = ssl or options['ssl'] 49 # TODO: minsize/maxsize 50 51 pool = cls(address, db, password, encoding, 52 minsize=minsize, maxsize=maxsize, 53 ssl=ssl, parser=parser, 54 create_connection_timeout=create_connection_timeout, 55 connection_cls=connection_cls, 56 loop=loop) 57 try: 58 await pool._fill_free(override_min=False) 59 except Exception: 60 pool.close() 61 await pool.wait_closed() 62 raise 63 return pool 64 65 66class ConnectionsPool(AbcPool): 67 """Redis connections pool.""" 68 69 def __init__(self, address, db=None, password=None, encoding=None, 70 *, minsize, maxsize, ssl=None, parser=None, 71 create_connection_timeout=None, 72 connection_cls=None, 73 loop=None): 74 assert isinstance(minsize, int) and minsize >= 0, ( 75 "minsize must be int >= 0", minsize, type(minsize)) 76 assert maxsize is not None, "Arbitrary pool size is disallowed." 77 assert isinstance(maxsize, int) and maxsize > 0, ( 78 "maxsize must be int > 0", maxsize, type(maxsize)) 79 assert minsize <= maxsize, ( 80 "Invalid pool min/max sizes", minsize, maxsize) 81 if loop is not None and sys.version_info >= (3, 8): 82 warnings.warn("The loop argument is deprecated", 83 DeprecationWarning) 84 self._address = address 85 self._db = db 86 self._password = password 87 self._ssl = ssl 88 self._encoding = encoding 89 self._parser_class = parser 90 self._minsize = minsize 91 self._create_connection_timeout = create_connection_timeout 92 self._pool = collections.deque(maxlen=maxsize) 93 self._used = set() 94 self._acquiring = 0 95 self._cond = asyncio.Condition(lock=Lock()) 96 self._close_state = CloseEvent(self._do_close) 97 self._pubsub_conn = None 98 self._connection_cls = connection_cls 99 100 def __repr__(self): 101 return '<{} [db:{}, size:[{}:{}], free:{}]>'.format( 102 self.__class__.__name__, self.db, 103 self.minsize, self.maxsize, self.freesize) 104 105 @property 106 def minsize(self): 107 """Minimum pool size.""" 108 return self._minsize 109 110 @property 111 def maxsize(self): 112 """Maximum pool size.""" 113 return self._pool.maxlen 114 115 @property 116 def size(self): 117 """Current pool size.""" 118 return self.freesize + len(self._used) + self._acquiring 119 120 @property 121 def freesize(self): 122 """Current number of free connections.""" 123 return len(self._pool) 124 125 @property 126 def address(self): 127 return self._address 128 129 async def clear(self): 130 """Clear pool connections. 131 132 Close and remove all free connections. 133 """ 134 async with self._cond: 135 await self._do_clear() 136 137 async def _do_clear(self): 138 waiters = [] 139 while self._pool: 140 conn = self._pool.popleft() 141 conn.close() 142 waiters.append(conn.wait_closed()) 143 await asyncio.gather(*waiters) 144 145 async def _do_close(self): 146 async with self._cond: 147 assert not self._acquiring, self._acquiring 148 waiters = [] 149 while self._pool: 150 conn = self._pool.popleft() 151 conn.close() 152 waiters.append(conn.wait_closed()) 153 for conn in self._used: 154 conn.close() 155 waiters.append(conn.wait_closed()) 156 await asyncio.gather(*waiters) 157 # TODO: close _pubsub_conn connection 158 logger.debug("Closed %d connection(s)", len(waiters)) 159 160 def close(self): 161 """Close all free and in-progress connections and mark pool as closed. 162 """ 163 if not self._close_state.is_set(): 164 self._close_state.set() 165 166 @property 167 def closed(self): 168 """True if pool is closed.""" 169 return self._close_state.is_set() 170 171 async def wait_closed(self): 172 """Wait until pool gets closed.""" 173 await self._close_state.wait() 174 175 @property 176 def db(self): 177 """Currently selected db index.""" 178 return self._db or 0 179 180 @property 181 def encoding(self): 182 """Current set codec or None.""" 183 return self._encoding 184 185 def execute(self, command, *args, **kw): 186 """Executes redis command in a free connection and returns 187 future waiting for result. 188 189 Picks connection from free pool and send command through 190 that connection. 191 If no connection is found, returns coroutine waiting for 192 free connection to execute command. 193 """ 194 conn, address = self.get_connection(command, args) 195 if conn is not None: 196 fut = conn.execute(command, *args, **kw) 197 return self._check_result(fut, command, args, kw) 198 else: 199 coro = self._wait_execute(address, command, args, kw) 200 return self._check_result(coro, command, args, kw) 201 202 def execute_pubsub(self, command, *channels): 203 """Executes Redis (p)subscribe/(p)unsubscribe commands. 204 205 ConnectionsPool picks separate connection for pub/sub 206 and uses it until explicitly closed or disconnected 207 (unsubscribing from all channels/patterns will leave connection 208 locked for pub/sub use). 209 210 There is no auto-reconnect for this PUB/SUB connection. 211 212 Returns asyncio.gather coroutine waiting for all channels/patterns 213 to receive answers. 214 """ 215 conn, address = self.get_connection(command) 216 if conn is not None: 217 return conn.execute_pubsub(command, *channels) 218 else: 219 return self._wait_execute_pubsub(address, command, channels, {}) 220 221 def get_connection(self, command, args=()): 222 """Get free connection from pool. 223 224 Returns connection. 225 """ 226 # TODO: find a better way to determine if connection is free 227 # and not havily used. 228 command = command.upper().strip() 229 is_pubsub = command in _PUBSUB_COMMANDS 230 if is_pubsub and self._pubsub_conn: 231 if not self._pubsub_conn.closed: 232 return self._pubsub_conn, self._pubsub_conn.address 233 self._pubsub_conn = None 234 for i in range(self.freesize): 235 conn = self._pool[0] 236 self._pool.rotate(1) 237 if conn.closed: # or conn._waiters: (eg: busy connection) 238 continue 239 if conn.in_pubsub: 240 continue 241 if is_pubsub: 242 self._pubsub_conn = conn 243 self._pool.remove(conn) 244 self._used.add(conn) 245 return conn, conn.address 246 return None, self._address # figure out 247 248 def _check_result(self, fut, *data): 249 """Hook to check result or catch exception (like MovedError). 250 251 This method can be coroutine. 252 """ 253 return fut 254 255 async def _wait_execute(self, address, command, args, kw): 256 """Acquire connection and execute command.""" 257 conn = await self.acquire(command, args) 258 try: 259 return (await conn.execute(command, *args, **kw)) 260 finally: 261 self.release(conn) 262 263 async def _wait_execute_pubsub(self, address, command, args, kw): 264 if self.closed: 265 raise PoolClosedError("Pool is closed") 266 assert self._pubsub_conn is None or self._pubsub_conn.closed, ( 267 "Expected no or closed connection", self._pubsub_conn) 268 async with self._cond: 269 if self.closed: 270 raise PoolClosedError("Pool is closed") 271 if self._pubsub_conn is None or self._pubsub_conn.closed: 272 conn = await self._create_new_connection(address) 273 self._pubsub_conn = conn 274 conn = self._pubsub_conn 275 return (await conn.execute_pubsub(command, *args, **kw)) 276 277 async def select(self, db): 278 """Changes db index for all free connections. 279 280 All previously acquired connections will be closed when released. 281 """ 282 res = True 283 async with self._cond: 284 for i in range(self.freesize): 285 res = res and (await self._pool[i].select(db)) 286 self._db = db 287 return res 288 289 async def auth(self, password): 290 self._password = password 291 async with self._cond: 292 for i in range(self.freesize): 293 await self._pool[i].auth(password) 294 295 @property 296 def in_pubsub(self): 297 if self._pubsub_conn and not self._pubsub_conn.closed: 298 return self._pubsub_conn.in_pubsub 299 return 0 300 301 @property 302 def pubsub_channels(self): 303 if self._pubsub_conn and not self._pubsub_conn.closed: 304 return self._pubsub_conn.pubsub_channels 305 return types.MappingProxyType({}) 306 307 @property 308 def pubsub_patterns(self): 309 if self._pubsub_conn and not self._pubsub_conn.closed: 310 return self._pubsub_conn.pubsub_patterns 311 return types.MappingProxyType({}) 312 313 async def acquire(self, command=None, args=()): 314 """Acquires a connection from free pool. 315 316 Creates new connection if needed. 317 """ 318 if self.closed: 319 raise PoolClosedError("Pool is closed") 320 async with self._cond: 321 if self.closed: 322 raise PoolClosedError("Pool is closed") 323 while True: 324 await self._fill_free(override_min=True) 325 if self.freesize: 326 conn = self._pool.popleft() 327 assert not conn.closed, conn 328 assert conn not in self._used, (conn, self._used) 329 self._used.add(conn) 330 return conn 331 else: 332 await self._cond.wait() 333 334 def release(self, conn): 335 """Returns used connection back into pool. 336 337 When returned connection has db index that differs from one in pool 338 the connection will be closed and dropped. 339 When queue of free connections is full the connection will be dropped. 340 """ 341 assert conn in self._used, ( 342 "Invalid connection, maybe from other pool", conn) 343 self._used.remove(conn) 344 if not conn.closed: 345 if conn.in_transaction: 346 logger.warning( 347 "Connection %r is in transaction, closing it.", conn) 348 conn.close() 349 elif conn.in_pubsub: 350 logger.warning( 351 "Connection %r is in subscribe mode, closing it.", conn) 352 conn.close() 353 elif conn._waiters: 354 logger.warning( 355 "Connection %r has pending commands, closing it.", conn) 356 conn.close() 357 elif conn.db == self.db: 358 if self.maxsize and self.freesize < self.maxsize: 359 self._pool.append(conn) 360 else: 361 # consider this connection as old and close it. 362 conn.close() 363 else: 364 conn.close() 365 # FIXME: check event loop is not closed 366 asyncio.ensure_future(self._wakeup()) 367 368 def _drop_closed(self): 369 for i in range(self.freesize): 370 conn = self._pool[0] 371 if conn.closed: 372 self._pool.popleft() 373 else: 374 self._pool.rotate(-1) 375 376 async def _fill_free(self, *, override_min): 377 # drop closed connections first 378 self._drop_closed() 379 # address = self._address 380 while self.size < self.minsize: 381 self._acquiring += 1 382 try: 383 conn = await self._create_new_connection(self._address) 384 # check the healthy of that connection, if 385 # something went wrong just trigger the Exception 386 await conn.execute('ping') 387 self._pool.append(conn) 388 finally: 389 self._acquiring -= 1 390 # connection may be closed at yield point 391 self._drop_closed() 392 if self.freesize: 393 return 394 if override_min: 395 while not self._pool and self.size < self.maxsize: 396 self._acquiring += 1 397 try: 398 conn = await self._create_new_connection(self._address) 399 self._pool.append(conn) 400 finally: 401 self._acquiring -= 1 402 # connection may be closed at yield point 403 self._drop_closed() 404 405 def _create_new_connection(self, address): 406 return create_connection(address, 407 db=self._db, 408 password=self._password, 409 ssl=self._ssl, 410 encoding=self._encoding, 411 parser=self._parser_class, 412 timeout=self._create_connection_timeout, 413 connection_cls=self._connection_cls, 414 ) 415 416 async def _wakeup(self, closing_conn=None): 417 async with self._cond: 418 self._cond.notify() 419 if closing_conn is not None: 420 await closing_conn.wait_closed() 421 422 def __enter__(self): 423 raise RuntimeError( 424 "'await' should be used as a context manager expression") 425 426 def __exit__(self, *args): 427 pass # pragma: nocover 428 429 def __await__(self): 430 # To make `with await pool` work 431 conn = yield from self.acquire().__await__() 432 return _ConnectionContextManager(self, conn) 433 434 def get(self): 435 '''Return async context manager for working with connection. 436 437 async with pool.get() as conn: 438 await conn.execute('get', 'my-key') 439 ''' 440 return _AsyncConnectionContextManager(self) 441 442 443class _ConnectionContextManager: 444 445 __slots__ = ('_pool', '_conn') 446 447 def __init__(self, pool, conn): 448 self._pool = pool 449 self._conn = conn 450 451 def __enter__(self): 452 return self._conn 453 454 def __exit__(self, exc_type, exc_value, tb): 455 try: 456 self._pool.release(self._conn) 457 finally: 458 self._pool = None 459 self._conn = None 460 461 462class _AsyncConnectionContextManager: 463 464 __slots__ = ('_pool', '_conn') 465 466 def __init__(self, pool): 467 self._pool = pool 468 self._conn = None 469 470 async def __aenter__(self): 471 conn = await self._pool.acquire() 472 self._conn = conn 473 return self._conn 474 475 async def __aexit__(self, exc_type, exc_value, tb): 476 try: 477 self._pool.release(self._conn) 478 finally: 479 self._pool = None 480 self._conn = None 481