diff --git a/pathwaysutils/elastic/elastic.py b/pathwaysutils/elastic/elastic.py new file mode 100644 index 0000000..6dd8af6 --- /dev/null +++ b/pathwaysutils/elastic/elastic.py @@ -0,0 +1,286 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Elasticity manager. + +This class provides a utility for elastic training. It provides a decorator that +retries a function in case of `jax.errors.JaxRuntimeError` caused by slice down +events. It also provides a utility for waiting for slices to become active. +""" + +import collections +from collections.abc import Mapping, Sequence +import logging +import time + +import jax +import numpy as np +from pathwaysutils.debug import timing + + +_logger = logging.getLogger(__name__) + +_SIMPLE_EXECUTION_TEST_VALUE = 100 +_ELASTIC_DOWN_ERROR_TYPES = frozenset( + "DATA_LOSS", +) +_ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES = frozenset( + "DEADLINE_EXCEEDED", + "NOT_FOUND", + "INTERNAL", +) + + +def _plus_one(x: jax.Array) -> jax.Array: + """Adds one to each element in the array. + + Used to test if a slice is active. + + Args: + x: The array to add one to. + + Returns: + The array with one added to each element. + """ + return x + 1 + + +def _simple_execution(devices: Sequence[jax.Device]) -> jax.Array: + """Simple execution to test if a slice is active. + + This function is used to test if a slice is active. It executes a simple + computation on the devices and returns the result. If any of the devices are + not active, the returned array will fail with a JaxRuntimeError used. + + Simply executing this function is not enough to determine if the slice is + active. We also need to check the value of the returned array. + + Args: + devices: The devices to execute on. + + Returns: + The result of the execution. + """ + if not devices: + raise ValueError("No devices") + + test_input = np.zeros(len(devices), dtype=float) + ( + _SIMPLE_EXECUTION_TEST_VALUE - 1 + ) + + return jax.pmap(_plus_one, devices=devices)(test_input) + + +def get_slice_to_devices( + devices: Sequence[jax.Device], +) -> dict[int, Sequence[jax.Device]]: + """Returns the mapping from slice index to devices.""" + slice_to_devices = collections.defaultdict(list) + for d in devices: + slice_to_devices[d.slice_index].append(d) + return dict(slice_to_devices) + + +@timing.timeit +def get_active_slice_indices( + slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None, +) -> set[int]: + """Returns the set of active slices indices. + + Args: + slice_to_devices: A mapping from slice index to devices. If None, + `get_slice_to_devices(jax.devices())` is used to gather all available + devices and group them by slice. + + Returns: + A set of integers representing the indices of the active slices. + """ + if slice_to_devices is None: + _logger.debug("slice_to_devices is None. Getting from jax.devices().") + slice_to_devices = get_slice_to_devices(tuple(jax.devices())) + + _logger.debug( + "Getting active slice indices for slices: %s", + sorted(list(slice_to_devices.keys())), + ) + + active_slice_indices = set() + + results = { + slice_index: _simple_execution(devices) + for slice_index, devices in slice_to_devices.items() + } + + for slice_index, x in results.items(): + _logger.debug("Checking slice_index=%s", slice_index) + expected = ( + np.zeros(len(slice_to_devices[slice_index]), dtype=float) + + _SIMPLE_EXECUTION_TEST_VALUE + ) + try: + with timing.Timer(f"Checking {slice_index=}"): + _logger.debug("Blocking until ready for slice_index=%s", slice_index) + jax.block_until_ready(x) + _logger.debug("Execution finished for slice_index=%s", slice_index) + if np.allclose(x, expected): + active_slice_indices.add(slice_index) + _logger.debug("slice_index=%s active", slice_index) + else: + _logger.error( + "Error with _simple_execution for slice_index=%s. " + "This should never happen. Expected: %r, Actual: %r", + slice_index, + expected, + x, + ) + raise ValueError( + f"Error with _simple_execution for slice_index={slice_index}." + ) + except jax.errors.JaxRuntimeError as error: + _logger.debug( + "Caught JaxRuntimeError for slice_index=%s: %s", slice_index, error + ) + if not is_error_due_to_slice_down(error): + _logger.info("Re-raising error for slice_index=%s", slice_index) + raise + _logger.debug("slice_index=%s bad", slice_index) + + _logger.debug("active_slice_indices=%s", active_slice_indices) + + return active_slice_indices + + +def wait_for_slices( + slice_count: int, + poll_interval: float | int = 10, + timeout: float | int | None = None, + slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None, +) -> set[int]: + """Waits until after at least `slice_count` slices become active. + + Args: + slice_count: The number of slices to wait for. + poll_interval: The minimum number of seconds to wait between availability + checks. If the check takes longer than this, the next check will start + immediately after the current check completes. Defaults to 10 seconds. + timeout: The maximum number of seconds to wait. If None, there is no + timeout. + slice_to_devices: A mapping from slice index to devices. If None, + `get_slice_to_devices(jax.devices())` is used. + + Returns: + The active slice indices + + Raises: + TimeoutError: If the timeout is reached before the slices become + active. + """ + if slice_to_devices is None: + _logger.debug("slice_to_devices is None. Getting from jax.devices().") + slice_to_devices = get_slice_to_devices(jax.devices()) + + _logger.info( + "Waiting for %s slices. Poll interval: %s, Timeout: %s", + slice_count, + poll_interval, + timeout, + ) + start_time = time.time() + + while True: + check_start_time = time.time() + + _logger.debug("Checking active slices...") + active_slice_indices = get_active_slice_indices(slice_to_devices) + if len(active_slice_indices) >= slice_count: + _logger.info( + "Sufficient slices active: %s >= %s. Active indices: %s", + len(active_slice_indices), + slice_count, + active_slice_indices, + ) + return active_slice_indices + + _logger.info( + "%s slices active. Wanting at least %s. Active indices: %s", + len(active_slice_indices), + slice_count, + active_slice_indices, + ) + + time_to_sleep = max(0, poll_interval - (time.time() - check_start_time)) + + if timeout is not None: + elapsed_time = time.time() - start_time + if elapsed_time + time_to_sleep >= timeout: + raise TimeoutError( + f"Timed out waiting for {slice_count} slices. Only" + f" {len(active_slice_indices)} active after" + f" {elapsed_time:.2f} seconds." + f" Next check would occur after the timeout of {timeout}" + " seconds." + ) + + if time_to_sleep > 0: + _logger.debug("Sleeping for %.2f seconds.", time_to_sleep) + + time.sleep(time_to_sleep) + + +def is_error_due_to_slice_down(error: Exception) -> bool: + """Returns True if the error is due to slice down. + + The error types that are considered due to slice down are + jax.errors.JaxRuntimeError with the following error kind in the message: + - DATA_LOSS + - DEADLINE_EXCEEDED + - NOT_FOUND + - INTERNAL + + Args: + error: The error to check. + """ + error_due_to_slice_down = False + traceback_logging_level = logging.DEBUG + + if isinstance(error, jax.errors.JaxRuntimeError): + _logger.debug("Checking if JaxRuntimeError is due to slice down: %s", error) + if any( + error_type in str(error) for error_type in _ELASTIC_DOWN_ERROR_TYPES + ): + _logger.debug( + "Caught an error due to slice down (matched" + " _ELASTIC_DOWN_ERROR_TYPES)" + ) + + error_due_to_slice_down = True + + elif any( + error_type in str(error) + for error_type in _ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES + ): + _logger.warning( + "Caught an error that may or may not be due to slice down (matched" + " _ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES). This error will be treated" + " as due to slice down." + ) + traceback_logging_level = logging.WARNING + + error_due_to_slice_down = True + + if not error_due_to_slice_down: + _logger.debug("Caught an error not due to slice down") + + _logger.log(traceback_logging_level, "Error details:", exc_info=True) + + return error_due_to_slice_down diff --git a/pathwaysutils/elastic/manager.py b/pathwaysutils/elastic/manager.py index 8bd712e..4c1bc79 100644 --- a/pathwaysutils/elastic/manager.py +++ b/pathwaysutils/elastic/manager.py @@ -18,58 +18,54 @@ events. It also provides a utility for waiting for slices to become active. """ -import collections -from collections.abc import Mapping, Sequence +import _thread +from collections.abc import Callable, Mapping, Sequence import functools -import itertools import logging -import time -import traceback -from typing import Any +import threading +from typing import Any, TypeVar import jax -import numpy as np -from pathwaysutils.debug import timing +from pathwaysutils.elastic import elastic _logger = logging.getLogger(__name__) -def _plus_one(x: jax.Array) -> jax.Array: - """Adds one to each element in the array. +class ElasticRuntimeError(RuntimeError): + """Error raised when elasticity cannot continue.""" - Used to test if a slice is active. - Args: - x: The array to add one to. +class NewSliceAvailableError(RuntimeError): + """Error raised when a new slice is available.""" - Returns: - The array with one added to each element. - """ - return x + 1 +_F = TypeVar("_F", bound=Callable[..., Any]) -class ElasticRuntimeError(RuntimeError): - """Error raised when elasticity cannot continue.""" + +def _elastic_event_cleanup(): + """Cleans up JAX profiles, caches, and live arrays.""" + try: + _logger.info("Cleaning up any ongoing traces") + jax.profiler.stop_trace() + except (RuntimeError, ValueError) as e: + _logger.info("No ongoing traces to clean up") + except Exception: + _logger.exception("Error cleaning up ongoing traces") + raise + + jax.clear_caches() + for array in jax.live_arrays(): + array.delete() class Manager: """Utility class for elastic training.""" - _devices: Sequence[jax.Device] _total_slice_count: int | None = None slice_to_devices: Mapping[int, Sequence[jax.Device]] active_slice_indices: set[int] - - _SIMPLE_EXECUTION_TEST_VALUE = 100 - _ELASTIC_DOWN_ERROR_TYPES = [ - "DATA_LOSS", - ] - _ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES = [ - "DEADLINE_EXCEEDED", - "NOT_FOUND", - "INTERNAL", - ] + new_slice_event: threading.Event def __init__(self, devices: Sequence[jax.Device] | None = None) -> None: """Initializes the manager. @@ -79,24 +75,14 @@ def __init__(self, devices: Sequence[jax.Device] | None = None) -> None: """ if devices is None: devices = jax.devices() - self.devices = devices + self.slice_to_devices = elastic.get_slice_to_devices(devices) - self.active_slice_indices = self.get_active_slice_indices() + self.all_slice_indices = set(self.slice_to_devices.keys()) - @property - def devices(self) -> Sequence[jax.Device]: - """Returns the devices.""" - return self._devices - - @devices.setter - def devices(self, devices: Sequence[jax.Device]) -> None: - """Sets the devices.""" - self._devices = devices - - self.slice_to_devices = collections.defaultdict(list) - for d in self._devices: - self.slice_to_devices[d.slice_index].append(d) - self.slice_to_devices = dict(self.slice_to_devices) + self.active_slice_indices = elastic.get_active_slice_indices( + slice_to_devices=self.slice_to_devices + ) + self.new_slice_event = threading.Event() @property def total_slice_count(self) -> int: @@ -105,143 +91,6 @@ def total_slice_count(self) -> int: self._total_slice_count = len(self.slice_to_devices) return self._total_slice_count - def slice_device_count(self, slice_index: int) -> int: - """Returns the number of devices in a slice.""" - try: - return len(self.slice_to_devices[slice_index]) - except KeyError as error: - raise ValueError( - f"Slice {slice_index=} not found in {self.slice_to_devices=}" - ) from error - - def is_error_due_to_slice_down(self, error: Exception) -> bool: - """Returns True if the error is due to slice down. - - The error types that are considered due to slice down are - jax.errors.JaxRuntimeError with the following error kind in the message: - - DATA_LOSS - - DEADLINE_EXCEEDED - - NOT_FOUND - - INTERNAL - - Args: - error: The error to check. - """ - error_due_to_slice_down = False - traceback_logging_level = logging.DEBUG - - if isinstance(error, jax.errors.JaxRuntimeError): - if any( - error_type in str(error) - for error_type in self._ELASTIC_DOWN_ERROR_TYPES - ): - _logger.info("Caught an error due to slice down") - - error_due_to_slice_down = True - - elif any( - error_type in str(error) - for error_type in self._ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES - ): - _logger.warning( - "Caught an error due that may or may not be due to slice down. This" - " error will be treated as due to slice down." - ) - traceback_logging_level = logging.WARNING - - error_due_to_slice_down = True - - if not error_due_to_slice_down: - _logger.info("Caught an error not due to slice down") - - _logger.log( - traceback_logging_level, "\n".join(traceback.format_exception(error)) - ) - - return error_due_to_slice_down - - def _simple_execution(self, devices: Sequence[jax.Device]) -> jax.Array: - """Simple execution to test if a slice is active. - - This function is used to test if a slice is active. It executes a simple - computation on the devices and returns the result. If any of the devices are - not active, the returned array will fail with a JaxRuntimeError used. - - Simply executing this function is not enough to determine if the slice is - active. We also need to check the value of the returned array. - - Args: - devices: The devices to execute on. - - Returns: - The result of the execution. - """ - if not devices: - raise ValueError("No devices") - - test_input = np.zeros(len(devices), dtype=float) + ( - self._SIMPLE_EXECUTION_TEST_VALUE - 1 - ) - - return jax.pmap(_plus_one, devices=devices)(test_input) - - @timing.timeit - def get_active_slice_indices(self) -> set[int]: - """Returns the set of active slices indices.""" - active_slice_indices = set() - - results = { - slice_index: self._simple_execution(devices) - for slice_index, devices in self.slice_to_devices.items() - } - - for slice_index, x in results.items(): - _logger.info("Checking slice_index=%s", slice_index) - expected = ( - np.zeros(self.slice_device_count(slice_index), dtype=float) - + self._SIMPLE_EXECUTION_TEST_VALUE - ) - try: - with timing.Timer(f"Checking {slice_index=}"): - jax.block_until_ready(x) - if np.allclose(x, expected): - active_slice_indices.add(slice_index) - _logger.info("slice_index=%s good", slice_index) - else: - _logger.error( - "Error with _simple_execution for slice_index=%s. " - "This should never happen. Expected: %s, Actual: %s", - slice_index, - expected, - x, - ) - raise ValueError( - f"Error with _simple_execution for slice_index={slice_index}." - ) - except jax.errors.JaxRuntimeError as error: - if not self.is_error_due_to_slice_down(error): - raise - _logger.info("slice_index=%s bad", slice_index) - - _logger.info("active_slice_indices=%s", active_slice_indices) - - return active_slice_indices - - @property - def active_slice_to_devices(self) -> dict[int, Sequence[jax.Device]]: - """The mapping from a active slice to its devices.""" - return { - slice_index: self.slice_to_devices[slice_index] - for slice_index in self.active_slice_indices - } - - @property - def active_devices(self) -> list[jax.Device]: - """Returns the active slice indices.""" - return list( - itertools.chain.from_iterable(self.active_slice_to_devices.values()) - ) - @property def default_device(self) -> jax.Device: """Returns the device that should be set to the default device. @@ -258,15 +107,20 @@ def active_slice_count(self) -> int: """Returns the number of slices.""" return len(self.active_slice_indices) + @property + def inactive_slice_indices(self) -> set[int]: + """Returns the set of inactive slice indices.""" + return self.all_slice_indices - self.active_slice_indices + def scale_by_active_slices(self, x: int | float) -> int | float: - """Scale x by the number of good slices.""" + """Scale x by the number of active slices.""" if isinstance(x, int): quotient, remainder = divmod( x * self.active_slice_count, self.total_slice_count ) if remainder: raise ValueError( - f"Cannot scale {x=} by good slices because it will result in a " + f"Cannot scale {x=} by active slices because it will result in a " f"remainder of {remainder=}." ) return quotient @@ -275,81 +129,78 @@ def scale_by_active_slices(self, x: int | float) -> int | float: else: raise ValueError(f"Unsupported type: {type(x)=}") - def wait_for_slices( + def _cleanup_on_retry(self): + """Cleans up JAX caches and traces on retry.""" + try: + _logger.debug("Cleaning up any ongoing traces") + jax.profiler.stop_trace() + except (RuntimeError, ValueError): + _logger.debug("No ongoing traces to clean up") + except Exception: # pylint: disable=broad-exception-caught + _logger.exception("Error cleaning up ongoing traces") + + jax.clear_caches() + for array in jax.live_arrays(): + array.delete() + + def _elasticity_retry_decorator( self, - slice_count: int | None = None, - poll_interval: float | int = 10, - timeout: float | int | None = None, - ) -> set[int]: - """Waits until after at least `slice_count` slices become active. - - Args: - slice_count: The number of slices to wait for. If None, waits for all - slices to become active. - poll_interval: The minimum number of seconds to wait between availability - checks. If the check takes longer than this, the next check will start - immediately after the current check completes. Defaults to 10 seconds. - timeout: The maximum number of seconds to wait. If None, there is no - timeout. + max_retries: int, + pre_callback: Callable[..., Any] | None = None, + on_elastic_event_callback: Callable[..., Any] | None = None, + ) -> Callable[[_F], _F]: + """Retries a function with elasticity fault tolerance.""" - Returns: - The active slice indices + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + for retry_index in range(max_retries): + try: + _logger.info( + "Elastic attempt %d out of %d", retry_index + 1, max_retries + ) + if pre_callback is not None: + pre_callback() - Raises: - TimeoutError: If the timeout is reached before the slices become - active. - """ - if slice_count is None: - slice_count = self.total_slice_count + with jax.default_device(self.default_device): + return func(*args, **kwargs) + except NewSliceAvailableError: + _logger.info("New slice available. Retrying.") + _elastic_event_cleanup() - start_time = time.time() + if on_elastic_event_callback is not None: + on_elastic_event_callback() + except jax.errors.JaxRuntimeError as error: + if not elastic.is_error_due_to_slice_down(error): + raise - while True: - check_start_time = time.time() + if self.new_slice_event.is_set(): + _logger.info( + "Slice down event and new slice available detected. Retrying." + ) + else: + _logger.info("Slice down event detected. Retrying.") - active_slice_indices = self.get_active_slice_indices() - if len(active_slice_indices) >= slice_count: - _logger.info( - "%s/%s slices are active", - len(active_slice_indices), - self.total_slice_count, - ) - return active_slice_indices - - _logger.info( - "%s/%s slices active. Wanting at least %s/%s.", - len(active_slice_indices), - self.total_slice_count, - slice_count, - self.total_slice_count, - ) + _elastic_event_cleanup() - time_to_sleep = max(0, poll_interval - (time.time() - check_start_time)) - - if ( - timeout is not None - and (elapsed_time := time.time() - start_time) + time_to_sleep - >= timeout - ): - raise TimeoutError( - f"Timed out waiting for {slice_count} slices. Only" - f" {len(active_slice_indices)} active after" - f" {elapsed_time:.2f} seconds." - f" Next check would occur after the timeout of {timeout}" - " seconds." + if on_elastic_event_callback is not None: + on_elastic_event_callback() + raise ElasticRuntimeError( + f"Elastic attempt {max_retries} out of {max_retries} failed." ) - if time_to_sleep > 0: - _logger.info("Sleeping for %.2f seconds.", time_to_sleep) + return wrapper - time.sleep(time_to_sleep) + return decorator def pause_resume( self, max_retries: int, poll_interval: float | int = 10, timeout: float | None = None, - ) -> Any: + pre_callback: Callable[..., Any] | None = None, + on_elastic_event_callback: Callable[..., Any] | None = None, + ) -> Callable[[_F], _F]: """Retries a function with pause/resume fault tolerance. This decorator wraps a function to automatically retry execution in case of @@ -370,6 +221,9 @@ def pause_resume( Defaults to 10 seconds. timeout: The maximum number of seconds to wait for slices to become active before each retry attempt. If None, there is no timeout. + pre_callback: A callback to call before the function is attempted. + on_elastic_event_callback: A callback to call after an elastic failure + occurs. Returns: The result of the wrapped function. @@ -379,39 +233,115 @@ def pause_resume( Exception: Any other exception raised by the wrapped function that is not due to a slice down event. """ - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - for retry_index in range(max_retries): - try: - _logger.info( - "Elastic attempt %d out of %d", retry_index + 1, max_retries - ) + def internal_pre_callback(): + self.active_slice_indices = elastic.wait_for_slices( + slice_count=self.total_slice_count, + slice_to_devices=self.slice_to_devices, + poll_interval=poll_interval, + timeout=timeout, + ) + if pre_callback is not None: + pre_callback() - self.active_slice_indices = self.wait_for_slices( - poll_interval=poll_interval, timeout=timeout - ) + return self._elasticity_retry_decorator( + max_retries=max_retries, + pre_callback=internal_pre_callback, + on_elastic_event_callback=on_elastic_event_callback, + ) - return func(*args, **kwargs) - except jax.errors.JaxRuntimeError as error: - if not self.is_error_due_to_slice_down(error): - raise + def _monitor_new_slices( + self, stop_event: threading.Event, poll_interval: float | int + ): + """Monitors for new slices and sets the `new_slice_event` if found.""" + while not stop_event.wait(poll_interval): + try: + if not self.inactive_slice_indices: + _logger.debug("No inactive slices to check.") + continue - try: - _logger.info("Cleaning up any ongoing traces") - jax.profiler.stop_trace() - except (RuntimeError, ValueError) as e: - _logger.info("No ongoing traces to clean up") - except Exception: - _logger.exception("Error cleaning up ongoing traces") - raise + _logger.debug( + "Checking inactive slices: %s", self.inactive_slice_indices + ) + inactive_slice_to_devices = { + i: self.slice_to_devices[i] for i in self.inactive_slice_indices + } + newly_active_indices = elastic.get_active_slice_indices( + inactive_slice_to_devices + ) - jax.clear_caches() - for array in jax.live_arrays(): - array.delete() - raise ElasticRuntimeError( - f"Elastic attempt {max_retries} out of {max_retries} failed." + if newly_active_indices: + _logger.info( + "New slices found: %s. Setting new slice event.", + newly_active_indices, + ) + self.new_slice_event.set() + return + + _logger.debug("No new slices found.") + except Exception: # pylint: disable=broad-exception-caught + _logger.exception("Error in monitor thread") + + def replica_resize( + self, + max_resizes: int, + poll_interval: float = 10, + pre_callback: Callable[..., Any] | None = None, + on_elastic_event_callback: Callable[..., Any] | None = None, + ) -> Callable[[_F], _F]: + """Retries a function with replica/resize fault tolerance. + + Args: + max_resizes: The maximum number of times to retry the function after + resizing the replica count. + poll_interval: The number of seconds to wait between active slice checks. + Defaults to 10 seconds. + pre_callback: A callback to call before the function is attempted. + on_elastic_event_callback: A callback to call after an elastic failure + occurs. + + Returns: + The result of the wrapped function. + + Raises: + ElasticRuntimeError: If all retry attempts fail. + Exception: Any other exception raised by the wrapped function that is not + due to a slice down event. + """ + + def internal_pre_callback(): + self.active_slice_indices = elastic.wait_for_slices( + slice_count=1, + slice_to_devices=self.slice_to_devices, + poll_interval=poll_interval, + ) + + if pre_callback is not None: + pre_callback() + + retry_decorator = self._elasticity_retry_decorator( + max_retries=max_resizes, + pre_callback=internal_pre_callback, + on_elastic_event_callback=on_elastic_event_callback, + ) + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + self.new_slice_event.clear() + stop_event = threading.Event() + + monitor_thread = threading.Thread( + target=self._monitor_new_slices, + args=(stop_event, poll_interval), + daemon=True, ) + monitor_thread.start() + try: + return func(*args, **kwargs) + finally: + stop_event.set() + monitor_thread.join() + + return retry_decorator(wrapper) - return wrapper return decorator diff --git a/pathwaysutils/elastic/simulated_manager.py b/pathwaysutils/elastic/simulated_manager.py deleted file mode 100644 index ced6c79..0000000 --- a/pathwaysutils/elastic/simulated_manager.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A simulated manager for elastic training. - -This module provides a simulated manager for elastic training. It can be used -to test elastic training without needing to actually trigger elastic events. -Instead, the user can control which slices are active at what times by -calling `update_active_slice_indices`. -""" - -import logging -from typing import Sequence - -import jax -from pathwaysutils.debug import timing -from pathwaysutils.elastic import manager - - -_logger = logging.getLogger(__name__) - - -class SimulatedManager(manager.Manager): - """An elastic manager with settable slice activity. - - This class can be used to modify which slices are marked as active by - overloading the `get_active_slice_indices` function. - """ - - _simulated_active_slice_indices: set[int] - - def __init__(self, devices: Sequence[jax.Device]) -> None: - """Initializes the simulated manager. - - Args: - devices: The devices to use. If None, jax.devices() is used. - """ - self._simulated_active_slice_indices = set(d.slice_index for d in devices) - - super().__init__(devices) - - def update_active_slice_indices(self, active_slice_indices: set[int]) -> None: - """Sets the active slice indices. - - Subsequent calls to `get_active_slice_indices` will return these indices. - - Args: - active_slice_indices: The simulated active slice indices. - """ - self._simulated_active_slice_indices = active_slice_indices - _logger.debug( - "Updated: simulated_active_slice_indices=%s", - self._simulated_active_slice_indices, - ) - - @timing.timeit - def get_active_slice_indices(self) -> set[int]: - """Returns the set of active slice indices. - - Returns: - The set of active slice indices from the last call to - update_active_slice_indices. Returns an empty set if - update_active_slice_indices has not been called. - """ - active_slice_indices = self._simulated_active_slice_indices - - _logger.debug("active_slice_indices=%s", active_slice_indices) - - return active_slice_indices