diff --git a/aredis/cache.py b/aredis/cache.py index 91c65d92..36eba1e8 100644 --- a/aredis/cache.py +++ b/aredis/cache.py @@ -175,13 +175,16 @@ async def delete_pattern(self, pattern, count=None): Deletes cache according to pattern in redis, delete `count` keys each time """ - cursor = '0' + cursor = 0 count_deleted = 0 - while cursor != 0: + while True: cursor, identities = await self.client.scan( cursor=cursor, match=pattern, count=count ) - count_deleted += await self.client.delete(*identities) + if identities: + count_deleted += await self.client.delete(*identities) + if cursor == 0: + break return count_deleted async def exist(self, key, param=None): diff --git a/aredis/commands/transaction.py b/aredis/commands/transaction.py index 2171de80..e1953d9d 100644 --- a/aredis/commands/transaction.py +++ b/aredis/commands/transaction.py @@ -33,8 +33,7 @@ async def transaction(self, func, *watches, **kwargs): except WatchError: if watch_delay is not None and watch_delay > 0: await asyncio.sleep( - watch_delay, - loop=self.connection_pool.loop + watch_delay ) continue @@ -75,7 +74,6 @@ async def transaction(self, func, *watches, **kwargs): except WatchError: if watch_delay is not None and watch_delay > 0: await asyncio.sleep( - watch_delay, - loop=self.connection_pool.loop + watch_delay ) continue diff --git a/aredis/connection.py b/aredis/connection.py index 6dc79050..f674a2e9 100755 --- a/aredis/connection.py +++ b/aredis/connection.py @@ -416,7 +416,7 @@ async def connect(self): except aredis.compat.CancelledError: raise except Exception as exc: - raise ConnectionError() + raise ConnectionError() from exc # run any user callbacks. right now the only internal callback # is for pubsub channel/pattern resubscription for callback in self._connect_callbacks: @@ -593,13 +593,23 @@ def __init__(self, host='127.0.0.1', port=6379, password=None, self.socket_keepalive_options = socket_keepalive_options or {} async def _connect(self): + if LOOP_DEPRECATED: + conn_coro = asyncio.open_connection( + host=self.host, + port=self.port, + ssl=self.ssl_context, + ) + else: + conn_coro = asyncio.open_connection( + host=self.host, + port=self.port, + ssl=self.ssl_context, + loop=self.loop, + ) reader, writer = await exec_with_timeout( - asyncio.open_connection(host=self.host, - port=self.port, - ssl=self.ssl_context, - loop=self.loop), + conn_coro, self._connect_timeout, - loop=self.loop + loop=self.loop, ) self._reader = reader self._writer = writer @@ -642,12 +652,21 @@ def __init__(self, path='', password=None, } async def _connect(self): + if LOOP_DEPRECATED: + conn_coro = asyncio.open_unix_connection( + path=self.path, + ssl=self.ssl_context, + ) + else: + conn_coro = asyncio.open_unix_connection( + path=self.path, + ssl=self.ssl_context, + loop=self.loop, + ) reader, writer = await exec_with_timeout( - asyncio.open_unix_connection(path=self.path, - ssl=self.ssl_context, - loop=self.loop), + conn_coro, self._connect_timeout, - loop=self.loop + loop=self.loop, ) self._reader = reader self._writer = writer diff --git a/aredis/lock.py b/aredis/lock.py index 2cd47d54..f643b519 100644 --- a/aredis/lock.py +++ b/aredis/lock.py @@ -119,7 +119,7 @@ async def acquire(self, blocking=None, blocking_timeout=None): return False if stop_trying_at is not None and mod_time.time() > stop_trying_at: return False - await asyncio.sleep(sleep, loop=self.redis.connection_pool.loop) + await asyncio.sleep(sleep) async def do_acquire(self, token): if self.timeout: @@ -347,7 +347,7 @@ async def acquire(self, blocking=None, blocking_timeout=None): return False if not blocking or mod_time.time() > stop_trying_at: return False - await asyncio.sleep(sleep, loop=self.redis.connection_pool.loop) + await asyncio.sleep(sleep) async def do_release(self, expected_token): await super(ClusterLock, self).do_release(expected_token) diff --git a/aredis/pipeline.py b/aredis/pipeline.py index 90368d5a..d1313dda 100644 --- a/aredis/pipeline.py +++ b/aredis/pipeline.py @@ -200,7 +200,7 @@ async def _execute_transaction(self, connection, commands, raise_on_error): # typing.Awaitable is not available in Python3.5 # so use inspect.isawaitable instead # according to issue https://github.com/NoneGG/aredis/issues/77 - if inspect.isawaitable(response): + if inspect.isawaitable(r): r = await r data.append(r) return data diff --git a/aredis/pool.py b/aredis/pool.py index 1fbb2cc7..821402f2 100644 --- a/aredis/pool.py +++ b/aredis/pool.py @@ -174,6 +174,32 @@ def __init__(self, connection_class=Connection, max_connections=None, self.reset() + def _schedule_idle_check(self, connection): + """ + Schedule an idle-connection reaper task on the right loop. + + We avoid asyncio.ensure_future() without an explicit loop because it + can attach to an unexpected loop (or fail) on modern asyncio. + """ + coro = self.disconnect_on_idle_time_exceeded(connection) + try: + if self.loop is not None: + task = self.loop.create_task(coro) + else: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.get_event_loop() + task = loop.create_task(coro) + except Exception: + # If we can't schedule a background task (e.g. no event loop), + # silently skip idle reaping. The connection will still be cleaned + # up when the pool is disconnected or the process exits. + return + + self._idle_check_tasks.add(task) + task.add_done_callback(lambda t: self._idle_check_tasks.discard(t)) + def __repr__(self): return '{}<{}>'.format( type(self).__name__, @@ -199,6 +225,7 @@ def reset(self): self._available_connections = [] self._in_use_connections = set() self._check_lock = threading.Lock() + self._idle_check_tasks = set() def _checkpid(self): if self.pid != os.getpid(): @@ -227,8 +254,7 @@ def make_connection(self): self._created_connections += 1 connection = self.connection_class(**self.connection_kwargs) if self.max_idle_time > self.idle_check_interval > 0: - # do not await the future - asyncio.ensure_future(self.disconnect_on_idle_time_exceeded(connection)) + self._schedule_idle_check(connection) return connection def release(self, connection): @@ -246,11 +272,16 @@ def release(self, connection): def disconnect(self): """Closes all connections in the pool""" - all_conns = chain(self._available_connections, - self._in_use_connections) + for task in list(self._idle_check_tasks): + task.cancel() + self._idle_check_tasks.clear() + + all_conns = list(chain(self._available_connections, self._in_use_connections)) for connection in all_conns: connection.disconnect() - self._created_connections -= 1 + self._available_connections = [] + self._in_use_connections = set() + self._created_connections = 0 class ClusterConnectionPool(ConnectionPool): @@ -301,6 +332,7 @@ def __init__(self, startup_nodes=None, connection_class=ClusterConnection, self.readonly = readonly self.max_idle_time = max_idle_time self.idle_check_interval = idle_check_interval + self.loop = self.connection_kwargs.get('loop') self.reset() if "stream_timeout" not in self.connection_kwargs: @@ -328,7 +360,12 @@ async def disconnect_on_idle_time_exceeded(self, connection): and not connection.awaiting_response): connection.disconnect() node = connection.node - self._available_connections[node['name']].remove(connection) + conn_list = self._available_connections.get(node['name']) + if conn_list is not None: + try: + conn_list.remove(connection) + except ValueError: + pass self._created_connections_per_node[node['name']] -= 1 break await asyncio.sleep(self.idle_check_interval) @@ -340,6 +377,7 @@ def reset(self): self._available_connections = {} # Dict(Node, List) self._in_use_connections = {} # Dict(Node, Set) self._check_lock = threading.Lock() + self._idle_check_tasks = set() self.initialized = False def _checkpid(self): @@ -398,8 +436,7 @@ def make_connection(self, node): # Must store node in the connection to make it eaiser to track connection.node = node if self.max_idle_time > self.idle_check_interval > 0: - # do not await the future - asyncio.ensure_future(self.disconnect_on_idle_time_exceeded(connection)) + self._schedule_idle_check(connection) return connection def release(self, connection): @@ -427,15 +464,22 @@ def release(self, connection): def disconnect(self): """Closes all connectins in the pool""" + for task in list(self._idle_check_tasks): + task.cancel() + self._idle_check_tasks.clear() + all_conns = chain( self._available_connections.values(), self._in_use_connections.values(), ) - for node_connections in all_conns: - for connection in node_connections: + for connection in list(node_connections): connection.disconnect() + self._available_connections = {} + self._in_use_connections = {} + self._created_connections_per_node = {} + def count_all_num_connections(self, node): if self.max_connections_per_node: return self._created_connections_per_node.get(node['name'], 0) diff --git a/aredis/pubsub.py b/aredis/pubsub.py index 7dbb9627..bd81b72f 100644 --- a/aredis/pubsub.py +++ b/aredis/pubsub.py @@ -1,4 +1,5 @@ import asyncio +import concurrent.futures import threading from aredis.compat import CancelledError @@ -324,10 +325,7 @@ def stop(self): if self.loop: unsubscribed = asyncio.run_coroutine_threadsafe(self.pubsub.unsubscribe(), self.loop) punsubscribed = asyncio.run_coroutine_threadsafe(self.pubsub.punsubscribe(), self.loop) - asyncio.wait( - [unsubscribed, punsubscribed], - loop=self.loop - ) + concurrent.futures.wait([unsubscribed, punsubscribed]) class ClusterPubSub(PubSub):