Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions kernels/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,13 @@ dev = [
"pytest-benchmark",
"torch>=2.5",
"types-pyyaml",
"types-requests",
"types-tabulate",
]

[project.optional-dependencies]
abi-check = ["kernel-abi-check>=0.6.2,<0.7.0"]
benchmark = [
"numpy>=2.0.2",
"requests>=2.32.5",
"tabulate>=0.9.0",
"torch",
]
Expand Down
15 changes: 3 additions & 12 deletions kernels/src/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

__version__ = importlib.metadata.version("kernels")

from kernels._windows import _add_additional_dll_paths
from kernels.benchmark import Benchmark
from kernels.layer import (
CUDAProperties,
Device,
Expand All @@ -19,18 +21,7 @@
use_kernel_func_from_hub,
use_kernel_mapping,
)
from kernels.utils import (
get_kernel,
get_local_kernel,
get_locked_kernel,
has_kernel,
install_kernel,
load_kernel,
)
from kernels.benchmark import Benchmark


from kernels._windows import _add_additional_dll_paths
from kernels.utils import get_kernel, get_local_kernel, get_locked_kernel, has_kernel, install_kernel, load_kernel

_add_additional_dll_paths()

Expand Down
9 changes: 6 additions & 3 deletions kernels/src/kernels/_versions.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import warnings

from huggingface_hub import HfApi
from huggingface_hub.hf_api import GitRefInfo
from packaging.specifiers import SpecifierSet
from packaging.version import InvalidVersion, Version


def _get_available_versions(repo_id: str) -> dict[int, GitRefInfo]:
"""Get kernel versions that are available in the repository."""
from kernels.utils import _get_hf_api

versions = {}
for branch in HfApi().list_repo_refs(repo_id).branches:
for branch in _get_hf_api().list_repo_refs(repo_id).branches:
if not branch.name.startswith("v"):
continue
try:
Expand All @@ -26,8 +27,10 @@ def _get_available_versions_old(repo_id: str) -> dict[Version, GitRefInfo]:

This is for the old tag-based versioning scheme.
"""
from kernels.utils import _get_hf_api

versions = {}
for tag in HfApi().list_repo_refs(repo_id).tags:
for tag in _get_hf_api().list_repo_refs(repo_id).tags:
if not tag.name.startswith("v"):
continue
try:
Expand Down
33 changes: 8 additions & 25 deletions kernels/src/kernels/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,12 @@
from pathlib import Path
from typing import Any

from huggingface_hub import get_token, snapshot_download
from huggingface_hub.utils import disable_progress_bars
from huggingface_hub.utils import build_hf_headers, disable_progress_bars, get_session, hf_raise_for_status

from kernels.utils import backend
from kernels.utils import _get_hf_api, backend

MISSING_DEPS: list[str] = []

try:
import requests
except ImportError:
requests = None # type: ignore[assignment]
MISSING_DEPS.append("requests")

try:
import torch

Expand Down Expand Up @@ -692,25 +685,15 @@ def submit_benchmark(
repo_id: str,
result: BenchmarkResult,
) -> None:
token = get_token()
if token is None:
raise ValueError(
"No HuggingFace token. Run `huggingface-cli login` or set HF_TOKEN"
)

# TODO: follow up on API design for benchmark submission
endpoint = f"https://huggingface.co/api/kernels/{repo_id}/benchmarks"
response = requests.post(
endpoint,
response = get_session().post(
f"https://huggingface.co/api/kernels/{repo_id}/benchmarks",
json=result.to_payload(),
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
headers=build_hf_headers(headers={"Content-Type": "application/json"}),
)
if not response.ok:
if response.status_code != 200:
print(f"Error {response.status_code}: {response.text}", file=sys.stderr)
response.raise_for_status()
hf_raise_for_status(response)


def run_benchmark(
Expand Down Expand Up @@ -756,7 +739,7 @@ def run_benchmark(
assert revision is not None # Guaranteed by parsing logic above

print(f"Downloading {repo_id}@{revision}...", file=sys.stderr)
repo_path = Path(snapshot_download(repo_id=repo_id, revision=revision))
repo_path = Path(str(_get_hf_api().snapshot_download(repo_id=repo_id, revision=revision)))

scripts = discover_benchmark_scripts(repo_id, repo_path)

Expand Down
6 changes: 1 addition & 5 deletions kernels/src/kernels/benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from .activation import SiluAndMulBenchmark
from .attention import (
FlashAttentionBenchmark,
FlashAttentionCausalBenchmark,
FlashAttentionVarlenBenchmark,
)
from .attention import FlashAttentionBenchmark, FlashAttentionCausalBenchmark, FlashAttentionVarlenBenchmark
from .layer_norm import LayerNormBenchmark, RMSNormBenchmark

__all__ = [
Expand Down
15 changes: 8 additions & 7 deletions kernels/src/kernels/check.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import sys
from pathlib import Path

from huggingface_hub import snapshot_download
from kernel_abi_check import ( # type: ignore[import-not-found]
BinaryFormat,
IncompatibleAbi3Symbol,
Expand All @@ -12,19 +11,21 @@
ObjectFile,
)

from kernels.utils import CACHE_DIR
from kernels.utils import CACHE_DIR, _get_hf_api


def check_kernel(
*, macos: str, manylinux: str, python_abi: str, repo_id: str, revision: str
):
variants_path = (
Path(
snapshot_download(
repo_id,
allow_patterns=["build/*"],
cache_dir=CACHE_DIR,
revision=revision,
str(
_get_hf_api().snapshot_download(
repo_id,
allow_patterns=["build/*"],
cache_dir=CACHE_DIR,
revision=revision,
)
)
)
/ "build"
Expand Down
13 changes: 2 additions & 11 deletions kernels/src/kernels/layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
from .device import CUDAProperties, Device
from .func import (
FuncRepository,
LocalFuncRepository,
LockedFuncRepository,
use_kernel_func_from_hub,
)
from .kernelize import (
kernelize,
register_kernel_mapping,
use_kernel_mapping,
)
from .func import FuncRepository, LocalFuncRepository, LockedFuncRepository, use_kernel_func_from_hub
from .kernelize import kernelize, register_kernel_mapping, use_kernel_mapping
from .layer import (
LayerRepository,
LocalLayerRepository,
Expand Down
10 changes: 3 additions & 7 deletions kernels/src/kernels/layer/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,16 @@
from kernels.layer.repos import RepositoryProtocol

from .._versions import select_revision_or_version
from ..utils import (
_get_caller_locked_kernel,
_get_locked_kernel,
get_kernel,
get_local_kernel,
)
from ..utils import _get_caller_locked_kernel, _get_locked_kernel, get_kernel, get_local_kernel

if TYPE_CHECKING:
from torch import nn


class FuncRepositoryProtocol(RepositoryProtocol, Protocol):
@property
def func_name(self) -> str: ...
def func_name(self) -> str:
...


class FuncRepository:
Expand Down
5 changes: 2 additions & 3 deletions kernels/src/kernels/layer/kernelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from copy import deepcopy
from typing import TYPE_CHECKING

from .repos import DeviceRepos
from .device import Device
from .globals import _KERNEL_MAPPING
from .layer import kernelize_layer
from .repos import RepositoryProtocol
from .mode import Mode
from .device import Device
from .repos import DeviceRepos, RepositoryProtocol

if TYPE_CHECKING:
import torch
Expand Down
10 changes: 3 additions & 7 deletions kernels/src/kernels/layer/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@
from typing import TYPE_CHECKING, Protocol, Type

from .._versions import select_revision_or_version
from ..utils import (
_get_caller_locked_kernel,
_get_locked_kernel,
get_kernel,
get_local_kernel,
)
from ..utils import _get_caller_locked_kernel, _get_locked_kernel, get_kernel, get_local_kernel
from .device import Device
from .globals import _DISABLE_KERNEL_MAPPING, _KERNEL_MAPPING
from .mode import Mode
Expand All @@ -26,7 +21,8 @@

class LayerRepositoryProtocol(RepositoryProtocol, Protocol):
@property
def layer_name(self) -> str: ...
def layer_name(self) -> str:
...


class LayerRepository:
Expand Down
15 changes: 8 additions & 7 deletions kernels/src/kernels/layer/repos.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Protocol, Type
import sys
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import TYPE_CHECKING, Protocol, Type

from .device import Device
from .mode import Mode
from ._interval_tree import IntervalTree
from .device import CUDAProperties, ROCMProperties
from .device import CUDAProperties, Device, ROCMProperties
from .mode import Mode

if TYPE_CHECKING:
from torch import nn


class RepositoryProtocol(Protocol):
def load(self) -> Type["nn.Module"]: ...
def load(self) -> Type["nn.Module"]:
...


class DeviceRepos(ABC):
Expand Down Expand Up @@ -43,7 +43,8 @@ def create_repo(device: Device) -> "DeviceRepos":
@abstractmethod
def repos(
self,
) -> dict[Mode, RepositoryProtocol] | None: ...
) -> dict[Mode, RepositoryProtocol] | None:
...

@abstractmethod
def insert(self, device: Device, repos: dict[Mode, RepositoryProtocol]):
Expand Down
6 changes: 3 additions & 3 deletions kernels/src/kernels/lockfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from dataclasses import dataclass
from pathlib import Path

from huggingface_hub import HfApi

from kernels._versions import resolve_version_spec_as_ref
from kernels.compat import tomllib

Expand Down Expand Up @@ -35,9 +33,11 @@ def get_kernel_locks(repo_id: str, version_spec: int | str) -> KernelLock:
The version specifier can be any valid Python version specifier:
https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers
"""
from kernels.utils import _get_hf_api

tag_for_newest = resolve_version_spec_as_ref(repo_id, version_spec)

r = HfApi().repo_info(
r = _get_hf_api().repo_info(
repo_id=repo_id, revision=tag_for_newest.target_commit, files_metadata=True
)
if r.sha is None:
Expand Down
12 changes: 6 additions & 6 deletions kernels/src/kernels/upload.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from pathlib import Path

from huggingface_hub import create_branch, create_repo, upload_folder

from kernels.metadata import Metadata
from kernels.utils import _get_hf_api
from kernels.variants import BUILD_VARIANT_REGEX


Expand All @@ -13,6 +12,7 @@ def upload_kernels_dir(
branch: str | None,
private: bool,
):
api = _get_hf_api()
kernel_dir = Path(kernel_dir).resolve()

build_dir = None
Expand Down Expand Up @@ -48,10 +48,10 @@ def upload_kernels_dir(
if version is not None:
branch = f"v{version}"

repo_id = create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id

if branch is not None:
create_branch(repo_id=repo_id, branch=branch, exist_ok=True)
api.create_branch(repo_id=repo_id, branch=branch, exist_ok=True)

delete_patterns: set[str] = set()
for build_variant in build_dir.iterdir():
Expand All @@ -60,7 +60,7 @@ def upload_kernels_dir(

# in the case we have variants, upload to the same as the kernel_dir
if (kernel_dir / "benchmarks").is_dir():
upload_folder(
api.upload_folder(
repo_id=repo_id,
folder_path=kernel_dir / "benchmarks",
revision=branch,
Expand All @@ -70,7 +70,7 @@ def upload_kernels_dir(
allow_patterns=["benchmark*.py"],
)

upload_folder(
api.upload_folder(
repo_id=repo_id,
folder_path=build_dir,
revision=branch,
Expand Down
Loading
Loading