From deadd06b4786967e74b9d23a21c3c7e9495e249d Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Wed, 11 Feb 2026 19:18:31 +0700 Subject: [PATCH 1/5] Add timing metadata for streaming ParallelExecution responses [ML-11879](https://iguazio.atlassian.net/browse/ML-11879) Streaming responses from `ParallelExecution` were missing the when and `microsec` timing metadata that non-streaming responses include. This metadata is required for model monitoring in MLRun. Changes * Add `_StreamingResult` class to wrap streaming generators with timing info * Set timing metadata on events before emitting streaming chunks * Handle both in-process streaming (`_StreamingResult`) and process-based streaming (raw generators) Notes For streaming, `microsec` is set to `None` since total runtime isn't available until streaming completes For process-based streaming, when uses the timestamp when chunks start arriving (timing from subprocess isn't available) --- storey/flow.py | 53 ++++++++++++++++++++++++++++++++++------- tests/test_streaming.py | 34 ++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 9 deletions(-) diff --git a/storey/flow.py b/storey/flow.py index f7a0fc80..e0a36ba1 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -1728,6 +1728,20 @@ def __init__(self, runnable_name: str, data: Any, runtime: float, timestamp: dat self.timestamp = timestamp +class _StreamingResult: + """Wraps a streaming generator with timing metadata for model monitoring.""" + + def __init__( + self, + runnable_name: str, + generator: Generator | AsyncGenerator, + timestamp: datetime.datetime, + ): + self.runnable_name = runnable_name + self.generator = generator + self.timestamp = timestamp + + class ParallelExecutionMechanisms(str, enum.Enum): process_pool = "process_pool" dedicated_process = "dedicated_process" @@ -1840,9 +1854,9 @@ def _run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any: start = time.monotonic() try: result = self.run(body, path, origin_name) - # Return generator directly for streaming support + # Return streaming result with timing metadata for streaming support if _is_generator(result): - return result + return _StreamingResult(origin_name or self.name, result, timestamp) body = result except Exception as e: if self._raise_exception: @@ -1858,9 +1872,9 @@ async def _async_run(self, body: Any, path: str, origin_name: Optional[str] = No try: result = self.run_async(body, path, origin_name) - # Return generator directly for streaming support + # Return streaming result with timing metadata for streaming support if _is_generator(result): - return result + return _StreamingResult(origin_name or self.name, result, timestamp) # Await if coroutine if asyncio.iscoroutine(result): @@ -1902,7 +1916,10 @@ def _streaming_run_wrapper( sending each chunk through the multiprocessing queue. """ try: - for chunk in runnable._run(input, path, origin_name): + result = runnable._run(input, path, origin_name) + # Unwrap _StreamingResult to get the generator + generator = result.generator if isinstance(result, _StreamingResult) else result + for chunk in generator: queue.put(("chunk", chunk)) queue.put(("done", None)) except Exception as e: @@ -2270,15 +2287,33 @@ async def _do(self, event): # Check for streaming response (only when a single runnable is selected) if len(runnables) == 1 and results: result = results[0] - # Check if the result is a generator (streaming response) - if _is_generator(result): + # Check if the result is a streaming result (contains generator with timing metadata) + if isinstance(result, _StreamingResult): + # Set timing metadata on the event before emitting chunks + # For streaming, microsec is None since we don't have total runtime + metadata = { + "microsec": None, + "when": result.timestamp.isoformat(sep=" ", timespec="microseconds"), + } + self.set_event_metadata(event, metadata) + await self._emit_streaming_chunks(event, result.generator) + return None + # Handle raw generator from process-based streaming (no timing info from subprocess) + elif _is_generator(result): + # Use current timestamp as fallback for process-based streaming + timestamp = datetime.datetime.now(tz=datetime.timezone.utc) + metadata = { + "microsec": None, + "when": timestamp.isoformat(sep=" ", timespec="microseconds"), + } + self.set_event_metadata(event, metadata) await self._emit_streaming_chunks(event, result) return None # Non-streaming path - # Check if any results are generators (not allowed with multiple runnables) + # Check if any results are streaming (not allowed with multiple runnables) for result in results: - if _is_generator(result): + if isinstance(result, _StreamingResult): raise StreamingError( "Streaming is not supported when multiple runnables are selected. " "Streaming runnables must be the only runnable selected for an event." diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 84dee94b..350f6938 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -1338,6 +1338,40 @@ def test_parallel_execution_streaming_error_propagation(self, execution_mechanis with pytest.raises(expected_error, match="Simulated streaming error"): controller.await_termination() + def test_parallel_execution_streaming_single_runnable_sets_metadata(self): + """Test that streaming ParallelExecution with single runnable sets timing metadata. + + This mirrors the non-streaming behavior where _metadata includes 'when' and 'microsec'. + After Collector aggregates chunks, the collected event should have timing metadata. + """ + runnable = StreamingRunnable(name="streamer") + controller = build_flow( + [ + SyncEmitSource(), + ParallelExecution( + runnables=[runnable], + execution_mechanism_by_runnable_name={"streamer": ParallelExecutionMechanisms.naive}, + ), + Collector(), + Reduce([], lambda acc, x: acc + [x], full_event=True), + ] + ).run() + + try: + controller.emit("test") + finally: + controller.terminate() + result = controller.await_termination() + + assert len(result) == 1 + event = result[0] + assert hasattr(event, "_metadata"), "Expected event to have _metadata attribute" + metadata = event._metadata + assert "when" in metadata, "Expected _metadata to include 'when' field" + assert "microsec" in metadata, "Expected _metadata to include 'microsec' field" + # Verify 'when' is a valid ISO timestamp string + assert isinstance(metadata["when"], str), "Expected 'when' to be a string" + class TestStreamingGraphSplits: """Tests for streaming through branching graph topologies.""" From 375cb12b55a21c6a6612bb84e9d310405b126beb Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Thu, 12 Feb 2026 13:47:47 +0700 Subject: [PATCH 2/5] Implement latency metadata in Collector --- storey/steps/collector.py | 50 +++++++++++++++++++++++++++++++++++++++ tests/test_streaming.py | 4 ++++ 2 files changed, 54 insertions(+) diff --git a/storey/steps/collector.py b/storey/steps/collector.py index 3f0b1ba4..c9dee4f9 100644 --- a/storey/steps/collector.py +++ b/storey/steps/collector.py @@ -13,6 +13,7 @@ # limitations under the License. # import copy +import datetime from collections import defaultdict from ..dtypes import StreamCompletion, _termination_obj @@ -46,6 +47,51 @@ def __init__(self, expected_completions: int = 1, **kwargs): lambda: {"chunks": [], "completions": 0, "first_event": None} ) + def _calculate_streaming_duration(self, event): + """ + Calculate total streaming duration and update event metadata with microsec. + + Uses the 'when' timestamp from the first chunk's metadata (set by ParallelExecution) + to calculate total elapsed time from stream start to completion. + """ + if not hasattr(event, "_metadata") or not event._metadata: + return + + # Get the start timestamp - could be at top level or nested under model name + when_str = None + if "when" in event._metadata: + when_str = event._metadata.get("when") + else: + # For multi-model (ModelRunnerStep), metadata is nested under model name + for value in event._metadata.values(): + if isinstance(value, dict) and "when" in value: + when_str = value.get("when") + break + + if not when_str: + return + + try: + # Parse the ISO format timestamp + start_time = datetime.datetime.fromisoformat(when_str) + now = datetime.datetime.now(tz=datetime.timezone.utc) + elapsed_microsec = int((now - start_time).total_seconds() * 1_000_000) + + # Update metadata with calculated microsec + if "when" in event._metadata: + event._metadata["microsec"] = elapsed_microsec + else: + # For nested metadata (ModelRunnerStep), update in the same nested dict + for value in event._metadata.values(): + if isinstance(value, dict) and "when" in value: + value["microsec"] = elapsed_microsec + break + except (ValueError, TypeError) as exc: + if self.logger: + self.logger.warning( + f"Failed to calculate streaming duration from 'when' timestamp '{when_str}': {exc}" + ) + async def _do(self, event): if event is _termination_obj: return await self._do_downstream(_termination_obj) @@ -73,6 +119,10 @@ async def _do(self, event): del collected_event.streaming_step if hasattr(collected_event, "chunk_id"): del collected_event.chunk_id + + # Calculate total streaming duration (microsec) if timing metadata exists + self._calculate_streaming_duration(collected_event) + await self._do_downstream(collected_event) # Clean up diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 350f6938..f95a547d 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -1343,6 +1343,7 @@ def test_parallel_execution_streaming_single_runnable_sets_metadata(self): This mirrors the non-streaming behavior where _metadata includes 'when' and 'microsec'. After Collector aggregates chunks, the collected event should have timing metadata. + The 'microsec' field should contain the total streaming duration calculated by Collector. """ runnable = StreamingRunnable(name="streamer") controller = build_flow( @@ -1371,6 +1372,9 @@ def test_parallel_execution_streaming_single_runnable_sets_metadata(self): assert "microsec" in metadata, "Expected _metadata to include 'microsec' field" # Verify 'when' is a valid ISO timestamp string assert isinstance(metadata["when"], str), "Expected 'when' to be a string" + # Verify 'microsec' is a positive integer (total streaming duration calculated by Collector) + assert isinstance(metadata["microsec"], int), "Expected 'microsec' to be an integer" + assert metadata["microsec"] >= 0, "Expected 'microsec' to be non-negative" class TestStreamingGraphSplits: From 7d187cf82b7e55a8581b7bc90a3f158d63381ddf Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Thu, 12 Feb 2026 15:08:34 +0700 Subject: [PATCH 3/5] Lint --- storey/steps/collector.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/storey/steps/collector.py b/storey/steps/collector.py index c9dee4f9..6d5e9562 100644 --- a/storey/steps/collector.py +++ b/storey/steps/collector.py @@ -88,9 +88,7 @@ def _calculate_streaming_duration(self, event): break except (ValueError, TypeError) as exc: if self.logger: - self.logger.warning( - f"Failed to calculate streaming duration from 'when' timestamp '{when_str}': {exc}" - ) + self.logger.warning(f"Failed to calculate streaming duration from 'when' timestamp '{when_str}': {exc}") async def _do(self, event): if event is _termination_obj: From 3bacea13610a669e8120453246286b63d2af5165 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Thu, 12 Feb 2026 16:03:45 +0700 Subject: [PATCH 4/5] Remove code that handles unsupported multi-model stream result --- storey/steps/collector.py | 28 +++++----------------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/storey/steps/collector.py b/storey/steps/collector.py index 6d5e9562..76a67d57 100644 --- a/storey/steps/collector.py +++ b/storey/steps/collector.py @@ -53,39 +53,21 @@ def _calculate_streaming_duration(self, event): Uses the 'when' timestamp from the first chunk's metadata (set by ParallelExecution) to calculate total elapsed time from stream start to completion. + + Streaming is only supported with a single selected runnable, so metadata is always + flat (top-level 'when' and 'microsec'), never nested under model names. """ if not hasattr(event, "_metadata") or not event._metadata: return - # Get the start timestamp - could be at top level or nested under model name - when_str = None - if "when" in event._metadata: - when_str = event._metadata.get("when") - else: - # For multi-model (ModelRunnerStep), metadata is nested under model name - for value in event._metadata.values(): - if isinstance(value, dict) and "when" in value: - when_str = value.get("when") - break - + when_str = event._metadata.get("when") if not when_str: return try: - # Parse the ISO format timestamp start_time = datetime.datetime.fromisoformat(when_str) now = datetime.datetime.now(tz=datetime.timezone.utc) - elapsed_microsec = int((now - start_time).total_seconds() * 1_000_000) - - # Update metadata with calculated microsec - if "when" in event._metadata: - event._metadata["microsec"] = elapsed_microsec - else: - # For nested metadata (ModelRunnerStep), update in the same nested dict - for value in event._metadata.values(): - if isinstance(value, dict) and "when" in value: - value["microsec"] = elapsed_microsec - break + event._metadata["microsec"] = int((now - start_time).total_seconds() * 1_000_000) except (ValueError, TypeError) as exc: if self.logger: self.logger.warning(f"Failed to calculate streaming duration from 'when' timestamp '{when_str}': {exc}") From f130d88b9c7911d7384ee1d542e783f38428dce9 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Thu, 12 Feb 2026 16:08:00 +0700 Subject: [PATCH 5/5] Add missing unit test for `StreamingError` --- tests/test_streaming.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index f95a547d..35667ebe 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -1084,6 +1084,32 @@ def stream_chunks(x): asyncio.run(_test()) + def test_streaming_with_multiple_runnables_raises_error(self): + """Test that streaming raises an error when multiple runnables are selected.""" + streaming = StreamingRunnable(name="streamer") + non_streaming = NonStreamingRunnable(name="non_streamer") + + controller = build_flow( + [ + SyncEmitSource(), + ParallelExecution( + runnables=[streaming, non_streaming], + execution_mechanism_by_runnable_name={ + "streamer": ParallelExecutionMechanisms.naive, + "non_streamer": ParallelExecutionMechanisms.naive, + }, + ), + Reduce([], lambda acc, x: acc + [x]), + ] + ).run() + + try: + controller.emit("test") + finally: + controller.terminate() + with pytest.raises(StreamingError, match="Streaming is not supported when multiple runnables are selected"): + controller.await_termination() + class TestStreamingWithIntermediateSteps: """Tests for streaming through intermediate non-streaming steps."""