diff --git a/cmd/upgrade/autodownload.go b/cmd/upgrade/autodownload.go new file mode 100644 index 0000000..f9cc968 --- /dev/null +++ b/cmd/upgrade/autodownload.go @@ -0,0 +1,196 @@ +package upgrade + +import ( + "context" + "errors" + "fmt" + "log/slog" + "os" + "path/filepath" + "strings" + "syscall" + "time" + + "hostlink/app/services/agentstate" + "hostlink/app/services/requestsigner" + "hostlink/app/services/updatecheck" + "hostlink/app/services/updatedownload" + "hostlink/config/appconf" + "hostlink/internal/httpclient" + "hostlink/internal/update" +) + +// Sentinel errors for auto-download failures +var ( + ErrUpdateCheckFailed = errors.New("update check failed") + ErrDownloadFailed = errors.New("download failed") + ErrExtractFailed = errors.New("extract failed") +) + +// UpdateCheckerInterface abstracts the update check client for testing. +type UpdateCheckerInterface interface { + Check() (*updatecheck.UpdateInfo, error) +} + +// DownloaderInterface abstracts the download functionality for testing. +type DownloaderInterface interface { + DownloadAndVerify(ctx context.Context, url, destPath, sha256 string) error +} + +// ExtractorInterface abstracts the tarball extraction for testing. +type ExtractorInterface interface { + Extract(tarPath, destPath string) error +} + +// AutoDownloader handles automatic download of the latest version. +type AutoDownloader struct { + UpdateChecker UpdateCheckerInterface + Downloader DownloaderInterface + Extractor ExtractorInterface + StagingDir string +} + +// DownloadLatestIfNeeded checks for updates and downloads the latest version if available. +// Returns the path to the staged binary, or empty string if no update is available. +func (ad *AutoDownloader) DownloadLatestIfNeeded(ctx context.Context) (string, error) { + // Check for updates + info, err := ad.UpdateChecker.Check() + if err != nil { + return "", fmt.Errorf("%w: %v", ErrUpdateCheckFailed, err) + } + + if !info.UpdateAvailable { + return "", nil + } + + // Create staging directory if it doesn't exist + if err := os.MkdirAll(ad.StagingDir, update.DirPermissions); err != nil { + return "", fmt.Errorf("failed to create staging directory: %w", err) + } + + // Download tarball + tarballPath := filepath.Join(ad.StagingDir, "hostlink.tar.gz") + if err := ad.Downloader.DownloadAndVerify(ctx, info.AgentURL, tarballPath, info.AgentSHA256); err != nil { + return "", fmt.Errorf("%w: %v", ErrDownloadFailed, err) + } + + // Extract binary + binaryPath := filepath.Join(ad.StagingDir, "hostlink") + if err := ad.Extractor.Extract(tarballPath, binaryPath); err != nil { + return "", fmt.Errorf("%w: %v", ErrExtractFailed, err) + } + + return binaryPath, nil +} + +// IsManualInvocation returns true if the upgrade command was invoked manually +// (e.g., from /usr/bin/hostlink) rather than spawned by selfupdatejob from staging. +func IsManualInvocation(selfPath, installPath, stagingDir string) bool { + // Normalize all paths to handle trailing slashes, .., etc. + selfPath = filepath.Clean(selfPath) + installPath = filepath.Clean(installPath) + stagingDir = filepath.Clean(stagingDir) + + // If running from the install path, it's manual + if selfPath == installPath { + return true + } + + // If running from staging directory, it's spawned by selfupdatejob + // Add separator to ensure directory boundary matching + // e.g., /var/lib/staging should NOT match /var/lib/staging-test + stagingDirWithSep := stagingDir + string(filepath.Separator) + if strings.HasPrefix(selfPath, stagingDirWithSep) { + return false + } + + // Any other path is considered manual (e.g., /tmp/hostlink for testing) + return true +} + +// realDownloader wraps updatedownload.Downloader to implement DownloaderInterface. +type realDownloader struct { + d *updatedownload.Downloader +} + +func (r *realDownloader) DownloadAndVerify(ctx context.Context, url, destPath, sha256 string) error { + _, err := r.d.DownloadAndVerify(ctx, url, destPath, sha256) + return err +} + +// realExtractor wraps update.InstallBinary to implement ExtractorInterface. +type realExtractor struct{} + +func (r *realExtractor) Extract(tarPath, destPath string) error { + return update.InstallBinary(tarPath, destPath) +} + +// NewAutoDownloaderConfig holds configuration for creating an AutoDownloader. +type NewAutoDownloaderConfig struct { + StagingDir string + Logger *slog.Logger +} + +// NewAutoDownloader creates an AutoDownloader with real dependencies. +// It loads agent ID and control plane URL from the environment/config files. +func NewAutoDownloader(cfg NewAutoDownloaderConfig) (*AutoDownloader, error) { + // Load agent ID from state file + state := agentstate.New(appconf.AgentStatePath()) + if err := state.Load(); err != nil { + return nil, fmt.Errorf("failed to load agent state: %w (is the agent registered?)", err) + } + agentID := state.GetAgentID() + if agentID == "" { + return nil, errors.New("agent not registered: run hostlink first to register") + } + + // Get control plane URL + controlPlaneURL := appconf.ControlPlaneURL() + if controlPlaneURL == "" { + return nil, errors.New("control plane URL not configured") + } + + // Create request signer + signer, err := requestsigner.New(appconf.AgentPrivateKeyPath(), agentID) + if err != nil { + return nil, fmt.Errorf("failed to create request signer: %w", err) + } + + // Create HTTP client with agent headers + client := httpclient.NewClient(30 * time.Second) + + // Create update checker + checker, err := updatecheck.New(client, controlPlaneURL, agentID, signer) + if err != nil { + return nil, fmt.Errorf("failed to create update checker: %w", err) + } + + // Create downloader + downloader := updatedownload.NewDownloader(updatedownload.DefaultDownloadConfig()) + + if cfg.Logger != nil { + cfg.Logger.Info("auto-downloader initialized", + "agent_id", agentID, + "control_plane_url", controlPlaneURL, + ) + } + + return &AutoDownloader{ + UpdateChecker: checker, + Downloader: &realDownloader{d: downloader}, + Extractor: &realExtractor{}, + StagingDir: cfg.StagingDir, + }, nil +} + +// ExecStagedBinary replaces the current process with the staged binary. +// This is used to hand off execution to the newly downloaded binary. +// The function never returns on success (process is replaced). +func ExecStagedBinary(stagedBinary string, args []string) error { + // Prepend the binary path as argv[0] + argv := append([]string{stagedBinary}, args...) + + // Replace current process with the staged binary + // This never returns on success + return syscall.Exec(stagedBinary, argv, os.Environ()) +} diff --git a/cmd/upgrade/autodownload_test.go b/cmd/upgrade/autodownload_test.go new file mode 100644 index 0000000..a92132c --- /dev/null +++ b/cmd/upgrade/autodownload_test.go @@ -0,0 +1,239 @@ +package upgrade + +import ( + "context" + "errors" + "os" + "path/filepath" + "testing" + + "hostlink/app/services/updatecheck" +) + +func TestIsManualInvocation(t *testing.T) { + tests := []struct { + name string + selfPath string + installPath string + stagingDir string + want bool + }{ + { + name: "running from install path is manual", + selfPath: "/usr/bin/hostlink", + installPath: "/usr/bin/hostlink", + stagingDir: "/var/lib/hostlink/updates/staging", + want: true, + }, + { + name: "running from staging dir is spawned", + selfPath: "/var/lib/hostlink/updates/staging/hostlink", + installPath: "/usr/bin/hostlink", + stagingDir: "/var/lib/hostlink/updates/staging", + want: false, + }, + { + name: "running from different path is manual", + selfPath: "/tmp/hostlink", + installPath: "/usr/bin/hostlink", + stagingDir: "/var/lib/hostlink/updates/staging", + want: true, + }, + { + name: "running from custom staging dir is spawned", + selfPath: "/custom/staging/hostlink", + installPath: "/usr/bin/hostlink", + stagingDir: "/custom/staging", + want: false, + }, + { + name: "staging dir with similar prefix is manual", + selfPath: "/var/lib/hostlink/updates/staging-test/hostlink", + installPath: "/usr/bin/hostlink", + stagingDir: "/var/lib/hostlink/updates/staging", + want: true, + }, + { + name: "staging dir with trailing slash normalizes", + selfPath: "/var/lib/hostlink/updates/staging/hostlink", + installPath: "/usr/bin/hostlink", + stagingDir: "/var/lib/hostlink/updates/staging/", + want: false, + }, + { + name: "install path with dot-dot normalizes", + selfPath: "/usr/bin/../bin/hostlink", + installPath: "/usr/bin/hostlink", + stagingDir: "/var/lib/hostlink/updates/staging", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsManualInvocation(tt.selfPath, tt.installPath, tt.stagingDir) + if got != tt.want { + t.Errorf("IsManualInvocation(%q, %q, %q) = %v, want %v", + tt.selfPath, tt.installPath, tt.stagingDir, got, tt.want) + } + }) + } +} + +// Mock implementations for testing + +type mockUpdateChecker struct { + info *updatecheck.UpdateInfo + err error +} + +func (m *mockUpdateChecker) Check() (*updatecheck.UpdateInfo, error) { + return m.info, m.err +} + +type mockDownloader struct { + err error +} + +func (m *mockDownloader) DownloadAndVerify(ctx context.Context, url, destPath, sha256 string) error { + if m.err != nil { + return m.err + } + // Create a dummy file to simulate download + return os.WriteFile(destPath, []byte("dummy binary"), 0755) +} + +type mockExtractor struct { + err error +} + +func (m *mockExtractor) Extract(tarPath, destPath string) error { + if m.err != nil { + return m.err + } + // Create a dummy binary to simulate extraction + return os.WriteFile(destPath, []byte("extracted binary"), 0755) +} + +func TestAutoDownloader_NoUpdateAvailable(t *testing.T) { + tmpDir := t.TempDir() + stagingDir := filepath.Join(tmpDir, "staging") + + ad := &AutoDownloader{ + UpdateChecker: &mockUpdateChecker{ + info: &updatecheck.UpdateInfo{UpdateAvailable: false}, + }, + StagingDir: stagingDir, + } + + stagedPath, err := ad.DownloadLatestIfNeeded(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if stagedPath != "" { + t.Errorf("expected empty staged path when no update, got %q", stagedPath) + } +} + +func TestAutoDownloader_UpdateAvailable_Downloads(t *testing.T) { + tmpDir := t.TempDir() + stagingDir := filepath.Join(tmpDir, "staging") + + ad := &AutoDownloader{ + UpdateChecker: &mockUpdateChecker{ + info: &updatecheck.UpdateInfo{ + UpdateAvailable: true, + TargetVersion: "1.0.0", + AgentURL: "https://example.com/hostlink.tar.gz", + AgentSHA256: "abc123", + }, + }, + Downloader: &mockDownloader{}, + Extractor: &mockExtractor{}, + StagingDir: stagingDir, + } + + stagedPath, err := ad.DownloadLatestIfNeeded(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + expectedPath := filepath.Join(stagingDir, "hostlink") + if stagedPath != expectedPath { + t.Errorf("staged path = %q, want %q", stagedPath, expectedPath) + } + + // Verify the binary was created + if _, err := os.Stat(stagedPath); os.IsNotExist(err) { + t.Errorf("staged binary does not exist at %q", stagedPath) + } +} + +func TestAutoDownloader_CheckError(t *testing.T) { + ad := &AutoDownloader{ + UpdateChecker: &mockUpdateChecker{ + err: errors.New("network error"), + }, + StagingDir: t.TempDir(), + } + + _, err := ad.DownloadLatestIfNeeded(context.Background()) + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, ErrUpdateCheckFailed) { + t.Errorf("expected ErrUpdateCheckFailed, got %v", err) + } +} + +func TestAutoDownloader_DownloadError(t *testing.T) { + tmpDir := t.TempDir() + stagingDir := filepath.Join(tmpDir, "staging") + + ad := &AutoDownloader{ + UpdateChecker: &mockUpdateChecker{ + info: &updatecheck.UpdateInfo{ + UpdateAvailable: true, + TargetVersion: "1.0.0", + AgentURL: "https://example.com/hostlink.tar.gz", + AgentSHA256: "abc123", + }, + }, + Downloader: &mockDownloader{err: errors.New("download failed")}, + StagingDir: stagingDir, + } + + _, err := ad.DownloadLatestIfNeeded(context.Background()) + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, ErrDownloadFailed) { + t.Errorf("expected ErrDownloadFailed, got %v", err) + } +} + +func TestAutoDownloader_ExtractError(t *testing.T) { + tmpDir := t.TempDir() + stagingDir := filepath.Join(tmpDir, "staging") + + ad := &AutoDownloader{ + UpdateChecker: &mockUpdateChecker{ + info: &updatecheck.UpdateInfo{ + UpdateAvailable: true, + TargetVersion: "1.0.0", + AgentURL: "https://example.com/hostlink.tar.gz", + AgentSHA256: "abc123", + }, + }, + Downloader: &mockDownloader{}, + Extractor: &mockExtractor{err: errors.New("extract failed")}, + StagingDir: stagingDir, + } + + _, err := ad.DownloadLatestIfNeeded(context.Background()) + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, ErrExtractFailed) { + t.Errorf("expected ErrExtractFailed, got %v", err) + } +} diff --git a/main.go b/main.go index a6b96d7..20708da 100644 --- a/main.go +++ b/main.go @@ -129,6 +129,54 @@ func runUpgrade(ctx context.Context, cmd *cli.Command) error { defer cleanup() } + // Skip auto-download in these cases: + // 1. --dry-run: only validates preconditions + // 2. --base-dir set: likely a test with custom paths + // 3. HOSTLINK_ENV=test: explicit test environment + // 4. Running from staging: spawned by selfupdatejob (handled by IsManualInvocation) + skipAutoDownload := dryRun || baseDir != "" || os.Getenv("HOSTLINK_ENV") == "test" + + if !skipAutoDownload && upgrade.IsManualInvocation(selfPath, installPath, paths.StagingDir) { + fmt.Fprintf(os.Stderr, "Manual upgrade detected, checking for latest version...\n") + + // Create auto-downloader to fetch latest version + ad, err := upgrade.NewAutoDownloader(upgrade.NewAutoDownloaderConfig{ + StagingDir: paths.StagingDir, + Logger: logger, + }) + if err != nil { + return fmt.Errorf("failed to initialize auto-downloader: %w", err) + } + + // Check and download latest version if available + stagedBinary, err := ad.DownloadLatestIfNeeded(ctx) + if err != nil { + return fmt.Errorf("failed to download latest version: %w", err) + } + + if stagedBinary == "" { + fmt.Fprintf(os.Stderr, "Already running the latest version (%s)\n", version.Version) + return nil + } + + fmt.Fprintf(os.Stderr, "Downloaded new version, executing upgrade...\n") + + // Hand off to the staged binary to complete the upgrade + // Pass through all original args + args := []string{"upgrade", "--install-path", installPath} + if baseDir != "" { + args = append(args, "--base-dir", baseDir) + } + if dryRun { + args = append(args, "--dry-run") + } + + // This replaces the current process - never returns on success + return upgrade.ExecStagedBinary(stagedBinary, args) + } + + // Normal upgrade flow (spawned by selfupdatejob or exec'd from auto-download) + // Build config cfg := &upgrade.Config{ InstallPath: installPath, diff --git a/test/integration/selfupdate_test.go b/test/integration/selfupdate_test.go index 25091f1..405cdd1 100644 --- a/test/integration/selfupdate_test.go +++ b/test/integration/selfupdate_test.go @@ -77,6 +77,12 @@ func createDummyBinary(t *testing.T, dir string) string { return path } +// testEnv returns the environment variables for running hostlink in test mode. +// This sets HOSTLINK_ENV=test to skip auto-download behavior. +func testEnv() []string { + return append(os.Environ(), "HOSTLINK_ENV=test") +} + func TestUpgrade_LockPreventsConcurrent(t *testing.T) { baseDir := setupUpdateDirs(t) paths := update.NewPaths(baseDir) @@ -99,6 +105,7 @@ func TestUpgrade_LockPreventsConcurrent(t *testing.T) { "--base-dir", baseDir, "--install-path", "/tmp/fake-hostlink", ) + cmd.Env = testEnv() output, err := cmd.CombinedOutput() require.Error(t, err, "upgrade should fail when lock is held") assert.Contains(t, string(output), "lock", "error should mention lock") @@ -132,6 +139,7 @@ func TestUpgrade_SignalHandling(t *testing.T) { "--base-dir", baseDir, "--install-path", installPath, ) + cmd.Env = testEnv() cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} err := cmd.Start() @@ -176,6 +184,7 @@ func TestUpgrade_MissingInstallPathFails(t *testing.T) { "--base-dir", baseDir, "--install-path", "/nonexistent/path/hostlink", ) + cmd.Env = testEnv() output, err := cmd.CombinedOutput() require.Error(t, err, "upgrade should fail with non-existent install-path") assert.Contains(t, string(output), "no such file", @@ -198,6 +207,7 @@ func TestUpgrade_DryRun(t *testing.T) { "--base-dir", baseDir, "--install-path", "/nonexistent/path/hostlink", ) + cmd.Env = testEnv() output, err := cmd.CombinedOutput() outputStr := string(output) @@ -229,6 +239,7 @@ func TestUpgrade_DryRunPassesWithValidPath(t *testing.T) { "--base-dir", baseDir, "--install-path", installPath, ) + cmd.Env = testEnv() output, err := cmd.CombinedOutput() outputStr := string(output)