Skip to content
Merged

Dev #11

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/test.pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ on:

jobs:
build_package:
if: contains(github.event.head_commit.message, 'feat:') || contains(github.event.head_commit.message, 'fix:')
runs-on: ubuntu-latest

steps:
Expand All @@ -30,6 +31,7 @@ jobs:
path: dist/

test-pypi-publish:
if: contains(github.event.head_commit.message, 'feat') || contains(github.event.head_commit.message, 'fix')
runs-on: ubuntu-latest
needs: build_package
permissions:
Expand Down
14 changes: 14 additions & 0 deletions makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
lint:
@uv run ruff check

format:
@uv run ruff format

bump-major:
@uv version --bump major

bump-minor:
@uv version --bump minor

bump-patch:
@uv version --bump patch
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "sshsync"
version = "0.11.2"
version = "0.12.0"
description = "sshsync is a CLI tool to run shell commands across multiple servers via SSH, either on specific groups or all servers. It also supports pushing and pulling files to and from remote hosts."
readme = "README.md"
authors = [
Expand All @@ -10,6 +10,7 @@ license = { text = "MIT" }
requires-python = ">=3.10"
dependencies = [
"asyncssh>=2.20.0",
"keyring>=25.6.0",
"pyyaml>=6.0.2",
"rich>=14.0.0",
"sshconf>=0.2.7",
Expand Down Expand Up @@ -38,3 +39,8 @@ exclude = [
"pyrightconfig.json",
"uv.lock",
]

[dependency-groups]
dev = [
"ruff>=0.11.13",
]
16 changes: 16 additions & 0 deletions src/sshsync/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sshsync.config import Config, ConfigError
from sshsync.schemas import FileTransferAction
from sshsync.utils import (
add_auth,
add_host,
add_hosts_to_group,
assign_groups_to_hosts,
Expand Down Expand Up @@ -167,6 +168,21 @@ def sync():
print_error(e, True)


@app.command(help="Set authentication method for one or more unconfigured hosts.")
def set_auth():
"""
Set authentication method for one or more unconfigured hosts.
"""
try:
config = Config()
hosts = config.get_unconfigured_hosts()
host_auth = add_auth(hosts)
config.save_host_auth(host_auth)
print_message("Authentication methods for hosts have been saved to config")
except Exception as e:
print_error(e, True)


@app.command(help="Push a file to remote hosts using SCP.")
def push(
local_path: str = typer.Argument(
Expand Down
34 changes: 34 additions & 0 deletions src/sshsync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Literal

import asyncssh
import keyring
import structlog
from rich.progress import Progress

Expand Down Expand Up @@ -158,6 +159,9 @@ def _handle_ssh_exceptions(self, e: Exception, host: Host) -> SSHResult:
self.logger.exception(data["output"], **data)
return SSHResult(**data)

def get_host_pass(self, host: str) -> str | None:
return keyring.get_password("sshsync", host)

async def _execute_command(
self, host: Host, cmd: str, progress: Progress
) -> SSHResult:
Expand All @@ -179,6 +183,16 @@ async def _execute_command(
if host.identity_file and Path(host.identity_file).expanduser().is_file():
if not self._is_key_encrypted(host.identity_file):
conn_kwargs["client_keys"] = [host.identity_file]
else:
host_pass = self.get_host_pass(host.alias)
if host_pass:
host_auth = self.config.config.host_auth.get(host.alias, None)
if host_auth:
if host_auth.auth == "key":
conn_kwargs["client_keys"] = [host.identity_file]
conn_kwargs["passphrase"] = host_pass
else:
conn_kwargs["password"] = host_pass

async with asyncssh.connect(**conn_kwargs) as conn:
progress.update(
Expand Down Expand Up @@ -271,6 +285,16 @@ async def _push(
if host.identity_file and Path(host.identity_file).expanduser().is_file():
if not self._is_key_encrypted(host.identity_file):
conn_kwargs["client_keys"] = [host.identity_file]
else:
host_pass = self.get_host_pass(host.alias)
if host_pass:
host_auth = self.config.config.host_auth.get(host.alias, None)
if host_auth:
if host_auth.auth == "key":
conn_kwargs["client_keys"] = [host.identity_file]
conn_kwargs["passphrase"] = host_pass
else:
conn_kwargs["password"] = host_pass

try:
progress.start_task(task)
Expand Down Expand Up @@ -334,6 +358,16 @@ async def _pull(
if host.identity_file and Path(host.identity_file).expanduser().is_file():
if not self._is_key_encrypted(host.identity_file):
conn_kwargs["client_keys"] = [host.identity_file]
else:
host_pass = self.get_host_pass(host.alias)
if host_pass:
host_auth = self.config.config.host_auth.get(host.alias, None)
if host_auth:
if host_auth.auth == "key":
conn_kwargs["client_keys"] = [host.identity_file]
conn_kwargs["passphrase"] = host_pass
else:
conn_kwargs["password"] = host_pass

try:
progress.start_task(task)
Expand Down
60 changes: 55 additions & 5 deletions src/sshsync/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sshconf import read_ssh_config

from sshsync.logging import setup_logging
from sshsync.schemas import Host, YamlConfig
from sshsync.schemas import Host, HostAuth, YamlConfig

setup_logging()

Expand Down Expand Up @@ -42,15 +42,27 @@ def __init__(self) -> None:

self.ensure_config_directory_exists()

self.config = self._load_groups()
self.config = self._load_config()

self.configure_ssh_hosts()

def configured_hosts(self):
"""
Return a list of all configured hosts except the default host.

Returns:
list[Host]: List of configured Host objects excluding the default.
"""
return list(filter(lambda x: x.alias != "default", self.hosts))

def _default_config(self) -> YamlConfig:
return YamlConfig(groups=dict())
"""
Return a default YamlConfig object with empty groups and host_auth.

Returns:
YamlConfig: Default configuration object.
"""
return YamlConfig(groups=dict(), host_auth=dict())

def ensure_config_directory_exists(self) -> None:
"""Ensures the config directory and file exist, creating them if necessary."""
Expand All @@ -62,6 +74,16 @@ def ensure_config_directory_exists(self) -> None:
def _resolve_ssh_value(
self, value: str | int | list[str | int] | None, default: str | int = ""
) -> str | int:
"""
Resolve a value from SSH config, handling lists and defaults.

Args:
value (str | int | list[str | int] | None): The value to resolve.
default (str | int): Default value if input is None or empty.

Returns:
str | int: The resolved value.
"""
if isinstance(value, list):
return value[0] if value else default
return value or default
Expand Down Expand Up @@ -125,7 +147,7 @@ def configure_ssh_hosts(self) -> None:

self.hosts = hosts

def _load_groups(self) -> YamlConfig:
def _load_config(self) -> YamlConfig:
"""
Loads configuration from the YAML.

Expand All @@ -142,8 +164,13 @@ def _load_groups(self) -> YamlConfig:
return self._default_config()

groups: dict[str, list[str]] = config.get("groups", dict())
host_auth_data: dict = config.get("host_auth", dict())
host_auth: dict[str, HostAuth] = dict()

for key, value in host_auth_data.items():
host_auth[key] = HostAuth(**value)

return YamlConfig(groups=groups)
return YamlConfig(groups=groups, host_auth=host_auth)

def _save_yaml(self) -> None:
"""Saves the current configuration to the YAML file."""
Expand Down Expand Up @@ -237,6 +264,19 @@ def get_ungrouped_hosts(self) -> list[str]:
if not host.groups and host.alias != "default"
]

def get_unconfigured_hosts(self) -> list[dict[str, str]]:
"""
Get a list of hosts that do not have authentication configured.

Returns:
list[dict[str, str]]: List of dicts with alias and identity_file for unconfigured hosts.
"""
return [
{"alias": host.alias, "identity_file": host.identity_file}
for host in self.hosts
if self.config.host_auth.get(host.alias) is None
]

def assign_groups_to_hosts(self, host_group_mapping: dict[str, list[str]]) -> None:
"""
Assign groups to hosts and update config.
Expand All @@ -253,6 +293,16 @@ def assign_groups_to_hosts(self, host_group_mapping: dict[str, list[str]]) -> No

self._save_yaml()

def save_host_auth(self, host_auth_details: dict[str, HostAuth]) -> None:
"""
Save authentication details for hosts and update the YAML config file.

Args:
host_auth_details (dict[str, HostAuth]): Mapping of host aliases to HostAuth objects.
"""
self.config.host_auth = host_auth_details
self._save_yaml()

def add_new_host(self, host: Host) -> None:
"""
Add or update a host in ~/.ssh/config.
Expand Down
2 changes: 2 additions & 0 deletions src/sshsync/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


def get_log_path() -> Path:
"""Return the log directory path, creating it if needed."""
home = Path.home()
if sys.platform.startswith("win"):
log_dir = home.joinpath("AppData", "Local", "sshsync", "logs")
Expand All @@ -19,6 +20,7 @@ def get_log_path() -> Path:


def setup_logging():
"""Configure structlog for file logging."""
structlog.configure(
processors=[
structlog.processors.TimeStamper(fmt="iso"),
Expand Down
21 changes: 18 additions & 3 deletions src/sshsync/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,39 @@ class Host:
groups: list[str]


@dataclass
class HostAuth:
"""
Authentication method and passphrase info for a host.

Attributes:
auth (Literal["key", "password"]): Authentication type.
key_has_passphrase (bool | None): Whether the key has a passphrase.
"""

auth: Literal["key", "password"]
key_has_passphrase: bool | None


@dataclass
class YamlConfig:
"""
Represents the YAML configuration containing a list of hosts and groups.

Attributes:
groups (dict[str, list[str]]): A mapping of group names to lists of host aliases.
host_auth (dict[str, HostAuth]): Mapping of host aliases to authentication info.
"""

groups: dict[str, list[str]]
host_auth: dict[str, HostAuth]

def as_dict(self) -> dict:
def as_dict(self) -> dict[str, object]:
"""
Converts the `YamlConfig` instance into a dictionary format.

Returns:
dict: A dictionary representation of the `YamlConfig` instance, where keys are attribute names
and values are the corresponding attribute values.
dict: A dictionary representation of the `YamlConfig` instance.
"""
return asdict(self)

Expand Down
Loading
Loading