Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 92 additions & 24 deletions stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
XLA_REPL_URL = "https://github.com/rocm/xla"

DEFAULT_XLA_DIR = "../xla"
DEFAULT_KERNELS_JAX_DIR = "../jax"
DEFAULT_JAX_DIR = "../jax"

MAKE_TEMPLATE = r"""
# gfx targets for which XLA and jax custom call kernels are built for
Expand Down Expand Up @@ -50,12 +50,13 @@
###


.PHONY: test clean install dist
.PHONY: test clean install dist all_wheels

.default: dist


dist: jax_rocm_plugin jax_rocm_pjrt
all_wheels: clean dist jaxlib_clean jaxlib jaxlib_install install

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only for my info: does the build order matter here?
example : you make edit to XLA and you build jaxlib first then plugins ? or it doesn't affect?

Copy link
Contributor

@Arech8 Arech8 Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@magaonka-amd the order specifies the build sequence. Though I also not sure why it's made the way it's made.

@i-chaochen , why don't you use high level rules refresh refresh_jaxlib here instead?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is refresh refresh_jaxlib ? it's in the bottom and never mentioned in https://github.com/ROCm/rocm-jax/blob/master/BUILDING.md so you mean this refresh refresh_jaxlib can replace this all_wheels ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You've modified the stack.py which contains the makefile template. How did you do this without reading it?

https://github.com/ROCm/rocm-jax/blob/master/DEVSETUP.md#the-makefile

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:) I used Cursor to make this PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chao, to be frank, by doing that you have willingly risked breaking the code base and making everyone life's worse, and wasted a couple of hours of my time in reviewing what you've done and answering to your questions (that were already answered over 9000 times in the chat previously as well). But hey, you've spared some time for yourself, didn't you?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I want to build all targets, I just need to use refresh refresh_jaxlib ?

TBH, it's obscure and when I first to create this PR I asked in the team and no one answered neither, seems it's only you know how to use this? I don't think it's a good thing and I would think the proper way is to indicate clearer with use cases in the docs.

Copy link
Contributor

@Arech8 Arech8 Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I want to build all targets, I just need to use refresh refresh_jaxlib ?

Yes, make refresh refresh_jaxlib will do the same as your proposed make all_wheels do.

TBH, it's obscure and when I first to create this PR I asked in the team and no one answered neither, seems it's only you know how to use this? I don't think it's a good thing and I would think the proper way is to indicate clearer with use cases in the docs.

This question was answered in details so many times (including direct replies to you personally - that's easily searchable in the chat, I shouldn't post links here) that I suspect that everyone is just tired to sing that nice song again one more time. I can relate to that. Wrt to the documentation, - rebukes aren't accepted. Charles was writing and rewriting it multiple times and if you were explicitly mandated to participate in that by asking your questions.

Copy link
Author

@i-chaochen i-chaochen Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, I'm not the only person confused the build stuff and I'm trying my best to fill this gap. I'm totally fine with make refresh refresh_jaxlib and rm all_wheels .

Also, I have zero attention to "rebuke" anyone and what I said that I would like to add this make refresh refresh_jaxlib use case into https://github.com/ROCm/rocm-jax/blob/master/DEVSETUP.md in PR to be more straightforward to those folks like me cannot read proper docs :)



jax_rocm_plugin:
Expand Down Expand Up @@ -163,7 +164,7 @@ def find_clang():
return None


def _resolve_relative_paths(xla_dir: str, kernels_jax_dir: str) -> tuple[str, str, str]:
def _resolve_relative_paths(xla_dir: str, jax_dir: str) -> tuple[str, str, str]:
"""Transforms relative to absolute paths. This is needed to properly support
symbolic information remapping"""
this_repo_root = os.path.dirname(os.path.realpath(__file__))
Expand All @@ -177,16 +178,16 @@ def _resolve_relative_paths(xla_dir: str, kernels_jax_dir: str) -> tuple[str, st
xla_path
), f"XLA path (specified as '{xla_dir}') doesn't resolve to existing directory at '{xla_path}'"

if kernels_jax_dir:
if jax_dir:
kernels_jax_path = (
kernels_jax_dir
if os.path.isabs(kernels_jax_dir)
else os.path.abspath(f"{this_repo_root}/jax_rocm_plugin/{kernels_jax_dir}")
jax_dir
if os.path.isabs(jax_dir)
else os.path.abspath(f"{this_repo_root}/jax_rocm_plugin/{jax_dir}")
)
# pylint: disable=line-too-long
assert os.path.isdir(
kernels_jax_path
), f"XLA path (specified as '{kernels_jax_dir}') doesn't resolve to existing directory at '{kernels_jax_path}'"
), f"XLA path (specified as '{jax_dir}') doesn't resolve to existing directory at '{kernels_jax_path}'"
else:
kernels_jax_path = None
return this_repo_root, xla_path, kernels_jax_path
Expand Down Expand Up @@ -262,7 +263,7 @@ def setup_development(
xla_ref: str,
xla_dir: str,
test_jax_ref: str,
kernels_jax_dir: str,
jax_dir: str,
rebuild_makefile: bool = False,
fix_bazel_symbols: bool = False,
rocm_path: str = "/opt/rocm",
Expand All @@ -288,7 +289,7 @@ def setup_development(
makefile_path = "./jax_rocm_plugin/Makefile"
if rebuild_makefile or not os.path.exists(makefile_path) or fix_bazel_symbols:
this_repo_root, xla_path, kernels_jax_path = _resolve_relative_paths(
xla_dir, kernels_jax_dir
xla_dir, jax_dir
)
if fix_bazel_symbols:
plugin_bazel_options = "${PLUGIN_SYMBOLS}"
Expand Down Expand Up @@ -346,6 +347,59 @@ def setup_development(
mf.write(makefile_content)


def build_and_install(
xla_ref: str,
xla_dir: str,
test_jax_ref: str,
jax_dir: str,
rebuild_makefile: bool = False,
fix_bazel_symbols: bool = False,
rocm_path: str = "/opt/rocm",
):
"""Run develop setup, then build all wheels and install jax"""
# Uninstall existing packages first
print("Uninstalling existing JAX packages...")
subprocess.run(["python3", "-m", "pip", "uninstall", "jax", "-y"], check=False)
subprocess.run(
[
"python3",
"-m",
"pip",
"uninstall",
"jaxlib",
"jax-rocm-pjrt",
"jax-rocm-plugin",
"jax-plugin",
Comment on lines +370 to +372
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@i-chaochen please explain, how did you test the PR changes?

pjrt and plugin wheel are jax-rocm7-pjrt and jax-rocm7-plugin respectively for ROCm v.7+ and were jax-rocm60-pjrt and jax-rocm60-plugin for prev generation. I don't even know if there were such wheels as simply jax-rocm-pjrt /plugin or jax-plugin.

Also what is the point of having 2 separate uninstall calls instead of just a single one?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested my local jax and xla installation and build on 0.7.1 and 0.8.0 w/o any issue. The reason for two separate calls is due to pylint check for the length, I can add as one if you think it's better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course there were no issues, because your install just didn't have any of "jax-rocm-pjrt", "jax-rocm-plugin", "jax-plugin" wheels! You could have used there any other non-existant package names as well, just what's the point of doing that?

Copy link
Author

@i-chaochen i-chaochen Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry I'm not sure I got your point. Just to be clear, this part

rocm-jax/stack.py

Lines 350 to 376 in ccbdc3b

def build_and_install(
xla_ref: str,
xla_dir: str,
test_jax_ref: str,
jax_dir: str,
rebuild_makefile: bool = False,
fix_bazel_symbols: bool = False,
rocm_path: str = "/opt/rocm",
):
"""Run develop setup, then build all wheels and install jax"""
# Uninstall existing packages first
print("Uninstalling existing JAX packages...")
subprocess.run(["python3", "-m", "pip", "uninstall", "jax", "-y"], check=False)
subprocess.run(
[
"python3",
"-m",
"pip",
"uninstall",
"jaxlib",
"jax-rocm-pjrt",
"jax-rocm-plugin",
"jax-plugin",
"-y",
],
check=False,
)

is a "sanity check" before the build and install, so uninstall the existing ones (if it has) and rebuild and install new ones, what kind of test do you expect?

"-y",
],
check=False,
)

setup_development(
xla_ref=xla_ref,
xla_dir=xla_dir,
test_jax_ref=test_jax_ref,
jax_dir=jax_dir,
rebuild_makefile=rebuild_makefile,
fix_bazel_symbols=fix_bazel_symbols,
rocm_path=rocm_path,
)

this_repo_root = os.path.dirname(os.path.realpath(__file__))
_, _, kernels_jax_path = _resolve_relative_paths(xla_dir, jax_dir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you use kernels_jax_path variable name instead of jax_dir here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernels_jax doesn't make sense at all, it should be jax_dir as it's for installation of jax repo

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understood your comment.. If it doesn't make sense for you, why did you use that name?


# 1. Run make all_wheels in jax_rocm_plugin
print("Building all wheels...")
subprocess.check_call(
["make", "all_wheels"], cwd=os.path.join(this_repo_root, "jax_rocm_plugin")
)

# 2. Run pip install . in jax repo
if kernels_jax_path:
print(f"Installing JAX from {kernels_jax_path}...")
subprocess.check_call(["pip", "install", "."], cwd=kernels_jax_path)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please treat this as suggestion since my understanding of JAX and XLA is not at reviewing code level yet.
But In my understanding pip install . in jax repo installs jax and stable jaxlib , that defeats the purpose of building jaxlib in previous steps right?
In my opinion it should be pip install --no-deps .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, thank you for noticing.

@i-chaochen also for dev setup installing jax in editable mode is more appropriate. Can you please fix it too?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@i-chaochen so at first you should install jax in editable mode. This will also install upstream jaxlib, but you can ignore that and just call make all_wheels to build and reinstall pjrt/plugin/jaxlib.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrt

In my opinion it should be pip install --no-deps

I'm not sure this would work if the stack.py build is called in a fresh environment where no jax was installed previously. jax requires some dependencies which will be absent in that case.

What could be beneficial here is --force-reinstall however, just to make sure. cc: @i-chaochen

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask what's editable mode in jax install?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same as for any other python package: it allows pip install a package from a directory and then modify the files in directory and immediately see results when importing the installed package. That's what we have to do when working on python test failures for example - modify JAX python test and sometimes core JAX python files.



def dev_docker(rm):
"""Start a docker container for local plugin development"""
cur_abs_path = os.path.abspath(os.curdir)
Expand Down Expand Up @@ -401,55 +455,59 @@ def parse_args():

subp = p.add_subparsers(dest="action", required=True)

dev = subp.add_parser("develop")
dev.add_argument(
# Common arguments for develop and build
common = argparse.ArgumentParser(add_help=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you instantiate a different top-level argparse.ArgumentParser object here? I think, the usual way is to instantiate a new subparser as it was done here previously...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since right now one is for develop and another one is for build, we could take the common part from both of them as one place.

common.add_argument(
"--rebuild-makefile",
help="Force rebuild of Makefile from template.",
action="store_true",
)
dev.add_argument(
common.add_argument(
"--xla-ref",
help="XLA commit reference to checkout on clone",
default=XLA_REPO_REF,
)
dev.add_argument(
common.add_argument(
"--xla-dir",
help=(
"Set the XLA path in the Makefile. This must either be a path "
"relative to jax_rocm_plugin or an absolute path."
),
default=DEFAULT_XLA_DIR,
)
dev.add_argument(
common.add_argument(
"--jax-ref",
help="JAX commit reference to checkout on clone",
default=TEST_JAX_REPO_REF,
)
dev.add_argument(
"--kernel-jax-dir",
common.add_argument(
"--jax-dir",
help=(
"If you want to use a local JAX directory for building the "
"plugin kernels wheel (jax_rocm7_plugin), the path to the "
"directory of repo. Defaults to %s" % DEFAULT_KERNELS_JAX_DIR
"directory of repo. Defaults to %s" % DEFAULT_JAX_DIR
),
default=DEFAULT_KERNELS_JAX_DIR,
default=DEFAULT_JAX_DIR,
)

dev.add_argument(
common.add_argument(
"--fix-bazel-symbols",
help="When this option is enabled, the script assumes you need to build "
"code in a release with symbolic info configuration to alleviate debugging. "
"The script enables respective bazel options and adds 'external' symbolic "
"links to corresponding workspaces pointing to bazel's dependencies storage.",
action="store_true",
)

dev.add_argument(
common.add_argument(
"--rocm-path",
help="Location of the ROCm to use for building Jax",
default="/opt/rocm",
)

subp.add_parser("develop", parents=[common], help="Setup development environment")
subp.add_parser(
"build", parents=[common], help="Setup, build all wheels, and install JAX"
)

doc_parser = subp.add_parser("docker")
doc_parser.add_argument(
"--rm",
Expand All @@ -469,11 +527,21 @@ def main():
xla_ref=args.xla_ref,
xla_dir=args.xla_dir,
test_jax_ref=args.jax_ref,
kernels_jax_dir=args.kernel_jax_dir,
jax_dir=args.jax_dir,
rebuild_makefile=args.rebuild_makefile,
fix_bazel_symbols=args.fix_bazel_symbols,
rocm_path=args.rocm_path,
)
elif args.action == "build":
build_and_install(
xla_ref=args.xla_ref,
xla_dir=args.xla_dir,
test_jax_ref=args.jax_ref,
jax_dir=args.jax_dir,
rebuild_makefile=True,
fix_bazel_symbols=args.fix_bazel_symbols,
rocm_path=args.rocm_path,
)


if __name__ == "__main__":
Expand Down