Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 9 additions & 67 deletions comfy_cli/command/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,27 +412,6 @@ def handle_github_rate_limit(response):
raise GitHubRateLimitError(message)


def fetch_github_releases(repo_owner: str, repo_name: str) -> list[dict[str, str]]:
"""
Fetch the list of releases from the GitHub API.
Handles rate limiting by logging the wait time.
"""
url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases"

headers = {}
if github_token := os.getenv("GITHUB_TOKEN"):
headers["Authorization"] = f"Bearer {github_token}"

response = requests.get(url, headers=headers, timeout=5)

# Handle rate limiting
if response.status_code in (403, 429):
handle_github_rate_limit(response)

response.raise_for_status()
return response.json()


class GithubRelease(TypedDict):
"""
A dictionary representing a GitHub release.
Expand All @@ -448,41 +427,6 @@ class GithubRelease(TypedDict):
download_url: str


def parse_releases(releases: list[dict[str, str]]) -> list[GithubRelease]:
"""
Parse the list of releases fetched from the GitHub API into a list of GithubRelease objects.
"""
parsed_releases: list[GithubRelease] = []
for release in releases:
tag = release["tag_name"]
if tag.lower() in ["latest", "nightly"]:
parsed_releases.append({"version": None, "download_url": release["zipball_url"], "tag": tag})
else:
version = semver.VersionInfo.parse(tag.lstrip("v"))
parsed_releases.append({"version": version, "download_url": release["zipball_url"], "tag": tag})

return parsed_releases


def select_version(releases: list[GithubRelease], version: str) -> GithubRelease | None:
"""
Given a list of Github releases, select the release that matches the specified version.
"""
if version.lower() == "latest":
return next((r for r in releases if r["tag"].lower() == version.lower()), None)

version = version.lstrip("v")

try:
requested_version = semver.VersionInfo.parse(version)
return next(
(r for r in releases if isinstance(r["version"], semver.VersionInfo) and r["version"] == requested_version),
None,
)
except ValueError:
return None


def clone_comfyui(url: str, repo_dir: str):
"""
Clone the ComfyUI repository from the specified URL.
Expand All @@ -500,22 +444,19 @@ def checkout_stable_comfyui(version: str, repo_dir: str):
Supports installing stable releases of Comfy (semantic versioning) or the 'latest' version.
"""
rprint(f"Looking for ComfyUI version '{version}'...")
selected_release = None
if version == "latest":
selected_release = get_latest_release("comfyanonymous", "ComfyUI")
if selected_release is None:
rprint(f"Error: No release found for version '{version}'.")
sys.exit(1)
tag = str(selected_release["tag"])
else:
releases = fetch_github_releases("comfyanonymous", "ComfyUI")
parsed_releases = parse_releases(releases)
selected_release = select_version(parsed_releases, version)

if selected_release is None:
rprint(f"Error: No release found for version '{version}'.")
sys.exit(1)
# For specific versions, directly construct the tag (add 'v' prefix if needed)
tag = f"v{version}" if not version.startswith("v") else version

tag = str(selected_release["tag"])
console.print(
Panel(
f"Checking out ComfyUI version: [bold cyan]{selected_release['tag']}[/bold cyan]",
f"Checking out ComfyUI version: [bold cyan]{tag}[/bold cyan]",
title="[yellow]ComfyUI Checkout[/yellow]",
border_style="green",
expand=False,
Expand All @@ -525,7 +466,8 @@ def checkout_stable_comfyui(version: str, repo_dir: str):
with console.status("[bold green]Checking out tag...", spinner="dots"):
success = git_checkout_tag(repo_dir, tag)
if not success:
console.print("\n[bold red]Failed to checkout tag![/bold red]")
console.print(f"\n[bold red]Failed to checkout tag '{tag}'![/bold red]")
console.print("[yellow]The version may not exist. Please check available versions.[/yellow]")
sys.exit(1)


Expand Down
172 changes: 1 addition & 171 deletions tests/comfy_cli/test_install.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
from unittest.mock import MagicMock, patch

import pytest
import requests
import semver

from comfy_cli.command.install import (
GithubRelease,
fetch_github_releases,
parse_releases,
select_version,
validate_version,
)
from comfy_cli.command.install import validate_version


def test_validate_version_nightly():
Expand Down Expand Up @@ -39,166 +29,6 @@ def test_validate_version_empty():
validate_version("")


# Tests for fetch_github_releases function
@patch("requests.get")
def test_fetch_releases_success(mock_get):
# Mock the response
mock_response = MagicMock()
mock_response.json.return_value = [{"id": 1, "tag_name": "v1.0.0"}, {"id": 2, "tag_name": "v1.1.0"}]
mock_get.return_value = mock_response

releases = fetch_github_releases("owner", "repo")

assert len(releases) == 2
assert releases[0]["tag_name"] == "v1.0.0"
assert releases[1]["tag_name"] == "v1.1.0"
mock_get.assert_called_once_with("https://api.github.com/repos/owner/repo/releases", headers={}, timeout=5)


@patch("requests.get")
def test_fetch_releases_empty(mock_get):
# Mock an empty response
mock_response = MagicMock()
mock_response.json.return_value = []
mock_get.return_value = mock_response

releases = fetch_github_releases("owner", "repo")

assert len(releases) == 0


@patch("requests.get")
def test_fetch_releases_error(mock_get):
# Mock a request exception
mock_get.side_effect = requests.RequestException("API error")

with pytest.raises(requests.RequestException):
fetch_github_releases("owner", "repo")


def test_parse_releases_with_semver():
input_releases = [
{"tag_name": "v1.2.3", "zipball_url": "https://api.github.com/repos/owner/repo/zipball/v1.2.3"},
{"tag_name": "2.0.0", "zipball_url": "https://api.github.com/repos/owner/repo/zipball/2.0.0"},
]

result = parse_releases(input_releases)

assert len(result) == 2
assert result[0]["version"] == semver.VersionInfo.parse("1.2.3")
assert result[0]["tag"] == "v1.2.3"
assert result[0]["download_url"] == "https://api.github.com/repos/owner/repo/zipball/v1.2.3"
assert result[1]["version"] == semver.VersionInfo.parse("2.0.0")
assert result[1]["tag"] == "2.0.0"


def test_parse_releases_with_special_tags():
input_releases = [
{"tag_name": "latest", "zipball_url": "https://api.github.com/repos/owner/repo/zipball/latest"},
{"tag_name": "nightly", "zipball_url": "https://api.github.com/repos/owner/repo/zipball/nightly"},
]

result = parse_releases(input_releases)

assert len(result) == 2
assert result[0]["version"] is None
assert result[0]["tag"] == "latest"
assert result[1]["version"] is None
assert result[1]["tag"] == "nightly"


def test_parse_releases_mixed():
input_releases = [
{"tag_name": "v1.0.0", "zipball_url": "https://api.github.com/repos/owner/repo/zipball/v1.0.0"},
{"tag_name": "latest", "zipball_url": "https://api.github.com/repos/owner/repo/zipball/latest"},
{"tag_name": "2.0.0-beta", "zipball_url": "https://api.github.com/repos/owner/repo/zipball/2.0.0-beta"},
]

result = parse_releases(input_releases)

assert len(result) == 3
assert result[0]["version"] == semver.VersionInfo.parse("1.0.0")
assert result[1]["version"] is None
assert result[1]["tag"] == "latest"
assert result[2]["version"] == semver.VersionInfo.parse("2.0.0-beta")


def test_parse_releases_empty_list():
input_releases: list[dict[str, str]] = []

result = parse_releases(input_releases)

assert len(result) == 0


def test_parse_releases_invalid_semver():
input_releases = [
{"tag_name": "invalid", "zipball_url": "https://api.github.com/repos/owner/repo/zipball/invalid"},
]

with pytest.raises(ValueError):
parse_releases(input_releases)


# Sample data for tests
sample_releases: list[GithubRelease] = [
{"version": semver.VersionInfo.parse("1.0.0"), "tag": "v1.0.0", "download_url": "url1"},
{"version": semver.VersionInfo.parse("1.1.0"), "tag": "v1.1.0", "download_url": "url2"},
{"version": semver.VersionInfo.parse("2.0.0"), "tag": "v2.0.0", "download_url": "url3"},
{"version": None, "tag": "latest", "download_url": "url_latest"},
{"version": None, "tag": "nightly", "download_url": "url_nightly"},
]


def test_select_version_latest():
result = select_version(sample_releases, "latest")
assert result is not None
assert result["tag"] == "latest"
assert result["download_url"] == "url_latest"


def test_select_version_specific():
result = select_version(sample_releases, "1.1.0")
assert result is not None
assert result["version"] == semver.VersionInfo.parse("1.1.0")
assert result["tag"] == "v1.1.0"


def test_select_version_with_v_prefix():
result = select_version(sample_releases, "v2.0.0")
assert result is not None
assert result["version"] == semver.VersionInfo.parse("2.0.0")
assert result["tag"] == "v2.0.0"


def test_select_version_nonexistent():
result = select_version(sample_releases, "3.0.0")
assert result is None


def test_select_version_invalid():
result = select_version(sample_releases, "invalid_version")
assert result is None


def test_select_version_case_insensitive_latest():
result = select_version(sample_releases, "LATEST")
assert result is not None
assert result["tag"] == "latest"


def test_select_version_nightly():
# Note: This test will fail with the current implementation
# as it doesn't handle "nightly" specifically
result = select_version(sample_releases, "nightly")
assert result is None # or assert result is not None if you want to handle nightly


def test_select_version_empty_list():
result = select_version([], "1.0.0")
assert result is None


# Run the tests
if __name__ == "__main__":
pytest.main([__file__])