From 73b54eabccbf0f2931d8e06d147fdfb1254e0a89 Mon Sep 17 00:00:00 2001 From: Eoghan Murray Date: Wed, 26 May 2021 14:01:03 +0100 Subject: [PATCH] SCAN: Iterate over cluster nodes in parallel rather than sequentially - with a MATCH scan I measured this as iterating over 2.5x keys per second compared to previous version (with 8 cluster nodes) --- aredis/commands/iter.py | 37 ++++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/aredis/commands/iter.py b/aredis/commands/iter.py index 5fabd7fb..3837b375 100644 --- a/aredis/commands/iter.py +++ b/aredis/commands/iter.py @@ -1,6 +1,7 @@ #!/usr/bin/python # -*- coding: utf-8 -*- from collections import defaultdict +import asyncio class IterCommandMixin: @@ -81,17 +82,31 @@ async def zscan_iter(self, name, match=None, count=None, class ClusterIterCommandMixin(IterCommandMixin): async def scan_iter(self, match=None, count=None): + + async def iterate_node(node, queue): + nonlocal match, count + cursor = '0' + while cursor != 0: + pieces = [cursor] + if match is not None: + pieces.extend(['MATCH', match]) + if count is not None: + pieces.extend(['COUNT', count]) + response = await self.execute_command_on_nodes( + [node], 'SCAN', *pieces) + cursor, data = list(response.values())[0] + for item in data: + await queue.put(item) # blocks if queue is full + + # maxsize ensures we don't pull too much data into + # memory if we are not processing it yet + queue = asyncio.Queue(maxsize=1000) + tasks = [] nodes = await self.cluster_nodes() for node in nodes: if 'master' in node['flags']: - cursor = '0' - while cursor != 0: - pieces = [cursor] - if match is not None: - pieces.extend(['MATCH', match]) - if count is not None: - pieces.extend(['COUNT', count]) - response = await self.execute_command_on_nodes([node], 'SCAN', *pieces) - cursor, data = list(response.values())[0] - for item in data: - yield item + t = asyncio.create_task(iterate_node(node, queue)) + tasks.append(t) + + while not all(t.done() for t in tasks): + yield await queue.get()