Skip to content

Rocm jaxlib v0.9.0#671

Open
Ruturaj4 wants to merge 6583 commits intomainfrom
rocm-jaxlib-v0.9.0
Open

Rocm jaxlib v0.9.0#671
Ruturaj4 wants to merge 6583 commits intomainfrom
rocm-jaxlib-v0.9.0

Conversation

@Ruturaj4
Copy link

### Problem

Pallas tests fail on AMD ROCm GPUs with the error:

ValueError: invalid literal for int() with base 10: 'gfx950'

This occurs because:

  1. The Mosaic GPU backend (NVIDIA-specific) attempts to parse the GPU architecture string
  2. NVIDIA GPUs return compute_capability as "major.minor" (e.g., "9.0")
  3. AMD ROCm GPUs return architecture identifiers like "gfx950"
  4. The code tries to parse "gfx950" as an integer, causing the failure

Solution

This PR fixes the issue with three changes:

  1. jax/_src/pallas/pallas_call.py: Route ROCm devices to Triton backend

    • Added is_rocm check in gpu_lowering()
    • When backend=None and running on ROCm, automatically use Triton instead of Mosaic GPU
    • Added clear error if user explicitly requests backend='mosaic_gpu' on ROCm
  2. jax/experimental/mosaic/gpu/core.py: Add safety check in _infer_arch()

    • Detect ROCm architecture strings (starting with "gfx") and raise a descriptive error
  3. tests/pallas/ops_test.py: Fix test_delay skip logic

    • The delay primitive is only implemented in Mosaic GPU, not Triton
    • Updated skip condition to include ROCm devices: jtu.is_device_rocm() or not use_mosaic_gpu

Testing

  • tests now pass (previously failing with gfx950 error)
  • test_delay properly skips on ROCm since delay is MGPU-only

Submission Checklist

boomanaiden154 and others added 30 commits January 7, 2026 13:10
An upcoming MLIR change is going to remove .isinstance. This is necessary to keep the code functioning. Fixes a couple more cases that I did not find earlier and found with integration testing.

PiperOrigin-RevId: 853383203
An upcoming MLIR change is going to remove .isinstance. This is necessary to keep the code functioning. Fixes a couple more cases that I did not find earlier and found with integration testing.

PiperOrigin-RevId: 853420916
Co-authored-by: Yash Katariya <yashkatariya@google.com>
Co-authored-by: Parker Schuh <parkers@google.com>
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 853442821
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 853464510
These tests pass just fine, now.

PiperOrigin-RevId: 853694590
…ut inference.

Use this `bitwidth` field to reject more relayouts that would fail in lowering.
Previously to this change, we would sometimes erroneously choose such
combinations of layouts. A concrete example is the following:

```
x: vector<?xi4>
x_cast = layout_cast(x, WGMMA_LAYOUT_UPCAST_4X)
y = convert(x): vector<?xbf16>
y_cast = layout_cast(x, WGMMA_LAYOUT)
```

Clearly, the `layout_cast`s here force us to pick a point at which to relayout
the `vector` from the `WGMMA_LAYOUT_UPCAST_4X` layout to the `WGMMA_LAYOUT`.

Without filtering for bitwidth, there are two choices. We can either relayout
before the `convert`, or after. However, we do not support this relayout for
`16`-bit values---and choosing to relayout after the `convert` will therefore
fail in lowering.

PiperOrigin-RevId: 853729640
The tiling is global and applies to all refs passed to the pipeline function.

This change is necessary to support using `pltpu.emit_pipeline` on SC where
the tiling can either be (8, 128) or (8,).

PiperOrigin-RevId: 853747075
PiperOrigin-RevId: 853768795
The thread guard prohibits execution of multi-process JAX operations on threads other than the owning one. This helps detect when McJAX operations are launched in different orders on different hosts, leading to intermittent crashes.

PiperOrigin-RevId: 853785440
PiperOrigin-RevId: 853801923
…`//tests/pallas:tpu_tests`.

PiperOrigin-RevId: 853827799
Co-authored-by: Roy Frostig <frostig@google.com>
See https://github.com/jax-ml/jax/actions/runs/20830797751/job/59844186018. There were some new deps added in the recent MLIR update that were not propagated through.

PiperOrigin-RevId: 853864868
Co-authored-by: Yash Katariya <yashkatariya@google.com>
…ckend.

`del wrapper.object` ensures that the colocated python code at the backend does not have any remaining references on the object, irrespective of whether the backend code has any (accidental) references left over for the wrapper itself.

PiperOrigin-RevId: 853946673
@Ruturaj4 Ruturaj4 requested a review from a team as a code owner January 28, 2026 23:56
Ruturaj4 and others added 24 commits January 28, 2026 18:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.