-
Notifications
You must be signed in to change notification settings - Fork 5
To build and install jax, jaxlib, pjrt and plugin at once #244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -14,7 +14,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| XLA_REPL_URL = "https://github.com/rocm/xla" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| DEFAULT_XLA_DIR = "../xla" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| DEFAULT_KERNELS_JAX_DIR = "../jax" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| DEFAULT_JAX_DIR = "../jax" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| MAKE_TEMPLATE = r""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # gfx targets for which XLA and jax custom call kernels are built for | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -50,12 +50,13 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ### | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| .PHONY: test clean install dist | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| .PHONY: test clean install dist all_wheels | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| .default: dist | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dist: jax_rocm_plugin jax_rocm_pjrt | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_wheels: clean dist jaxlib_clean jaxlib jaxlib_install install | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| jax_rocm_plugin: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -163,7 +164,7 @@ def find_clang(): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _resolve_relative_paths(xla_dir: str, kernels_jax_dir: str) -> tuple[str, str, str]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _resolve_relative_paths(xla_dir: str, jax_dir: str) -> tuple[str, str, str]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Transforms relative to absolute paths. This is needed to properly support | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| symbolic information remapping""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| this_repo_root = os.path.dirname(os.path.realpath(__file__)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -177,16 +178,16 @@ def _resolve_relative_paths(xla_dir: str, kernels_jax_dir: str) -> tuple[str, st | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| xla_path | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), f"XLA path (specified as '{xla_dir}') doesn't resolve to existing directory at '{xla_path}'" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if kernels_jax_dir: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if jax_dir: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernels_jax_path = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernels_jax_dir | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if os.path.isabs(kernels_jax_dir) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else os.path.abspath(f"{this_repo_root}/jax_rocm_plugin/{kernels_jax_dir}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| jax_dir | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if os.path.isabs(jax_dir) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else os.path.abspath(f"{this_repo_root}/jax_rocm_plugin/{jax_dir}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # pylint: disable=line-too-long | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert os.path.isdir( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernels_jax_path | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), f"XLA path (specified as '{kernels_jax_dir}') doesn't resolve to existing directory at '{kernels_jax_path}'" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), f"XLA path (specified as '{jax_dir}') doesn't resolve to existing directory at '{kernels_jax_path}'" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernels_jax_path = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return this_repo_root, xla_path, kernels_jax_path | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -262,7 +263,7 @@ def setup_development( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| xla_ref: str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| xla_dir: str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| test_jax_ref: str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernels_jax_dir: str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| jax_dir: str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rebuild_makefile: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fix_bazel_symbols: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rocm_path: str = "/opt/rocm", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -288,7 +289,7 @@ def setup_development( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| makefile_path = "./jax_rocm_plugin/Makefile" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if rebuild_makefile or not os.path.exists(makefile_path) or fix_bazel_symbols: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| this_repo_root, xla_path, kernels_jax_path = _resolve_relative_paths( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| xla_dir, kernels_jax_dir | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| xla_dir, jax_dir | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if fix_bazel_symbols: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| plugin_bazel_options = "${PLUGIN_SYMBOLS}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -346,6 +347,59 @@ def setup_development( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mf.write(makefile_content) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def build_and_install( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| xla_ref: str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| xla_dir: str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| test_jax_ref: str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| jax_dir: str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rebuild_makefile: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fix_bazel_symbols: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rocm_path: str = "/opt/rocm", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Run develop setup, then build all wheels and install jax""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Uninstall existing packages first | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print("Uninstalling existing JAX packages...") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| subprocess.run(["python3", "-m", "pip", "uninstall", "jax", "-y"], check=False) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| subprocess.run( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "python3", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "-m", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "pip", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "uninstall", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "jaxlib", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "jax-rocm-pjrt", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "jax-rocm-plugin", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "jax-plugin", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+370
to
+372
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @i-chaochen please explain, how did you test the PR changes? pjrt and plugin wheel are Also what is the point of having 2 separate
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tested my local jax and xla installation and build on 0.7.1 and 0.8.0 w/o any issue. The reason for two separate calls is due to pylint check for the length, I can add as one if you think it's better.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Of course there were no issues, because your install just didn't have any of
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm sorry I'm not sure I got your point. Just to be clear, this part Lines 350 to 376 in ccbdc3b
is a "sanity check" before the build and install, so uninstall the existing ones (if it has) and rebuild and install new ones, what kind of test do you expect? |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "-y", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| check=False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| setup_development( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| xla_ref=xla_ref, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| xla_dir=xla_dir, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| test_jax_ref=test_jax_ref, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| jax_dir=jax_dir, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rebuild_makefile=rebuild_makefile, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fix_bazel_symbols=fix_bazel_symbols, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rocm_path=rocm_path, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| this_repo_root = os.path.dirname(os.path.realpath(__file__)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _, _, kernels_jax_path = _resolve_relative_paths(xla_dir, jax_dir) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do you use
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure I understood your comment.. If it doesn't make sense for you, why did you use that name? |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # 1. Run make all_wheels in jax_rocm_plugin | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print("Building all wheels...") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| subprocess.check_call( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ["make", "all_wheels"], cwd=os.path.join(this_repo_root, "jax_rocm_plugin") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # 2. Run pip install . in jax repo | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if kernels_jax_path: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"Installing JAX from {kernels_jax_path}...") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| subprocess.check_call(["pip", "install", "."], cwd=kernels_jax_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please treat this as suggestion since my understanding of JAX and XLA is not at reviewing code level yet.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct, thank you for noticing. @i-chaochen also for dev setup installing jax in
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @i-chaochen so at first you should install jax in editable mode. This will also install upstream
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wrt
I'm not sure this would work if the What could be beneficial here is
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May I ask what's
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The same as for any other python package: it allows |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def dev_docker(rm): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Start a docker container for local plugin development""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cur_abs_path = os.path.abspath(os.curdir) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -401,55 +455,59 @@ def parse_args(): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| subp = p.add_subparsers(dest="action", required=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dev = subp.add_parser("develop") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dev.add_argument( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Common arguments for develop and build | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| common = argparse.ArgumentParser(add_help=False) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do you instantiate a different top-level
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since right now one is for |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| common.add_argument( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "--rebuild-makefile", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| help="Force rebuild of Makefile from template.", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| action="store_true", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dev.add_argument( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| common.add_argument( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "--xla-ref", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| help="XLA commit reference to checkout on clone", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| default=XLA_REPO_REF, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dev.add_argument( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| common.add_argument( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "--xla-dir", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| help=( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "Set the XLA path in the Makefile. This must either be a path " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "relative to jax_rocm_plugin or an absolute path." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| default=DEFAULT_XLA_DIR, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dev.add_argument( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| common.add_argument( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "--jax-ref", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| help="JAX commit reference to checkout on clone", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| default=TEST_JAX_REPO_REF, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dev.add_argument( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "--kernel-jax-dir", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| common.add_argument( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "--jax-dir", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| help=( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "If you want to use a local JAX directory for building the " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "plugin kernels wheel (jax_rocm7_plugin), the path to the " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "directory of repo. Defaults to %s" % DEFAULT_KERNELS_JAX_DIR | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "directory of repo. Defaults to %s" % DEFAULT_JAX_DIR | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| default=DEFAULT_KERNELS_JAX_DIR, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| default=DEFAULT_JAX_DIR, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dev.add_argument( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| common.add_argument( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "--fix-bazel-symbols", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| help="When this option is enabled, the script assumes you need to build " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "code in a release with symbolic info configuration to alleviate debugging. " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "The script enables respective bazel options and adds 'external' symbolic " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "links to corresponding workspaces pointing to bazel's dependencies storage.", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| action="store_true", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dev.add_argument( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| common.add_argument( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "--rocm-path", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| help="Location of the ROCm to use for building Jax", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| default="/opt/rocm", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| subp.add_parser("develop", parents=[common], help="Setup development environment") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| subp.add_parser( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "build", parents=[common], help="Setup, build all wheels, and install JAX" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| doc_parser = subp.add_parser("docker") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| doc_parser.add_argument( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "--rm", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -469,11 +527,21 @@ def main(): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| xla_ref=args.xla_ref, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| xla_dir=args.xla_dir, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| test_jax_ref=args.jax_ref, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernels_jax_dir=args.kernel_jax_dir, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| jax_dir=args.jax_dir, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rebuild_makefile=args.rebuild_makefile, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fix_bazel_symbols=args.fix_bazel_symbols, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rocm_path=args.rocm_path, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif args.action == "build": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| build_and_install( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| xla_ref=args.xla_ref, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| xla_dir=args.xla_dir, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| test_jax_ref=args.jax_ref, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| jax_dir=args.jax_dir, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rebuild_makefile=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fix_bazel_symbols=args.fix_bazel_symbols, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rocm_path=args.rocm_path, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only for my info: does the build order matter here?
example : you make edit to XLA and you build jaxlib first then plugins ? or it doesn't affect?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@magaonka-amd the order specifies the build sequence. Though I also not sure why it's made the way it's made.
@i-chaochen , why don't you use high level rules
refresh refresh_jaxlibhere instead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is
refresh refresh_jaxlib? it's in the bottom and never mentioned in https://github.com/ROCm/rocm-jax/blob/master/BUILDING.md so you mean thisrefresh refresh_jaxlibcan replace thisall_wheels?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You've modified the stack.py which contains the makefile template. How did you do this without reading it?
https://github.com/ROCm/rocm-jax/blob/master/DEVSETUP.md#the-makefile
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:) I used Cursor to make this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Chao, to be frank, by doing that you have willingly risked breaking the code base and making everyone life's worse, and wasted a couple of hours of my time in reviewing what you've done and answering to your questions (that were already answered over 9000 times in the chat previously as well). But hey, you've spared some time for yourself, didn't you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if I want to build all targets, I just need to use
refresh refresh_jaxlib?TBH, it's obscure and when I first to create this PR I asked in the team and no one answered neither, seems it's only you know how to use this? I don't think it's a good thing and I would think the proper way is to indicate clearer with use cases in the docs.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes,
make refresh refresh_jaxlibwill do the same as your proposedmake all_wheelsdo.This question was answered in details so many times (including direct replies to you personally - that's easily searchable in the chat, I shouldn't post links here) that I suspect that everyone is just tired to sing that nice song again one more time. I can relate to that. Wrt to the documentation, - rebukes aren't accepted. Charles was writing and rewriting it multiple times and if you were explicitly mandated to participate in that by asking your questions.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, I'm not the only person confused the build stuff and I'm trying my best to fill this gap. I'm totally fine with
make refresh refresh_jaxliband rmall_wheels.Also, I have zero attention to "rebuke" anyone and what I said that I would like to add this
make refresh refresh_jaxlibuse case into https://github.com/ROCm/rocm-jax/blob/master/DEVSETUP.md in PR to be more straightforward to those folks like me cannot read proper docs :)