diff --git a/.gitignore b/.gitignore index 1a433828..e73e13ba 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Claude Code +.claude/ + # Byte-compiled / optimized / DLL files __pycache__/ text_generation_server/__pycache__/ diff --git a/build2cmake/src/config/mod.rs b/build2cmake/src/config/mod.rs index 8d69276a..126a0cb7 100644 --- a/build2cmake/src/config/mod.rs +++ b/build2cmake/src/config/mod.rs @@ -165,6 +165,7 @@ pub enum Kernel { cxx_flags: Option>, depends: Vec, include: Option>, + metal_std_version: Option, src: Vec, }, Rocm { @@ -234,6 +235,15 @@ impl Kernel { | Kernel::Xpu { src, .. } => src, } } + + pub fn metal_std_version(&self) -> Option<&str> { + match self { + Kernel::Metal { + metal_std_version, .. + } => metal_std_version.as_deref(), + _ => None, + } + } } #[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)] diff --git a/build2cmake/src/config/v2.rs b/build2cmake/src/config/v2.rs index 871cb093..22f4ada7 100644 --- a/build2cmake/src/config/v2.rs +++ b/build2cmake/src/config/v2.rs @@ -96,6 +96,7 @@ pub enum Kernel { cxx_flags: Option>, depends: Vec, include: Option>, + metal_std_version: Option, src: Vec, }, #[serde(rename_all = "kebab-case")] @@ -232,11 +233,13 @@ impl From for super::Kernel { cxx_flags, depends, include, + metal_std_version, src, } => super::Kernel::Metal { cxx_flags, depends, include, + metal_std_version, src, }, Kernel::Rocm { diff --git a/build2cmake/src/config/v3.rs b/build2cmake/src/config/v3.rs index 6592c29d..1204f918 100644 --- a/build2cmake/src/config/v3.rs +++ b/build2cmake/src/config/v3.rs @@ -102,6 +102,7 @@ pub enum Kernel { cxx_flags: Option>, depends: Vec, include: Option>, + metal_std_version: Option, src: Vec, }, #[serde(rename_all = "kebab-case")] @@ -261,11 +262,13 @@ impl From for super::Kernel { cxx_flags, depends, include, + metal_std_version, src, } => super::Kernel::Metal { cxx_flags, depends, include, + metal_std_version, src, }, Kernel::Rocm { @@ -425,11 +428,13 @@ impl From for Kernel { cxx_flags, depends, include, + metal_std_version, src, } => Kernel::Metal { cxx_flags, depends, include, + metal_std_version, src, }, super::Kernel::Rocm { diff --git a/build2cmake/src/templates/kernel.cmake b/build2cmake/src/templates/kernel.cmake index 454f3e04..2d2d6985 100644 --- a/build2cmake/src/templates/kernel.cmake +++ b/build2cmake/src/templates/kernel.cmake @@ -231,7 +231,7 @@ endfunction() function(metal_kernel_component SRC_VAR) set(options) - set(oneValueArgs) + set(oneValueArgs METAL_STD_VERSION) set(multiValueArgs SOURCES INCLUDES CXX_FLAGS) cmake_parse_arguments(KERNEL "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -293,4 +293,9 @@ function(metal_kernel_component SRC_VAR) list(APPEND _TMP_METAL_INCLUDES ${KERNEL_INCLUDES}) set(METAL_INCLUDE_DIRS ${_TMP_METAL_INCLUDES} PARENT_SCOPE) endif() + + # Propagate Metal std version to parent scope for compile_metal_shaders + if(KERNEL_METAL_STD_VERSION) + set(METAL_STD_VERSION ${KERNEL_METAL_STD_VERSION} PARENT_SCOPE) + endif() endfunction() diff --git a/build2cmake/src/templates/metal/compile-metal.cmake b/build2cmake/src/templates/metal/compile-metal.cmake index 50d44a2d..2021515c 100644 --- a/build2cmake/src/templates/metal/compile-metal.cmake +++ b/build2cmake/src/templates/metal/compile-metal.cmake @@ -17,8 +17,14 @@ function(compile_metal_shaders TARGET_NAME METAL_SOURCES EXTRA_INCLUDE_DIRS) set(METAL_TOOLCHAIN "${CMAKE_MATCH_1}/Metal.xctoolchain") endif() - # Set Metal compiler flags - set(METAL_FLAGS "-std=metal4.0" "-O2") + # Set Metal compiler flags. + # metal3.1 → air64_v26, macOS 14+ + # metal3.2 → air64_v27, macOS 15+ + # metal4.0 → air64_v28, macOS 26+ + if(NOT DEFINED METAL_STD_VERSION) + set(METAL_STD_VERSION "metal4.0") + endif() + set(METAL_FLAGS "-std=${METAL_STD_VERSION}" "-O2") # Output directory for compiled metallib set(METALLIB_OUTPUT_DIR "${CMAKE_BINARY_DIR}/metallib") diff --git a/build2cmake/src/templates/metal/kernel.cmake b/build2cmake/src/templates/metal/kernel.cmake index 6c198311..98139401 100644 --- a/build2cmake/src/templates/metal/kernel.cmake +++ b/build2cmake/src/templates/metal/kernel.cmake @@ -3,5 +3,6 @@ if(GPU_LANG STREQUAL "METAL") SOURCES {{ sources }} {% if includes %}INCLUDES "{{ includes }}"{% endif %} {% if cxx_flags %}CXX_FLAGS "{{ cxx_flags }}"{% endif %} + {% if metal_std_version %}METAL_STD_VERSION "{{ metal_std_version }}"{% endif %} ) endif() diff --git a/build2cmake/src/torch/kernel.rs b/build2cmake/src/torch/kernel.rs index 7a872bfa..aeae991f 100644 --- a/build2cmake/src/torch/kernel.rs +++ b/build2cmake/src/torch/kernel.rs @@ -148,6 +148,7 @@ fn render_kernel_component_metal( cxx_flags => kernel.cxx_flags().map(|flags| flags.join(";")), includes => kernel.include().map(prefix_and_join_includes), kernel_name => kernel_name, + metal_std_version => kernel.metal_std_version(), sources => sources, }, &mut *write, diff --git a/builder/examples/relu-metal-cpp/build.toml b/builder/examples/relu-metal-cpp/build.toml index 8cf012bc..9e3b5326 100644 --- a/builder/examples/relu-metal-cpp/build.toml +++ b/builder/examples/relu-metal-cpp/build.toml @@ -11,6 +11,7 @@ src = [ [kernel.relu_metal] backend = "metal" +metal-std-version = "metal3.1" src = [ "relu/relu.cpp", "relu/metallib_loader.mm", diff --git a/flake.nix b/flake.nix index 3496b514..9a8db90f 100644 --- a/flake.nix +++ b/flake.nix @@ -90,7 +90,7 @@ # fail in a GPU-less sandbox. Even in that case, it's better to lazily # load the part with this functionality. doGetKernelCheck ? true, - pythonCheckInputs ? pkgs: [ ], + pythonCheckInputs ? pkgs: [ pkgs.kernels-test-utils ], pythonNativeCheckInputs ? pkgs: [ ], torchVersions ? _: torchVersions', }: diff --git a/kernels-test-utils/pyproject.toml b/kernels-test-utils/pyproject.toml new file mode 100644 index 00000000..56c72f56 --- /dev/null +++ b/kernels-test-utils/pyproject.toml @@ -0,0 +1,9 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +name = "kernels-test-utils" +version = "0.1.0" +requires-python = ">=3.10" +dependencies = ["pytest", "torch"] diff --git a/kernels-test-utils/src/kernels_test_utils/__init__.py b/kernels-test-utils/src/kernels_test_utils/__init__.py new file mode 100644 index 00000000..68e0ecb2 --- /dev/null +++ b/kernels-test-utils/src/kernels_test_utils/__init__.py @@ -0,0 +1,14 @@ +"""Shared test utilities for kernel repos.""" + +from kernels_test_utils.allclose import fp8_allclose +from kernels_test_utils.device import get_available_devices, get_device, skip_if_no_gpu +from kernels_test_utils.tolerances import DEFAULT_TOLERANCES, get_tolerances + +__all__ = [ + "fp8_allclose", + "get_available_devices", + "get_device", + "get_tolerances", + "skip_if_no_gpu", + "DEFAULT_TOLERANCES", +] diff --git a/kernels-test-utils/src/kernels_test_utils/allclose.py b/kernels-test-utils/src/kernels_test_utils/allclose.py new file mode 100644 index 00000000..932301dc --- /dev/null +++ b/kernels-test-utils/src/kernels_test_utils/allclose.py @@ -0,0 +1,32 @@ +"""Allclose variants that work around device limitations.""" + +import torch +from torch._prims_common import TensorLikeType + + +def fp8_allclose( + a: TensorLikeType, + b: TensorLikeType, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, +) -> bool: + """``torch.allclose`` replacement that handles FP8 types and MPS. + + On MPS (which lacks float64) the comparison is done in float32. + Everywhere else the tensors are promoted to float64. + """ + torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol) + + if a.device.type == "mps" or b.device.type == "mps": + a_cmp = a.float() + b_cmp = b.float() + else: + a_cmp = a.double() + b_cmp = b.double() + + return bool( + torch.all( + torch.isclose(a_cmp, b_cmp, rtol=rtol, atol=atol, equal_nan=equal_nan) + ).item() + ) diff --git a/kernels-test-utils/src/kernels_test_utils/device.py b/kernels-test-utils/src/kernels_test_utils/device.py new file mode 100644 index 00000000..3a244b8b --- /dev/null +++ b/kernels-test-utils/src/kernels_test_utils/device.py @@ -0,0 +1,41 @@ +"""Device detection utilities for kernel tests.""" + +from typing import List + +import pytest +import torch + + +def get_device() -> torch.device: + """Return the best available compute device (MPS > CUDA > XPU > CPU).""" + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return torch.device("mps") + if torch.cuda.is_available(): + return torch.device("cuda") + if hasattr(torch, "xpu") and torch.xpu.is_available(): + return torch.device("xpu") + return torch.device("cpu") + + +def get_available_devices() -> List[str]: + """Return device strings suitable for pytest parametrization. + + On MPS: ``["mps"]`` + On CUDA: ``["cuda:0", "cuda:1", ...]`` for each visible GPU. + On XPU: ``["xpu:0", "xpu:1", ...]`` for each visible accelerator. + Fallback: ``["cpu"]`` + """ + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return ["mps"] + if torch.cuda.is_available(): + return [f"cuda:{i}" for i in range(max(1, torch.cuda.device_count()))] + if hasattr(torch, "xpu") and torch.xpu.is_available(): + return [f"xpu:{i}" for i in range(max(1, torch.xpu.device_count()))] + return ["cpu"] + + +def skip_if_no_gpu() -> None: + """Call inside a test to skip when no GPU is available.""" + dev = get_device() + if dev.type == "cpu": + pytest.skip("No GPU device available") diff --git a/kernels-test-utils/src/kernels_test_utils/tolerances.py b/kernels-test-utils/src/kernels_test_utils/tolerances.py new file mode 100644 index 00000000..780f1820 --- /dev/null +++ b/kernels-test-utils/src/kernels_test_utils/tolerances.py @@ -0,0 +1,19 @@ +"""Default tolerance tables for kernel tests.""" + +from typing import Dict + +import torch + +DEFAULT_TOLERANCES: Dict[torch.dtype, Dict[str, float]] = { + torch.float32: {"atol": 1e-5, "rtol": 1e-5}, + torch.float16: {"atol": 1e-3, "rtol": 1e-3}, + torch.bfloat16: {"atol": 1e-2, "rtol": 1.6e-2}, +} + + +def get_tolerances(dtype: torch.dtype) -> Dict[str, float]: + """Return ``{"atol": ..., "rtol": ...}`` for *dtype*. + + Falls back to ``atol=0.1, rtol=0.1`` for unknown dtypes. + """ + return DEFAULT_TOLERANCES.get(dtype, {"atol": 0.1, "rtol": 0.1}) diff --git a/nix/overlay.nix b/nix/overlay.nix index 8beb2971..ff76aa1a 100644 --- a/nix/overlay.nix +++ b/nix/overlay.nix @@ -83,6 +83,8 @@ in kernels = callPackage ./pkgs/python-modules/kernels { }; + kernels-test-utils = callPackage ./pkgs/python-modules/kernels-test-utils { }; + pyclibrary = python-self.callPackage ./pkgs/python-modules/pyclibrary { }; mkTorch = callPackage ./pkgs/python-modules/torch/binary { }; diff --git a/nix/pkgs/python-modules/kernels-test-utils/default.nix b/nix/pkgs/python-modules/kernels-test-utils/default.nix new file mode 100644 index 00000000..7332e0ec --- /dev/null +++ b/nix/pkgs/python-modules/kernels-test-utils/default.nix @@ -0,0 +1,42 @@ +{ + lib, + buildPythonPackage, + setuptools, + + pytest, + torch, +}: + +let + version = + (builtins.fromTOML (builtins.readFile ../../../../kernels-test-utils/pyproject.toml)).project.version; +in +buildPythonPackage { + pname = "kernels-test-utils"; + inherit version; + pyproject = true; + + src = + let + sourceFiles = file: file.hasExt "toml" || file.hasExt "py"; + in + lib.fileset.toSource { + root = ../../../../kernels-test-utils; + fileset = lib.fileset.fileFilter sourceFiles ../../../../kernels-test-utils; + }; + + build-system = [ setuptools ]; + + dependencies = [ + pytest + torch + ]; + + pythonImportsCheck = [ + "kernels_test_utils" + ]; + + meta = with lib; { + description = "Shared test utilities for kernel repos"; + }; +} diff --git a/template/tests/test___KERNEL_NAME_NORMALIZED__.py b/template/tests/test___KERNEL_NAME_NORMALIZED__.py index d7d02e3e..702949a2 100644 --- a/template/tests/test___KERNEL_NAME_NORMALIZED__.py +++ b/template/tests/test___KERNEL_NAME_NORMALIZED__.py @@ -1,19 +1,12 @@ -import platform - import torch +from kernels_test_utils import get_device + import __KERNEL_NAME_NORMALIZED__ def test___KERNEL_NAME_NORMALIZED__(): - if platform.system() == "Darwin": - device = torch.device("mps") - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device("xpu") - elif torch.version.cuda is not None and torch.cuda.is_available(): - device = torch.device("cuda") - else: - device = torch.device("cpu") + device = get_device() x = torch.randn(1024, 1024, dtype=torch.float32, device=device) expected = x + 1.0