Skip to content

Comments

To build and install jax, jaxlib, pjrt and plugin at once#244

Open
i-chaochen wants to merge 2 commits intomasterfrom
chao/all_wheels_build_master
Open

To build and install jax, jaxlib, pjrt and plugin at once#244
i-chaochen wants to merge 2 commits intomasterfrom
chao/all_wheels_build_master

Conversation

@i-chaochen
Copy link

@i-chaochen i-chaochen commented Jan 9, 2026

Motivation

0.8.0: #243

Now we can just use one line to build everything (jax, jaxlib, pjrt and plugin) as we're used to.

python3 stack.py build --xla-dir=/my/own/xla/path --jax-dir=/my/own/jax/path

Copy link

@magaonka-amd magaonka-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left two minor comments.

# 2. Run pip install . in jax repo
if kernels_jax_path:
print(f"Installing JAX from {kernels_jax_path}...")
subprocess.check_call(["pip", "install", "."], cwd=kernels_jax_path)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please treat this as suggestion since my understanding of JAX and XLA is not at reviewing code level yet.
But In my understanding pip install . in jax repo installs jax and stable jaxlib , that defeats the purpose of building jaxlib in previous steps right?
In my opinion it should be pip install --no-deps .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, thank you for noticing.

@i-chaochen also for dev setup installing jax in editable mode is more appropriate. Can you please fix it too?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@i-chaochen so at first you should install jax in editable mode. This will also install upstream jaxlib, but you can ignore that and just call make all_wheels to build and reinstall pjrt/plugin/jaxlib.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrt

In my opinion it should be pip install --no-deps

I'm not sure this would work if the stack.py build is called in a fresh environment where no jax was installed previously. jax requires some dependencies which will be absent in that case.

What could be beneficial here is --force-reinstall however, just to make sure. cc: @i-chaochen

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask what's editable mode in jax install?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same as for any other python package: it allows pip install a package from a directory and then modify the files in directory and immediately see results when importing the installed package. That's what we have to do when working on python test failures for example - modify JAX python test and sometimes core JAX python files.



dist: jax_rocm_plugin jax_rocm_pjrt
all_wheels: clean dist jaxlib_clean jaxlib jaxlib_install install

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only for my info: does the build order matter here?
example : you make edit to XLA and you build jaxlib first then plugins ? or it doesn't affect?

Copy link
Contributor

@Arech8 Arech8 Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@magaonka-amd the order specifies the build sequence. Though I also not sure why it's made the way it's made.

@i-chaochen , why don't you use high level rules refresh refresh_jaxlib here instead?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is refresh refresh_jaxlib ? it's in the bottom and never mentioned in https://github.com/ROCm/rocm-jax/blob/master/BUILDING.md so you mean this refresh refresh_jaxlib can replace this all_wheels ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You've modified the stack.py which contains the makefile template. How did you do this without reading it?

https://github.com/ROCm/rocm-jax/blob/master/DEVSETUP.md#the-makefile

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:) I used Cursor to make this PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chao, to be frank, by doing that you have willingly risked breaking the code base and making everyone life's worse, and wasted a couple of hours of my time in reviewing what you've done and answering to your questions (that were already answered over 9000 times in the chat previously as well). But hey, you've spared some time for yourself, didn't you?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I want to build all targets, I just need to use refresh refresh_jaxlib ?

TBH, it's obscure and when I first to create this PR I asked in the team and no one answered neither, seems it's only you know how to use this? I don't think it's a good thing and I would think the proper way is to indicate clearer with use cases in the docs.

Copy link
Contributor

@Arech8 Arech8 Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I want to build all targets, I just need to use refresh refresh_jaxlib ?

Yes, make refresh refresh_jaxlib will do the same as your proposed make all_wheels do.

TBH, it's obscure and when I first to create this PR I asked in the team and no one answered neither, seems it's only you know how to use this? I don't think it's a good thing and I would think the proper way is to indicate clearer with use cases in the docs.

This question was answered in details so many times (including direct replies to you personally - that's easily searchable in the chat, I shouldn't post links here) that I suspect that everyone is just tired to sing that nice song again one more time. I can relate to that. Wrt to the documentation, - rebukes aren't accepted. Charles was writing and rewriting it multiple times and if you were explicitly mandated to participate in that by asking your questions.

Copy link
Author

@i-chaochen i-chaochen Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, I'm not the only person confused the build stuff and I'm trying my best to fill this gap. I'm totally fine with make refresh refresh_jaxlib and rm all_wheels .

Also, I have zero attention to "rebuke" anyone and what I said that I would like to add this make refresh refresh_jaxlib use case into https://github.com/ROCm/rocm-jax/blob/master/DEVSETUP.md in PR to be more straightforward to those folks like me cannot read proper docs :)

Copy link
Contributor

@Arech8 Arech8 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please address the requests and add a section in the PR description describing your changes testing procedure.

Comment on lines +370 to +372
"jax-rocm-pjrt",
"jax-rocm-plugin",
"jax-plugin",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@i-chaochen please explain, how did you test the PR changes?

pjrt and plugin wheel are jax-rocm7-pjrt and jax-rocm7-plugin respectively for ROCm v.7+ and were jax-rocm60-pjrt and jax-rocm60-plugin for prev generation. I don't even know if there were such wheels as simply jax-rocm-pjrt /plugin or jax-plugin.

Also what is the point of having 2 separate uninstall calls instead of just a single one?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested my local jax and xla installation and build on 0.7.1 and 0.8.0 w/o any issue. The reason for two separate calls is due to pylint check for the length, I can add as one if you think it's better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course there were no issues, because your install just didn't have any of "jax-rocm-pjrt", "jax-rocm-plugin", "jax-plugin" wheels! You could have used there any other non-existant package names as well, just what's the point of doing that?

Copy link
Author

@i-chaochen i-chaochen Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry I'm not sure I got your point. Just to be clear, this part

rocm-jax/stack.py

Lines 350 to 376 in ccbdc3b

def build_and_install(
xla_ref: str,
xla_dir: str,
test_jax_ref: str,
jax_dir: str,
rebuild_makefile: bool = False,
fix_bazel_symbols: bool = False,
rocm_path: str = "/opt/rocm",
):
"""Run develop setup, then build all wheels and install jax"""
# Uninstall existing packages first
print("Uninstalling existing JAX packages...")
subprocess.run(["python3", "-m", "pip", "uninstall", "jax", "-y"], check=False)
subprocess.run(
[
"python3",
"-m",
"pip",
"uninstall",
"jaxlib",
"jax-rocm-pjrt",
"jax-rocm-plugin",
"jax-plugin",
"-y",
],
check=False,
)

is a "sanity check" before the build and install, so uninstall the existing ones (if it has) and rebuild and install new ones, what kind of test do you expect?

)

this_repo_root = os.path.dirname(os.path.realpath(__file__))
_, _, kernels_jax_path = _resolve_relative_paths(xla_dir, jax_dir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you use kernels_jax_path variable name instead of jax_dir here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernels_jax doesn't make sense at all, it should be jax_dir as it's for installation of jax repo

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understood your comment.. If it doesn't make sense for you, why did you use that name?

# 2. Run pip install . in jax repo
if kernels_jax_path:
print(f"Installing JAX from {kernels_jax_path}...")
subprocess.check_call(["pip", "install", "."], cwd=kernels_jax_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@i-chaochen so at first you should install jax in editable mode. This will also install upstream jaxlib, but you can ignore that and just call make all_wheels to build and reinstall pjrt/plugin/jaxlib.

dev = subp.add_parser("develop")
dev.add_argument(
# Common arguments for develop and build
common = argparse.ArgumentParser(add_help=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you instantiate a different top-level argparse.ArgumentParser object here? I think, the usual way is to instantiate a new subparser as it was done here previously...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since right now one is for develop and another one is for build, we could take the common part from both of them as one place.

@i-chaochen
Copy link
Author

i-chaochen commented Jan 12, 2026

What kind of tests do you want? I tested on my local to build 0.71 and 0.8.0.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants