Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 92 additions & 24 deletions stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
XLA_REPL_URL = "https://github.com/rocm/xla"

DEFAULT_XLA_DIR = "../xla"
DEFAULT_KERNELS_JAX_DIR = "../jax"
DEFAULT_JAX_DIR = "../jax"

MAKE_TEMPLATE = r"""
# gfx targets for which XLA and jax custom call kernels are built for
Expand Down Expand Up @@ -50,12 +50,13 @@
###


.PHONY: test clean install dist
.PHONY: test clean install dist all_wheels

.default: dist


dist: jax_rocm_plugin jax_rocm_pjrt
all_wheels: clean dist jaxlib_clean jaxlib jaxlib_install install


jax_rocm_plugin:
Expand Down Expand Up @@ -163,7 +164,7 @@ def find_clang():
return None


def _resolve_relative_paths(xla_dir: str, kernels_jax_dir: str) -> tuple[str, str, str]:
def _resolve_relative_paths(xla_dir: str, jax_dir: str) -> tuple[str, str, str]:
"""Transforms relative to absolute paths. This is needed to properly support
symbolic information remapping"""
this_repo_root = os.path.dirname(os.path.realpath(__file__))
Expand All @@ -177,16 +178,16 @@ def _resolve_relative_paths(xla_dir: str, kernels_jax_dir: str) -> tuple[str, st
xla_path
), f"XLA path (specified as '{xla_dir}') doesn't resolve to existing directory at '{xla_path}'"

if kernels_jax_dir:
if jax_dir:
kernels_jax_path = (
kernels_jax_dir
if os.path.isabs(kernels_jax_dir)
else os.path.abspath(f"{this_repo_root}/jax_rocm_plugin/{kernels_jax_dir}")
jax_dir
if os.path.isabs(jax_dir)
else os.path.abspath(f"{this_repo_root}/jax_rocm_plugin/{jax_dir}")
)
# pylint: disable=line-too-long
assert os.path.isdir(
kernels_jax_path
), f"XLA path (specified as '{kernels_jax_dir}') doesn't resolve to existing directory at '{kernels_jax_path}'"
), f"XLA path (specified as '{jax_dir}') doesn't resolve to existing directory at '{kernels_jax_path}'"
else:
kernels_jax_path = None
return this_repo_root, xla_path, kernels_jax_path
Expand Down Expand Up @@ -262,7 +263,7 @@ def setup_development(
xla_ref: str,
xla_dir: str,
test_jax_ref: str,
kernels_jax_dir: str,
jax_dir: str,
rebuild_makefile: bool = False,
fix_bazel_symbols: bool = False,
rocm_path: str = "/opt/rocm",
Expand All @@ -288,7 +289,7 @@ def setup_development(
makefile_path = "./jax_rocm_plugin/Makefile"
if rebuild_makefile or not os.path.exists(makefile_path) or fix_bazel_symbols:
this_repo_root, xla_path, kernels_jax_path = _resolve_relative_paths(
xla_dir, kernels_jax_dir
xla_dir, jax_dir
)
if fix_bazel_symbols:
plugin_bazel_options = "${PLUGIN_SYMBOLS}"
Expand Down Expand Up @@ -346,6 +347,59 @@ def setup_development(
mf.write(makefile_content)


def build_and_install(
xla_ref: str,
xla_dir: str,
test_jax_ref: str,
jax_dir: str,
rebuild_makefile: bool = False,
fix_bazel_symbols: bool = False,
rocm_path: str = "/opt/rocm",
):
"""Run develop setup, then build all wheels and install jax"""
# Uninstall existing packages first
print("Uninstalling existing JAX packages...")
subprocess.run(["python3", "-m", "pip", "uninstall", "jax", "-y"], check=False)
subprocess.run(
[
"python3",
"-m",
"pip",
"uninstall",
"jaxlib",
"jax-rocm-pjrt",
"jax-rocm-plugin",
"jax-plugin",
"-y",
],
check=False,
)

setup_development(
xla_ref=xla_ref,
xla_dir=xla_dir,
test_jax_ref=test_jax_ref,
jax_dir=jax_dir,
rebuild_makefile=rebuild_makefile,
fix_bazel_symbols=fix_bazel_symbols,
rocm_path=rocm_path,
)

this_repo_root = os.path.dirname(os.path.realpath(__file__))
_, _, kernels_jax_path = _resolve_relative_paths(xla_dir, jax_dir)

# 1. Run make all_wheels in jax_rocm_plugin
print("Building all wheels...")
subprocess.check_call(
["make", "all_wheels"], cwd=os.path.join(this_repo_root, "jax_rocm_plugin")
)

# 2. Run pip install . in jax repo
if kernels_jax_path:
print(f"Installing JAX from {kernels_jax_path}...")
subprocess.check_call(["pip", "install", "."], cwd=kernels_jax_path)


def dev_docker(rm):
"""Start a docker container for local plugin development"""
cur_abs_path = os.path.abspath(os.curdir)
Expand Down Expand Up @@ -401,55 +455,59 @@ def parse_args():

subp = p.add_subparsers(dest="action", required=True)

dev = subp.add_parser("develop")
dev.add_argument(
# Common arguments for develop and build
common = argparse.ArgumentParser(add_help=False)
common.add_argument(
"--rebuild-makefile",
help="Force rebuild of Makefile from template.",
action="store_true",
)
dev.add_argument(
common.add_argument(
"--xla-ref",
help="XLA commit reference to checkout on clone",
default=XLA_REPO_REF,
)
dev.add_argument(
common.add_argument(
"--xla-dir",
help=(
"Set the XLA path in the Makefile. This must either be a path "
"relative to jax_rocm_plugin or an absolute path."
),
default=DEFAULT_XLA_DIR,
)
dev.add_argument(
common.add_argument(
"--jax-ref",
help="JAX commit reference to checkout on clone",
default=TEST_JAX_REPO_REF,
)
dev.add_argument(
"--kernel-jax-dir",
common.add_argument(
"--jax-dir",
help=(
"If you want to use a local JAX directory for building the "
"plugin kernels wheel (jax_rocm7_plugin), the path to the "
"directory of repo. Defaults to %s" % DEFAULT_KERNELS_JAX_DIR
"directory of repo. Defaults to %s" % DEFAULT_JAX_DIR
),
default=DEFAULT_KERNELS_JAX_DIR,
default=DEFAULT_JAX_DIR,
)

dev.add_argument(
common.add_argument(
"--fix-bazel-symbols",
help="When this option is enabled, the script assumes you need to build "
"code in a release with symbolic info configuration to alleviate debugging. "
"The script enables respective bazel options and adds 'external' symbolic "
"links to corresponding workspaces pointing to bazel's dependencies storage.",
action="store_true",
)

dev.add_argument(
common.add_argument(
"--rocm-path",
help="Location of the ROCm to use for building Jax",
default="/opt/rocm",
)

subp.add_parser("develop", parents=[common], help="Setup development environment")
subp.add_parser(
"build", parents=[common], help="Setup, build all wheels, and install JAX"
)

doc_parser = subp.add_parser("docker")
doc_parser.add_argument(
"--rm",
Expand All @@ -469,11 +527,21 @@ def main():
xla_ref=args.xla_ref,
xla_dir=args.xla_dir,
test_jax_ref=args.jax_ref,
kernels_jax_dir=args.kernel_jax_dir,
jax_dir=args.jax_dir,
rebuild_makefile=args.rebuild_makefile,
fix_bazel_symbols=args.fix_bazel_symbols,
rocm_path=args.rocm_path,
)
elif args.action == "build":
build_and_install(
xla_ref=args.xla_ref,
xla_dir=args.xla_dir,
test_jax_ref=args.jax_ref,
jax_dir=args.jax_dir,
rebuild_makefile=True,
fix_bazel_symbols=args.fix_bazel_symbols,
rocm_path=args.rocm_path,
)


if __name__ == "__main__":
Expand Down