diff --git a/stack.py b/stack.py index f9b15fc8a9..9ffc3ba054 100644 --- a/stack.py +++ b/stack.py @@ -14,7 +14,7 @@ XLA_REPL_URL = "https://github.com/rocm/xla" DEFAULT_XLA_DIR = "../xla" -DEFAULT_KERNELS_JAX_DIR = "../jax" +DEFAULT_JAX_DIR = "../jax" MAKE_TEMPLATE = r""" # gfx targets for which XLA and jax custom call kernels are built for @@ -50,12 +50,13 @@ ### -.PHONY: test clean install dist +.PHONY: test clean install dist all_wheels .default: dist dist: jax_rocm_plugin jax_rocm_pjrt +all_wheels: clean dist jaxlib_clean jaxlib jaxlib_install install jax_rocm_plugin: @@ -163,7 +164,7 @@ def find_clang(): return None -def _resolve_relative_paths(xla_dir: str, kernels_jax_dir: str) -> tuple[str, str, str]: +def _resolve_relative_paths(xla_dir: str, jax_dir: str) -> tuple[str, str, str]: """Transforms relative to absolute paths. This is needed to properly support symbolic information remapping""" this_repo_root = os.path.dirname(os.path.realpath(__file__)) @@ -177,16 +178,16 @@ def _resolve_relative_paths(xla_dir: str, kernels_jax_dir: str) -> tuple[str, st xla_path ), f"XLA path (specified as '{xla_dir}') doesn't resolve to existing directory at '{xla_path}'" - if kernels_jax_dir: + if jax_dir: kernels_jax_path = ( - kernels_jax_dir - if os.path.isabs(kernels_jax_dir) - else os.path.abspath(f"{this_repo_root}/jax_rocm_plugin/{kernels_jax_dir}") + jax_dir + if os.path.isabs(jax_dir) + else os.path.abspath(f"{this_repo_root}/jax_rocm_plugin/{jax_dir}") ) # pylint: disable=line-too-long assert os.path.isdir( kernels_jax_path - ), f"XLA path (specified as '{kernels_jax_dir}') doesn't resolve to existing directory at '{kernels_jax_path}'" + ), f"XLA path (specified as '{jax_dir}') doesn't resolve to existing directory at '{kernels_jax_path}'" else: kernels_jax_path = None return this_repo_root, xla_path, kernels_jax_path @@ -262,7 +263,7 @@ def setup_development( xla_ref: str, xla_dir: str, test_jax_ref: str, - kernels_jax_dir: str, + jax_dir: str, rebuild_makefile: bool = False, fix_bazel_symbols: bool = False, rocm_path: str = "/opt/rocm", @@ -288,7 +289,7 @@ def setup_development( makefile_path = "./jax_rocm_plugin/Makefile" if rebuild_makefile or not os.path.exists(makefile_path) or fix_bazel_symbols: this_repo_root, xla_path, kernels_jax_path = _resolve_relative_paths( - xla_dir, kernels_jax_dir + xla_dir, jax_dir ) if fix_bazel_symbols: plugin_bazel_options = "${PLUGIN_SYMBOLS}" @@ -346,6 +347,59 @@ def setup_development( mf.write(makefile_content) +def build_and_install( + xla_ref: str, + xla_dir: str, + test_jax_ref: str, + jax_dir: str, + rebuild_makefile: bool = False, + fix_bazel_symbols: bool = False, + rocm_path: str = "/opt/rocm", +): + """Run develop setup, then build all wheels and install jax""" + # Uninstall existing packages first + print("Uninstalling existing JAX packages...") + subprocess.run(["python3", "-m", "pip", "uninstall", "jax", "-y"], check=False) + subprocess.run( + [ + "python3", + "-m", + "pip", + "uninstall", + "jaxlib", + "jax-rocm-pjrt", + "jax-rocm-plugin", + "jax-plugin", + "-y", + ], + check=False, + ) + + setup_development( + xla_ref=xla_ref, + xla_dir=xla_dir, + test_jax_ref=test_jax_ref, + jax_dir=jax_dir, + rebuild_makefile=rebuild_makefile, + fix_bazel_symbols=fix_bazel_symbols, + rocm_path=rocm_path, + ) + + this_repo_root = os.path.dirname(os.path.realpath(__file__)) + _, _, kernels_jax_path = _resolve_relative_paths(xla_dir, jax_dir) + + # 1. Run make all_wheels in jax_rocm_plugin + print("Building all wheels...") + subprocess.check_call( + ["make", "all_wheels"], cwd=os.path.join(this_repo_root, "jax_rocm_plugin") + ) + + # 2. Run pip install . in jax repo + if kernels_jax_path: + print(f"Installing JAX from {kernels_jax_path}...") + subprocess.check_call(["pip", "install", "."], cwd=kernels_jax_path) + + def dev_docker(rm): """Start a docker container for local plugin development""" cur_abs_path = os.path.abspath(os.curdir) @@ -401,18 +455,19 @@ def parse_args(): subp = p.add_subparsers(dest="action", required=True) - dev = subp.add_parser("develop") - dev.add_argument( + # Common arguments for develop and build + common = argparse.ArgumentParser(add_help=False) + common.add_argument( "--rebuild-makefile", help="Force rebuild of Makefile from template.", action="store_true", ) - dev.add_argument( + common.add_argument( "--xla-ref", help="XLA commit reference to checkout on clone", default=XLA_REPO_REF, ) - dev.add_argument( + common.add_argument( "--xla-dir", help=( "Set the XLA path in the Makefile. This must either be a path " @@ -420,22 +475,21 @@ def parse_args(): ), default=DEFAULT_XLA_DIR, ) - dev.add_argument( + common.add_argument( "--jax-ref", help="JAX commit reference to checkout on clone", default=TEST_JAX_REPO_REF, ) - dev.add_argument( - "--kernel-jax-dir", + common.add_argument( + "--jax-dir", help=( "If you want to use a local JAX directory for building the " "plugin kernels wheel (jax_rocm7_plugin), the path to the " - "directory of repo. Defaults to %s" % DEFAULT_KERNELS_JAX_DIR + "directory of repo. Defaults to %s" % DEFAULT_JAX_DIR ), - default=DEFAULT_KERNELS_JAX_DIR, + default=DEFAULT_JAX_DIR, ) - - dev.add_argument( + common.add_argument( "--fix-bazel-symbols", help="When this option is enabled, the script assumes you need to build " "code in a release with symbolic info configuration to alleviate debugging. " @@ -443,13 +497,17 @@ def parse_args(): "links to corresponding workspaces pointing to bazel's dependencies storage.", action="store_true", ) - - dev.add_argument( + common.add_argument( "--rocm-path", help="Location of the ROCm to use for building Jax", default="/opt/rocm", ) + subp.add_parser("develop", parents=[common], help="Setup development environment") + subp.add_parser( + "build", parents=[common], help="Setup, build all wheels, and install JAX" + ) + doc_parser = subp.add_parser("docker") doc_parser.add_argument( "--rm", @@ -469,11 +527,21 @@ def main(): xla_ref=args.xla_ref, xla_dir=args.xla_dir, test_jax_ref=args.jax_ref, - kernels_jax_dir=args.kernel_jax_dir, + jax_dir=args.jax_dir, rebuild_makefile=args.rebuild_makefile, fix_bazel_symbols=args.fix_bazel_symbols, rocm_path=args.rocm_path, ) + elif args.action == "build": + build_and_install( + xla_ref=args.xla_ref, + xla_dir=args.xla_dir, + test_jax_ref=args.jax_ref, + jax_dir=args.jax_dir, + rebuild_makefile=True, + fix_bazel_symbols=args.fix_bazel_symbols, + rocm_path=args.rocm_path, + ) if __name__ == "__main__":