Skip to content

Commit 02e7c51

Browse files
committed
Allow for map_async to run items in parallel
Use an asyncio.Queue of the tasks to ensure that arrival and departure order of elements match. Asserts back pressure when a new value arrives via update but the work queue is full. Because asyncio.Queue cannot peak, the parallelism factor is not precise as the worker callback can have either zero or one task in hand but it must free up a slot in the queue to do so. Under pressure, the parallelism will generally be `(parallelism + 1)` instead of `parallelism` as given in the `__init__` as one Future will be in the awaited in the worker callback while the queue fills up from update calls.
1 parent cbdfd2b commit 02e7c51

File tree

2 files changed

+37
-16
lines changed

2 files changed

+37
-16
lines changed

streamz/core.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -720,15 +720,16 @@ def update(self, x, who=None, metadata=None):
720720

721721
@Stream.register_api()
722722
class map_async(Stream):
723-
""" Apply an async function to every element in the stream
723+
""" Apply an async function to every element in the stream, preserving order
724+
even when evaluating multiple inputs in parallel.
724725
725726
Parameters
726727
----------
727728
func: async callable
728729
*args :
729730
The arguments to pass to the function.
730-
buffer_size:
731-
The max size of the input buffer, default value is unlimited
731+
parallelism:
732+
The maximum number of parallel Tasks for evaluating func, default value is 1
732733
**kwargs:
733734
Keyword arguments to pass to func
734735
@@ -747,32 +748,31 @@ class map_async(Stream):
747748
4
748749
6
749750
8
750-
751751
"""
752-
def __init__(self, upstream, func, *args, buffer_size=0, **kwargs):
752+
def __init__(self, upstream, func, *args, parallelism=1, **kwargs):
753753
self.func = func
754754
stream_name = kwargs.pop('stream_name', None)
755755
self.kwargs = kwargs
756756
self.args = args
757-
self.input_queue = asyncio.Queue(maxsize=buffer_size)
757+
self.work_queue = asyncio.Queue(maxsize=parallelism)
758758

759759
Stream.__init__(self, upstream, stream_name=stream_name, ensure_io_loop=True)
760-
self.input_task = self._create_task(self.input_callback())
760+
self.work_task = self._create_task(self.work_callback())
761761

762762
def update(self, x, who=None, metadata=None):
763-
coro = self.func(x, *self.args, **self.kwargs)
764-
self._retain_refs(metadata)
765-
return self._create_task(self.input_queue.put((coro, metadata)))
763+
return self._create_task(self._insert_job(x, metadata))
766764

767765
def _create_task(self, coro):
766+
if gen.is_future(coro):
767+
return coro
768768
return self.loop.asyncio_loop.create_task(coro)
769769

770-
async def input_callback(self):
770+
async def work_callback(self):
771771
while True:
772772
try:
773-
coro, metadata = await self.input_queue.get()
774-
self.input_queue.task_done()
775-
result = await coro
773+
task, metadata = await self.work_queue.get()
774+
self.work_queue.task_done()
775+
result = await task
776776
except Exception as e:
777777
logger.exception(e)
778778
raise
@@ -782,6 +782,21 @@ async def input_callback(self):
782782
await asyncio.gather(*results)
783783
self._release_refs(metadata)
784784

785+
async def _wait_for_work_slot(self):
786+
while self.work_queue.full():
787+
await asyncio.sleep(0)
788+
789+
async def _insert_job(self, x, metadata):
790+
try:
791+
await self._wait_for_work_slot()
792+
coro = self.func(x, *self.args, **self.kwargs)
793+
task = self._create_task(coro)
794+
await self.work_queue.put((task, metadata))
795+
self._retain_refs(metadata)
796+
except Exception as e:
797+
logger.exception(e)
798+
raise
799+
785800

786801
@Stream.register_api()
787802
class starmap(Stream):

streamz/tests/test_core.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,13 @@ def add_tor(x=0, y=0):
133133
return x + y
134134

135135
async def add_native(x=0, y=0):
136+
await asyncio.sleep(0.1)
136137
return x + y
137138

138139
source = Stream(asynchronous=True)
139-
L = source.map_async(add_tor, y=1).map_async(add_native, y=2).sink_to_list()
140+
L = source.map_async(add_tor, y=1).map_async(add_native, parallelism=2, y=2).buffer(1).sink_to_list()
140141

142+
start = time()
141143
yield source.emit(0)
142144
yield source.emit(1)
143145
yield source.emit(2)
@@ -146,6 +148,7 @@ def fail_func():
146148
assert L == [3, 4, 5]
147149

148150
yield await_for(lambda: L == [3, 4, 5], 1, fail_func=fail_func)
151+
assert (time() - start) == pytest.approx(0.1, abs=4e-3)
149152

150153

151154
@pytest.mark.asyncio
@@ -155,11 +158,13 @@ def add_tor(x=0, y=0):
155158
return x + y
156159

157160
async def add_native(x=0, y=0):
161+
await asyncio.sleep(0.1)
158162
return x + y
159163

160164
source = Stream(asynchronous=True)
161-
L = source.map_async(add_tor, y=1).map_async(add_native, y=2).sink_to_list()
165+
L = source.map_async(add_tor, y=1).map_async(add_native, parallelism=2, y=2).sink_to_list()
162166

167+
start = time()
163168
await source.emit(0)
164169
await source.emit(1)
165170
await source.emit(2)
@@ -168,6 +173,7 @@ def fail_func():
168173
assert L == [3, 4, 5]
169174

170175
await await_for(lambda: L == [3, 4, 5], 1, fail_func=fail_func)
176+
assert (time() - start) == pytest.approx(0.1, abs=4e-3)
171177

172178

173179
def test_map_args():

0 commit comments

Comments
 (0)