Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions src/tinker/cli/commands/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,6 @@ def _export_checkpoint_to_hub(
create_pr: bool,
exist_ok: bool,
allow_patterns: list[str] | None,
ignore_patterns: list[str] | None,
add_model_card: bool,
) -> str:
# Lazy imports to keep CLI startup fast
Expand Down Expand Up @@ -495,6 +494,19 @@ def _sanitize_repo_name(value: str) -> str:

api.create_repo(repo_id=repo_id, private=private, exist_ok=exist_ok)

# Create the revision/branch if specified and it doesn't exist
if revision:
try:
refs = api.list_repo_refs(repo_id=repo_id)
branch_exists = any(ref.name == revision for ref in refs.branches)
if not branch_exists:
api.create_branch(repo_id=repo_id, branch=revision, exist_ok=True)
except Exception as e:
raise TinkerCliError(
f"Failed to create branch {revision} in repo {repo_id}",
f"Error: {e}",
) from e

def _readme_tinker_path() -> str | None:
try:
readme_file = hf_hub_download(
Expand All @@ -519,10 +531,10 @@ def _readme_tinker_path() -> str | None:
f"Found {existing_tinker_path}, expected {tinker_path}.",
)

# Remove checkpoint_complete file before upload if no allow_patterns specified
if allow_patterns is None:
ignore_patterns = list(ignore_patterns) if ignore_patterns else []
if "checkpoint_complete" not in ignore_patterns:
ignore_patterns.append("checkpoint_complete")
checkpoint_complete_file = extract_dir / "checkpoint_complete"
checkpoint_complete_file.unlink(missing_ok=True)
Comment on lines +536 to +537
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
checkpoint_complete_file = extract_dir / "checkpoint_complete"
checkpoint_complete_file.unlink(missing_ok=True)
checkpoint_complete.unlink(missing_ok=True)


api.upload_folder(
folder_path=os.fspath(extract_dir),
Expand All @@ -532,7 +544,6 @@ def _readme_tinker_path() -> str | None:
commit_message=commit_message,
create_pr=create_pr,
allow_patterns=list(allow_patterns) if allow_patterns else None,
ignore_patterns=list(ignore_patterns) if ignore_patterns else None,
)

return repo_id
Expand Down Expand Up @@ -1003,12 +1014,6 @@ def download(
multiple=True,
help="Only upload files matching this pattern (can be repeated).",
)
@click.option(
"--ignore-pattern",
"ignore_patterns",
multiple=True,
help="Skip files matching this pattern (can be repeated).",
)
@click.option(
"--no-model-card",
is_flag=True,
Expand All @@ -1025,7 +1030,6 @@ def push_hf(
commit_message: str | None,
create_pr: bool,
allow_patterns: tuple[str, ...],
ignore_patterns: tuple[str, ...],
no_model_card: bool,
) -> None:
"""Upload a checkpoint to the Hugging Face Hub as a PEFT adapter.
Expand All @@ -1050,7 +1054,6 @@ def push_hf(
create_pr=create_pr,
exist_ok=True,
allow_patterns=list(allow_patterns) if allow_patterns else None,
ignore_patterns=list(ignore_patterns) if ignore_patterns else None,
add_model_card=not no_model_card,
)

Expand Down