To build and install jax, jaxlib, pjrt and plugin at once#244
To build and install jax, jaxlib, pjrt and plugin at once#244i-chaochen wants to merge 2 commits intomasterfrom
Conversation
…ith given xla and jax repo
| # 2. Run pip install . in jax repo | ||
| if kernels_jax_path: | ||
| print(f"Installing JAX from {kernels_jax_path}...") | ||
| subprocess.check_call(["pip", "install", "."], cwd=kernels_jax_path) |
There was a problem hiding this comment.
Please treat this as suggestion since my understanding of JAX and XLA is not at reviewing code level yet.
But In my understanding pip install . in jax repo installs jax and stable jaxlib , that defeats the purpose of building jaxlib in previous steps right?
In my opinion it should be pip install --no-deps .
There was a problem hiding this comment.
Correct, thank you for noticing.
@i-chaochen also for dev setup installing jax in editable mode is more appropriate. Can you please fix it too?
There was a problem hiding this comment.
@i-chaochen so at first you should install jax in editable mode. This will also install upstream jaxlib, but you can ignore that and just call make all_wheels to build and reinstall pjrt/plugin/jaxlib.
There was a problem hiding this comment.
wrt
In my opinion it should be pip install --no-deps
I'm not sure this would work if the stack.py build is called in a fresh environment where no jax was installed previously. jax requires some dependencies which will be absent in that case.
What could be beneficial here is --force-reinstall however, just to make sure. cc: @i-chaochen
There was a problem hiding this comment.
May I ask what's editable mode in jax install?
There was a problem hiding this comment.
The same as for any other python package: it allows pip install a package from a directory and then modify the files in directory and immediately see results when importing the installed package. That's what we have to do when working on python test failures for example - modify JAX python test and sometimes core JAX python files.
|
|
||
|
|
||
| dist: jax_rocm_plugin jax_rocm_pjrt | ||
| all_wheels: clean dist jaxlib_clean jaxlib jaxlib_install install |
There was a problem hiding this comment.
This only for my info: does the build order matter here?
example : you make edit to XLA and you build jaxlib first then plugins ? or it doesn't affect?
There was a problem hiding this comment.
@magaonka-amd the order specifies the build sequence. Though I also not sure why it's made the way it's made.
@i-chaochen , why don't you use high level rules refresh refresh_jaxlib here instead?
There was a problem hiding this comment.
where is refresh refresh_jaxlib ? it's in the bottom and never mentioned in https://github.com/ROCm/rocm-jax/blob/master/BUILDING.md so you mean this refresh refresh_jaxlib can replace this all_wheels ?
There was a problem hiding this comment.
You've modified the stack.py which contains the makefile template. How did you do this without reading it?
https://github.com/ROCm/rocm-jax/blob/master/DEVSETUP.md#the-makefile
There was a problem hiding this comment.
:) I used Cursor to make this PR
There was a problem hiding this comment.
Chao, to be frank, by doing that you have willingly risked breaking the code base and making everyone life's worse, and wasted a couple of hours of my time in reviewing what you've done and answering to your questions (that were already answered over 9000 times in the chat previously as well). But hey, you've spared some time for yourself, didn't you?
There was a problem hiding this comment.
So if I want to build all targets, I just need to use refresh refresh_jaxlib ?
TBH, it's obscure and when I first to create this PR I asked in the team and no one answered neither, seems it's only you know how to use this? I don't think it's a good thing and I would think the proper way is to indicate clearer with use cases in the docs.
There was a problem hiding this comment.
So if I want to build all targets, I just need to use refresh refresh_jaxlib ?
Yes, make refresh refresh_jaxlib will do the same as your proposed make all_wheels do.
TBH, it's obscure and when I first to create this PR I asked in the team and no one answered neither, seems it's only you know how to use this? I don't think it's a good thing and I would think the proper way is to indicate clearer with use cases in the docs.
This question was answered in details so many times (including direct replies to you personally - that's easily searchable in the chat, I shouldn't post links here) that I suspect that everyone is just tired to sing that nice song again one more time. I can relate to that. Wrt to the documentation, - rebukes aren't accepted. Charles was writing and rewriting it multiple times and if you were explicitly mandated to participate in that by asking your questions.
There was a problem hiding this comment.
Again, I'm not the only person confused the build stuff and I'm trying my best to fill this gap. I'm totally fine with make refresh refresh_jaxlib and rm all_wheels .
Also, I have zero attention to "rebuke" anyone and what I said that I would like to add this make refresh refresh_jaxlib use case into https://github.com/ROCm/rocm-jax/blob/master/DEVSETUP.md in PR to be more straightforward to those folks like me cannot read proper docs :)
Arech8
left a comment
There was a problem hiding this comment.
Please address the requests and add a section in the PR description describing your changes testing procedure.
| "jax-rocm-pjrt", | ||
| "jax-rocm-plugin", | ||
| "jax-plugin", |
There was a problem hiding this comment.
@i-chaochen please explain, how did you test the PR changes?
pjrt and plugin wheel are jax-rocm7-pjrt and jax-rocm7-plugin respectively for ROCm v.7+ and were jax-rocm60-pjrt and jax-rocm60-plugin for prev generation. I don't even know if there were such wheels as simply jax-rocm-pjrt /plugin or jax-plugin.
Also what is the point of having 2 separate uninstall calls instead of just a single one?
There was a problem hiding this comment.
I tested my local jax and xla installation and build on 0.7.1 and 0.8.0 w/o any issue. The reason for two separate calls is due to pylint check for the length, I can add as one if you think it's better.
There was a problem hiding this comment.
Of course there were no issues, because your install just didn't have any of "jax-rocm-pjrt", "jax-rocm-plugin", "jax-plugin" wheels! You could have used there any other non-existant package names as well, just what's the point of doing that?
There was a problem hiding this comment.
I'm sorry I'm not sure I got your point. Just to be clear, this part
Lines 350 to 376 in ccbdc3b
is a "sanity check" before the build and install, so uninstall the existing ones (if it has) and rebuild and install new ones, what kind of test do you expect?
| ) | ||
|
|
||
| this_repo_root = os.path.dirname(os.path.realpath(__file__)) | ||
| _, _, kernels_jax_path = _resolve_relative_paths(xla_dir, jax_dir) |
There was a problem hiding this comment.
why do you use kernels_jax_path variable name instead of jax_dir here?
There was a problem hiding this comment.
kernels_jax doesn't make sense at all, it should be jax_dir as it's for installation of jax repo
There was a problem hiding this comment.
Not sure I understood your comment.. If it doesn't make sense for you, why did you use that name?
| # 2. Run pip install . in jax repo | ||
| if kernels_jax_path: | ||
| print(f"Installing JAX from {kernels_jax_path}...") | ||
| subprocess.check_call(["pip", "install", "."], cwd=kernels_jax_path) |
There was a problem hiding this comment.
@i-chaochen so at first you should install jax in editable mode. This will also install upstream jaxlib, but you can ignore that and just call make all_wheels to build and reinstall pjrt/plugin/jaxlib.
| dev = subp.add_parser("develop") | ||
| dev.add_argument( | ||
| # Common arguments for develop and build | ||
| common = argparse.ArgumentParser(add_help=False) |
There was a problem hiding this comment.
why do you instantiate a different top-level argparse.ArgumentParser object here? I think, the usual way is to instantiate a new subparser as it was done here previously...
There was a problem hiding this comment.
since right now one is for develop and another one is for build, we could take the common part from both of them as one place.
|
What kind of tests do you want? I tested on my local to build 0.71 and 0.8.0. |
Motivation
0.8.0: #243
Now we can just use one line to build everything (jax, jaxlib, pjrt and plugin) as we're used to.