[fix] Apply with_sharding_constraint recursively to pytrees#4
Open
dvruette wants to merge 9 commits intoerfanzar:mainfrom
Open
[fix] Apply with_sharding_constraint recursively to pytrees#4dvruette wants to merge 9 commits intoerfanzar:mainfrom
with_sharding_constraint recursively to pytrees#4dvruette wants to merge 9 commits intoerfanzar:mainfrom
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
[generated by copilot]
This pull request refactors and enhances the handling of sharding constraints in the
eformer/escale/partition/constraints.pyfile. The most notable changes include renaming and improving an existing function for applying sharding constraints and introducing a new function to handle PyTrees of JAX arrays with enhanced validation and correction logic.Function renaming and improvement:
with_sharding_constrainttoarray_with_sharding_constraintto clarify its purpose as operating on a single JAX array. Updated the function's type annotations to usejax.Arrayfor improved clarity and consistency.New functionality for PyTrees:
with_sharding_constraintfunction to apply sharding constraints to PyTrees of JAX arrays. This function validates the compatibility of the input PyTree structure and sharding specification, ensures all elements in the sharding specification are valid types, and applies corrections to incompatible sharding axes.