Open
Conversation
An upcoming MLIR change is going to remove .isinstance. This is necessary to keep the code functioning. Fixes a couple more cases that I did not find earlier and found with integration testing. PiperOrigin-RevId: 853383203
PiperOrigin-RevId: 853413470
PiperOrigin-RevId: 853419648
An upcoming MLIR change is going to remove .isinstance. This is necessary to keep the code functioning. Fixes a couple more cases that I did not find earlier and found with integration testing. PiperOrigin-RevId: 853420916
Co-authored-by: Yash Katariya <yashkatariya@google.com>
Co-authored-by: Parker Schuh <parkers@google.com> Co-authored-by: Yash Katariya <yashkatariya@google.com> PiperOrigin-RevId: 853442821
PiperOrigin-RevId: 853444130
PiperOrigin-RevId: 853448754
Co-authored-by: Matthew Johnson <mattjj@google.com> PiperOrigin-RevId: 853464510
…ommit/76db112da7c2c66afeb550fc1089e6bec297bd4d PiperOrigin-RevId: 853591014
These tests pass just fine, now. PiperOrigin-RevId: 853694590
…ut inference. Use this `bitwidth` field to reject more relayouts that would fail in lowering. Previously to this change, we would sometimes erroneously choose such combinations of layouts. A concrete example is the following: ``` x: vector<?xi4> x_cast = layout_cast(x, WGMMA_LAYOUT_UPCAST_4X) y = convert(x): vector<?xbf16> y_cast = layout_cast(x, WGMMA_LAYOUT) ``` Clearly, the `layout_cast`s here force us to pick a point at which to relayout the `vector` from the `WGMMA_LAYOUT_UPCAST_4X` layout to the `WGMMA_LAYOUT`. Without filtering for bitwidth, there are two choices. We can either relayout before the `convert`, or after. However, we do not support this relayout for `16`-bit values---and choosing to relayout after the `convert` will therefore fail in lowering. PiperOrigin-RevId: 853729640
The tiling is global and applies to all refs passed to the pipeline function. This change is necessary to support using `pltpu.emit_pipeline` on SC where the tiling can either be (8, 128) or (8,). PiperOrigin-RevId: 853747075
PiperOrigin-RevId: 853768795
The thread guard prohibits execution of multi-process JAX operations on threads other than the owning one. This helps detect when McJAX operations are launched in different orders on different hosts, leading to intermittent crashes. PiperOrigin-RevId: 853785440
PiperOrigin-RevId: 853801923
PiperOrigin-RevId: 853806284
…`//tests/pallas:tpu_tests`. PiperOrigin-RevId: 853827799
Co-authored-by: Roy Frostig <frostig@google.com>
See https://github.com/jax-ml/jax/actions/runs/20830797751/job/59844186018. There were some new deps added in the recent MLIR update that were not propagated through. PiperOrigin-RevId: 853864868
Co-authored-by: Yash Katariya <yashkatariya@google.com>
…refs PiperOrigin-RevId: 853931279
…ckend. `del wrapper.object` ensures that the colocated python code at the backend does not have any remaining references on the object, irrespective of whether the backend code has any (accidental) references left over for the wrapper itself. PiperOrigin-RevId: 853946673
PiperOrigin-RevId: 853966274
PiperOrigin-RevId: 853977644
…ommit/9ae3d6dab2c10c8195c8d9862f475904c7cdca91 PiperOrigin-RevId: 854059415
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
### Problem Pallas tests fail on AMD ROCm GPUs with the error:ValueError: invalid literal for int() with base 10: 'gfx950'
This occurs because:
compute_capabilityas "major.minor" (e.g., "9.0")Solution
This PR fixes the issue with three changes:
jax/_src/pallas/pallas_call.py: Route ROCm devices to Triton backendis_rocmcheck ingpu_lowering()backend=Noneand running on ROCm, automatically use Triton instead of Mosaic GPUbackend='mosaic_gpu'on ROCmjax/experimental/mosaic/gpu/core.py: Add safety check in_infer_arch()tests/pallas/ops_test.py: Fixtest_delayskip logicdelayprimitive is only implemented in Mosaic GPU, not Tritonjtu.is_device_rocm() or not use_mosaic_gpuTesting
test_delayproperly skips on ROCm since delay is MGPU-onlySubmission Checklist