From e344f2436bc7f013d60a938e946588296aed5b61 Mon Sep 17 00:00:00 2001 From: Minghao Liu Date: Thu, 29 Jan 2026 00:53:44 -0500 Subject: [PATCH 1/6] Implement time delay plugin --- strax/plugins/__init__.py | 1 + strax/plugins/time_delay_plugin.py | 338 +++++++++++++++++++++++++++++ tests/test_time_delay_plugin.py | 276 +++++++++++++++++++++++ 3 files changed, 615 insertions(+) create mode 100644 strax/plugins/time_delay_plugin.py create mode 100644 tests/test_time_delay_plugin.py diff --git a/strax/plugins/__init__.py b/strax/plugins/__init__.py index ff9c0ce34..52e289d81 100644 --- a/strax/plugins/__init__.py +++ b/strax/plugins/__init__.py @@ -6,3 +6,4 @@ from .parrallel_source_plugin import * from .down_chunking_plugin import * from .exhaust_plugin import * +from .time_delay_plugin import * diff --git a/strax/plugins/time_delay_plugin.py b/strax/plugins/time_delay_plugin.py new file mode 100644 index 000000000..89135ec55 --- /dev/null +++ b/strax/plugins/time_delay_plugin.py @@ -0,0 +1,338 @@ +"""Plugin base class for algorithms that add time delays to output.""" + +import numpy as np +import strax +from .plugin import Plugin + +export, __all__ = strax.exporter() + + +@export +class TimeDelayPlugin(Plugin): + """Plugin base class for algorithms that add time delays to output. + + Use this when your algorithm shifts output timestamps forward in time, + potentially beyond input chunk boundaries. Handles variable delays with + known maximum, re-sorting, buffering across chunk boundaries, and + multi-output plugins. + + Subclasses must implement: + get_max_delay(): Return maximum possible delay in nanoseconds + compute_with_delay(**kwargs): Return delayed output data (arrays, not Chunks) + + For multi-output plugins, compute_with_delay should return a dict + mapping data_type names to numpy arrays. + + """ + + parallel = False + + def __init__(self): + super().__init__() + self._init_buffers() + + def _init_buffers(self): + """Initialize/reset all buffer state.""" + self.output_buffer = {} + self.last_output_end = 0 + self.first_output = True + self._cached_superrun = None + self._cached_subruns = None + self._min_buffered_time = float("inf") + + def get_max_delay(self): + """Return the maximum possible delay in nanoseconds.""" + raise NotImplementedError("Subclasses must implement get_max_delay()") + + def compute_with_delay(self, **kwargs): + """Compute output data with time delays already applied. + + Input arrays are numpy arrays (not Chunks). Output arrays do NOT + need to be sorted. For multi-output, return a dict mapping + data_type to arrays. + + """ + raise NotImplementedError("Subclasses must implement compute_with_delay()") + + def iter(self, iters, executor=None): + """Override iter to flush buffer at end of processing.""" + yield from super().iter(iters, executor=executor) + final_result = self._flush_buffers() + if final_result is not None: + yield final_result + + def _flush_buffers(self): + """Flush all remaining data from buffers.""" + if self.multi_output: + return self._flush_multi_output() + else: + return self._flush_single_output() + + def _flush_single_output(self): + """Flush buffer for single-output plugin.""" + buf = self.output_buffer.get(None) + if buf is None or len(buf) == 0: + return None + + buf.sort(order="time") + data_end = int(strax.endtime(buf).max()) + chunk_end = max(self.last_output_end, data_end) + + result = self._make_chunk( + data=buf, + data_type=self.provides[0], + start=self.last_output_end, + end=chunk_end, + ) + result = self.superrun_transformation( + result, self._cached_superrun, self._cached_subruns + ) + + self.output_buffer = {} + return result + + def _flush_multi_output(self): + """Flush buffers for multi-output plugin.""" + has_data = any( + len(self.output_buffer.get(dt, [])) > 0 for dt in self.provides + ) + if not has_data: + return None + + chunk_end = self.last_output_end + for data_type in self.provides: + buf = self.output_buffer.get(data_type) + if buf is not None and len(buf) > 0: + buf.sort(order="time") + self.output_buffer[data_type] = buf + data_end = int(strax.endtime(buf).max()) + chunk_end = max(chunk_end, data_end) + + result = {} + for data_type in self.provides: + buf = self.output_buffer.get( + data_type, np.empty(0, self.dtype_for(data_type)) + ) + result[data_type] = self._make_chunk( + data=buf, + data_type=data_type, + start=self.last_output_end, + end=chunk_end, + ) + + result = self.superrun_transformation( + result, self._cached_superrun, self._cached_subruns + ) + + self.output_buffer = {} + return result + + def do_compute(self, chunk_i=None, **kwargs): + """Process input, buffer output, return safe portion.""" + input_start, input_end = self._get_input_timing(kwargs) + + self._cached_superrun = self._check_subruns_uniqueness( + kwargs, {k: v.superrun for k, v in kwargs.items()} + ) + self._cached_subruns = self._check_subruns_uniqueness( + kwargs, {k: v.subruns for k, v in kwargs.items()} + ) + + input_data = {k: v.data for k, v in kwargs.items()} + new_output = self.compute_with_delay(**input_data) + + self._add_to_buffers(new_output) + + safe_boundary = input_end + + if self.multi_output: + return self._process_multi_output(safe_boundary) + else: + return self._process_single_output(safe_boundary) + + def _get_input_timing(self, kwargs): + """Extract input chunk timing.""" + if not kwargs: + raise RuntimeError("TimeDelayPlugin must have dependencies") + first_chunk = next(iter(kwargs.values())) + return first_chunk.start, first_chunk.end + + def _add_to_buffers(self, new_output): + """Add new output to appropriate buffers.""" + if self.multi_output: + self._add_to_buffers_multi(new_output) + else: + self._add_to_buffers_single(new_output) + + def _add_to_buffers_single(self, new_output): + """Add output to buffer for single-output plugin.""" + if isinstance(new_output, dict): + raise ValueError( + f"{self.__class__.__name__} is single-output, " + "compute_with_delay should not return a dict" + ) + if not isinstance(new_output, np.ndarray): + new_output = strax.dict_to_rec(new_output, dtype=self.dtype) + + if None not in self.output_buffer: + self.output_buffer[None] = new_output + elif len(new_output) > 0: + self.output_buffer[None] = np.concatenate( + [self.output_buffer[None], new_output] + ) + + def _add_to_buffers_multi(self, new_output): + """Add output to buffers for multi-output plugin.""" + if not isinstance(new_output, dict): + raise ValueError( + f"{self.__class__.__name__} is multi-output, " + "compute_with_delay must return a dict" + ) + for data_type in self.provides: + arr = new_output.get(data_type, np.empty(0, self.dtype_for(data_type))) + if not isinstance(arr, np.ndarray): + arr = strax.dict_to_rec(arr, dtype=self.dtype_for(data_type)) + + if data_type not in self.output_buffer: + self.output_buffer[data_type] = arr + elif len(arr) > 0: + self.output_buffer[data_type] = np.concatenate( + [self.output_buffer[data_type], arr] + ) + + def _process_single_output(self, safe_boundary): + """Process buffer for single-output plugin.""" + buf = self.output_buffer.get(None, np.empty(0, self.dtype)) + + if len(buf) > 0: + buf.sort(order="time") + self.output_buffer[None] = buf + + safe_data, remaining = self._split_buffer(buf, safe_boundary) + self.output_buffer[None] = remaining + + self._update_min_buffered_time() + + chunk_start, chunk_end = self._get_chunk_boundaries(safe_data, safe_boundary) + + self.last_output_end = chunk_end + self.first_output = False + + result = self._make_chunk( + data=safe_data, + data_type=self.provides[0], + start=chunk_start, + end=chunk_end, + ) + + return self.superrun_transformation( + result, self._cached_superrun, self._cached_subruns + ) + + def _process_multi_output(self, safe_boundary): + """Process buffers for multi-output plugin.""" + for data_type in self.provides: + buf = self.output_buffer.get(data_type) + if buf is not None and len(buf) > 0: + buf.sort(order="time") + self.output_buffer[data_type] = buf + + safe_data_dict = {} + for data_type in self.provides: + buf = self.output_buffer.get( + data_type, np.empty(0, self.dtype_for(data_type)) + ) + safe_data, remaining = self._split_buffer(buf, safe_boundary) + self.output_buffer[data_type] = remaining + safe_data_dict[data_type] = safe_data + + self._update_min_buffered_time() + + chunk_start = None + chunk_end = None + + for data_type in self.provides: + safe_data = safe_data_dict[data_type] + dt_start, dt_end = self._get_chunk_boundaries(safe_data, safe_boundary) + + if chunk_start is None: + chunk_start = dt_start + chunk_end = dt_end + else: + chunk_start = min(chunk_start, dt_start) + chunk_end = max(chunk_end, dt_end) + + result = {} + for data_type in self.provides: + result[data_type] = self._make_chunk( + data=safe_data_dict[data_type], + data_type=data_type, + start=chunk_start, + end=chunk_end, + ) + + self.last_output_end = chunk_end + self.first_output = False + + return self.superrun_transformation( + result, self._cached_superrun, self._cached_subruns + ) + + def _split_buffer(self, buf, safe_boundary): + """Split buffer into safe portion (endtime <= boundary) and remainder.""" + if len(buf) == 0: + empty = np.empty(0, buf.dtype) + return empty, empty + + endtimes = strax.endtime(buf) + safe_mask = endtimes <= safe_boundary + + safe_data = buf[safe_mask].copy() + remaining = buf[~safe_mask].copy() + + return safe_data, remaining + + def _update_min_buffered_time(self): + """Recalculate minimum time across all buffered data.""" + min_time = float("inf") + for key, buf in self.output_buffer.items(): + if buf is not None and len(buf) > 0: + min_time = min(min_time, buf["time"].min()) + self._min_buffered_time = min_time + + def _get_chunk_boundaries(self, safe_data, safe_boundary): + """Determine chunk start/end ensuring buffered data fits in next chunk.""" + if self.first_output: + if len(safe_data) > 0: + chunk_start = int(safe_data[0]["time"]) + else: + chunk_start = 0 + else: + chunk_start = self.last_output_end + + if len(safe_data) > 0: + data_end = int(strax.endtime(safe_data).max()) + chunk_end = max(data_end, safe_boundary) + else: + chunk_end = safe_boundary + + # Don't advance chunk_end past minimum buffered time + if self._min_buffered_time < float("inf"): + chunk_end = min(chunk_end, int(self._min_buffered_time)) + + chunk_end = max(chunk_start, chunk_end) + + return chunk_start, chunk_end + + def _make_chunk(self, data, data_type, start, end): + """Create a strax Chunk with proper metadata.""" + return strax.Chunk( + start=start, + end=end, + data=data, + data_type=data_type, + data_kind=self.data_kind_for(data_type), + dtype=self.dtype_for(data_type), + run_id=self._run_id, + target_size_mb=self.chunk_target_size_mb, + ) diff --git a/tests/test_time_delay_plugin.py b/tests/test_time_delay_plugin.py new file mode 100644 index 000000000..7abb97a7a --- /dev/null +++ b/tests/test_time_delay_plugin.py @@ -0,0 +1,276 @@ +"""Tests for TimeDelayPlugin.""" + +import numpy as np +import strax +import pytest + + +def simple_interval_dtype(): + return [ + ("time", np.int64), + ("length", np.int32), + ("dt", np.int16), + ("value", np.int32), + ] + + +class ChunkedSource(strax.Plugin): + """Source plugin that yields pre-defined chunks.""" + + depends_on = tuple() + provides = "source_data" + dtype = simple_interval_dtype() + rechunk_on_save = False + chunks_data = [] + + def is_ready(self, chunk_i): + return chunk_i < len(self.chunks_data) + + def source_finished(self): + return True + + def compute(self, chunk_i): + start, end, data = self.chunks_data[chunk_i] + return self.chunk(start=start, end=end, data=data) + + +class ConstantDelayPlugin(strax.TimeDelayPlugin): + """Test plugin that adds a constant delay to all records.""" + + depends_on = ("source_data",) + provides = "delayed_data" + dtype = simple_interval_dtype() + data_kind = "delayed_data" + delay = 0 + + def get_max_delay(self): + return self.delay + + def compute_with_delay(self, source_data): + result = source_data.copy() + result["time"] = result["time"] + self.delay + return result + + +class VariableDelayPlugin(strax.TimeDelayPlugin): + """Test plugin that adds variable delays based on a pattern.""" + + depends_on = ("source_data",) + provides = "variable_delayed_data" + dtype = simple_interval_dtype() + data_kind = "variable_delayed_data" + max_delay = 100 + delay_pattern = [0] + + def get_max_delay(self): + return self.max_delay + + def compute_with_delay(self, source_data): + result = source_data.copy() + delays = np.array( + [self.delay_pattern[i % len(self.delay_pattern)] for i in range(len(result))] + ) + result["time"] = result["time"] + delays + return result + + +class MultiOutputDelayPlugin(strax.TimeDelayPlugin): + """Test plugin with multiple outputs.""" + + depends_on = ("source_data",) + provides = ("delayed_output_a", "delayed_output_b") + data_kind = { + "delayed_output_a": "delayed_output_a", + "delayed_output_b": "delayed_output_b", + } + delay_a = 0 + delay_b = 0 + max_delay = 100 + + def infer_dtype(self): + return { + "delayed_output_a": simple_interval_dtype(), + "delayed_output_b": simple_interval_dtype(), + } + + def get_max_delay(self): + return self.max_delay + + def compute_with_delay(self, source_data): + result_a = source_data.copy() + result_a["time"] = result_a["time"] + self.delay_a + result_b = source_data.copy() + result_b["time"] = result_b["time"] + self.delay_b + return {"delayed_output_a": result_a, "delayed_output_b": result_b} + + +def make_test_data(times, length=1, dt=1, values=None): + """Create test data array with given times.""" + n = len(times) + data = np.zeros(n, dtype=simple_interval_dtype()) + data["time"] = times + data["length"] = length + data["dt"] = dt + data["value"] = values if values is not None else np.arange(n) + return data + + +def create_context_with_source(chunks_data): + """Create a strax context with ChunkedSource configured.""" + + class TestSource(ChunkedSource): + pass + + TestSource.chunks_data = chunks_data + st = strax.Context(storage=[]) + st.register(TestSource) + return st + + +def test_constant_delay_across_chunks(): + """Test constant delay with buffering across chunk boundaries.""" + delay = 30 + + data1 = make_test_data(np.array([10, 40]), values=np.array([0, 1])) + data2 = make_test_data(np.array([60, 90]), values=np.array([2, 3])) + + chunks_data = [ + (0, 50, data1), + (50, 100, data2), + ] + st = create_context_with_source(chunks_data) + + class TestDelayPlugin(ConstantDelayPlugin): + delay = 30 + + st.register(TestDelayPlugin) + result = st.get_array(run_id="test", targets="delayed_data") + + expected_times = np.array([10, 40, 60, 90]) + delay + np.testing.assert_array_equal(sorted(result["time"]), sorted(expected_times)) + assert len(result) == 4 + + +def test_variable_delay_reorders_and_buffers(): + """Test variable delays with reordering and buffering.""" + data1 = make_test_data(np.array([0, 10, 20]), values=np.array([0, 1, 2])) + data2 = make_test_data(np.array([50, 60, 70]), values=np.array([3, 4, 5])) + + chunks_data = [ + (0, 50, data1), + (50, 100, data2), + ] + st = create_context_with_source(chunks_data) + + class TestVariableDelay(VariableDelayPlugin): + max_delay = 100 + delay_pattern = [0, 80, 20] + + st.register(TestVariableDelay) + result = st.get_array(run_id="test", targets="variable_delayed_data") + + assert len(result) == 6 + assert np.all(np.diff(result["time"]) >= 0), "Output must be sorted" + + +def test_empty_input_chunk(): + """Test handling of empty input chunks.""" + data1 = make_test_data(np.array([10, 20]), values=np.array([0, 1])) + empty_data = make_test_data(np.array([], dtype=np.int64)) + data3 = make_test_data(np.array([110, 120]), values=np.array([2, 3])) + + chunks_data = [ + (0, 50, data1), + (50, 100, empty_data), + (100, 150, data3), + ] + st = create_context_with_source(chunks_data) + + class TestDelayPlugin(ConstantDelayPlugin): + delay = 20 + + st.register(TestDelayPlugin) + result = st.get_array(run_id="test", targets="delayed_data") + + assert len(result) == 4 + + +def test_multi_output_different_delays(): + """Test multi-output plugin with different delays per output.""" + data = make_test_data(np.array([10, 50]), values=np.array([0, 1])) + chunks_data = [(0, 100, data)] + st = create_context_with_source(chunks_data) + + class TestMultiOutput(MultiOutputDelayPlugin): + delay_a = 20 + delay_b = 60 + max_delay = 60 + + st.register(TestMultiOutput) + + result_a = st.get_array(run_id="test", targets="delayed_output_a") + result_b = st.get_array(run_id="test", targets="delayed_output_b") + + np.testing.assert_array_equal(result_a["time"], [30, 70]) + np.testing.assert_array_equal(result_b["time"], [70, 110]) + + +def test_chunk_continuity(): + """Test that output chunks maintain proper continuity.""" + data1 = make_test_data(np.array([10, 20]), values=np.array([0, 1])) + data2 = make_test_data(np.array([60, 70]), values=np.array([2, 3])) + data3 = make_test_data(np.array([110, 120]), values=np.array([4, 5])) + + chunks_data = [ + (0, 50, data1), + (50, 100, data2), + (100, 150, data3), + ] + st = create_context_with_source(chunks_data) + + class TestDelayPlugin(ConstantDelayPlugin): + delay = 30 + + st.register(TestDelayPlugin) + chunks = list(st.get_iter(run_id="test", targets="delayed_data")) + + for i in range(1, len(chunks)): + assert chunks[i].start == chunks[i - 1].end, ( + f"Chunk {i} start ({chunks[i].start}) != chunk {i-1} end ({chunks[i-1].end})" + ) + + +def test_straddling_data_across_boundary(): + """Test data that straddles chunk boundary (time < boundary < endtime).""" + # After delay: time=95, endtime=105 (straddles boundary at 100) + data1 = np.zeros(1, dtype=simple_interval_dtype()) + data1["time"] = 85 + data1["length"] = 10 + data1["dt"] = 1 + data1["value"] = 1 + + data2 = np.zeros(1, dtype=simple_interval_dtype()) + data2["time"] = 120 + data2["length"] = 10 + data2["dt"] = 1 + data2["value"] = 2 + + chunks_data = [ + (0, 100, data1), + (100, 200, data2), + ] + st = create_context_with_source(chunks_data) + + class TestDelayPlugin(ConstantDelayPlugin): + delay = 10 + + st.register(TestDelayPlugin) + result = st.get_array(run_id="test", targets="delayed_data") + + assert len(result) == 2 + np.testing.assert_array_equal(result["time"], [95, 130]) + np.testing.assert_array_equal(strax.endtime(result), [105, 140]) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 7a77ff6bfb7137777421203d59444fe5a6a8f6cf Mon Sep 17 00:00:00 2001 From: Minghao Liu Date: Thu, 29 Jan 2026 01:16:11 -0500 Subject: [PATCH 2/6] Add documents --- docs/source/advanced/plugin_dev.rst | 31 +++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/docs/source/advanced/plugin_dev.rst b/docs/source/advanced/plugin_dev.rst index 3fec7efab..be7b541d3 100644 --- a/docs/source/advanced/plugin_dev.rst +++ b/docs/source/advanced/plugin_dev.rst @@ -35,6 +35,7 @@ There are several plugin types: * ``CutPlugin``: Plugin type where using ``def cut_by(self, )`` inside the plugin a user can return a boolean array that can be used to select data. * ``MergeOnlyPlugin``: This is for internal use and only merges two plugins into a new one. See as an example in straxen the ``EventInfo`` plugin where the following datatypes are merged ``'events', 'event_basics', 'event_positions', 'corrected_areas', 'energy_estimates'``. * ``ParallelSourcePlugin``: For internal use only to parallelize the processing of low level plugins. This can be activated using stating ``parallel = 'process'`` in a plugin. + * ``TimeDelayPlugin``: For plugins that add variable time delays to output data, causing output timestamps to potentially exceed input chunk boundaries. Useful for simulation plugins (e.g., adding electron drift time). The user must define ``get_max_delay(self)`` returning the maximum possible delay in nanoseconds, and ``compute_with_delay(self, )`` returning the delayed output arrays. Minimal examples @@ -178,6 +179,36 @@ ________ st.get_array(run_id, 'merged_data') +strax.TimeDelayPlugin +_________________________ +.. code-block:: python + + class VariableDelayPlugin(strax.TimeDelayPlugin): + """ + Plugin that adds random delays, simulating e.g. drift time. + Output timestamps may exceed input chunk boundaries. + """ + depends_on = 'records' + provides = 'delayed_records' + data_kind = 'delayed_records' + dtype = strax.record_dtype() + max_delay = 100 + + def get_max_delay(self): + # Return maximum possible delay in nanoseconds + return self.max_delay + + def compute_with_delay(self, records): + result = records.copy() + # Simulate variable drift time + delays = np.random.randint(0, self.max_delay, size=len(result)) + result['time'] = result['time'] + delays + return result + + st.register(VariableDelayPlugin) + st.get_array(run_id, 'delayed_records') + + Plugin inheritance ---------------------- It is possible to inherit the ``compute()`` method of an already existing plugin with another plugin. We call these types of plugins child plugins. Child plugins are recognized by strax when the ``child_plugin`` attribute of the plugin is set to ``True``. Below you can find a simple example of a child plugin with its parent plugin: From d874f57e3578f2a9df36970e059f9f9513e994ed Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Jan 2026 06:18:01 +0000 Subject: [PATCH 3/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- strax/plugins/time_delay_plugin.py | 41 ++++++++---------------------- tests/test_time_delay_plugin.py | 6 ++--- 2 files changed, 14 insertions(+), 33 deletions(-) diff --git a/strax/plugins/time_delay_plugin.py b/strax/plugins/time_delay_plugin.py index 89135ec55..0594aa167 100644 --- a/strax/plugins/time_delay_plugin.py +++ b/strax/plugins/time_delay_plugin.py @@ -47,9 +47,8 @@ def get_max_delay(self): def compute_with_delay(self, **kwargs): """Compute output data with time delays already applied. - Input arrays are numpy arrays (not Chunks). Output arrays do NOT - need to be sorted. For multi-output, return a dict mapping - data_type to arrays. + Input arrays are numpy arrays (not Chunks). Output arrays do NOT need to be sorted. For + multi-output, return a dict mapping data_type to arrays. """ raise NotImplementedError("Subclasses must implement compute_with_delay()") @@ -84,18 +83,14 @@ def _flush_single_output(self): start=self.last_output_end, end=chunk_end, ) - result = self.superrun_transformation( - result, self._cached_superrun, self._cached_subruns - ) + result = self.superrun_transformation(result, self._cached_superrun, self._cached_subruns) self.output_buffer = {} return result def _flush_multi_output(self): """Flush buffers for multi-output plugin.""" - has_data = any( - len(self.output_buffer.get(dt, [])) > 0 for dt in self.provides - ) + has_data = any(len(self.output_buffer.get(dt, [])) > 0 for dt in self.provides) if not has_data: return None @@ -110,9 +105,7 @@ def _flush_multi_output(self): result = {} for data_type in self.provides: - buf = self.output_buffer.get( - data_type, np.empty(0, self.dtype_for(data_type)) - ) + buf = self.output_buffer.get(data_type, np.empty(0, self.dtype_for(data_type))) result[data_type] = self._make_chunk( data=buf, data_type=data_type, @@ -120,9 +113,7 @@ def _flush_multi_output(self): end=chunk_end, ) - result = self.superrun_transformation( - result, self._cached_superrun, self._cached_subruns - ) + result = self.superrun_transformation(result, self._cached_superrun, self._cached_subruns) self.output_buffer = {} return result @@ -177,9 +168,7 @@ def _add_to_buffers_single(self, new_output): if None not in self.output_buffer: self.output_buffer[None] = new_output elif len(new_output) > 0: - self.output_buffer[None] = np.concatenate( - [self.output_buffer[None], new_output] - ) + self.output_buffer[None] = np.concatenate([self.output_buffer[None], new_output]) def _add_to_buffers_multi(self, new_output): """Add output to buffers for multi-output plugin.""" @@ -196,9 +185,7 @@ def _add_to_buffers_multi(self, new_output): if data_type not in self.output_buffer: self.output_buffer[data_type] = arr elif len(arr) > 0: - self.output_buffer[data_type] = np.concatenate( - [self.output_buffer[data_type], arr] - ) + self.output_buffer[data_type] = np.concatenate([self.output_buffer[data_type], arr]) def _process_single_output(self, safe_boundary): """Process buffer for single-output plugin.""" @@ -225,9 +212,7 @@ def _process_single_output(self, safe_boundary): end=chunk_end, ) - return self.superrun_transformation( - result, self._cached_superrun, self._cached_subruns - ) + return self.superrun_transformation(result, self._cached_superrun, self._cached_subruns) def _process_multi_output(self, safe_boundary): """Process buffers for multi-output plugin.""" @@ -239,9 +224,7 @@ def _process_multi_output(self, safe_boundary): safe_data_dict = {} for data_type in self.provides: - buf = self.output_buffer.get( - data_type, np.empty(0, self.dtype_for(data_type)) - ) + buf = self.output_buffer.get(data_type, np.empty(0, self.dtype_for(data_type))) safe_data, remaining = self._split_buffer(buf, safe_boundary) self.output_buffer[data_type] = remaining safe_data_dict[data_type] = safe_data @@ -274,9 +257,7 @@ def _process_multi_output(self, safe_boundary): self.last_output_end = chunk_end self.first_output = False - return self.superrun_transformation( - result, self._cached_superrun, self._cached_subruns - ) + return self.superrun_transformation(result, self._cached_superrun, self._cached_subruns) def _split_buffer(self, buf, safe_boundary): """Split buffer into safe portion (endtime <= boundary) and remainder.""" diff --git a/tests/test_time_delay_plugin.py b/tests/test_time_delay_plugin.py index 7abb97a7a..ccd348ce3 100644 --- a/tests/test_time_delay_plugin.py +++ b/tests/test_time_delay_plugin.py @@ -235,9 +235,9 @@ class TestDelayPlugin(ConstantDelayPlugin): chunks = list(st.get_iter(run_id="test", targets="delayed_data")) for i in range(1, len(chunks)): - assert chunks[i].start == chunks[i - 1].end, ( - f"Chunk {i} start ({chunks[i].start}) != chunk {i-1} end ({chunks[i-1].end})" - ) + assert ( + chunks[i].start == chunks[i - 1].end + ), f"Chunk {i} start ({chunks[i].start}) != chunk {i-1} end ({chunks[i-1].end})" def test_straddling_data_across_boundary(): From b822b2676f96a7e389c261e9c7e6bd5b344ff56a Mon Sep 17 00:00:00 2001 From: Minghao Liu Date: Thu, 29 Jan 2026 01:26:17 -0500 Subject: [PATCH 4/6] Fix pre-commit failure --- tests/test_time_delay_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_time_delay_plugin.py b/tests/test_time_delay_plugin.py index ccd348ce3..20745dd73 100644 --- a/tests/test_time_delay_plugin.py +++ b/tests/test_time_delay_plugin.py @@ -21,7 +21,7 @@ class ChunkedSource(strax.Plugin): provides = "source_data" dtype = simple_interval_dtype() rechunk_on_save = False - chunks_data = [] + chunks_data: list = [] def is_ready(self, chunk_i): return chunk_i < len(self.chunks_data) @@ -237,7 +237,7 @@ class TestDelayPlugin(ConstantDelayPlugin): for i in range(1, len(chunks)): assert ( chunks[i].start == chunks[i - 1].end - ), f"Chunk {i} start ({chunks[i].start}) != chunk {i-1} end ({chunks[i-1].end})" + ), f"Chunk {i} start ({chunks[i].start}) != chunk {i - 1} end ({chunks[i - 1].end})" def test_straddling_data_across_boundary(): From 65a405fb5fcc9cf18dde50d4bfde251298c77d30 Mon Sep 17 00:00:00 2001 From: Minghao Liu Date: Thu, 29 Jan 2026 10:51:57 -0500 Subject: [PATCH 5/6] Reduce code redundancy --- docs/source/advanced/plugin_dev.rst | 8 +- strax/plugins/time_delay_plugin.py | 169 ++++++++-------------------- tests/test_time_delay_plugin.py | 13 --- 3 files changed, 49 insertions(+), 141 deletions(-) diff --git a/docs/source/advanced/plugin_dev.rst b/docs/source/advanced/plugin_dev.rst index be7b541d3..665a6c0c4 100644 --- a/docs/source/advanced/plugin_dev.rst +++ b/docs/source/advanced/plugin_dev.rst @@ -35,7 +35,7 @@ There are several plugin types: * ``CutPlugin``: Plugin type where using ``def cut_by(self, )`` inside the plugin a user can return a boolean array that can be used to select data. * ``MergeOnlyPlugin``: This is for internal use and only merges two plugins into a new one. See as an example in straxen the ``EventInfo`` plugin where the following datatypes are merged ``'events', 'event_basics', 'event_positions', 'corrected_areas', 'energy_estimates'``. * ``ParallelSourcePlugin``: For internal use only to parallelize the processing of low level plugins. This can be activated using stating ``parallel = 'process'`` in a plugin. - * ``TimeDelayPlugin``: For plugins that add variable time delays to output data, causing output timestamps to potentially exceed input chunk boundaries. Useful for simulation plugins (e.g., adding electron drift time). The user must define ``get_max_delay(self)`` returning the maximum possible delay in nanoseconds, and ``compute_with_delay(self, )`` returning the delayed output arrays. + * ``TimeDelayPlugin``: For plugins that add variable time delays to output data, causing output timestamps to potentially exceed input chunk boundaries. Useful for simulation plugins (e.g., adding electron drift time). The user must define ``compute_with_delay(self, )`` returning the delayed output arrays. Minimal examples @@ -192,11 +192,7 @@ _________________________ provides = 'delayed_records' data_kind = 'delayed_records' dtype = strax.record_dtype() - max_delay = 100 - - def get_max_delay(self): - # Return maximum possible delay in nanoseconds - return self.max_delay + max_delay = 100 # for use in compute_with_delay def compute_with_delay(self, records): result = records.copy() diff --git a/strax/plugins/time_delay_plugin.py b/strax/plugins/time_delay_plugin.py index 0594aa167..a1060a22e 100644 --- a/strax/plugins/time_delay_plugin.py +++ b/strax/plugins/time_delay_plugin.py @@ -17,7 +17,6 @@ class TimeDelayPlugin(Plugin): multi-output plugins. Subclasses must implement: - get_max_delay(): Return maximum possible delay in nanoseconds compute_with_delay(**kwargs): Return delayed output data (arrays, not Chunks) For multi-output plugins, compute_with_delay should return a dict @@ -29,10 +28,6 @@ class TimeDelayPlugin(Plugin): def __init__(self): super().__init__() - self._init_buffers() - - def _init_buffers(self): - """Initialize/reset all buffer state.""" self.output_buffer = {} self.last_output_end = 0 self.first_output = True @@ -40,10 +35,6 @@ def _init_buffers(self): self._cached_subruns = None self._min_buffered_time = float("inf") - def get_max_delay(self): - """Return the maximum possible delay in nanoseconds.""" - raise NotImplementedError("Subclasses must implement get_max_delay()") - def compute_with_delay(self, **kwargs): """Compute output data with time delays already applied. @@ -62,47 +53,20 @@ def iter(self, iters, executor=None): def _flush_buffers(self): """Flush all remaining data from buffers.""" - if self.multi_output: - return self._flush_multi_output() - else: - return self._flush_single_output() - - def _flush_single_output(self): - """Flush buffer for single-output plugin.""" - buf = self.output_buffer.get(None) - if buf is None or len(buf) == 0: - return None - - buf.sort(order="time") - data_end = int(strax.endtime(buf).max()) - chunk_end = max(self.last_output_end, data_end) - - result = self._make_chunk( - data=buf, - data_type=self.provides[0], - start=self.last_output_end, - end=chunk_end, - ) - result = self.superrun_transformation(result, self._cached_superrun, self._cached_subruns) - - self.output_buffer = {} - return result - - def _flush_multi_output(self): - """Flush buffers for multi-output plugin.""" has_data = any(len(self.output_buffer.get(dt, [])) > 0 for dt in self.provides) if not has_data: return None + # Sort buffers and compute chunk_end chunk_end = self.last_output_end for data_type in self.provides: buf = self.output_buffer.get(data_type) if buf is not None and len(buf) > 0: buf.sort(order="time") - self.output_buffer[data_type] = buf data_end = int(strax.endtime(buf).max()) chunk_end = max(chunk_end, data_end) + # Build result dict result = {} for data_type in self.provides: buf = self.output_buffer.get(data_type, np.empty(0, self.dtype_for(data_type))) @@ -114,13 +78,16 @@ def _flush_multi_output(self): ) result = self.superrun_transformation(result, self._cached_superrun, self._cached_subruns) - self.output_buffer = {} - return result + + return self._unwrap_result(result) def do_compute(self, chunk_i=None, **kwargs): """Process input, buffer output, return safe portion.""" - input_start, input_end = self._get_input_timing(kwargs) + if not kwargs: + raise RuntimeError("TimeDelayPlugin must have dependencies") + first_chunk = next(iter(kwargs.values())) + input_end = first_chunk.end self._cached_superrun = self._check_subruns_uniqueness( kwargs, {k: v.superrun for k, v in kwargs.items()} @@ -134,94 +101,53 @@ def do_compute(self, chunk_i=None, **kwargs): self._add_to_buffers(new_output) - safe_boundary = input_end + return self._process_output(safe_boundary=input_end) + def _unwrap_result(self, result): + """Unwrap result dict to single Chunk for single-output plugins.""" if self.multi_output: - return self._process_multi_output(safe_boundary) - else: - return self._process_single_output(safe_boundary) - - def _get_input_timing(self, kwargs): - """Extract input chunk timing.""" - if not kwargs: - raise RuntimeError("TimeDelayPlugin must have dependencies") - first_chunk = next(iter(kwargs.values())) - return first_chunk.start, first_chunk.end + return result + return result[self.provides[0]] def _add_to_buffers(self, new_output): - """Add new output to appropriate buffers.""" + """Add new output to buffers.""" + # Normalize output to dict format if self.multi_output: - self._add_to_buffers_multi(new_output) + if not isinstance(new_output, dict): + raise ValueError( + f"{self.__class__.__name__} is multi-output, " + "compute_with_delay must return a dict" + ) + output_dict = new_output else: - self._add_to_buffers_single(new_output) - - def _add_to_buffers_single(self, new_output): - """Add output to buffer for single-output plugin.""" - if isinstance(new_output, dict): - raise ValueError( - f"{self.__class__.__name__} is single-output, " - "compute_with_delay should not return a dict" - ) - if not isinstance(new_output, np.ndarray): - new_output = strax.dict_to_rec(new_output, dtype=self.dtype) - - if None not in self.output_buffer: - self.output_buffer[None] = new_output - elif len(new_output) > 0: - self.output_buffer[None] = np.concatenate([self.output_buffer[None], new_output]) - - def _add_to_buffers_multi(self, new_output): - """Add output to buffers for multi-output plugin.""" - if not isinstance(new_output, dict): - raise ValueError( - f"{self.__class__.__name__} is multi-output, " - "compute_with_delay must return a dict" - ) + if isinstance(new_output, dict): + raise ValueError( + f"{self.__class__.__name__} is single-output, " + "compute_with_delay should not return a dict" + ) + output_dict = {self.provides[0]: new_output} + for data_type in self.provides: - arr = new_output.get(data_type, np.empty(0, self.dtype_for(data_type))) + arr = output_dict.get(data_type, np.empty(0, self.dtype_for(data_type))) if not isinstance(arr, np.ndarray): arr = strax.dict_to_rec(arr, dtype=self.dtype_for(data_type)) if data_type not in self.output_buffer: self.output_buffer[data_type] = arr elif len(arr) > 0: - self.output_buffer[data_type] = np.concatenate([self.output_buffer[data_type], arr]) - - def _process_single_output(self, safe_boundary): - """Process buffer for single-output plugin.""" - buf = self.output_buffer.get(None, np.empty(0, self.dtype)) - - if len(buf) > 0: - buf.sort(order="time") - self.output_buffer[None] = buf - - safe_data, remaining = self._split_buffer(buf, safe_boundary) - self.output_buffer[None] = remaining + self.output_buffer[data_type] = np.concatenate( + [self.output_buffer[data_type], arr] + ) - self._update_min_buffered_time() - - chunk_start, chunk_end = self._get_chunk_boundaries(safe_data, safe_boundary) - - self.last_output_end = chunk_end - self.first_output = False - - result = self._make_chunk( - data=safe_data, - data_type=self.provides[0], - start=chunk_start, - end=chunk_end, - ) - - return self.superrun_transformation(result, self._cached_superrun, self._cached_subruns) - - def _process_multi_output(self, safe_boundary): - """Process buffers for multi-output plugin.""" + def _process_output(self, safe_boundary): + """Process buffers and return safe portion.""" + # Sort all buffers for data_type in self.provides: buf = self.output_buffer.get(data_type) if buf is not None and len(buf) > 0: buf.sort(order="time") - self.output_buffer[data_type] = buf + # Split buffers into safe and remaining portions safe_data_dict = {} for data_type in self.provides: buf = self.output_buffer.get(data_type, np.empty(0, self.dtype_for(data_type))) @@ -229,15 +155,19 @@ def _process_multi_output(self, safe_boundary): self.output_buffer[data_type] = remaining safe_data_dict[data_type] = safe_data - self._update_min_buffered_time() + # Update minimum buffered time + min_time = float("inf") + for buf in self.output_buffer.values(): + if buf is not None and len(buf) > 0: + min_time = min(min_time, buf["time"].min()) + self._min_buffered_time = min_time + # Compute unified chunk boundaries across all data types chunk_start = None chunk_end = None - for data_type in self.provides: safe_data = safe_data_dict[data_type] dt_start, dt_end = self._get_chunk_boundaries(safe_data, safe_boundary) - if chunk_start is None: chunk_start = dt_start chunk_end = dt_end @@ -245,6 +175,7 @@ def _process_multi_output(self, safe_boundary): chunk_start = min(chunk_start, dt_start) chunk_end = max(chunk_end, dt_end) + # Build result dict result = {} for data_type in self.provides: result[data_type] = self._make_chunk( @@ -257,7 +188,9 @@ def _process_multi_output(self, safe_boundary): self.last_output_end = chunk_end self.first_output = False - return self.superrun_transformation(result, self._cached_superrun, self._cached_subruns) + result = self.superrun_transformation(result, self._cached_superrun, self._cached_subruns) + + return self._unwrap_result(result) def _split_buffer(self, buf, safe_boundary): """Split buffer into safe portion (endtime <= boundary) and remainder.""" @@ -273,14 +206,6 @@ def _split_buffer(self, buf, safe_boundary): return safe_data, remaining - def _update_min_buffered_time(self): - """Recalculate minimum time across all buffered data.""" - min_time = float("inf") - for key, buf in self.output_buffer.items(): - if buf is not None and len(buf) > 0: - min_time = min(min_time, buf["time"].min()) - self._min_buffered_time = min_time - def _get_chunk_boundaries(self, safe_data, safe_boundary): """Determine chunk start/end ensuring buffered data fits in next chunk.""" if self.first_output: diff --git a/tests/test_time_delay_plugin.py b/tests/test_time_delay_plugin.py index 20745dd73..ffefd6d8d 100644 --- a/tests/test_time_delay_plugin.py +++ b/tests/test_time_delay_plugin.py @@ -43,9 +43,6 @@ class ConstantDelayPlugin(strax.TimeDelayPlugin): data_kind = "delayed_data" delay = 0 - def get_max_delay(self): - return self.delay - def compute_with_delay(self, source_data): result = source_data.copy() result["time"] = result["time"] + self.delay @@ -59,12 +56,8 @@ class VariableDelayPlugin(strax.TimeDelayPlugin): provides = "variable_delayed_data" dtype = simple_interval_dtype() data_kind = "variable_delayed_data" - max_delay = 100 delay_pattern = [0] - def get_max_delay(self): - return self.max_delay - def compute_with_delay(self, source_data): result = source_data.copy() delays = np.array( @@ -85,7 +78,6 @@ class MultiOutputDelayPlugin(strax.TimeDelayPlugin): } delay_a = 0 delay_b = 0 - max_delay = 100 def infer_dtype(self): return { @@ -93,9 +85,6 @@ def infer_dtype(self): "delayed_output_b": simple_interval_dtype(), } - def get_max_delay(self): - return self.max_delay - def compute_with_delay(self, source_data): result_a = source_data.copy() result_a["time"] = result_a["time"] + self.delay_a @@ -163,7 +152,6 @@ def test_variable_delay_reorders_and_buffers(): st = create_context_with_source(chunks_data) class TestVariableDelay(VariableDelayPlugin): - max_delay = 100 delay_pattern = [0, 80, 20] st.register(TestVariableDelay) @@ -204,7 +192,6 @@ def test_multi_output_different_delays(): class TestMultiOutput(MultiOutputDelayPlugin): delay_a = 20 delay_b = 60 - max_delay = 60 st.register(TestMultiOutput) From a48929acf90106e5528183aa039c6bde773c50ef Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Jan 2026 15:52:34 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- strax/plugins/time_delay_plugin.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/strax/plugins/time_delay_plugin.py b/strax/plugins/time_delay_plugin.py index a1060a22e..e980048d9 100644 --- a/strax/plugins/time_delay_plugin.py +++ b/strax/plugins/time_delay_plugin.py @@ -135,9 +135,7 @@ def _add_to_buffers(self, new_output): if data_type not in self.output_buffer: self.output_buffer[data_type] = arr elif len(arr) > 0: - self.output_buffer[data_type] = np.concatenate( - [self.output_buffer[data_type], arr] - ) + self.output_buffer[data_type] = np.concatenate([self.output_buffer[data_type], arr]) def _process_output(self, safe_boundary): """Process buffers and return safe portion."""