From 83bc1f74479ac9ff805461579597355b7645abc8 Mon Sep 17 00:00:00 2001 From: Mohammad Aziz Date: Sat, 24 Jan 2026 11:29:01 +0530 Subject: [PATCH] fix: address code review findings for upgrade subcommand - Add writeErrorState call on spawn failure (selfupdatejob.go:232-234) - Handle rollbackFrom return value with errors.Join (upgrade.go:250,270,295) - Combine original error with rollback error when both fail --- .goreleaser.yaml | 25 - AGENTS.md | 22 +- app/jobs/selfupdatejob/selfupdatejob.go | 65 +- app/jobs/selfupdatejob/selfupdatejob_test.go | 525 +++++++-- app/services/updatecheck/updatecheck.go | 3 - app/services/updatecheck/updatecheck_test.go | 12 - app/services/updatedownload/staging.go | 16 - app/services/updatedownload/staging_test.go | 32 +- cli_test.go | 275 +++++ cmd/updater/main.go | 102 -- cmd/updater/signals_test.go | 37 - cmd/updater/updater.go | 306 ----- cmd/updater/updater_test.go | 644 ----------- cmd/upgrade/dryrun.go | 166 +++ cmd/upgrade/dryrun_test.go | 277 +++++ cmd/upgrade/logger.go | 50 + cmd/upgrade/logger_test.go | 267 +++++ cmd/{updater => upgrade}/signals.go | 2 +- cmd/upgrade/signals_test.go | 67 ++ cmd/upgrade/upgrade.go | 418 +++++++ cmd/upgrade/upgrade_test.go | 1075 ++++++++++++++++++ config/appconf/appconf.go | 9 + config/appconf/appconf_test.go | 10 + internal/update/binary.go | 109 +- internal/update/binary_test.go | 280 ++++- internal/update/dirs.go | 3 - internal/update/dirs_test.go | 4 - internal/update/health.go | 37 +- internal/update/health_test.go | 56 +- internal/update/service.go | 18 + internal/update/service_test.go | 42 + internal/update/spawn.go | 12 +- internal/update/spawn_test.go | 10 +- internal/update/state.go | 8 + main.go | 270 ++++- test/integration/selfupdate_test.go | 224 ++-- 36 files changed, 3948 insertions(+), 1530 deletions(-) create mode 100644 cli_test.go delete mode 100644 cmd/updater/main.go delete mode 100644 cmd/updater/signals_test.go delete mode 100644 cmd/updater/updater.go delete mode 100644 cmd/updater/updater_test.go create mode 100644 cmd/upgrade/dryrun.go create mode 100644 cmd/upgrade/dryrun_test.go create mode 100644 cmd/upgrade/logger.go create mode 100644 cmd/upgrade/logger_test.go rename cmd/{updater => upgrade}/signals.go (96%) create mode 100644 cmd/upgrade/signals_test.go create mode 100644 cmd/upgrade/upgrade.go create mode 100644 cmd/upgrade/upgrade_test.go diff --git a/.goreleaser.yaml b/.goreleaser.yaml index a37c0e4..3fc04b6 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -30,19 +30,6 @@ builds: ldflags: - -s -w -X hostlink/version.Version={{.Version}} - - id: hostlink-updater - main: ./cmd/updater - binary: hostlink-updater - env: - - CGO_ENABLED=0 - goos: - - linux - goarch: - - amd64 - - arm64 - ldflags: - - -s -w -X hostlink/version.Version={{.Version}} - archives: - id: hostlink-archive builds: [hostlink] @@ -61,18 +48,6 @@ archives: dst: scripts strip_parent: true - - id: updater-archive - builds: [hostlink-updater] - formats: [tar.gz] - name_template: >- - hostlink-updater_ - {{- title .Os }}_ - {{- if eq .Arch "amd64" }}x86_64 - {{- else }}{{ .Arch }}{{ end }} - {{- if .Arm }}v{{ .Arm }}{{ end }} - files: - - none* - changelog: sort: asc filters: diff --git a/AGENTS.md b/AGENTS.md index 40b5085..c078ce5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,4 +1,22 @@ -- ALWAYS USE PARALLEL TASKS SUBAGENTS FOR CODE EXPLORATION, DEEP DIVES, AND SO ON -- I use jj instead of git +- ALWAYS USE PARALLEL TASKS SUBAGENTS FOR CODE EXPLORATION, INVESTIGATION, DEEP DIVES +- Use all tools available to keep current context window as small as possible +- When reading files, DELEGATE to subagents, if possible +- In plan mode, be bias to delegate to subagents +- Use question tool more frequently +- Use jj instead of git - ALWAYS FOLLOW TDD, red phase to green phase - Use ripgrep instead of grep, use fd instead of find + +## Usage of question tool + +Before any kind of implementation, interview me in detail using the question tool. + +Ask about technical implementation, UI/UX, edge cases, concerns, and tradeoffs. +Don't ask obvious questions, dig into the hard parts I might not have considered. + +Keep interviewing until we've covered everything. + +## Tests + +- Test actual behavior, not the implementation +- Only test implementation when there is a technical limit to simulating the behavior diff --git a/app/jobs/selfupdatejob/selfupdatejob.go b/app/jobs/selfupdatejob/selfupdatejob.go index c3c6fcf..e565a3f 100644 --- a/app/jobs/selfupdatejob/selfupdatejob.go +++ b/app/jobs/selfupdatejob/selfupdatejob.go @@ -7,6 +7,7 @@ import ( "sync" "time" + "github.com/google/uuid" log "github.com/sirupsen/logrus" "hostlink/app/services/updatecheck" @@ -50,11 +51,11 @@ type StateWriterInterface interface { Write(data update.StateData) error } -// SpawnFunc is a function that spawns the updater binary. -type SpawnFunc func(updaterPath string, args []string) error +// SpawnFunc is a function that spawns a binary with the given args. +type SpawnFunc func(binaryPath string, args []string) error -// InstallUpdaterFunc is a function that extracts and installs the updater binary from a tarball. -type InstallUpdaterFunc func(tarPath, destPath string) error +// InstallBinaryFunc extracts a binary from a tarball to a destination path. +type InstallBinaryFunc func(tarPath, destPath string) error // SelfUpdateJobConfig holds the configuration for the SelfUpdateJob. type SelfUpdateJobConfig struct { @@ -65,11 +66,10 @@ type SelfUpdateJobConfig struct { LockManager LockManagerInterface StateWriter StateWriterInterface Spawn SpawnFunc - InstallUpdater InstallUpdaterFunc + InstallBinary InstallBinaryFunc CurrentVersion string - UpdaterPath string // Where to install the extracted updater binary - StagingDir string // Where to download tarballs - BaseDir string // Base update directory (for -base-dir flag to updater) + InstallPath string // Target install path (e.g., /usr/bin/hostlink) + StagingDir string // Where to download tarballs and extract binary } // SelfUpdateJob periodically checks for and applies updates. @@ -133,10 +133,11 @@ func (j *SelfUpdateJob) runUpdate(ctx context.Context) error { return nil } - log.Infof("update available: %s -> %s", j.config.CurrentVersion, info.TargetVersion) + updateID := uuid.NewString() + log.Infof("update available: %s -> %s (update_id=%s)", j.config.CurrentVersion, info.TargetVersion, updateID) // Step 2: Pre-flight checks - requiredSpace := info.AgentSize + info.UpdaterSize + requiredSpace := info.AgentSize if requiredSpace == 0 { requiredSpace = defaultRequiredSpace } @@ -161,12 +162,24 @@ func (j *SelfUpdateJob) runUpdate(ctx context.Context) error { // Step 4: Write initialized state if err := j.config.StateWriter.Write(update.StateData{ State: update.StateInitialized, + UpdateID: updateID, SourceVersion: j.config.CurrentVersion, TargetVersion: info.TargetVersion, }); err != nil { return fmt.Errorf("failed to write initialized state: %w", err) } + // Helper to write error state (best-effort, errors ignored) + writeErrorState := func(errMsg string) { + j.config.StateWriter.Write(update.StateData{ + State: update.StateInitialized, + UpdateID: updateID, + SourceVersion: j.config.CurrentVersion, + TargetVersion: info.TargetVersion, + Error: &errMsg, + }) + } + if err := ctx.Err(); err != nil { return err } @@ -174,26 +187,31 @@ func (j *SelfUpdateJob) runUpdate(ctx context.Context) error { // Step 5: Download agent tarball agentDest := filepath.Join(j.config.StagingDir, updatedownload.AgentTarballName) if _, err := j.config.Downloader.DownloadAndVerify(ctx, info.AgentURL, agentDest, info.AgentSHA256); err != nil { + writeErrorState(fmt.Sprintf("failed to download agent: %s", err)) return fmt.Errorf("failed to download agent: %w", err) } if err := ctx.Err(); err != nil { + writeErrorState(err.Error()) return err } - // Step 6: Download updater tarball - updaterDest := filepath.Join(j.config.StagingDir, updatedownload.UpdaterTarballName) - if _, err := j.config.Downloader.DownloadAndVerify(ctx, info.UpdaterURL, updaterDest, info.UpdaterSHA256); err != nil { - return fmt.Errorf("failed to download updater: %w", err) + // Step 6: Extract hostlink binary from tarball to staging dir + stagedBinary := filepath.Join(j.config.StagingDir, "hostlink") + if err := j.config.InstallBinary(agentDest, stagedBinary); err != nil { + writeErrorState(fmt.Sprintf("failed to extract binary from tarball: %s", err)) + return fmt.Errorf("failed to extract binary from tarball: %w", err) } if err := ctx.Err(); err != nil { + writeErrorState(err.Error()) return err } // Step 7: Write staged state if err := j.config.StateWriter.Write(update.StateData{ State: update.StateStaged, + UpdateID: updateID, SourceVersion: j.config.CurrentVersion, TargetVersion: info.TargetVersion, }); err != nil { @@ -201,24 +219,21 @@ func (j *SelfUpdateJob) runUpdate(ctx context.Context) error { } if err := ctx.Err(); err != nil { + writeErrorState(err.Error()) return err } - // Step 8: Extract updater binary from tarball - if err := j.config.InstallUpdater(updaterDest, j.config.UpdaterPath); err != nil { - return fmt.Errorf("failed to install updater binary: %w", err) - } - - // Step 9: Release lock before spawning updater + // Step 8: Release lock before spawning upgrade j.config.LockManager.Unlock() locked = false - // Step 10: Spawn updater in its own process group - args := []string{"-version", info.TargetVersion, "-base-dir", j.config.BaseDir} - if err := j.config.Spawn(j.config.UpdaterPath, args); err != nil { - return fmt.Errorf("failed to spawn updater: %w", err) + // Step 9: Spawn staged binary with upgrade subcommand + args := []string{"upgrade", "--install-path", j.config.InstallPath, "--update-id", updateID, "--source-version", j.config.CurrentVersion} + if err := j.config.Spawn(stagedBinary, args); err != nil { + writeErrorState(err.Error()) + return fmt.Errorf("failed to spawn upgrade: %w", err) } - log.Infof("updater spawned for version %s", info.TargetVersion) + log.Infof("upgrade spawned for version %s", info.TargetVersion) return nil } diff --git a/app/jobs/selfupdatejob/selfupdatejob_test.go b/app/jobs/selfupdatejob/selfupdatejob_test.go index efaeb3a..fc697fc 100644 --- a/app/jobs/selfupdatejob/selfupdatejob_test.go +++ b/app/jobs/selfupdatejob/selfupdatejob_test.go @@ -3,6 +3,7 @@ package selfupdatejob import ( "context" "errors" + "strings" "sync" "sync/atomic" "testing" @@ -243,8 +244,6 @@ func TestUpdateFlow_FullFlow(t *testing.T) { TargetVersion: "2.0.0", AgentURL: "https://example.com/agent.tar.gz", AgentSHA256: "abc123", - UpdaterURL: "https://example.com/updater.tar.gz", - UpdaterSHA256: "def456", }, } preflight := &mockPreflight{ @@ -254,6 +253,7 @@ func TestUpdateFlow_FullFlow(t *testing.T) { state := &mockStateWriter{} downloader := &mockDownloader{} spawner := &mockSpawner{} + installer := &mockBinaryInstaller{} job := NewWithConfig(SelfUpdateJobConfig{ UpdateChecker: checker, @@ -262,11 +262,10 @@ func TestUpdateFlow_FullFlow(t *testing.T) { LockManager: lock, StateWriter: state, Spawn: spawner.spawn, - InstallUpdater: noopInstaller, + InstallBinary: installer.install, CurrentVersion: "1.0.0", - UpdaterPath: "/tmp/updater", + InstallPath: "/usr/bin/hostlink", StagingDir: "/tmp/staging", - BaseDir: "/var/lib/hostlink/updates", }) err := job.runUpdate(context.Background()) @@ -274,7 +273,7 @@ func TestUpdateFlow_FullFlow(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - // Verify full flow: check → preflight → lock → state(init) → download agent → download updater → state(staged) → unlock → spawn + // Verify full flow: check → preflight → lock → state(init) → download agent → extract binary → state(staged) → unlock → spawn if checker.callCount.Load() != 1 { t.Errorf("expected 1 check call, got %d", checker.callCount.Load()) } @@ -284,8 +283,11 @@ func TestUpdateFlow_FullFlow(t *testing.T) { if lock.lockCount.Load() != 1 { t.Errorf("expected 1 lock call, got %d", lock.lockCount.Load()) } - if downloader.callCount.Load() != 2 { - t.Errorf("expected 2 download calls (agent + updater), got %d", downloader.callCount.Load()) + if downloader.callCount.Load() != 1 { + t.Errorf("expected 1 download call (agent only), got %d", downloader.callCount.Load()) + } + if installer.callCount.Load() != 1 { + t.Errorf("expected 1 install call (extract from tarball), got %d", installer.callCount.Load()) } if lock.unlockCount.Load() != 1 { t.Errorf("expected 1 unlock call, got %d", lock.unlockCount.Load()) @@ -317,8 +319,6 @@ func TestUpdateFlow_UnlocksBeforeSpawn(t *testing.T) { TargetVersion: "2.0.0", AgentURL: "https://example.com/agent.tar.gz", AgentSHA256: "abc", - UpdaterURL: "https://example.com/updater.tar.gz", - UpdaterSHA256: "def", }, } preflight := &mockPreflight{ @@ -348,11 +348,10 @@ func TestUpdateFlow_UnlocksBeforeSpawn(t *testing.T) { LockManager: lock, StateWriter: state, Spawn: spawner.spawn, - InstallUpdater: noopInstaller, + InstallBinary: noopInstaller, CurrentVersion: "1.0.0", - UpdaterPath: "/tmp/updater", + InstallPath: "/usr/bin/hostlink", StagingDir: "/tmp/staging", - BaseDir: "/var/lib/hostlink/updates", }) err := job.runUpdate(context.Background()) @@ -390,8 +389,6 @@ func TestUpdateFlow_DownloadFailure(t *testing.T) { TargetVersion: "2.0.0", AgentURL: "https://example.com/agent.tar.gz", AgentSHA256: "abc", - UpdaterURL: "https://example.com/updater.tar.gz", - UpdaterSHA256: "def", }, } preflight := &mockPreflight{ @@ -409,11 +406,10 @@ func TestUpdateFlow_DownloadFailure(t *testing.T) { LockManager: lock, StateWriter: state, Spawn: spawner.spawn, - InstallUpdater: noopInstaller, + InstallBinary: noopInstaller, CurrentVersion: "1.0.0", - UpdaterPath: "/tmp/updater", + InstallPath: "/usr/bin/hostlink", StagingDir: "/tmp/staging", - BaseDir: "/var/lib/hostlink/updates", }) job.runUpdate(context.Background()) @@ -427,15 +423,13 @@ func TestUpdateFlow_DownloadFailure(t *testing.T) { } } -func TestUpdateFlow_ChecksumMismatch(t *testing.T) { +func TestUpdateFlow_ExtractFailure_PreventsSpawn(t *testing.T) { checker := &mockUpdateChecker{ result: &updatecheck.UpdateInfo{ UpdateAvailable: true, TargetVersion: "2.0.0", AgentURL: "https://example.com/agent.tar.gz", AgentSHA256: "abc", - UpdaterURL: "https://example.com/updater.tar.gz", - UpdaterSHA256: "def", }, } preflight := &mockPreflight{ @@ -443,9 +437,9 @@ func TestUpdateFlow_ChecksumMismatch(t *testing.T) { } lock := &mockLockManager{} state := &mockStateWriter{} - // First call (agent) succeeds, second (updater) fails - downloader := &mockDownloader{failOnCall: 2, err: errors.New("checksum mismatch")} + downloader := &mockDownloader{} spawner := &mockSpawner{} + installer := &mockBinaryInstaller{err: errors.New("extraction failed")} job := NewWithConfig(SelfUpdateJobConfig{ UpdateChecker: checker, @@ -454,17 +448,16 @@ func TestUpdateFlow_ChecksumMismatch(t *testing.T) { LockManager: lock, StateWriter: state, Spawn: spawner.spawn, - InstallUpdater: noopInstaller, + InstallBinary: installer.install, CurrentVersion: "1.0.0", - UpdaterPath: "/tmp/updater", + InstallPath: "/usr/bin/hostlink", StagingDir: "/tmp/staging", - BaseDir: "/var/lib/hostlink/updates", }) job.runUpdate(context.Background()) if spawner.callCount.Load() > 0 { - t.Error("spawner should not be called when checksum fails") + t.Error("spawner should not be called when extraction fails") } if lock.unlockCount.Load() != 1 { t.Errorf("expected lock to be released on failure, unlock count: %d", lock.unlockCount.Load()) @@ -478,8 +471,6 @@ func TestUpdateFlow_SpawnArgs(t *testing.T) { TargetVersion: "2.0.0", AgentURL: "https://example.com/agent.tar.gz", AgentSHA256: "abc", - UpdaterURL: "https://example.com/updater.tar.gz", - UpdaterSHA256: "def", }, } preflight := &mockPreflight{ @@ -497,11 +488,10 @@ func TestUpdateFlow_SpawnArgs(t *testing.T) { LockManager: lock, StateWriter: state, Spawn: spawner.spawn, - InstallUpdater: noopInstaller, + InstallBinary: noopInstaller, CurrentVersion: "1.0.0", - UpdaterPath: "/opt/updater/hostlink-updater", + InstallPath: "/usr/bin/hostlink", StagingDir: "/tmp/staging", - BaseDir: "/var/lib/hostlink/updates", }) err := job.runUpdate(context.Background()) @@ -509,25 +499,42 @@ func TestUpdateFlow_SpawnArgs(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - if spawner.lastPath != "/opt/updater/hostlink-updater" { - t.Errorf("expected updater path /opt/updater/hostlink-updater, got %s", spawner.lastPath) + // Spawn path should be the extracted binary in staging dir + expectedSpawnPath := "/tmp/staging/hostlink" + if spawner.lastPath != expectedSpawnPath { + t.Errorf("expected spawn path %s, got %s", expectedSpawnPath, spawner.lastPath) } - // Verify args contain -version and -base-dir - foundVersion := false - foundDir := false - for i, arg := range spawner.lastArgs { - if arg == "-version" && i+1 < len(spawner.lastArgs) && spawner.lastArgs[i+1] == "2.0.0" { - foundVersion = true - } - if arg == "-base-dir" && i+1 < len(spawner.lastArgs) && spawner.lastArgs[i+1] == "/var/lib/hostlink/updates" { - foundDir = true - } + // Verify args are: upgrade --install-path --update-id --source-version + if len(spawner.lastArgs) != 7 { + t.Fatalf("expected 7 args, got %v", spawner.lastArgs) + } + if spawner.lastArgs[0] != "upgrade" { + t.Errorf("arg[0]: expected 'upgrade', got %q", spawner.lastArgs[0]) + } + if spawner.lastArgs[1] != "--install-path" || spawner.lastArgs[2] != "/usr/bin/hostlink" { + t.Errorf("expected --install-path /usr/bin/hostlink, got %v", spawner.lastArgs[1:3]) + } + if spawner.lastArgs[3] != "--update-id" { + t.Errorf("arg[3]: expected '--update-id', got %q", spawner.lastArgs[3]) + } + if spawner.lastArgs[4] == "" { + t.Error("update-id should not be empty") + } + if spawner.lastArgs[5] != "--source-version" || spawner.lastArgs[6] != "1.0.0" { + t.Errorf("expected --source-version 1.0.0, got %v", spawner.lastArgs[5:7]) + } + + // Verify UpdateID is consistent across state writes and spawn args + writes := state.getWrites() + if len(writes) < 2 { + t.Fatalf("expected at least 2 state writes, got %d", len(writes)) } - if !foundVersion { - t.Errorf("expected -version 2.0.0 in spawn args, got %v", spawner.lastArgs) + updateID := spawner.lastArgs[4] + if writes[0].UpdateID != updateID { + t.Errorf("Initialized state UpdateID: expected %q, got %q", updateID, writes[0].UpdateID) } - if !foundDir { - t.Errorf("expected -base-dir /var/lib/hostlink/updates in spawn args, got %v", spawner.lastArgs) + if writes[1].UpdateID != updateID { + t.Errorf("Staged state UpdateID: expected %q, got %q", updateID, writes[1].UpdateID) } } @@ -539,9 +546,6 @@ func TestUpdateFlow_PassesDownloadSizeToPreflight(t *testing.T) { AgentURL: "https://example.com/agent.tar.gz", AgentSHA256: "abc", AgentSize: 30 * 1024 * 1024, // 30MB - UpdaterURL: "https://example.com/updater.tar.gz", - UpdaterSHA256: "def", - UpdaterSize: 5 * 1024 * 1024, // 5MB }, } preflight := &mockPreflight{ @@ -559,11 +563,10 @@ func TestUpdateFlow_PassesDownloadSizeToPreflight(t *testing.T) { LockManager: lock, StateWriter: state, Spawn: spawner.spawn, - InstallUpdater: noopInstaller, + InstallBinary: noopInstaller, CurrentVersion: "1.0.0", - UpdaterPath: "/tmp/updater", + InstallPath: "/usr/bin/hostlink", StagingDir: "/tmp/staging", - BaseDir: "/var/lib/hostlink/updates", }) err := job.runUpdate(context.Background()) @@ -571,22 +574,20 @@ func TestUpdateFlow_PassesDownloadSizeToPreflight(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - expected := int64(35 * 1024 * 1024) // 30MB + 5MB + expected := int64(30 * 1024 * 1024) // 30MB (agent only) if preflight.getLastRequiredSpace() != expected { t.Errorf("expected preflight requiredSpace %d, got %d", expected, preflight.getLastRequiredSpace()) } } -func TestUpdateFlow_FallsBackTo50MB_WhenSizesZero(t *testing.T) { +func TestUpdateFlow_FallsBackTo50MB_WhenSizeZero(t *testing.T) { checker := &mockUpdateChecker{ result: &updatecheck.UpdateInfo{ UpdateAvailable: true, TargetVersion: "2.0.0", AgentURL: "https://example.com/agent.tar.gz", AgentSHA256: "abc", - UpdaterURL: "https://example.com/updater.tar.gz", - UpdaterSHA256: "def", - // AgentSize and UpdaterSize are zero (not provided by control plane) + // AgentSize is zero (not provided by control plane) }, } preflight := &mockPreflight{ @@ -604,11 +605,10 @@ func TestUpdateFlow_FallsBackTo50MB_WhenSizesZero(t *testing.T) { LockManager: lock, StateWriter: state, Spawn: spawner.spawn, - InstallUpdater: noopInstaller, + InstallBinary: noopInstaller, CurrentVersion: "1.0.0", - UpdaterPath: "/tmp/updater", + InstallPath: "/usr/bin/hostlink", StagingDir: "/tmp/staging", - BaseDir: "/var/lib/hostlink/updates", }) err := job.runUpdate(context.Background()) @@ -629,8 +629,6 @@ func TestUpdateFlow_AgentDestUsesCanonicalTarballName(t *testing.T) { TargetVersion: "2.0.0", AgentURL: "https://example.com/agent.tar.gz", AgentSHA256: "abc", - UpdaterURL: "https://example.com/updater.tar.gz", - UpdaterSHA256: "def", }, } preflight := &mockPreflight{ @@ -648,11 +646,10 @@ func TestUpdateFlow_AgentDestUsesCanonicalTarballName(t *testing.T) { LockManager: lock, StateWriter: state, Spawn: spawner.spawn, - InstallUpdater: noopInstaller, + InstallBinary: noopInstaller, CurrentVersion: "1.0.0", - UpdaterPath: "/tmp/updater", + InstallPath: "/usr/bin/hostlink", StagingDir: "/tmp/staging", - BaseDir: "/var/lib/hostlink/updates", }) err := job.runUpdate(context.Background()) @@ -660,22 +657,20 @@ func TestUpdateFlow_AgentDestUsesCanonicalTarballName(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - // Agent tarball must use the canonical name expected by the updater + // Agent tarball must use the canonical name expectedAgentDest := "/tmp/staging/" + updatedownload.AgentTarballName if len(downloader.destPaths) < 1 || downloader.destPaths[0] != expectedAgentDest { t.Errorf("expected agent dest path %q, got %v", expectedAgentDest, downloader.destPaths) } } -func TestUpdateFlow_ExtractsUpdaterBeforeSpawn(t *testing.T) { +func TestUpdateFlow_InstallBinaryArgs(t *testing.T) { checker := &mockUpdateChecker{ result: &updatecheck.UpdateInfo{ UpdateAvailable: true, TargetVersion: "2.0.0", AgentURL: "https://example.com/agent.tar.gz", AgentSHA256: "abc", - UpdaterURL: "https://example.com/updater.tar.gz", - UpdaterSHA256: "def", }, } preflight := &mockPreflight{ @@ -685,7 +680,7 @@ func TestUpdateFlow_ExtractsUpdaterBeforeSpawn(t *testing.T) { state := &mockStateWriter{} downloader := &mockDownloader{} spawner := &mockSpawner{} - installer := &mockUpdaterInstaller{} + installer := &mockBinaryInstaller{} job := NewWithConfig(SelfUpdateJobConfig{ UpdateChecker: checker, @@ -694,11 +689,10 @@ func TestUpdateFlow_ExtractsUpdaterBeforeSpawn(t *testing.T) { LockManager: lock, StateWriter: state, Spawn: spawner.spawn, - InstallUpdater: installer.install, + InstallBinary: installer.install, CurrentVersion: "1.0.0", - UpdaterPath: "/opt/updater/hostlink-updater", + InstallPath: "/usr/bin/hostlink", StagingDir: "/tmp/staging", - BaseDir: "/var/lib/hostlink/updates", }) err := job.runUpdate(context.Background()) @@ -706,32 +700,30 @@ func TestUpdateFlow_ExtractsUpdaterBeforeSpawn(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - // Verify InstallUpdater was called with correct args + // Verify InstallBinary was called with tarball path and staging dest if installer.callCount.Load() != 1 { t.Fatalf("expected 1 install call, got %d", installer.callCount.Load()) } - expectedTarPath := "/tmp/staging/" + updatedownload.UpdaterTarballName + expectedTarPath := "/tmp/staging/" + updatedownload.AgentTarballName if installer.lastTarPath != expectedTarPath { t.Errorf("expected tarPath %q, got %q", expectedTarPath, installer.lastTarPath) } - if installer.lastDestPath != "/opt/updater/hostlink-updater" { - t.Errorf("expected destPath %q, got %q", "/opt/updater/hostlink-updater", installer.lastDestPath) - } - // Verify spawn was still called (extraction didn't block it) - if spawner.callCount.Load() != 1 { - t.Errorf("expected 1 spawn call, got %d", spawner.callCount.Load()) + // destPath should be staging/hostlink (the extracted binary to spawn) + expectedDestPath := "/tmp/staging/hostlink" + if installer.lastDestPath != expectedDestPath { + t.Errorf("expected destPath %q, got %q", expectedDestPath, installer.lastDestPath) } } -func TestUpdateFlow_ExtractFailure_PreventsSpawn(t *testing.T) { +func TestUpdateFlow_ContextCancelledAfterDownload(t *testing.T) { + var cancelCtx context.CancelFunc + checker := &mockUpdateChecker{ result: &updatecheck.UpdateInfo{ UpdateAvailable: true, TargetVersion: "2.0.0", AgentURL: "https://example.com/agent.tar.gz", AgentSHA256: "abc", - UpdaterURL: "https://example.com/updater.tar.gz", - UpdaterSHA256: "def", }, } preflight := &mockPreflight{ @@ -739,9 +731,15 @@ func TestUpdateFlow_ExtractFailure_PreventsSpawn(t *testing.T) { } lock := &mockLockManager{} state := &mockStateWriter{} - downloader := &mockDownloader{} + downloader := &mockDownloader{ + onCall: func(count int32) { + if count == 1 { + // Cancel context after download completes + cancelCtx() + } + }, + } spawner := &mockSpawner{} - installer := &mockUpdaterInstaller{err: errors.New("extraction failed")} job := NewWithConfig(SelfUpdateJobConfig{ UpdateChecker: checker, @@ -750,26 +748,31 @@ func TestUpdateFlow_ExtractFailure_PreventsSpawn(t *testing.T) { LockManager: lock, StateWriter: state, Spawn: spawner.spawn, - InstallUpdater: installer.install, + InstallBinary: noopInstaller, CurrentVersion: "1.0.0", - UpdaterPath: "/opt/updater/hostlink-updater", + InstallPath: "/usr/bin/hostlink", StagingDir: "/tmp/staging", - BaseDir: "/var/lib/hostlink/updates", }) - job.runUpdate(context.Background()) + ctx, cancel := context.WithCancel(context.Background()) + cancelCtx = cancel - // Extraction failed, so spawn should NOT be called - if spawner.callCount.Load() != 0 { - t.Error("spawner should not be called when extraction fails") + err := job.runUpdate(ctx) + + // Should return context.Canceled error + if err == nil { + t.Fatal("expected error from cancelled context, got nil") } - // Lock should still be released - if lock.unlockCount.Load() != 1 { - t.Errorf("expected lock to be released on failure, unlock count: %d", lock.unlockCount.Load()) + if !errors.Is(err, context.Canceled) { + t.Errorf("expected context.Canceled error, got: %v", err) + } + // Spawn should NOT have been called + if spawner.callCount.Load() != 0 { + t.Error("spawner should not be called when context is cancelled") } } -func TestUpdateFlow_ContextCancelledBetweenDownloads(t *testing.T) { +func TestUpdateFlow_ContextCancelledAfterDownload_WritesErrorState(t *testing.T) { var cancelCtx context.CancelFunc checker := &mockUpdateChecker{ @@ -778,8 +781,6 @@ func TestUpdateFlow_ContextCancelledBetweenDownloads(t *testing.T) { TargetVersion: "2.0.0", AgentURL: "https://example.com/agent.tar.gz", AgentSHA256: "abc", - UpdaterURL: "https://example.com/updater.tar.gz", - UpdaterSHA256: "def", }, } preflight := &mockPreflight{ @@ -790,12 +791,10 @@ func TestUpdateFlow_ContextCancelledBetweenDownloads(t *testing.T) { downloader := &mockDownloader{ onCall: func(count int32) { if count == 1 { - // Cancel context after first download (agent) completes cancelCtx() } }, } - spawner := &mockSpawner{} job := NewWithConfig(SelfUpdateJobConfig{ UpdateChecker: checker, @@ -803,39 +802,297 @@ func TestUpdateFlow_ContextCancelledBetweenDownloads(t *testing.T) { PreflightChecker: preflight, LockManager: lock, StateWriter: state, - Spawn: spawner.spawn, - InstallUpdater: noopInstaller, + Spawn: func(string, []string) error { return nil }, + InstallBinary: noopInstaller, CurrentVersion: "1.0.0", - UpdaterPath: "/tmp/updater", + InstallPath: "/usr/bin/hostlink", StagingDir: "/tmp/staging", - BaseDir: "/var/lib/hostlink/updates", }) ctx, cancel := context.WithCancel(context.Background()) cancelCtx = cancel - err := job.runUpdate(ctx) + job.runUpdate(ctx) - // Should return context.Canceled error - if err == nil { - t.Fatal("expected error from cancelled context, got nil") + // Verify error state was written + writes := state.getWrites() + if len(writes) < 2 { + t.Fatalf("expected at least 2 state writes (init + error), got %d", len(writes)) } - if !errors.Is(err, context.Canceled) { - t.Errorf("expected context.Canceled error, got: %v", err) + lastWrite := writes[len(writes)-1] + if lastWrite.State != update.StateInitialized { + t.Errorf("expected last state to be Initialized, got %s", lastWrite.State) } - // Second download (updater) should NOT have been called - if downloader.callCount.Load() != 1 { - t.Errorf("expected only 1 download call (agent), got %d", downloader.callCount.Load()) + if lastWrite.Error == nil { + t.Error("expected Error field to be set on last state write") + } else if *lastWrite.Error != context.Canceled.Error() { + t.Errorf("expected error message %q, got %q", context.Canceled.Error(), *lastWrite.Error) } - // Spawn should NOT have been called - if spawner.callCount.Load() != 0 { - t.Error("spawner should not be called when context is cancelled") +} + +func TestUpdateFlow_ExtractFailure_WritesErrorState(t *testing.T) { + checker := &mockUpdateChecker{ + result: &updatecheck.UpdateInfo{ + UpdateAvailable: true, + TargetVersion: "2.0.0", + AgentURL: "https://example.com/agent.tar.gz", + AgentSHA256: "abc", + }, + } + preflight := &mockPreflight{ + result: &updatepreflight.PreflightResult{Passed: true}, + } + lock := &mockLockManager{} + state := &mockStateWriter{} + downloader := &mockDownloader{} + installer := &mockBinaryInstaller{err: errors.New("extraction failed: corrupt tarball")} + + job := NewWithConfig(SelfUpdateJobConfig{ + UpdateChecker: checker, + Downloader: downloader, + PreflightChecker: preflight, + LockManager: lock, + StateWriter: state, + Spawn: func(string, []string) error { return nil }, + InstallBinary: installer.install, + CurrentVersion: "1.0.0", + InstallPath: "/usr/bin/hostlink", + StagingDir: "/tmp/staging", + }) + + job.runUpdate(context.Background()) + + // Verify error state was written + writes := state.getWrites() + if len(writes) < 2 { + t.Fatalf("expected at least 2 state writes (init + error), got %d", len(writes)) + } + lastWrite := writes[len(writes)-1] + if lastWrite.State != update.StateInitialized { + t.Errorf("expected last state to be Initialized, got %s", lastWrite.State) + } + if lastWrite.Error == nil { + t.Error("expected Error field to be set on last state write") + } else if !strings.Contains(*lastWrite.Error, "extraction failed") { + t.Errorf("expected error message to contain 'extraction failed', got %q", *lastWrite.Error) + } +} + +func TestUpdateFlow_DownloadFailure_WritesErrorState(t *testing.T) { + checker := &mockUpdateChecker{ + result: &updatecheck.UpdateInfo{ + UpdateAvailable: true, + TargetVersion: "2.0.0", + AgentURL: "https://example.com/agent.tar.gz", + AgentSHA256: "abc", + }, + } + preflight := &mockPreflight{ + result: &updatepreflight.PreflightResult{Passed: true}, + } + lock := &mockLockManager{} + state := &mockStateWriter{} + downloader := &mockDownloader{err: errors.New("download failed: connection timeout")} + + job := NewWithConfig(SelfUpdateJobConfig{ + UpdateChecker: checker, + Downloader: downloader, + PreflightChecker: preflight, + LockManager: lock, + StateWriter: state, + Spawn: func(string, []string) error { return nil }, + InstallBinary: noopInstaller, + CurrentVersion: "1.0.0", + InstallPath: "/usr/bin/hostlink", + StagingDir: "/tmp/staging", + }) + + job.runUpdate(context.Background()) + + // Verify error state was written + writes := state.getWrites() + if len(writes) < 2 { + t.Fatalf("expected at least 2 state writes (init + error), got %d", len(writes)) + } + lastWrite := writes[len(writes)-1] + if lastWrite.State != update.StateInitialized { + t.Errorf("expected last state to be Initialized, got %s", lastWrite.State) + } + if lastWrite.Error == nil { + t.Error("expected Error field to be set on last state write") + } else if !strings.Contains(*lastWrite.Error, "download") { + t.Errorf("expected error message to contain 'download', got %q", *lastWrite.Error) + } +} + +func TestUpdateFlow_ContextCancelledAfterExtract_WritesErrorState(t *testing.T) { + checker := &mockUpdateChecker{ + result: &updatecheck.UpdateInfo{ + UpdateAvailable: true, + TargetVersion: "2.0.0", + AgentURL: "https://example.com/agent.tar.gz", + AgentSHA256: "abc", + }, + } + preflight := &mockPreflight{ + result: &updatepreflight.PreflightResult{Passed: true}, + } + lock := &mockLockManager{} + state := &mockStateWriter{} + downloader := &mockDownloader{} + + var cancelCtx context.CancelFunc + installer := &mockBinaryInstaller{ + onInstall: func() { + // Cancel context after extraction succeeds + cancelCtx() + }, + } + + job := NewWithConfig(SelfUpdateJobConfig{ + UpdateChecker: checker, + Downloader: downloader, + PreflightChecker: preflight, + LockManager: lock, + StateWriter: state, + Spawn: func(string, []string) error { return nil }, + InstallBinary: installer.install, + CurrentVersion: "1.0.0", + InstallPath: "/usr/bin/hostlink", + StagingDir: "/tmp/staging", + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancelCtx = cancel + + job.runUpdate(ctx) + + // Verify error state was written + writes := state.getWrites() + if len(writes) < 2 { + t.Fatalf("expected at least 2 state writes (init + error), got %d", len(writes)) + } + lastWrite := writes[len(writes)-1] + if lastWrite.State != update.StateInitialized { + t.Errorf("expected last state to be Initialized, got %s", lastWrite.State) + } + if lastWrite.Error == nil { + t.Error("expected Error field to be set on last state write") + } else if *lastWrite.Error != context.Canceled.Error() { + t.Errorf("expected error message %q, got %q", context.Canceled.Error(), *lastWrite.Error) + } +} + +func TestUpdateFlow_ContextCancelledAfterStagedStateWrite_WritesErrorState(t *testing.T) { + checker := &mockUpdateChecker{ + result: &updatecheck.UpdateInfo{ + UpdateAvailable: true, + TargetVersion: "2.0.0", + AgentURL: "https://example.com/agent.tar.gz", + AgentSHA256: "abc", + }, + } + preflight := &mockPreflight{ + result: &updatepreflight.PreflightResult{Passed: true}, + } + lock := &mockLockManager{} + downloader := &mockDownloader{} + + var cancelCtx context.CancelFunc + state := &mockStateWriter{ + onWrite: func(data update.StateData) { + // Cancel context after staged state is written + if data.State == update.StateStaged { + cancelCtx() + } + }, + } + + job := NewWithConfig(SelfUpdateJobConfig{ + UpdateChecker: checker, + Downloader: downloader, + PreflightChecker: preflight, + LockManager: lock, + StateWriter: state, + Spawn: func(string, []string) error { return nil }, + InstallBinary: noopInstaller, + CurrentVersion: "1.0.0", + InstallPath: "/usr/bin/hostlink", + StagingDir: "/tmp/staging", + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancelCtx = cancel + + job.runUpdate(ctx) + + // Verify error state was written after staged state + writes := state.getWrites() + if len(writes) < 3 { + t.Fatalf("expected at least 3 state writes (init + staged + error), got %d", len(writes)) + } + lastWrite := writes[len(writes)-1] + if lastWrite.State != update.StateInitialized { + t.Errorf("expected last state to be Initialized (error state), got %s", lastWrite.State) + } + if lastWrite.Error == nil { + t.Error("expected Error field to be set on last state write") + } else if *lastWrite.Error != context.Canceled.Error() { + t.Errorf("expected error message %q, got %q", context.Canceled.Error(), *lastWrite.Error) + } +} + +func TestUpdateFlow_SpawnFailure_WritesErrorState(t *testing.T) { + checker := &mockUpdateChecker{ + result: &updatecheck.UpdateInfo{ + UpdateAvailable: true, + TargetVersion: "2.0.0", + AgentURL: "https://example.com/agent.tar.gz", + AgentSHA256: "abc", + }, + } + preflight := &mockPreflight{ + result: &updatepreflight.PreflightResult{Passed: true}, + } + lock := &mockLockManager{} + state := &mockStateWriter{} + downloader := &mockDownloader{} + spawner := &mockSpawner{err: errors.New("spawn failed: exec format error")} + + job := NewWithConfig(SelfUpdateJobConfig{ + UpdateChecker: checker, + Downloader: downloader, + PreflightChecker: preflight, + LockManager: lock, + StateWriter: state, + Spawn: spawner.spawn, + InstallBinary: noopInstaller, + CurrentVersion: "1.0.0", + InstallPath: "/usr/bin/hostlink", + StagingDir: "/tmp/staging", + }) + + job.runUpdate(context.Background()) + + // Verify error state was written after spawn failure + writes := state.getWrites() + if len(writes) < 3 { + t.Fatalf("expected at least 3 state writes (init + staged + error), got %d", len(writes)) + } + lastWrite := writes[len(writes)-1] + if lastWrite.State != update.StateInitialized { + t.Errorf("expected last state to be Initialized (error state), got %s", lastWrite.State) + } + if lastWrite.Error == nil { + t.Error("expected Error field to be set on last state write") + } else if !strings.Contains(*lastWrite.Error, "spawn") { + t.Errorf("expected error message to contain 'spawn', got %q", *lastWrite.Error) } } // --- Helpers --- -// noopInstaller is a no-op InstallUpdaterFunc for tests that don't care about extraction. +// noopInstaller is a no-op InstallBinaryFunc for tests that don't care about extraction. func noopInstaller(tarPath, destPath string) error { return nil } // immediateTrigger calls fn exactly n times synchronously then returns. @@ -903,15 +1160,21 @@ func (m *mockLockManager) Unlock() error { } type mockStateWriter struct { - mu sync.Mutex - states []update.State - err error + mu sync.Mutex + states []update.State + writes []update.StateData + err error + onWrite func(data update.StateData) // Called after each write } func (m *mockStateWriter) Write(data update.StateData) error { m.mu.Lock() defer m.mu.Unlock() m.states = append(m.states, data.State) + m.writes = append(m.writes, data) + if m.onWrite != nil { + m.onWrite(data) + } return m.err } @@ -923,6 +1186,14 @@ func (m *mockStateWriter) getStates() []update.State { return result } +func (m *mockStateWriter) getWrites() []update.StateData { + m.mu.Lock() + defer m.mu.Unlock() + result := make([]update.StateData, len(m.writes)) + copy(result, m.writes) + return result +} + type mockDownloader struct { err error failOnCall int32 // fail on this call number (1-indexed), 0 means all fail if err is set @@ -969,19 +1240,23 @@ func (m *mockSpawner) spawn(updaterPath string, args []string) error { return m.err } -type mockUpdaterInstaller struct { +type mockBinaryInstaller struct { err error callCount atomic.Int32 lastTarPath string lastDestPath string + onInstall func() mu sync.Mutex } -func (m *mockUpdaterInstaller) install(tarPath, destPath string) error { +func (m *mockBinaryInstaller) install(tarPath, destPath string) error { m.callCount.Add(1) m.mu.Lock() m.lastTarPath = tarPath m.lastDestPath = destPath m.mu.Unlock() + if m.onInstall != nil { + m.onInstall() + } return m.err } diff --git a/app/services/updatecheck/updatecheck.go b/app/services/updatecheck/updatecheck.go index a509777..247dbf2 100644 --- a/app/services/updatecheck/updatecheck.go +++ b/app/services/updatecheck/updatecheck.go @@ -14,9 +14,6 @@ type UpdateInfo struct { AgentURL string `json:"agent_url"` AgentSHA256 string `json:"agent_sha256"` AgentSize int64 `json:"agent_size"` - UpdaterURL string `json:"updater_url"` - UpdaterSHA256 string `json:"updater_sha256"` - UpdaterSize int64 `json:"updater_size"` } // RequestSignerInterface abstracts request signing for testability. diff --git a/app/services/updatecheck/updatecheck_test.go b/app/services/updatecheck/updatecheck_test.go index 35eab14..6ca0bff 100644 --- a/app/services/updatecheck/updatecheck_test.go +++ b/app/services/updatecheck/updatecheck_test.go @@ -17,9 +17,6 @@ func TestCheck_UpdateAvailable(t *testing.T) { AgentURL: "https://example.com/agent.tar.gz", AgentSHA256: "abc123", AgentSize: 52428800, - UpdaterURL: "https://example.com/updater.tar.gz", - UpdaterSHA256: "def456", - UpdaterSize: 10485760, } json.NewEncoder(w).Encode(resp) })) @@ -42,18 +39,9 @@ func TestCheck_UpdateAvailable(t *testing.T) { if info.AgentSHA256 != "abc123" { t.Errorf("expected AgentSHA256 abc123, got %s", info.AgentSHA256) } - if info.UpdaterURL != "https://example.com/updater.tar.gz" { - t.Errorf("expected UpdaterURL https://example.com/updater.tar.gz, got %s", info.UpdaterURL) - } if info.AgentSize != 52428800 { t.Errorf("expected AgentSize 52428800, got %d", info.AgentSize) } - if info.UpdaterSHA256 != "def456" { - t.Errorf("expected UpdaterSHA256 def456, got %s", info.UpdaterSHA256) - } - if info.UpdaterSize != 10485760 { - t.Errorf("expected UpdaterSize 10485760, got %d", info.UpdaterSize) - } } func TestCheck_NoUpdateAvailable(t *testing.T) { diff --git a/app/services/updatedownload/staging.go b/app/services/updatedownload/staging.go index 480b815..fab8f50 100644 --- a/app/services/updatedownload/staging.go +++ b/app/services/updatedownload/staging.go @@ -10,8 +10,6 @@ import ( const ( // AgentTarballName is the filename for the staged agent tarball. AgentTarballName = "hostlink.tar.gz" - // UpdaterTarballName is the filename for the staged updater tarball. - UpdaterTarballName = "updater.tar.gz" // StagingDirPermissions is the permission mode for the staging directory. StagingDirPermissions = 0700 ) @@ -52,25 +50,11 @@ func (s *StagingManager) StageAgent(ctx context.Context, url, sha256 string) err return err } -// StageUpdater downloads and verifies the updater tarball to the staging area. -func (s *StagingManager) StageUpdater(ctx context.Context, url, sha256 string) error { - destPath := s.GetUpdaterPath() - _, err := s.downloader.DownloadAndVerify(ctx, url, destPath, sha256) - return err -} - // GetAgentPath returns the path to the staged agent tarball. -// Note: Returns tarball path, not extracted binary. Extraction happens in updater phase. func (s *StagingManager) GetAgentPath() string { return filepath.Join(s.basePath, AgentTarballName) } -// GetUpdaterPath returns the path to the staged updater tarball. -// Note: Returns tarball path, not extracted binary. Extraction happens in updater phase. -func (s *StagingManager) GetUpdaterPath() string { - return filepath.Join(s.basePath, UpdaterTarballName) -} - // Cleanup removes the entire staging directory. func (s *StagingManager) Cleanup() error { // Check if directory exists diff --git a/app/services/updatedownload/staging_test.go b/app/services/updatedownload/staging_test.go index 3e41496..c52b268 100644 --- a/app/services/updatedownload/staging_test.go +++ b/app/services/updatedownload/staging_test.go @@ -81,41 +81,11 @@ func TestStagingManager_StageAgent(t *testing.T) { assert.Equal(t, content, data) } -func TestStagingManager_StageUpdater(t *testing.T) { - content := []byte("fake updater tarball content") - contentSHA := computeStagingSHA256(content) - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write(content) - })) - defer server.Close() - - tmpDir := t.TempDir() - stagingPath := filepath.Join(tmpDir, "staging") - - downloader := NewDownloaderWithSleep(DefaultDownloadConfig(), noopSleep) - sm := NewStagingManager(stagingPath, downloader) - - err := sm.Prepare() - require.NoError(t, err) - - err = sm.StageUpdater(context.Background(), server.URL, contentSHA) - require.NoError(t, err) - - // Verify file was downloaded - updaterPath := sm.GetUpdaterPath() - data, err := os.ReadFile(updaterPath) - require.NoError(t, err) - assert.Equal(t, content, data) -} - -func TestStagingManager_GetPaths(t *testing.T) { +func TestStagingManager_GetAgentPath(t *testing.T) { stagingPath := "/var/lib/hostlink/updates/staging" sm := NewStagingManager(stagingPath, nil) assert.Equal(t, filepath.Join(stagingPath, "hostlink.tar.gz"), sm.GetAgentPath()) - assert.Equal(t, filepath.Join(stagingPath, UpdaterTarballName), sm.GetUpdaterPath()) } func TestStagingManager_Cleanup(t *testing.T) { diff --git a/cli_test.go b/cli_test.go new file mode 100644 index 0000000..508c604 --- /dev/null +++ b/cli_test.go @@ -0,0 +1,275 @@ +package main + +import ( + "context" + "testing" + + "github.com/urfave/cli/v3" + "hostlink/version" +) + +func TestCLIApp_VersionFlag(t *testing.T) { + app := newApp() + + if app.Version != version.Version { + t.Errorf("expected version %q, got %q", version.Version, app.Version) + } +} + +func TestCLIApp_HasVersionCommand(t *testing.T) { + app := newApp() + + var found bool + for _, cmd := range app.Commands { + if cmd.Name == "version" { + found = true + break + } + } + if !found { + t.Error("expected 'version' subcommand to exist") + } +} + +func TestCLIApp_HasUpgradeCommand(t *testing.T) { + app := newApp() + + var found bool + for _, cmd := range app.Commands { + if cmd.Name == "upgrade" { + found = true + break + } + } + if !found { + t.Error("expected 'upgrade' subcommand to exist") + } +} + +func TestCLIApp_UpgradeHasInstallPathFlag(t *testing.T) { + app := newApp() + + var upgradeCmd *cli.Command + for _, cmd := range app.Commands { + if cmd.Name == "upgrade" { + upgradeCmd = cmd + break + } + } + if upgradeCmd == nil { + t.Fatal("upgrade command not found") + } + + var found bool + for _, f := range upgradeCmd.Flags { + if hasName(f, "install-path") { + found = true + break + } + } + if !found { + t.Error("expected 'upgrade' to have --install-path flag") + } +} + +func TestCLIApp_UpgradeInstallPathDefaultValue(t *testing.T) { + app := newApp() + + var upgradeCmd *cli.Command + for _, cmd := range app.Commands { + if cmd.Name == "upgrade" { + upgradeCmd = cmd + break + } + } + if upgradeCmd == nil { + t.Fatal("upgrade command not found") + } + + for _, f := range upgradeCmd.Flags { + if hasName(f, "install-path") { + sf, ok := f.(*cli.StringFlag) + if !ok { + t.Fatal("install-path is not a StringFlag") + } + if sf.Value != "/usr/bin/hostlink" { + t.Errorf("expected default '/usr/bin/hostlink', got %q", sf.Value) + } + return + } + } + t.Error("install-path flag not found") +} + +func TestCLIApp_UpgradeActionIsWired(t *testing.T) { + app := newApp() + + var upgradeCmd *cli.Command + for _, cmd := range app.Commands { + if cmd.Name == "upgrade" { + upgradeCmd = cmd + break + } + } + if upgradeCmd == nil { + t.Fatal("upgrade command not found") + } + + if upgradeCmd.Action == nil { + t.Error("upgrade action should be wired (not nil)") + } +} + +func TestCLIApp_UpgradeHasDryRunFlag(t *testing.T) { + app := newApp() + + var upgradeCmd *cli.Command + for _, cmd := range app.Commands { + if cmd.Name == "upgrade" { + upgradeCmd = cmd + break + } + } + if upgradeCmd == nil { + t.Fatal("upgrade command not found") + } + + var found bool + for _, f := range upgradeCmd.Flags { + if hasName(f, "dry-run") { + found = true + break + } + } + if !found { + t.Error("expected 'upgrade' to have --dry-run flag") + } +} + +func TestCLIApp_UpgradeHasUpdateIDFlag(t *testing.T) { + app := newApp() + + var upgradeCmd *cli.Command + for _, cmd := range app.Commands { + if cmd.Name == "upgrade" { + upgradeCmd = cmd + break + } + } + if upgradeCmd == nil { + t.Fatal("upgrade command not found") + } + + var found bool + for _, f := range upgradeCmd.Flags { + if hasName(f, "update-id") { + sf, ok := f.(*cli.StringFlag) + if !ok { + t.Fatal("update-id is not a StringFlag") + } + if !sf.Hidden { + t.Error("update-id flag should be hidden") + } + found = true + break + } + } + if !found { + t.Error("expected 'upgrade' to have --update-id flag") + } +} + +func TestCLIApp_UpgradeHasSourceVersionFlag(t *testing.T) { + app := newApp() + + var upgradeCmd *cli.Command + for _, cmd := range app.Commands { + if cmd.Name == "upgrade" { + upgradeCmd = cmd + break + } + } + if upgradeCmd == nil { + t.Fatal("upgrade command not found") + } + + var found bool + for _, f := range upgradeCmd.Flags { + if hasName(f, "source-version") { + sf, ok := f.(*cli.StringFlag) + if !ok { + t.Fatal("source-version is not a StringFlag") + } + if !sf.Hidden { + t.Error("source-version flag should be hidden") + } + found = true + break + } + } + if !found { + t.Error("expected 'upgrade' to have --source-version flag") + } +} + +func TestCLIApp_UpgradeHasBaseDirFlag(t *testing.T) { + app := newApp() + + var upgradeCmd *cli.Command + for _, cmd := range app.Commands { + if cmd.Name == "upgrade" { + upgradeCmd = cmd + break + } + } + if upgradeCmd == nil { + t.Fatal("upgrade command not found") + } + + var found bool + for _, f := range upgradeCmd.Flags { + if hasName(f, "base-dir") { + sf, ok := f.(*cli.StringFlag) + if !ok { + t.Fatal("base-dir is not a StringFlag") + } + if !sf.Hidden { + t.Error("base-dir flag should be hidden") + } + found = true + break + } + } + if !found { + t.Error("expected 'upgrade' to have --base-dir flag") + } +} + +func TestCLIApp_DefaultActionExists(t *testing.T) { + app := newApp() + + // The default action (no subcommand) should be set + if app.Action == nil { + t.Error("expected default action to be set (starts Echo server)") + } +} + +func TestCLIApp_Name(t *testing.T) { + app := newApp() + + if app.Name != "hostlink" { + t.Errorf("expected app name 'hostlink', got %q", app.Name) + } +} + +func hasName(f cli.Flag, name string) bool { + for _, n := range f.Names() { + if n == name { + return true + } + } + return false +} + +// Suppress unused import +var _ = context.Background diff --git a/cmd/updater/main.go b/cmd/updater/main.go deleted file mode 100644 index 8d2ac8e..0000000 --- a/cmd/updater/main.go +++ /dev/null @@ -1,102 +0,0 @@ -package main - -import ( - "context" - "flag" - "fmt" - "log" - "os" - "time" - - "hostlink/internal/update" -) - -const ( - // Default paths - DefaultBinaryPath = "/usr/bin/hostlink" - DefaultBaseDir = "/var/lib/hostlink/updates" - DefaultHealthURL = "http://localhost:8080/health" - DefaultServiceName = "hostlink" -) - -func main() { - // Parse command line flags - var ( - binaryPath = flag.String("binary", DefaultBinaryPath, "Path to the agent binary") - baseDir = flag.String("base-dir", DefaultBaseDir, "Base directory for update files") - healthURL = flag.String("health-url", DefaultHealthURL, "Health check URL") - targetVersion = flag.String("version", "", "Target version to verify after update (required)") - showVersion = flag.Bool("v", false, "Print version and exit") - ) - flag.Parse() - - if *showVersion { - printVersion() - os.Exit(0) - } - - if *targetVersion == "" { - // Try to read from state file - paths := update.NewPaths(*baseDir) - stateWriter := update.NewStateWriter(update.StateConfig{StatePath: paths.StateFile}) - state, err := stateWriter.Read() - if err != nil || state.TargetVersion == "" { - log.Fatal("target version is required: use -version flag or ensure state.json has target version") - } - *targetVersion = state.TargetVersion - } - - // Build paths - paths := update.NewPaths(*baseDir) - - // Create configuration - cfg := &UpdaterConfig{ - AgentBinaryPath: *binaryPath, - BackupDir: paths.BackupDir, - StagingDir: paths.StagingDir, - LockPath: paths.LockFile, - StatePath: paths.StateFile, - HealthURL: *healthURL, - TargetVersion: *targetVersion, - ServiceStopTimeout: 30 * time.Second, - ServiceStartTimeout: 30 * time.Second, - HealthCheckRetries: 5, - HealthCheckInterval: 5 * time.Second, - HealthInitialWait: 5 * time.Second, - LockRetries: 5, - LockRetryInterval: 1 * time.Second, - } - - // Create updater - updater := NewUpdater(cfg) - updater.onPhaseChange = func(phase Phase) { - log.Printf("Phase: %s", phase) - } - - // Create context with overall timeout (90s budget) - ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second) - defer cancel() - - // Set up signal handler: SIGTERM/SIGINT simply cancel the context. - // Run() checks ctx.Err() between phases and does its own cleanup. - stopSignals := WatchSignals(cancel) - defer stopSignals() - - // Run the update - Run() owns all cleanup/rollback. - if err := updater.Run(ctx); err != nil { - log.Printf("Update failed: %v", err) - os.Exit(1) - } - - log.Println("Update completed successfully") -} - -// Version information (set via ldflags) -var ( - version = "dev" - commit = "unknown" -) - -func printVersion() { - fmt.Printf("hostlink-updater %s (%s)\n", version, commit) -} diff --git a/cmd/updater/signals_test.go b/cmd/updater/signals_test.go deleted file mode 100644 index b7d60ff..0000000 --- a/cmd/updater/signals_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package main - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestWatchSignals_CancelsContextOnSignal(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - stop := WatchSignals(cancel) - defer stop() - - // Cancel directly (simulating signal effect) to verify wiring - cancel() - - assert.Error(t, ctx.Err()) - assert.ErrorIs(t, ctx.Err(), context.Canceled) -} - -func TestWatchSignals_StopPreventsCancel(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - stop := WatchSignals(cancel) - stop() // Stop before any signal - - // Give a moment for any goroutine to fire - time.Sleep(10 * time.Millisecond) - - // Context should not be cancelled - assert.NoError(t, ctx.Err()) -} diff --git a/cmd/updater/updater.go b/cmd/updater/updater.go deleted file mode 100644 index 9599aa1..0000000 --- a/cmd/updater/updater.go +++ /dev/null @@ -1,306 +0,0 @@ -package main - -import ( - "context" - "fmt" - "os" - "path/filepath" - "strings" - "time" - - "hostlink/app/services/updatedownload" - "hostlink/internal/update" -) - -// Phase represents the current phase of the update process. -type Phase string - -const ( - PhaseAcquireLock Phase = "acquire_lock" - PhaseStopping Phase = "stopping" - PhaseBackup Phase = "backup" - PhaseInstalling Phase = "installing" - PhaseStarting Phase = "starting" - PhaseVerifying Phase = "verifying" - PhaseCompleted Phase = "completed" - PhaseRollback Phase = "rollback" -) - -// Default configuration values -const ( - DefaultLockRetries = 5 - DefaultLockRetryInterval = 1 * time.Second - DefaultLockExpiration = 5 * time.Minute -) - -// UpdaterConfig holds the configuration for the Updater. -type UpdaterConfig struct { - AgentBinaryPath string // /usr/bin/hostlink - BackupDir string // /var/lib/hostlink/updates/backup/ - StagingDir string // /var/lib/hostlink/updates/staging/ - LockPath string // /var/lib/hostlink/updates/update.lock - StatePath string // /var/lib/hostlink/updates/state.json - HealthURL string // http://localhost:8080/health - TargetVersion string // Version to verify after update - ServiceStopTimeout time.Duration // 30s - ServiceStartTimeout time.Duration // 30s - HealthCheckRetries int // 5 - HealthCheckInterval time.Duration // 5s - HealthInitialWait time.Duration // 5s - LockRetries int // 5 - LockRetryInterval time.Duration // 1s - SleepFunc func(time.Duration) // For testing -} - -// ServiceController interface for mocking in tests -type ServiceController interface { - Stop(ctx context.Context) error - Start(ctx context.Context) error -} - -// Updater orchestrates the update process. -type Updater struct { - config *UpdaterConfig - lock *update.LockManager - state *update.StateWriter - serviceController ServiceController - healthChecker *update.HealthChecker - currentPhase Phase - onPhaseChange func(Phase) // For testing -} - -// NewUpdater creates a new Updater with the given configuration. -func NewUpdater(cfg *UpdaterConfig) *Updater { - // Apply defaults - if cfg.LockRetries == 0 { - cfg.LockRetries = DefaultLockRetries - } - if cfg.LockRetryInterval == 0 { - cfg.LockRetryInterval = DefaultLockRetryInterval - } - if cfg.ServiceStopTimeout == 0 { - cfg.ServiceStopTimeout = update.DefaultStopTimeout - } - if cfg.ServiceStartTimeout == 0 { - cfg.ServiceStartTimeout = update.DefaultStartTimeout - } - if cfg.HealthCheckRetries == 0 { - cfg.HealthCheckRetries = update.DefaultHealthRetries - } - if cfg.HealthCheckInterval == 0 { - cfg.HealthCheckInterval = update.DefaultHealthInterval - } - if cfg.HealthInitialWait == 0 { - cfg.HealthInitialWait = update.DefaultInitialWait - } - - return &Updater{ - config: cfg, - lock: update.NewLockManager(update.LockConfig{ - LockPath: cfg.LockPath, - }), - state: update.NewStateWriter(update.StateConfig{ - StatePath: cfg.StatePath, - }), - serviceController: update.NewServiceController(update.ServiceConfig{ - ServiceName: "hostlink", - StopTimeout: cfg.ServiceStopTimeout, - StartTimeout: cfg.ServiceStartTimeout, - }), - healthChecker: update.NewHealthChecker(update.HealthConfig{ - URL: cfg.HealthURL, - TargetVersion: cfg.TargetVersion, - MaxRetries: cfg.HealthCheckRetries, - RetryInterval: cfg.HealthCheckInterval, - InitialWait: cfg.HealthInitialWait, - SleepFunc: cfg.SleepFunc, - }), - } -} - -// setPhase updates the current phase and calls the callback if set. -func (u *Updater) setPhase(phase Phase) { - u.currentPhase = phase - if u.onPhaseChange != nil { - u.onPhaseChange(phase) - } -} - -// Run executes the full update process: -// lock → stop → backup → install → start → verify → cleanup → unlock -// -// Run owns all cleanup. If ctx is cancelled (e.g. by a signal), Run aborts -// between phases and ensures the service is left running. -func (u *Updater) Run(ctx context.Context) error { - // Clean up any leftover temp files first - u.cleanupTempFiles() - - // Phase 1: Acquire lock - u.setPhase(PhaseAcquireLock) - if err := u.lock.TryLockWithRetry(DefaultLockExpiration, u.config.LockRetries, u.config.LockRetryInterval); err != nil { - return fmt.Errorf("failed to acquire lock: %w", err) - } - defer u.lock.Unlock() - - // serviceStopped tracks whether we've stopped the service. - // If true, any abort path must restart it. - serviceStopped := false - - // abort restarts the service (if stopped) using a background context - // since the original ctx may be cancelled. - abort := func(reason error) error { - if serviceStopped { - u.serviceController.Start(context.Background()) - } - return reason - } - - // Check for cancellation before stopping - if ctx.Err() != nil { - return abort(ctx.Err()) - } - - // Phase 2: Stop service - u.setPhase(PhaseStopping) - if err := u.serviceController.Stop(ctx); err != nil { - if ctx.Err() != nil { - return abort(ctx.Err()) - } - return fmt.Errorf("failed to stop service: %w", err) - } - serviceStopped = true - - // Check for cancellation after stop - if ctx.Err() != nil { - return abort(ctx.Err()) - } - - // Phase 3: Backup current binary - u.setPhase(PhaseBackup) - if err := update.BackupBinary(u.config.AgentBinaryPath, u.config.BackupDir); err != nil { - return abort(fmt.Errorf("failed to backup binary: %w", err)) - } - - // Check for cancellation after backup - if ctx.Err() != nil { - return abort(ctx.Err()) - } - - // Phase 4: Install new binary - u.setPhase(PhaseInstalling) - tarballPath := filepath.Join(u.config.StagingDir, updatedownload.AgentTarballName) - if err := update.InstallBinary(tarballPath, u.config.AgentBinaryPath); err != nil { - u.rollbackFrom(PhaseInstalling) - return fmt.Errorf("failed to install binary: %w", err) - } - - // Check for cancellation after install - rollback needed - if ctx.Err() != nil { - u.rollbackFrom(PhaseInstalling) - return ctx.Err() - } - - // Phase 5: Start service - // Use background context: even if cancelled, we must start the service - // since the binary is already installed. - u.setPhase(PhaseStarting) - if err := u.serviceController.Start(context.Background()); err != nil { - u.rollbackFrom(PhaseStarting) - return fmt.Errorf("failed to start service: %w", err) - } - serviceStopped = false - - // Check for cancellation after start - service is running, - // skip verification and exit cleanly. - if ctx.Err() != nil { - return ctx.Err() - } - - // Phase 6: Verify health - u.setPhase(PhaseVerifying) - if err := u.healthChecker.WaitForHealth(ctx); err != nil { - if ctx.Err() != nil { - // Cancelled during verification - service is running, just exit. - return ctx.Err() - } - // Health check failed (not due to cancellation) - rollback - u.rollbackFrom(PhaseVerifying) - return fmt.Errorf("health check failed: %w", err) - } - - // Phase 7: Success! - u.setPhase(PhaseCompleted) - - // Update state to completed - u.state.Write(update.StateData{ - State: update.StateCompleted, - TargetVersion: u.config.TargetVersion, - CompletedAt: timePtr(time.Now()), - }) - - return nil -} - -// Rollback restores the backup and starts the service. -// Uses a background context for all operations since this may be called -// after the original context is cancelled. -func (u *Updater) Rollback(ctx context.Context) error { - return u.rollbackFrom(PhaseVerifying) -} - -// rollbackFrom restores the backup and starts the service. -// Uses context.Background() for all operations since this is cleanup -// that must complete regardless of cancellation state. -func (u *Updater) rollbackFrom(failedPhase Phase) error { - u.setPhase(PhaseRollback) - - // Update state to rollback in progress - u.state.Write(update.StateData{ - State: update.StateRollback, - TargetVersion: u.config.TargetVersion, - }) - - // Stop the service first (best-effort) - it may still be running the bad binary - u.serviceController.Stop(context.Background()) - - // Restore backup - if err := update.RestoreBackup(u.config.BackupDir, u.config.AgentBinaryPath); err != nil { - return fmt.Errorf("failed to restore backup: %w", err) - } - - // Start service with old binary (use background context - must complete) - if err := u.serviceController.Start(context.Background()); err != nil { - return fmt.Errorf("failed to start service after rollback: %w", err) - } - - // Update state to rolled back - u.state.Write(update.StateData{ - State: update.StateRolledBack, - TargetVersion: u.config.TargetVersion, - CompletedAt: timePtr(time.Now()), - }) - - // Clean up temp files - u.cleanupTempFiles() - - return nil -} - -// cleanupTempFiles removes any leftover hostlink.tmp.* files. -func (u *Updater) cleanupTempFiles() { - dir := filepath.Dir(u.config.AgentBinaryPath) - entries, err := os.ReadDir(dir) - if err != nil { - return - } - - for _, entry := range entries { - if strings.HasPrefix(entry.Name(), "hostlink.tmp.") { - os.Remove(filepath.Join(dir, entry.Name())) - } - } -} - -func timePtr(t time.Time) *time.Time { - return &t -} diff --git a/cmd/updater/updater_test.go b/cmd/updater/updater_test.go deleted file mode 100644 index 6c8aafc..0000000 --- a/cmd/updater/updater_test.go +++ /dev/null @@ -1,644 +0,0 @@ -package main - -import ( - "archive/tar" - "compress/gzip" - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "os" - "path/filepath" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "hostlink/internal/update" -) - -// Test helpers -func createTestBinary(t *testing.T, path string, content []byte) { - t.Helper() - dir := filepath.Dir(path) - require.NoError(t, os.MkdirAll(dir, 0755)) - require.NoError(t, os.WriteFile(path, content, 0755)) -} - -func createTestTarball(t *testing.T, path string, binaryContent []byte) { - t.Helper() - dir := filepath.Dir(path) - require.NoError(t, os.MkdirAll(dir, 0755)) - - // Create a tar.gz file using archive/tar and compress/gzip - f, err := os.Create(path) - require.NoError(t, err) - defer f.Close() - - gw := newGzipWriter(f) - defer gw.Close() - - tw := newTarWriter(gw) - defer tw.Close() - - require.NoError(t, tw.WriteHeader(&tarHeader{ - Name: "hostlink", - Mode: 0755, - Size: int64(len(binaryContent)), - })) - _, err = tw.Write(binaryContent) - require.NoError(t, err) -} - -func TestUpdater_Run_HappyPath(t *testing.T) { - tmpDir := t.TempDir() - - // Setup paths - binaryPath := filepath.Join(tmpDir, "usr", "bin", "hostlink") - backupDir := filepath.Join(tmpDir, "backup") - stagingDir := filepath.Join(tmpDir, "staging") - lockPath := filepath.Join(tmpDir, "update.lock") - statePath := filepath.Join(tmpDir, "state.json") - - // Create current binary - createTestBinary(t, binaryPath, []byte("old binary v1.0.0")) - - // Create staged tarball with new binary - tarballPath := filepath.Join(stagingDir, "hostlink.tar.gz") - createTestTarball(t, tarballPath, []byte("new binary v2.0.0")) - - // Mock health server - healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(update.HealthResponse{Ok: true, Version: "v2.0.0"}) - })) - defer healthServer.Close() - - // Create updater - u := NewUpdater(&UpdaterConfig{ - AgentBinaryPath: binaryPath, - BackupDir: backupDir, - StagingDir: stagingDir, - LockPath: lockPath, - StatePath: statePath, - HealthURL: healthServer.URL, - TargetVersion: "v2.0.0", - ServiceStopTimeout: 100 * time.Millisecond, - ServiceStartTimeout: 100 * time.Millisecond, - HealthCheckRetries: 1, - HealthCheckInterval: 10 * time.Millisecond, - HealthInitialWait: 1 * time.Millisecond, - SleepFunc: func(d time.Duration) {}, // No-op for tests - }) - - // Mock service controller (no real systemctl) - u.serviceController = &mockServiceController{} - - err := u.Run(context.Background()) - require.NoError(t, err) - - // Verify new binary is installed - content, err := os.ReadFile(binaryPath) - require.NoError(t, err) - assert.Equal(t, []byte("new binary v2.0.0"), content) - - // Verify backup exists - backupContent, err := os.ReadFile(filepath.Join(backupDir, "hostlink")) - require.NoError(t, err) - assert.Equal(t, []byte("old binary v1.0.0"), backupContent) -} - -func TestUpdater_Run_RollbackOnHealthCheckFailure(t *testing.T) { - tmpDir := t.TempDir() - - // Setup paths - binaryPath := filepath.Join(tmpDir, "usr", "bin", "hostlink") - backupDir := filepath.Join(tmpDir, "backup") - stagingDir := filepath.Join(tmpDir, "staging") - lockPath := filepath.Join(tmpDir, "update.lock") - statePath := filepath.Join(tmpDir, "state.json") - - // Create current binary - oldContent := []byte("old binary v1.0.0") - createTestBinary(t, binaryPath, oldContent) - - // Create staged tarball with new binary - tarballPath := filepath.Join(stagingDir, "hostlink.tar.gz") - createTestTarball(t, tarballPath, []byte("new binary v2.0.0")) - - // Mock health server that always returns unhealthy - healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(update.HealthResponse{Ok: false, Version: "v2.0.0"}) - })) - defer healthServer.Close() - - u := NewUpdater(&UpdaterConfig{ - AgentBinaryPath: binaryPath, - BackupDir: backupDir, - StagingDir: stagingDir, - LockPath: lockPath, - StatePath: statePath, - HealthURL: healthServer.URL, - TargetVersion: "v2.0.0", - ServiceStopTimeout: 100 * time.Millisecond, - ServiceStartTimeout: 100 * time.Millisecond, - HealthCheckRetries: 1, - HealthCheckInterval: 10 * time.Millisecond, - HealthInitialWait: 1 * time.Millisecond, - SleepFunc: func(d time.Duration) {}, - }) - - u.serviceController = &mockServiceController{} - - err := u.Run(context.Background()) - assert.Error(t, err) - - // Verify rollback occurred - old binary should be restored - content, err := os.ReadFile(binaryPath) - require.NoError(t, err) - assert.Equal(t, oldContent, content) -} - -func TestUpdater_Run_LockAcquisitionFailure(t *testing.T) { - tmpDir := t.TempDir() - - lockPath := filepath.Join(tmpDir, "update.lock") - - // Acquire lock first with another lock manager to simulate contention - otherLock := update.NewLockManager(update.LockConfig{LockPath: lockPath}) - require.NoError(t, otherLock.TryLock(1*time.Hour)) - defer otherLock.Unlock() - - u := NewUpdater(&UpdaterConfig{ - AgentBinaryPath: filepath.Join(tmpDir, "hostlink"), - BackupDir: filepath.Join(tmpDir, "backup"), - StagingDir: filepath.Join(tmpDir, "staging"), - LockPath: lockPath, - StatePath: filepath.Join(tmpDir, "state.json"), - HealthURL: "http://localhost:8080/health", - TargetVersion: "v2.0.0", - LockRetries: 1, - LockRetryInterval: 10 * time.Millisecond, - }) - - err := u.Run(context.Background()) - assert.Error(t, err) - assert.ErrorIs(t, err, update.ErrLockAcquireFailed) -} - -func TestUpdater_Run_CleansUpTempFiles(t *testing.T) { - tmpDir := t.TempDir() - - binaryPath := filepath.Join(tmpDir, "usr", "bin", "hostlink") - backupDir := filepath.Join(tmpDir, "backup") - stagingDir := filepath.Join(tmpDir, "staging") - lockPath := filepath.Join(tmpDir, "update.lock") - statePath := filepath.Join(tmpDir, "state.json") - - // Create current binary - createTestBinary(t, binaryPath, []byte("old binary")) - - // Create leftover temp files - tempFile1 := filepath.Join(tmpDir, "usr", "bin", "hostlink.tmp.abc123") - tempFile2 := filepath.Join(tmpDir, "usr", "bin", "hostlink.tmp.def456") - require.NoError(t, os.WriteFile(tempFile1, []byte("temp"), 0755)) - require.NoError(t, os.WriteFile(tempFile2, []byte("temp"), 0755)) - - // Create staged tarball - tarballPath := filepath.Join(stagingDir, "hostlink.tar.gz") - createTestTarball(t, tarballPath, []byte("new binary v2.0.0")) - - // Mock health server - healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(update.HealthResponse{Ok: true, Version: "v2.0.0"}) - })) - defer healthServer.Close() - - u := NewUpdater(&UpdaterConfig{ - AgentBinaryPath: binaryPath, - BackupDir: backupDir, - StagingDir: stagingDir, - LockPath: lockPath, - StatePath: statePath, - HealthURL: healthServer.URL, - TargetVersion: "v2.0.0", - ServiceStopTimeout: 100 * time.Millisecond, - ServiceStartTimeout: 100 * time.Millisecond, - HealthCheckRetries: 1, - HealthCheckInterval: 10 * time.Millisecond, - HealthInitialWait: 1 * time.Millisecond, - SleepFunc: func(d time.Duration) {}, - }) - - u.serviceController = &mockServiceController{} - - err := u.Run(context.Background()) - require.NoError(t, err) - - // Verify temp files were cleaned up - entries, err := os.ReadDir(filepath.Dir(binaryPath)) - require.NoError(t, err) - for _, entry := range entries { - assert.NotContains(t, entry.Name(), ".tmp.", "temp files should be cleaned up") - } -} - -func TestUpdater_Rollback_RestoresAndStartsService(t *testing.T) { - tmpDir := t.TempDir() - - binaryPath := filepath.Join(tmpDir, "hostlink") - backupDir := filepath.Join(tmpDir, "backup") - statePath := filepath.Join(tmpDir, "state.json") - - // Create backup - backupContent := []byte("backup binary v1.0.0") - require.NoError(t, os.MkdirAll(backupDir, 0755)) - require.NoError(t, os.WriteFile(filepath.Join(backupDir, "hostlink"), backupContent, 0755)) - - // Create "broken" current binary - require.NoError(t, os.WriteFile(binaryPath, []byte("broken"), 0755)) - - // Track service call order - var callOrder []string - mockSvc := &mockServiceController{ - onStop: func() { callOrder = append(callOrder, "stop") }, - onStart: func() { callOrder = append(callOrder, "start") }, - } - - u := NewUpdater(&UpdaterConfig{ - AgentBinaryPath: binaryPath, - BackupDir: backupDir, - StagingDir: filepath.Join(tmpDir, "staging"), - LockPath: filepath.Join(tmpDir, "update.lock"), - StatePath: statePath, - HealthURL: "http://localhost:8080/health", - TargetVersion: "v2.0.0", - ServiceStopTimeout: 100 * time.Millisecond, - ServiceStartTimeout: 100 * time.Millisecond, - }) - u.serviceController = mockSvc - - err := u.Rollback(context.Background()) - require.NoError(t, err) - - // Verify binary was restored - content, err := os.ReadFile(binaryPath) - require.NoError(t, err) - assert.Equal(t, backupContent, content) - - // Verify service was stopped then started (in that order) - assert.True(t, mockSvc.stopCalled) - assert.True(t, mockSvc.startCalled) - assert.Equal(t, []string{"stop", "start"}, callOrder) -} - -func TestUpdater_Rollback_UpdatesStateToRolledBack(t *testing.T) { - tmpDir := t.TempDir() - - binaryPath := filepath.Join(tmpDir, "hostlink") - backupDir := filepath.Join(tmpDir, "backup") - statePath := filepath.Join(tmpDir, "state.json") - - // Create backup - require.NoError(t, os.MkdirAll(backupDir, 0755)) - require.NoError(t, os.WriteFile(filepath.Join(backupDir, "hostlink"), []byte("backup"), 0755)) - - // Create current binary - require.NoError(t, os.WriteFile(binaryPath, []byte("current"), 0755)) - - u := NewUpdater(&UpdaterConfig{ - AgentBinaryPath: binaryPath, - BackupDir: backupDir, - StagingDir: filepath.Join(tmpDir, "staging"), - LockPath: filepath.Join(tmpDir, "update.lock"), - StatePath: statePath, - HealthURL: "http://localhost:8080/health", - TargetVersion: "v2.0.0", - ServiceStopTimeout: 100 * time.Millisecond, - ServiceStartTimeout: 100 * time.Millisecond, - }) - u.serviceController = &mockServiceController{} - - err := u.Rollback(context.Background()) - require.NoError(t, err) - - // Verify state was updated - stateWriter := update.NewStateWriter(update.StateConfig{StatePath: statePath}) - state, err := stateWriter.Read() - require.NoError(t, err) - assert.Equal(t, update.StateRolledBack, state.State) -} - -func TestUpdater_UpdatePhases(t *testing.T) { - tmpDir := t.TempDir() - - binaryPath := filepath.Join(tmpDir, "hostlink") - backupDir := filepath.Join(tmpDir, "backup") - stagingDir := filepath.Join(tmpDir, "staging") - statePath := filepath.Join(tmpDir, "state.json") - - // Create current binary - createTestBinary(t, binaryPath, []byte("old binary")) - - // Create staged tarball - tarballPath := filepath.Join(stagingDir, "hostlink.tar.gz") - createTestTarball(t, tarballPath, []byte("new binary v2.0.0")) - - // Track phase transitions - var phases []string - - healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(update.HealthResponse{Ok: true, Version: "v2.0.0"}) - })) - defer healthServer.Close() - - mockSvc := &mockServiceController{ - onStop: func() { phases = append(phases, "stop") }, - onStart: func() { phases = append(phases, "start") }, - } - - u := NewUpdater(&UpdaterConfig{ - AgentBinaryPath: binaryPath, - BackupDir: backupDir, - StagingDir: stagingDir, - LockPath: filepath.Join(tmpDir, "update.lock"), - StatePath: statePath, - HealthURL: healthServer.URL, - TargetVersion: "v2.0.0", - ServiceStopTimeout: 100 * time.Millisecond, - ServiceStartTimeout: 100 * time.Millisecond, - HealthCheckRetries: 1, - HealthCheckInterval: 10 * time.Millisecond, - HealthInitialWait: 1 * time.Millisecond, - SleepFunc: func(d time.Duration) {}, - }) - u.serviceController = mockSvc - u.onPhaseChange = func(phase Phase) { - phases = append(phases, string(phase)) - } - - err := u.Run(context.Background()) - require.NoError(t, err) - - // Verify phases executed in order - expectedPhases := []string{ - string(PhaseAcquireLock), - string(PhaseStopping), - "stop", - string(PhaseBackup), - string(PhaseInstalling), - string(PhaseStarting), - "start", - string(PhaseVerifying), - string(PhaseCompleted), - } - assert.Equal(t, expectedPhases, phases) -} - -func TestUpdater_Run_CancelledBeforeStop(t *testing.T) { - tmpDir := t.TempDir() - - binaryPath := filepath.Join(tmpDir, "hostlink") - createTestBinary(t, binaryPath, []byte("binary")) - - mockSvc := &mockServiceController{} - - u := NewUpdater(&UpdaterConfig{ - AgentBinaryPath: binaryPath, - BackupDir: filepath.Join(tmpDir, "backup"), - StagingDir: filepath.Join(tmpDir, "staging"), - LockPath: filepath.Join(tmpDir, "update.lock"), - StatePath: filepath.Join(tmpDir, "state.json"), - HealthURL: "http://localhost:8080/health", - TargetVersion: "v2.0.0", - ServiceStopTimeout: 100 * time.Millisecond, - ServiceStartTimeout: 100 * time.Millisecond, - HealthInitialWait: 1 * time.Millisecond, - SleepFunc: func(d time.Duration) {}, - }) - u.serviceController = mockSvc - - // Cancel context before calling Run - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - err := u.Run(ctx) - - assert.ErrorIs(t, err, context.Canceled) - assert.False(t, mockSvc.stopCalled, "service should not have been stopped") - assert.False(t, mockSvc.startCalled, "service should not have been started") -} - -func TestUpdater_Run_CancelledAfterStop(t *testing.T) { - tmpDir := t.TempDir() - - binaryPath := filepath.Join(tmpDir, "hostlink") - createTestBinary(t, binaryPath, []byte("binary")) - stagingDir := filepath.Join(tmpDir, "staging") - createTestTarball(t, filepath.Join(stagingDir, "hostlink.tar.gz"), []byte("new")) - - ctx, cancel := context.WithCancel(context.Background()) - - mockSvc := &mockServiceController{ - onStop: func() { - cancel() // Cancel after stop completes - }, - } - - u := NewUpdater(&UpdaterConfig{ - AgentBinaryPath: binaryPath, - BackupDir: filepath.Join(tmpDir, "backup"), - StagingDir: stagingDir, - LockPath: filepath.Join(tmpDir, "update.lock"), - StatePath: filepath.Join(tmpDir, "state.json"), - HealthURL: "http://localhost:8080/health", - TargetVersion: "v2.0.0", - ServiceStopTimeout: 100 * time.Millisecond, - ServiceStartTimeout: 100 * time.Millisecond, - HealthInitialWait: 1 * time.Millisecond, - SleepFunc: func(d time.Duration) {}, - }) - u.serviceController = mockSvc - - err := u.Run(ctx) - - assert.ErrorIs(t, err, context.Canceled) - assert.True(t, mockSvc.stopCalled) - assert.True(t, mockSvc.startCalled, "service must be restarted after being stopped") -} - -func TestUpdater_Run_CancelledAfterInstall(t *testing.T) { - tmpDir := t.TempDir() - - binaryPath := filepath.Join(tmpDir, "hostlink") - oldContent := []byte("old binary v1.0.0") - createTestBinary(t, binaryPath, oldContent) - stagingDir := filepath.Join(tmpDir, "staging") - createTestTarball(t, filepath.Join(stagingDir, "hostlink.tar.gz"), []byte("new binary")) - - ctx, cancel := context.WithCancel(context.Background()) - - u := NewUpdater(&UpdaterConfig{ - AgentBinaryPath: binaryPath, - BackupDir: filepath.Join(tmpDir, "backup"), - StagingDir: stagingDir, - LockPath: filepath.Join(tmpDir, "update.lock"), - StatePath: filepath.Join(tmpDir, "state.json"), - HealthURL: "http://localhost:8080/health", - TargetVersion: "v2.0.0", - ServiceStopTimeout: 100 * time.Millisecond, - ServiceStartTimeout: 100 * time.Millisecond, - HealthInitialWait: 1 * time.Millisecond, - SleepFunc: func(d time.Duration) {}, - }) - - // Cancel right after install phase begins - u.serviceController = &mockServiceController{} - u.onPhaseChange = func(phase Phase) { - if phase == PhaseInstalling { - // Let install complete, cancel right after - // We use a goroutine to cancel after a tiny delay - } - if phase == PhaseStarting { - // Cancel before start completes its inner work - cancel() - } - } - - err := u.Run(ctx) - - // Context was cancelled after install, during start. - // Start uses context.Background() so it should still succeed. - // But since ctx is cancelled after start returns, verification is skipped. - assert.ErrorIs(t, err, context.Canceled) - - // Service should have been started (start uses Background ctx) - svc := u.serviceController.(*mockServiceController) - assert.True(t, svc.startCalled) -} - -func TestUpdater_Run_CancelledDuringVerification(t *testing.T) { - tmpDir := t.TempDir() - - binaryPath := filepath.Join(tmpDir, "hostlink") - createTestBinary(t, binaryPath, []byte("old binary")) - stagingDir := filepath.Join(tmpDir, "staging") - createTestTarball(t, filepath.Join(stagingDir, "hostlink.tar.gz"), []byte("new binary")) - - ctx, cancel := context.WithCancel(context.Background()) - - // Health server that blocks until context is cancelled - healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Wait for context cancellation to simulate slow health check - <-ctx.Done() - w.WriteHeader(http.StatusServiceUnavailable) - })) - defer healthServer.Close() - - u := NewUpdater(&UpdaterConfig{ - AgentBinaryPath: binaryPath, - BackupDir: filepath.Join(tmpDir, "backup"), - StagingDir: stagingDir, - LockPath: filepath.Join(tmpDir, "update.lock"), - StatePath: filepath.Join(tmpDir, "state.json"), - HealthURL: healthServer.URL, - TargetVersion: "v2.0.0", - ServiceStopTimeout: 100 * time.Millisecond, - ServiceStartTimeout: 100 * time.Millisecond, - HealthCheckRetries: 1, - HealthCheckInterval: 10 * time.Millisecond, - HealthInitialWait: 1 * time.Millisecond, - SleepFunc: func(d time.Duration) {}, - }) - u.serviceController = &mockServiceController{} - - // Cancel during verification - go func() { - time.Sleep(50 * time.Millisecond) - cancel() - }() - - err := u.Run(ctx) - - // Should return context.Canceled, NOT trigger rollback - assert.ErrorIs(t, err, context.Canceled) - - // Binary should NOT be rolled back (service is running with new binary) - content, err := os.ReadFile(binaryPath) - require.NoError(t, err) - assert.Equal(t, []byte("new binary"), content) -} - -func TestUpdater_Run_NoDoubleUnlock(t *testing.T) { - tmpDir := t.TempDir() - - binaryPath := filepath.Join(tmpDir, "hostlink") - createTestBinary(t, binaryPath, []byte("binary")) - - // Pre-cancelled context - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - u := NewUpdater(&UpdaterConfig{ - AgentBinaryPath: binaryPath, - BackupDir: filepath.Join(tmpDir, "backup"), - StagingDir: filepath.Join(tmpDir, "staging"), - LockPath: filepath.Join(tmpDir, "update.lock"), - StatePath: filepath.Join(tmpDir, "state.json"), - HealthURL: "http://localhost:8080/health", - TargetVersion: "v2.0.0", - ServiceStopTimeout: 100 * time.Millisecond, - ServiceStartTimeout: 100 * time.Millisecond, - HealthInitialWait: 1 * time.Millisecond, - SleepFunc: func(d time.Duration) {}, - }) - u.serviceController = &mockServiceController{} - - // This should not panic - only Run() unlocks via defer - err := u.Run(ctx) - assert.ErrorIs(t, err, context.Canceled) - - // Lock file should not exist (properly unlocked once) - _, statErr := os.Stat(filepath.Join(tmpDir, "update.lock")) - assert.True(t, os.IsNotExist(statErr), "lock should be released") -} - -// Mock service controller for testing -type mockServiceController struct { - stopCalled bool - startCalled bool - stopErr error - startErr error - onStop func() - onStart func() -} - -func (m *mockServiceController) Stop(ctx context.Context) error { - m.stopCalled = true - if m.onStop != nil { - m.onStop() - } - return m.stopErr -} - -func (m *mockServiceController) Start(ctx context.Context) error { - m.startCalled = true - if m.onStart != nil { - m.onStart() - } - return m.startErr -} - -func newGzipWriter(w *os.File) *gzip.Writer { - return gzip.NewWriter(w) -} - -func newTarWriter(gw *gzip.Writer) *tar.Writer { - return tar.NewWriter(gw) -} - -type tarHeader = tar.Header diff --git a/cmd/upgrade/dryrun.go b/cmd/upgrade/dryrun.go new file mode 100644 index 0000000..8cca5cc --- /dev/null +++ b/cmd/upgrade/dryrun.go @@ -0,0 +1,166 @@ +package upgrade + +import ( + "context" + "fmt" + "os" + "path/filepath" + + "golang.org/x/sys/unix" +) + +// CheckName identifies a dry-run precondition check. +type CheckName string + +const ( + CheckLock CheckName = "lock_acquirable" + CheckBinary CheckName = "binary_writable" + CheckBackupDir CheckName = "backup_dir_writable" + CheckSelf CheckName = "self_executable" + CheckService CheckName = "service_exists" + CheckDiskSpace CheckName = "disk_space" +) + +// CheckResult holds the outcome of a single dry-run check. +type CheckResult struct { + Name CheckName + Passed bool + Detail string // Human-readable detail (error message if failed) +} + +// DryRun validates all upgrade preconditions without modifying state. +// Returns a slice of check results (one per check) and an error only if +// something unexpected prevents the checks from running at all. +func (u *Upgrader) DryRun(ctx context.Context) []CheckResult { + results := make([]CheckResult, 0, 6) + + results = append(results, u.checkLock()) + results = append(results, u.checkBinaryWritable()) + results = append(results, u.checkBackupDir()) + results = append(results, u.checkSelfExecutable()) + results = append(results, u.checkServiceExists(ctx)) + results = append(results, u.checkDiskSpace()) + + return results +} + +// checkLock verifies the upgrade lock can be acquired (then immediately releases it). +func (u *Upgrader) checkLock() CheckResult { + err := u.lock.TryLock(DefaultLockExpiration) + if err != nil { + return CheckResult{Name: CheckLock, Passed: false, Detail: err.Error()} + } + u.lock.Unlock() + return CheckResult{Name: CheckLock, Passed: true, Detail: "lock is available"} +} + +// checkBinaryWritable verifies the install path exists and is writable. +func (u *Upgrader) checkBinaryWritable() CheckResult { + info, err := os.Stat(u.config.InstallPath) + if err != nil { + return CheckResult{Name: CheckBinary, Passed: false, Detail: fmt.Sprintf("cannot stat: %v", err)} + } + if info.IsDir() { + return CheckResult{Name: CheckBinary, Passed: false, Detail: "path is a directory"} + } + + // Check write access to the directory (needed for atomic rename) + dir := filepath.Dir(u.config.InstallPath) + if err := unix.Access(dir, unix.W_OK); err != nil { + return CheckResult{Name: CheckBinary, Passed: false, Detail: fmt.Sprintf("directory not writable: %v", err)} + } + + return CheckResult{Name: CheckBinary, Passed: true, Detail: "binary exists and directory is writable"} +} + +// checkBackupDir verifies the backup directory is writable or can be created. +func (u *Upgrader) checkBackupDir() CheckResult { + info, err := os.Stat(u.config.BackupDir) + + // Directory doesn't exist — verify parent is writable (could create it) + if os.IsNotExist(err) { + parent := filepath.Dir(u.config.BackupDir) + if err := unix.Access(parent, unix.W_OK); err != nil { + return CheckResult{Name: CheckBackupDir, Passed: false, Detail: fmt.Sprintf("parent not writable: %v", err)} + } + return CheckResult{Name: CheckBackupDir, Passed: true, Detail: "backup directory can be created"} + } + + // Stat failed for another reason + if err != nil { + return CheckResult{Name: CheckBackupDir, Passed: false, Detail: fmt.Sprintf("cannot stat: %v", err)} + } + + // Path exists but is not a directory + if !info.IsDir() { + return CheckResult{Name: CheckBackupDir, Passed: false, Detail: "path is not a directory"} + } + + // Directory exists — verify writable + if err := unix.Access(u.config.BackupDir, unix.W_OK); err != nil { + return CheckResult{Name: CheckBackupDir, Passed: false, Detail: fmt.Sprintf("not writable: %v", err)} + } + + return CheckResult{Name: CheckBackupDir, Passed: true, Detail: "backup directory is writable"} +} + +// checkSelfExecutable verifies the staged binary (self) exists and is executable. +func (u *Upgrader) checkSelfExecutable() CheckResult { + info, err := os.Stat(u.config.SelfPath) + if err != nil { + return CheckResult{Name: CheckSelf, Passed: false, Detail: fmt.Sprintf("cannot stat: %v", err)} + } + if info.IsDir() { + return CheckResult{Name: CheckSelf, Passed: false, Detail: "path is a directory"} + } + + if err := unix.Access(u.config.SelfPath, unix.X_OK); err != nil { + return CheckResult{Name: CheckSelf, Passed: false, Detail: fmt.Sprintf("not executable: %v", err)} + } + + return CheckResult{Name: CheckSelf, Passed: true, Detail: "staged binary is executable"} +} + +// checkServiceExists verifies the systemd service unit is loaded. +func (u *Upgrader) checkServiceExists(ctx context.Context) CheckResult { + exists, err := u.serviceController.Exists(ctx) + if err != nil { + return CheckResult{Name: CheckService, Passed: false, Detail: fmt.Sprintf("check failed: %v", err)} + } + if !exists { + return CheckResult{Name: CheckService, Passed: false, Detail: "service unit not found"} + } + return CheckResult{Name: CheckService, Passed: true, Detail: "service unit is loaded"} +} + +// checkDiskSpace verifies there is enough space for the backup. +func (u *Upgrader) checkDiskSpace() CheckResult { + info, err := os.Stat(u.config.InstallPath) + if err != nil { + return CheckResult{Name: CheckDiskSpace, Passed: false, Detail: fmt.Sprintf("cannot stat binary: %v", err)} + } + binarySize := info.Size() + + // Check available space in backup directory (or its parent if it doesn't exist yet) + statDir := u.config.BackupDir + if _, err := os.Stat(statDir); os.IsNotExist(err) { + statDir = filepath.Dir(statDir) + } + var stat unix.Statfs_t + if err := unix.Statfs(statDir, &stat); err != nil { + return CheckResult{Name: CheckDiskSpace, Passed: false, Detail: fmt.Sprintf("cannot check disk space: %v", err)} + } + + available := int64(stat.Bavail) * int64(stat.Bsize) + // Need at least 2x binary size (backup + new binary during atomic rename) + required := binarySize * 2 + if available < required { + return CheckResult{ + Name: CheckDiskSpace, + Passed: false, + Detail: fmt.Sprintf("insufficient space: need %d bytes, have %d", required, available), + } + } + + return CheckResult{Name: CheckDiskSpace, Passed: true, Detail: fmt.Sprintf("sufficient space available (%d bytes free)", available)} +} diff --git a/cmd/upgrade/dryrun_test.go b/cmd/upgrade/dryrun_test.go new file mode 100644 index 0000000..6b9d7cf --- /dev/null +++ b/cmd/upgrade/dryrun_test.go @@ -0,0 +1,277 @@ +package upgrade + +import ( + "context" + "errors" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "hostlink/internal/update" +) + +func newDryRunUpgrader(t *testing.T, tmpDir string) (*Upgrader, *mockServiceController) { + t.Helper() + + installPath := filepath.Join(tmpDir, "usr", "bin", "hostlink") + selfPath := filepath.Join(tmpDir, "staging", "hostlink") + backupDir := filepath.Join(tmpDir, "backup") + lockPath := filepath.Join(tmpDir, "update.lock") + + // Create the install binary + createTestBinary(t, installPath, []byte("current binary")) + // Create the self binary + createTestBinary(t, selfPath, []byte("staged binary")) + + mockSvc := &mockServiceController{existsVal: true} + + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: selfPath, + BackupDir: backupDir, + LockPath: lockPath, + StatePath: filepath.Join(tmpDir, "state.json"), + HealthURL: "http://localhost:8080/health", + TargetVersion: "v2.0.0", + LockRetries: 1, + LockRetryInterval: 10 * time.Millisecond, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, + }) + require.NoError(t, err) + u.serviceController = mockSvc + + return u, mockSvc +} + +func findCheck(results []CheckResult, name CheckName) *CheckResult { + for i := range results { + if results[i].Name == name { + return &results[i] + } + } + return nil +} + +func TestDryRun_AllPassWhenPreConditionsMet(t *testing.T) { + tmpDir := t.TempDir() + u, _ := newDryRunUpgrader(t, tmpDir) + + results := u.DryRun(context.Background()) + + require.Len(t, results, 6) + for _, r := range results { + assert.True(t, r.Passed, "check %s should pass: %s", r.Name, r.Detail) + } +} + +func TestDryRun_LockCheck_FailsWhenLocked(t *testing.T) { + tmpDir := t.TempDir() + u, _ := newDryRunUpgrader(t, tmpDir) + + // Hold the lock + lockPath := filepath.Join(tmpDir, "update.lock") + otherLock := update.NewLockManager(update.LockConfig{LockPath: lockPath}) + require.NoError(t, otherLock.TryLock(1*time.Hour)) + defer otherLock.Unlock() + + results := u.DryRun(context.Background()) + + check := findCheck(results, CheckLock) + require.NotNil(t, check) + assert.False(t, check.Passed) +} + +func TestDryRun_LockCheck_ReleasesAfterCheck(t *testing.T) { + tmpDir := t.TempDir() + u, _ := newDryRunUpgrader(t, tmpDir) + + results := u.DryRun(context.Background()) + + // Lock should pass + check := findCheck(results, CheckLock) + require.NotNil(t, check) + assert.True(t, check.Passed) + + // Lock should be released - try again + lockPath := filepath.Join(tmpDir, "update.lock") + otherLock := update.NewLockManager(update.LockConfig{LockPath: lockPath}) + err := otherLock.TryLock(1 * time.Hour) + assert.NoError(t, err, "lock should have been released after dry-run check") + otherLock.Unlock() +} + +func TestDryRun_BinaryCheck_FailsWhenMissing(t *testing.T) { + tmpDir := t.TempDir() + u, _ := newDryRunUpgrader(t, tmpDir) + + // Remove the install binary + os.Remove(u.config.InstallPath) + + results := u.DryRun(context.Background()) + + check := findCheck(results, CheckBinary) + require.NotNil(t, check) + assert.False(t, check.Passed) + assert.Contains(t, check.Detail, "cannot stat") +} + +func TestDryRun_BinaryCheck_FailsWhenDirectory(t *testing.T) { + tmpDir := t.TempDir() + u, _ := newDryRunUpgrader(t, tmpDir) + + // Replace binary with a directory + os.Remove(u.config.InstallPath) + os.MkdirAll(u.config.InstallPath, 0755) + + results := u.DryRun(context.Background()) + + check := findCheck(results, CheckBinary) + require.NotNil(t, check) + assert.False(t, check.Passed) + assert.Contains(t, check.Detail, "directory") +} + +func TestDryRun_BackupDirCheck_PassesWhenParentWritable(t *testing.T) { + tmpDir := t.TempDir() + u, _ := newDryRunUpgrader(t, tmpDir) + + // Backup dir doesn't exist yet + _, err := os.Stat(u.config.BackupDir) + require.True(t, os.IsNotExist(err)) + + results := u.DryRun(context.Background()) + + check := findCheck(results, CheckBackupDir) + require.NotNil(t, check) + assert.True(t, check.Passed) + assert.Contains(t, check.Detail, "can be created") + + // Directory should NOT have been created (no side effects) + _, err = os.Stat(u.config.BackupDir) + assert.True(t, os.IsNotExist(err), "dry-run should not create backup directory") +} + +func TestDryRun_SelfCheck_FailsWhenMissing(t *testing.T) { + tmpDir := t.TempDir() + u, _ := newDryRunUpgrader(t, tmpDir) + + os.Remove(u.config.SelfPath) + + results := u.DryRun(context.Background()) + + check := findCheck(results, CheckSelf) + require.NotNil(t, check) + assert.False(t, check.Passed) + assert.Contains(t, check.Detail, "cannot stat") +} + +func TestDryRun_SelfCheck_FailsWhenNotExecutable(t *testing.T) { + tmpDir := t.TempDir() + u, _ := newDryRunUpgrader(t, tmpDir) + + // Remove execute permission + os.Chmod(u.config.SelfPath, 0644) + + results := u.DryRun(context.Background()) + + check := findCheck(results, CheckSelf) + require.NotNil(t, check) + assert.False(t, check.Passed) + assert.Contains(t, check.Detail, "not executable") +} + +func TestDryRun_ServiceCheck_FailsWhenNotFound(t *testing.T) { + tmpDir := t.TempDir() + u, mockSvc := newDryRunUpgrader(t, tmpDir) + + mockSvc.existsVal = false + + results := u.DryRun(context.Background()) + + check := findCheck(results, CheckService) + require.NotNil(t, check) + assert.False(t, check.Passed) + assert.Contains(t, check.Detail, "not found") +} + +func TestDryRun_ServiceCheck_FailsOnError(t *testing.T) { + tmpDir := t.TempDir() + u, mockSvc := newDryRunUpgrader(t, tmpDir) + + mockSvc.existsVal = false + mockSvc.existsErr = errors.New("systemctl not available") + + results := u.DryRun(context.Background()) + + check := findCheck(results, CheckService) + require.NotNil(t, check) + assert.False(t, check.Passed) + assert.Contains(t, check.Detail, "check failed") +} + +func TestDryRun_DiskSpaceCheck_Passes(t *testing.T) { + tmpDir := t.TempDir() + u, _ := newDryRunUpgrader(t, tmpDir) + + // Ensure backup dir exists for statfs + os.MkdirAll(u.config.BackupDir, 0755) + + results := u.DryRun(context.Background()) + + check := findCheck(results, CheckDiskSpace) + require.NotNil(t, check) + assert.True(t, check.Passed) + assert.Contains(t, check.Detail, "sufficient space") +} + +func TestDryRun_DoesNotModifyBinary(t *testing.T) { + tmpDir := t.TempDir() + u, _ := newDryRunUpgrader(t, tmpDir) + + originalContent, err := os.ReadFile(u.config.InstallPath) + require.NoError(t, err) + + u.DryRun(context.Background()) + + afterContent, err := os.ReadFile(u.config.InstallPath) + require.NoError(t, err) + assert.Equal(t, originalContent, afterContent, "dry-run should not modify the binary") +} + +func TestDryRun_DoesNotStopService(t *testing.T) { + tmpDir := t.TempDir() + u, mockSvc := newDryRunUpgrader(t, tmpDir) + + u.DryRun(context.Background()) + + assert.False(t, mockSvc.stopCalled, "dry-run should not stop the service") + assert.False(t, mockSvc.startCalled, "dry-run should not start the service") +} + +func TestDryRun_ReturnsAllChecksEvenOnFailure(t *testing.T) { + tmpDir := t.TempDir() + u, mockSvc := newDryRunUpgrader(t, tmpDir) + + // Make multiple checks fail + os.Remove(u.config.InstallPath) + os.Remove(u.config.SelfPath) + mockSvc.existsVal = false + + results := u.DryRun(context.Background()) + + // Should still get all 6 checks + require.Len(t, results, 6) + + // Verify failed ones + assert.False(t, findCheck(results, CheckBinary).Passed) + assert.False(t, findCheck(results, CheckSelf).Passed) + assert.False(t, findCheck(results, CheckService).Passed) + + // Lock and backup dir should still pass + assert.True(t, findCheck(results, CheckLock).Passed) + assert.True(t, findCheck(results, CheckBackupDir).Passed) +} diff --git a/cmd/upgrade/logger.go b/cmd/upgrade/logger.go new file mode 100644 index 0000000..897f900 --- /dev/null +++ b/cmd/upgrade/logger.go @@ -0,0 +1,50 @@ +package upgrade + +import ( + "io" + "log/slog" + "os" + "path/filepath" +) + +const ( + // DefaultLogPath is the default path for upgrade logs. + DefaultLogPath = "/var/log/hostlink/upgrade.log" +) + +// NewLogger creates a structured JSON logger that writes to both the given +// log file (append mode) and stderr. Returns the logger and a cleanup function +// that closes the log file. +func NewLogger(logPath string) (*slog.Logger, func(), error) { + // Ensure log directory exists + dir := filepath.Dir(logPath) + if err := os.MkdirAll(dir, 0755); err != nil { + return nil, nil, err + } + + // Open log file in append mode + f, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + return nil, nil, err + } + + // Multi-writer: file + stderr + w := io.MultiWriter(f, os.Stderr) + + handler := slog.NewJSONHandler(w, &slog.HandlerOptions{ + Level: slog.LevelInfo, + }) + + logger := slog.New(handler) + + cleanup := func() { + f.Close() + } + + return logger, cleanup, nil +} + +// discardLogger returns a logger that writes nothing (for testing). +func discardLogger() *slog.Logger { + return slog.New(slog.NewJSONHandler(io.Discard, nil)) +} diff --git a/cmd/upgrade/logger_test.go b/cmd/upgrade/logger_test.go new file mode 100644 index 0000000..fba097b --- /dev/null +++ b/cmd/upgrade/logger_test.go @@ -0,0 +1,267 @@ +package upgrade + +import ( + "bytes" + "context" + "encoding/json" + "log/slog" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "hostlink/internal/update" +) + +func TestNewLogger_CreatesLogFile(t *testing.T) { + tmpDir := t.TempDir() + logPath := filepath.Join(tmpDir, "sub", "upgrade.log") + + logger, cleanup, err := NewLogger(logPath) + require.NoError(t, err) + defer cleanup() + + logger.Info("test message") + + // Verify file was created + _, err = os.Stat(logPath) + assert.NoError(t, err) +} + +func TestNewLogger_AppendsToExistingFile(t *testing.T) { + tmpDir := t.TempDir() + logPath := filepath.Join(tmpDir, "upgrade.log") + + // Write existing content + require.NoError(t, os.WriteFile(logPath, []byte("existing\n"), 0644)) + + logger, cleanup, err := NewLogger(logPath) + require.NoError(t, err) + defer cleanup() + + logger.Info("new message") + + content, err := os.ReadFile(logPath) + require.NoError(t, err) + assert.True(t, strings.HasPrefix(string(content), "existing\n"), "should preserve existing content") + assert.Contains(t, string(content), "new message") +} + +func TestNewLogger_WritesJSON(t *testing.T) { + tmpDir := t.TempDir() + logPath := filepath.Join(tmpDir, "upgrade.log") + + logger, cleanup, err := NewLogger(logPath) + require.NoError(t, err) + defer cleanup() + + logger.Info("structured log", "key", "value") + + content, err := os.ReadFile(logPath) + require.NoError(t, err) + + // Parse as JSON + var entry map[string]interface{} + err = json.Unmarshal(content, &entry) + require.NoError(t, err, "log entry should be valid JSON") + assert.Equal(t, "structured log", entry["msg"]) + assert.Equal(t, "value", entry["key"]) + assert.Equal(t, "INFO", entry["level"]) +} + +func TestNewLogger_ErrorsOnInvalidPath(t *testing.T) { + // Path under a file (not a directory) + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "file") + require.NoError(t, os.WriteFile(filePath, []byte("x"), 0644)) + + logPath := filepath.Join(filePath, "sub", "upgrade.log") + + _, _, err := NewLogger(logPath) + assert.Error(t, err) +} + +func testLogger(t *testing.T) (*slog.Logger, *bytes.Buffer) { + t.Helper() + buf := &bytes.Buffer{} + handler := slog.NewJSONHandler(buf, &slog.HandlerOptions{Level: slog.LevelInfo}) + return slog.New(handler), buf +} + +func TestUpgrader_Run_LogsPhaseTransitions(t *testing.T) { + tmpDir := t.TempDir() + + installPath := filepath.Join(tmpDir, "hostlink") + selfPath := filepath.Join(tmpDir, "staging", "hostlink") + + createTestBinary(t, installPath, []byte("old binary")) + createTestBinary(t, selfPath, []byte("new binary v2.0.0")) + + healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(update.HealthResponse{Ok: true, Version: "v2.0.0"}) + })) + defer healthServer.Close() + + logger, buf := testLogger(t) + + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: selfPath, + BackupDir: filepath.Join(tmpDir, "backup"), + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: filepath.Join(tmpDir, "state.json"), + HealthURL: healthServer.URL, + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + HealthCheckRetries: 1, + HealthCheckInterval: 10 * time.Millisecond, + HealthInitialWait: 1 * time.Millisecond, + Logger: logger, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, + }) + require.NoError(t, err) + u.serviceController = &mockServiceController{existsVal: true} + + err = u.Run(context.Background()) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "upgrade started") + assert.Contains(t, output, "lock acquired") + assert.Contains(t, output, "backup created") + assert.Contains(t, output, "service stopped") + assert.Contains(t, output, "binary installed") + assert.Contains(t, output, "service started") + assert.Contains(t, output, "health check passed") + assert.Contains(t, output, "upgrade completed successfully") +} + +func TestUpgrader_Run_LogsRollbackOnHealthFailure(t *testing.T) { + tmpDir := t.TempDir() + + installPath := filepath.Join(tmpDir, "hostlink") + selfPath := filepath.Join(tmpDir, "staging", "hostlink") + + createTestBinary(t, installPath, []byte("old binary")) + createTestBinary(t, selfPath, []byte("new binary v2.0.0")) + + healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(update.HealthResponse{Ok: false, Version: "v2.0.0"}) + })) + defer healthServer.Close() + + logger, buf := testLogger(t) + + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: selfPath, + BackupDir: filepath.Join(tmpDir, "backup"), + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: filepath.Join(tmpDir, "state.json"), + HealthURL: healthServer.URL, + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + HealthCheckRetries: 1, + HealthCheckInterval: 10 * time.Millisecond, + HealthInitialWait: 1 * time.Millisecond, + Logger: logger, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, + }) + require.NoError(t, err) + u.serviceController = &mockServiceController{existsVal: true} + + err = u.Run(context.Background()) + assert.Error(t, err) + + output := buf.String() + assert.Contains(t, output, "health check failed, rolling back") + assert.Contains(t, output, "rollback initiated") + assert.Contains(t, output, "backup restored") + assert.Contains(t, output, "service restarted after rollback") + assert.Contains(t, output, "rollback completed") +} + +func TestUpgrader_Run_LogsLockFailure(t *testing.T) { + tmpDir := t.TempDir() + + lockPath := filepath.Join(tmpDir, "update.lock") + otherLock := update.NewLockManager(update.LockConfig{LockPath: lockPath}) + require.NoError(t, otherLock.TryLock(1*time.Hour)) + defer otherLock.Unlock() + + logger, buf := testLogger(t) + + u, err := NewUpgrader(&Config{ + InstallPath: filepath.Join(tmpDir, "hostlink"), + SelfPath: filepath.Join(tmpDir, "staging", "hostlink"), + BackupDir: filepath.Join(tmpDir, "backup"), + LockPath: lockPath, + StatePath: filepath.Join(tmpDir, "state.json"), + HealthURL: "http://localhost:8080/health", + TargetVersion: "v2.0.0", + LockRetries: 1, + LockRetryInterval: 10 * time.Millisecond, + Logger: logger, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, + }) + require.NoError(t, err) + + err = u.Run(context.Background()) + assert.Error(t, err) + + output := buf.String() + assert.Contains(t, output, "failed to acquire lock") +} + +func TestUpgrader_Run_LogsCancellation(t *testing.T) { + tmpDir := t.TempDir() + + installPath := filepath.Join(tmpDir, "hostlink") + selfPath := filepath.Join(tmpDir, "staging", "hostlink") + createTestBinary(t, installPath, []byte("binary")) + createTestBinary(t, selfPath, []byte("new")) + + logger, buf := testLogger(t) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Already cancelled + + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: selfPath, + BackupDir: filepath.Join(tmpDir, "backup"), + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: filepath.Join(tmpDir, "state.json"), + HealthURL: "http://localhost:8080/health", + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + Logger: logger, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, + }) + require.NoError(t, err) + u.serviceController = &mockServiceController{existsVal: true} + + err = u.Run(ctx) + assert.ErrorIs(t, err, context.Canceled) + + output := buf.String() + assert.Contains(t, output, "cancelled") +} + +func TestDiscardLogger_DoesNotPanic(t *testing.T) { + logger := discardLogger() + // Should not panic + logger.Info("message", "key", "value") + logger.Error("error", "err", "something") +} diff --git a/cmd/updater/signals.go b/cmd/upgrade/signals.go similarity index 96% rename from cmd/updater/signals.go rename to cmd/upgrade/signals.go index a58a053..8690196 100644 --- a/cmd/updater/signals.go +++ b/cmd/upgrade/signals.go @@ -1,4 +1,4 @@ -package main +package upgrade import ( "context" diff --git a/cmd/upgrade/signals_test.go b/cmd/upgrade/signals_test.go new file mode 100644 index 0000000..4a43033 --- /dev/null +++ b/cmd/upgrade/signals_test.go @@ -0,0 +1,67 @@ +package upgrade + +import ( + "context" + "os" + "syscall" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWatchSignals_CancelsContextOnSignal(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + stop := WatchSignals(cancel) + defer stop() + + // Send SIGINT to ourselves + proc, err := os.FindProcess(os.Getpid()) + require.NoError(t, err) + require.NoError(t, proc.Signal(syscall.SIGINT)) + + // Wait for cancellation + select { + case <-ctx.Done(): + assert.ErrorIs(t, ctx.Err(), context.Canceled) + case <-time.After(2 * time.Second): + t.Fatal("context was not cancelled within timeout") + } +} + +func TestWatchSignals_StopPreventsCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + stop := WatchSignals(cancel) + stop() // Stop before any signal + + // Give goroutine time to exit + time.Sleep(10 * time.Millisecond) + + // Context should not be cancelled + assert.NoError(t, ctx.Err()) +} + +func TestWatchSignals_MultipleCallsAreIndependent(t *testing.T) { + ctx1, cancel1 := context.WithCancel(context.Background()) + defer cancel1() + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + + stop1 := WatchSignals(cancel1) + stop2 := WatchSignals(cancel2) + defer stop2() + + // Stop only the first watcher + stop1() + + time.Sleep(10 * time.Millisecond) + + // Neither context should be cancelled + assert.NoError(t, ctx1.Err()) + assert.NoError(t, ctx2.Err()) +} diff --git a/cmd/upgrade/upgrade.go b/cmd/upgrade/upgrade.go new file mode 100644 index 0000000..8aad516 --- /dev/null +++ b/cmd/upgrade/upgrade.go @@ -0,0 +1,418 @@ +// Package upgrade implements the hostlink upgrade subcommand. +// It orchestrates the in-place upgrade of the hostlink binary: +// lock → backup → stop → install (self) → start → verify → cleanup. +package upgrade + +import ( + "context" + "errors" + "fmt" + "log/slog" + "os" + "path/filepath" + "strings" + "time" + + "hostlink/internal/update" +) + +// Phase represents the current phase of the upgrade process. +type Phase string + +const ( + PhaseAcquireLock Phase = "acquire_lock" + PhaseBackup Phase = "backup" + PhaseStopping Phase = "stopping" + PhaseInstalling Phase = "installing" + PhaseStarting Phase = "starting" + PhaseVerifying Phase = "verifying" + PhaseCompleted Phase = "completed" + PhaseRollback Phase = "rollback" +) + +// Default configuration values +const ( + DefaultLockRetries = 5 + DefaultLockRetryInterval = 1 * time.Second + DefaultLockExpiration = 5 * time.Minute +) + +// ServiceController interface for mocking in tests. +type ServiceController interface { + Stop(ctx context.Context) error + Start(ctx context.Context) error + Exists(ctx context.Context) (bool, error) +} + +// Config holds the configuration for the Upgrader. +type Config struct { + InstallPath string // Target path (e.g. /usr/bin/hostlink) + SelfPath string // Path to the staged binary (os.Executable()) + BackupDir string // Backup directory + LockPath string // Lock file path + StatePath string // State file path + HealthURL string // Health check URL + TargetVersion string // Version to verify after upgrade + UpdateID string // Unique ID for this update operation + SourceVersion string // Version being upgraded from + ServiceStopTimeout time.Duration // 30s + ServiceStartTimeout time.Duration // 30s + HealthCheckRetries int // 5 + HealthCheckInterval time.Duration // 5s + HealthInitialWait time.Duration // 5s + LockRetries int // 5 + LockRetryInterval time.Duration // 1s + Logger *slog.Logger // Structured logger (nil = discard) + RollbackStopRetries int // Retries for Stop during rollback (default: 3) + RollbackStopRetryInterval time.Duration // Interval between stop retries (default: 1s) + RollbackHealthCheckFunc func(ctx context.Context) error // Health check after rollback restart (nil = skip) + SleepFunc func(context.Context, time.Duration) error // For testing; context-aware sleep (health checker) + LockSleepFunc func(time.Duration) // For testing; simple sleep (lock manager) + InstallFunc func(srcPath, destPath string) error // Defaults to update.InstallSelf +} + +// Upgrader orchestrates the upgrade process. +type Upgrader struct { + config *Config + lock *update.LockManager + state update.StateWriterInterface + serviceController ServiceController + healthChecker *update.HealthChecker + logger *slog.Logger + currentPhase Phase + startedAt time.Time + onPhaseChange func(Phase) // For testing +} + +// NewUpgrader creates a new Upgrader with the given configuration. +// Returns an error if required configuration is missing or invalid. +func NewUpgrader(cfg *Config) (*Upgrader, error) { + if cfg.InstallPath == "" { + return nil, fmt.Errorf("install-path cannot be empty") + } + + if cfg.LockRetries == 0 { + cfg.LockRetries = DefaultLockRetries + } + if cfg.LockRetryInterval == 0 { + cfg.LockRetryInterval = DefaultLockRetryInterval + } + if cfg.ServiceStopTimeout == 0 { + cfg.ServiceStopTimeout = update.DefaultStopTimeout + } + if cfg.ServiceStartTimeout == 0 { + cfg.ServiceStartTimeout = update.DefaultStartTimeout + } + if cfg.HealthCheckRetries == 0 { + cfg.HealthCheckRetries = update.DefaultHealthRetries + } + if cfg.HealthCheckInterval == 0 { + cfg.HealthCheckInterval = update.DefaultHealthInterval + } + if cfg.HealthInitialWait == 0 { + cfg.HealthInitialWait = update.DefaultInitialWait + } + if cfg.RollbackStopRetries == 0 { + cfg.RollbackStopRetries = 3 + } + if cfg.RollbackStopRetryInterval == 0 { + cfg.RollbackStopRetryInterval = 1 * time.Second + } + if cfg.InstallFunc == nil { + cfg.InstallFunc = update.InstallSelf + } + + logger := cfg.Logger + if logger == nil { + logger = discardLogger() + } + + return &Upgrader{ + config: cfg, + logger: logger, + lock: update.NewLockManager(update.LockConfig{ + LockPath: cfg.LockPath, + SleepFunc: cfg.LockSleepFunc, + }), + state: update.NewStateWriter(update.StateConfig{ + StatePath: cfg.StatePath, + }), + serviceController: update.NewServiceController(update.ServiceConfig{ + ServiceName: "hostlink", + StopTimeout: cfg.ServiceStopTimeout, + StartTimeout: cfg.ServiceStartTimeout, + }), + healthChecker: update.NewHealthChecker(update.HealthConfig{ + URL: cfg.HealthURL, + TargetVersion: cfg.TargetVersion, + MaxRetries: cfg.HealthCheckRetries, + RetryInterval: cfg.HealthCheckInterval, + InitialWait: cfg.HealthInitialWait, + SleepFunc: cfg.SleepFunc, + }), + }, nil +} + +// setPhase updates the current phase and calls the callback if set. +func (u *Upgrader) setPhase(phase Phase) { + u.currentPhase = phase + if u.onPhaseChange != nil { + u.onPhaseChange(phase) + } +} + +// writeState writes state data to disk and logs a warning on error. +// State file is for observability only; write failures should not fail the upgrade. +func (u *Upgrader) writeState(data update.StateData) { + if err := u.state.Write(data); err != nil { + u.logger.Warn("failed to write state file", "error", err, "state", data.State) + } +} + +// Run executes the full upgrade process: +// lock → backup → stop → install (self) → start → verify → cleanup → unlock +// +// The key difference from the old updater: instead of extracting a binary from +// a tarball, it copies itself (the staged binary) to the install path. +func (u *Upgrader) Run(ctx context.Context) error { + u.startedAt = time.Now() + u.logger.Info("upgrade started", + "target_version", u.config.TargetVersion, + "install_path", u.config.InstallPath, + "self_path", u.config.SelfPath, + ) + + // Clean up any leftover temp files first + u.cleanupTempFiles() + + // Phase 1: Acquire lock + u.setPhase(PhaseAcquireLock) + if err := u.lock.TryLockWithRetry(DefaultLockExpiration, u.config.LockRetries, u.config.LockRetryInterval); err != nil { + u.logger.Error("failed to acquire lock", "error", err) + return fmt.Errorf("failed to acquire lock: %w", err) + } + defer u.lock.Unlock() + u.logger.Info("lock acquired") + + // serviceStopped tracks whether we've stopped the service. + serviceStopped := false + + // abort restarts the service (if stopped) using a background context. + abort := func(reason error) error { + if serviceStopped { + u.serviceController.Start(context.Background()) + } + return reason + } + + // Check for cancellation before backup + if ctx.Err() != nil { + u.logger.Warn("cancelled before backup", "error", ctx.Err()) + return abort(ctx.Err()) + } + + // Phase 2: Backup current binary + u.setPhase(PhaseBackup) + if err := update.BackupBinary(u.config.InstallPath, u.config.BackupDir); err != nil { + u.logger.Error("failed to backup binary", "error", err) + return abort(fmt.Errorf("failed to backup binary: %w", err)) + } + u.logger.Info("backup created", "backup_dir", u.config.BackupDir) + + // Check for cancellation before stop + if ctx.Err() != nil { + u.logger.Warn("cancelled before stop", "error", ctx.Err()) + return abort(ctx.Err()) + } + + // Phase 3: Stop service + u.setPhase(PhaseStopping) + if err := u.serviceController.Stop(ctx); err != nil { + if ctx.Err() != nil { + u.logger.Warn("cancelled during stop", "error", ctx.Err()) + return abort(ctx.Err()) + } + u.logger.Error("failed to stop service", "error", err) + return fmt.Errorf("failed to stop service: %w", err) + } + serviceStopped = true + u.logger.Info("service stopped") + + // Check for cancellation after stop + if ctx.Err() != nil { + u.logger.Warn("cancelled after stop", "error", ctx.Err()) + return abort(ctx.Err()) + } + + // Phase 4: Install new binary (copy self to install path) + u.setPhase(PhaseInstalling) + if err := u.config.InstallFunc(u.config.SelfPath, u.config.InstallPath); err != nil { + u.logger.Error("failed to install binary, rolling back", "error", err) + installErr := fmt.Errorf("failed to install binary: %w", err) + if rollbackErr := u.rollbackFrom(PhaseInstalling); rollbackErr != nil { + return errors.Join(installErr, rollbackErr) + } + return installErr + } + u.logger.Info("binary installed", "install_path", u.config.InstallPath) + + // After install, the new binary is in place. Start it and exit. + if ctx.Err() != nil { + u.logger.Warn("cancelled after install, starting new service", "error", ctx.Err()) + u.serviceController.Start(context.Background()) + return ctx.Err() + } + + // Phase 5: Start service + // Use background context: even if cancelled, we must start the service. + u.setPhase(PhaseStarting) + if err := u.serviceController.Start(context.Background()); err != nil { + u.logger.Error("failed to start service, rolling back", "error", err) + startErr := fmt.Errorf("failed to start service: %w", err) + if rollbackErr := u.rollbackFrom(PhaseStarting); rollbackErr != nil { + return errors.Join(startErr, rollbackErr) + } + return startErr + } + serviceStopped = false + u.logger.Info("service started") + + // Check for cancellation after start - service is running, skip verification. + if ctx.Err() != nil { + u.logger.Warn("cancelled after start, skipping verification", "error", ctx.Err()) + return ctx.Err() + } + + // Phase 6: Verify health + u.setPhase(PhaseVerifying) + if err := u.healthChecker.WaitForHealth(ctx); err != nil { + if ctx.Err() != nil { + // Cancelled during verification - service is running, just exit. + u.logger.Warn("cancelled during verification", "error", ctx.Err()) + return ctx.Err() + } + // Health check failed (not cancellation) - rollback + u.logger.Error("health check failed, rolling back", "error", err) + healthErr := fmt.Errorf("health check failed: %w", err) + if rollbackErr := u.rollbackFrom(PhaseVerifying); rollbackErr != nil { + return errors.Join(healthErr, rollbackErr) + } + return healthErr + } + u.logger.Info("health check passed") + + // Phase 7: Success + u.setPhase(PhaseCompleted) + u.writeState(update.StateData{ + State: update.StateCompleted, + UpdateID: u.config.UpdateID, + SourceVersion: u.config.SourceVersion, + TargetVersion: u.config.TargetVersion, + StartedAt: u.startedAt, + CompletedAt: timePtr(time.Now()), + }) + + u.logger.Info("upgrade completed successfully", "target_version", u.config.TargetVersion) + return nil +} + +// rollbackFrom restores the backup and starts the service. +func (u *Upgrader) rollbackFrom(failedPhase Phase) error { + u.setPhase(PhaseRollback) + u.logger.Warn("rollback initiated", "failed_phase", string(failedPhase)) + + u.writeState(update.StateData{ + State: update.StateRollback, + UpdateID: u.config.UpdateID, + SourceVersion: u.config.SourceVersion, + TargetVersion: u.config.TargetVersion, + StartedAt: u.startedAt, + }) + + // Stop the service with retries + var stopErr error + for i := 0; i < u.config.RollbackStopRetries; i++ { + if stopErr = u.serviceController.Stop(context.Background()); stopErr == nil { + break + } + u.logger.Warn("rollback stop attempt failed", "attempt", i+1, "error", stopErr) + if i < u.config.RollbackStopRetries-1 { + if u.config.LockSleepFunc != nil { + u.config.LockSleepFunc(u.config.RollbackStopRetryInterval) + } else { + time.Sleep(u.config.RollbackStopRetryInterval) + } + } + } + if stopErr != nil { + u.logger.Warn("all rollback stop attempts failed, proceeding with restore", "error", stopErr) + } + + // Restore backup + if err := update.RestoreBackup(u.config.BackupDir, u.config.InstallPath); err != nil { + u.logger.Error("failed to restore backup during rollback", "error", err) + return fmt.Errorf("failed to restore backup: %w", err) + } + u.logger.Info("backup restored", "install_path", u.config.InstallPath) + + // Start service with old binary + if err := u.serviceController.Start(context.Background()); err != nil { + u.logger.Error("failed to start service after rollback", "error", err) + return fmt.Errorf("failed to start service after rollback: %w", err) + } + u.logger.Info("service restarted after rollback") + + // Health check after rollback restart + if u.config.RollbackHealthCheckFunc != nil { + if err := u.config.RollbackHealthCheckFunc(context.Background()); err != nil { + u.logger.Error("health check failed after rollback restart", "error", err) + // Still write RolledBack state (binary was restored, service started) + u.writeState(update.StateData{ + State: update.StateRolledBack, + UpdateID: u.config.UpdateID, + SourceVersion: u.config.SourceVersion, + TargetVersion: u.config.TargetVersion, + StartedAt: u.startedAt, + CompletedAt: timePtr(time.Now()), + }) + return fmt.Errorf("rollback health check failed: %w", err) + } + u.logger.Info("health check passed after rollback") + } + + // Update state to rolled back + u.writeState(update.StateData{ + State: update.StateRolledBack, + UpdateID: u.config.UpdateID, + SourceVersion: u.config.SourceVersion, + TargetVersion: u.config.TargetVersion, + StartedAt: u.startedAt, + CompletedAt: timePtr(time.Now()), + }) + + u.logger.Info("rollback completed", "restored_version", u.config.SourceVersion) + + // Clean up temp files + u.cleanupTempFiles() + + return nil +} + +// cleanupTempFiles removes any leftover hostlink.tmp.* files. +func (u *Upgrader) cleanupTempFiles() { + dir := filepath.Dir(u.config.InstallPath) + entries, err := os.ReadDir(dir) + if err != nil { + return + } + + for _, entry := range entries { + if strings.HasPrefix(entry.Name(), "hostlink.tmp.") { + os.Remove(filepath.Join(dir, entry.Name())) + } + } +} + +func timePtr(t time.Time) *time.Time { + return &t +} diff --git a/cmd/upgrade/upgrade_test.go b/cmd/upgrade/upgrade_test.go new file mode 100644 index 0000000..f4d29ed --- /dev/null +++ b/cmd/upgrade/upgrade_test.go @@ -0,0 +1,1075 @@ +package upgrade + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "hostlink/internal/update" +) + +// Test helpers +func createTestBinary(t *testing.T, path string, content []byte) { + t.Helper() + dir := filepath.Dir(path) + require.NoError(t, os.MkdirAll(dir, 0755)) + require.NoError(t, os.WriteFile(path, content, 0755)) +} + +func TestUpgrader_Run_HappyPath(t *testing.T) { + tmpDir := t.TempDir() + + // Setup paths + installPath := filepath.Join(tmpDir, "usr", "bin", "hostlink") + selfPath := filepath.Join(tmpDir, "staging", "hostlink") + backupDir := filepath.Join(tmpDir, "backup") + lockPath := filepath.Join(tmpDir, "update.lock") + statePath := filepath.Join(tmpDir, "state.json") + + // Create current binary at install path + createTestBinary(t, installPath, []byte("old binary v1.0.0")) + + // Create "self" binary (the staged new binary) + createTestBinary(t, selfPath, []byte("new binary v2.0.0")) + + // Mock health server + healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(update.HealthResponse{Ok: true, Version: "v2.0.0"}) + })) + defer healthServer.Close() + + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: selfPath, + BackupDir: backupDir, + LockPath: lockPath, + StatePath: statePath, + HealthURL: healthServer.URL, + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + HealthCheckRetries: 1, + HealthCheckInterval: 10 * time.Millisecond, + HealthInitialWait: 1 * time.Millisecond, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, + }) + require.NoError(t, err) + + // Mock service controller (no real systemctl) + u.serviceController = &mockServiceController{} + + err = u.Run(context.Background()) + require.NoError(t, err) + + // Verify new binary is installed (should be a copy of selfPath) + content, err := os.ReadFile(installPath) + require.NoError(t, err) + assert.Equal(t, []byte("new binary v2.0.0"), content) + + // Verify backup exists + backupContent, err := os.ReadFile(filepath.Join(backupDir, "hostlink")) + require.NoError(t, err) + assert.Equal(t, []byte("old binary v1.0.0"), backupContent) +} + +func TestUpgrader_Run_RollbackOnHealthCheckFailure(t *testing.T) { + tmpDir := t.TempDir() + + installPath := filepath.Join(tmpDir, "usr", "bin", "hostlink") + selfPath := filepath.Join(tmpDir, "staging", "hostlink") + backupDir := filepath.Join(tmpDir, "backup") + lockPath := filepath.Join(tmpDir, "update.lock") + statePath := filepath.Join(tmpDir, "state.json") + + oldContent := []byte("old binary v1.0.0") + createTestBinary(t, installPath, oldContent) + createTestBinary(t, selfPath, []byte("new binary v2.0.0")) + + // Mock health server that always returns unhealthy + healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(update.HealthResponse{Ok: false, Version: "v2.0.0"}) + })) + defer healthServer.Close() + + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: selfPath, + BackupDir: backupDir, + LockPath: lockPath, + StatePath: statePath, + HealthURL: healthServer.URL, + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + HealthCheckRetries: 1, + HealthCheckInterval: 10 * time.Millisecond, + HealthInitialWait: 1 * time.Millisecond, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, + }) + require.NoError(t, err) + u.serviceController = &mockServiceController{} + + err = u.Run(context.Background()) + assert.Error(t, err) + + // Verify rollback occurred - old binary should be restored + content, err := os.ReadFile(installPath) + require.NoError(t, err) + assert.Equal(t, oldContent, content) +} + +func TestUpgrader_Run_LockAcquisitionFailure(t *testing.T) { + tmpDir := t.TempDir() + + lockPath := filepath.Join(tmpDir, "update.lock") + + // Acquire lock first to simulate contention + otherLock := update.NewLockManager(update.LockConfig{LockPath: lockPath}) + require.NoError(t, otherLock.TryLock(1*time.Hour)) + defer otherLock.Unlock() + + u, err := NewUpgrader(&Config{ + InstallPath: filepath.Join(tmpDir, "hostlink"), + SelfPath: filepath.Join(tmpDir, "staging", "hostlink"), + BackupDir: filepath.Join(tmpDir, "backup"), + LockPath: lockPath, + StatePath: filepath.Join(tmpDir, "state.json"), + HealthURL: "http://localhost:8080/health", + TargetVersion: "v2.0.0", + LockRetries: 1, + LockRetryInterval: 10 * time.Millisecond, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, + }) + require.NoError(t, err) + + err = u.Run(context.Background()) + assert.Error(t, err) + assert.ErrorIs(t, err, update.ErrLockAcquireFailed) +} + +func TestUpgrader_Run_CleansUpTempFiles(t *testing.T) { + tmpDir := t.TempDir() + + installPath := filepath.Join(tmpDir, "usr", "bin", "hostlink") + selfPath := filepath.Join(tmpDir, "staging", "hostlink") + backupDir := filepath.Join(tmpDir, "backup") + lockPath := filepath.Join(tmpDir, "update.lock") + statePath := filepath.Join(tmpDir, "state.json") + + createTestBinary(t, installPath, []byte("old binary")) + createTestBinary(t, selfPath, []byte("new binary v2.0.0")) + + // Create leftover temp files + binDir := filepath.Dir(installPath) + require.NoError(t, os.WriteFile(filepath.Join(binDir, "hostlink.tmp.abc123"), []byte("temp"), 0755)) + require.NoError(t, os.WriteFile(filepath.Join(binDir, "hostlink.tmp.def456"), []byte("temp"), 0755)) + + healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(update.HealthResponse{Ok: true, Version: "v2.0.0"}) + })) + defer healthServer.Close() + + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: selfPath, + BackupDir: backupDir, + LockPath: lockPath, + StatePath: statePath, + HealthURL: healthServer.URL, + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + HealthCheckRetries: 1, + HealthCheckInterval: 10 * time.Millisecond, + HealthInitialWait: 1 * time.Millisecond, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, + }) + require.NoError(t, err) + u.serviceController = &mockServiceController{} + + err = u.Run(context.Background()) + require.NoError(t, err) + + // Verify temp files were cleaned up + entries, err := os.ReadDir(binDir) + require.NoError(t, err) + for _, entry := range entries { + assert.NotContains(t, entry.Name(), ".tmp.", "temp files should be cleaned up") + } +} + +func TestUpgrader_PhaseOrder(t *testing.T) { + tmpDir := t.TempDir() + + installPath := filepath.Join(tmpDir, "hostlink") + selfPath := filepath.Join(tmpDir, "staging", "hostlink") + backupDir := filepath.Join(tmpDir, "backup") + statePath := filepath.Join(tmpDir, "state.json") + + createTestBinary(t, installPath, []byte("old binary")) + createTestBinary(t, selfPath, []byte("new binary v2.0.0")) + + var phases []string + + healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(update.HealthResponse{Ok: true, Version: "v2.0.0"}) + })) + defer healthServer.Close() + + mockSvc := &mockServiceController{ + onStop: func() { phases = append(phases, "svc_stop") }, + onStart: func() { phases = append(phases, "svc_start") }, + } + + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: selfPath, + BackupDir: backupDir, + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: statePath, + HealthURL: healthServer.URL, + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + HealthCheckRetries: 1, + HealthCheckInterval: 10 * time.Millisecond, + HealthInitialWait: 1 * time.Millisecond, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, + }) + require.NoError(t, err) + u.serviceController = mockSvc + u.onPhaseChange = func(phase Phase) { + phases = append(phases, string(phase)) + } + + err = u.Run(context.Background()) + require.NoError(t, err) + + expectedPhases := []string{ + string(PhaseAcquireLock), + string(PhaseBackup), + string(PhaseStopping), + "svc_stop", + string(PhaseInstalling), + string(PhaseStarting), + "svc_start", + string(PhaseVerifying), + string(PhaseCompleted), + } + assert.Equal(t, expectedPhases, phases) +} + +func TestUpgrader_Run_CancelledBeforeStop(t *testing.T) { + tmpDir := t.TempDir() + + installPath := filepath.Join(tmpDir, "hostlink") + selfPath := filepath.Join(tmpDir, "staging", "hostlink") + createTestBinary(t, installPath, []byte("binary")) + createTestBinary(t, selfPath, []byte("new")) + + mockSvc := &mockServiceController{} + + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: selfPath, + BackupDir: filepath.Join(tmpDir, "backup"), + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: filepath.Join(tmpDir, "state.json"), + HealthURL: "http://localhost:8080/health", + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + HealthInitialWait: 1 * time.Millisecond, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, + }) + require.NoError(t, err) + u.serviceController = mockSvc + + // Cancel context before calling Run + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = u.Run(ctx) + + assert.ErrorIs(t, err, context.Canceled) + assert.False(t, mockSvc.stopCalled, "service should not have been stopped") + assert.False(t, mockSvc.startCalled, "service should not have been started") +} + +func TestUpgrader_Run_CancelledAfterStop(t *testing.T) { + tmpDir := t.TempDir() + + installPath := filepath.Join(tmpDir, "hostlink") + selfPath := filepath.Join(tmpDir, "staging", "hostlink") + createTestBinary(t, installPath, []byte("binary")) + createTestBinary(t, selfPath, []byte("new")) + + ctx, cancel := context.WithCancel(context.Background()) + + mockSvc := &mockServiceController{ + onStop: func() { + cancel() // Cancel after stop completes + }, + } + + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: selfPath, + BackupDir: filepath.Join(tmpDir, "backup"), + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: filepath.Join(tmpDir, "state.json"), + HealthURL: "http://localhost:8080/health", + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + HealthInitialWait: 1 * time.Millisecond, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, + }) + require.NoError(t, err) + u.serviceController = mockSvc + + err = u.Run(ctx) + + assert.ErrorIs(t, err, context.Canceled) + assert.True(t, mockSvc.stopCalled) + assert.True(t, mockSvc.startCalled, "service must be restarted after being stopped") +} + +func TestUpgrader_Run_CancelledAfterInstall(t *testing.T) { + tmpDir := t.TempDir() + + installPath := filepath.Join(tmpDir, "hostlink") + selfPath := filepath.Join(tmpDir, "staging", "hostlink") + createTestBinary(t, installPath, []byte("old binary")) + createTestBinary(t, selfPath, []byte("new binary")) + + ctx, cancel := context.WithCancel(context.Background()) + + mockSvc := &mockServiceController{} + + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: selfPath, + BackupDir: filepath.Join(tmpDir, "backup"), + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: filepath.Join(tmpDir, "state.json"), + HealthURL: "http://localhost:8080/health", + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + HealthInitialWait: 1 * time.Millisecond, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, + InstallFunc: func(srcPath, destPath string) error { + // Do the real install, then cancel + err := update.InstallSelf(srcPath, destPath) + cancel() + return err + }, + }) + require.NoError(t, err) + u.serviceController = mockSvc + + err = u.Run(ctx) + + assert.ErrorIs(t, err, context.Canceled) + // Per spec: after install, start the new service (not rollback) + assert.True(t, mockSvc.startCalled, "new service must be started after install") + // Install path should contain the new binary (not rolled back) + content, readErr := os.ReadFile(installPath) + require.NoError(t, readErr) + assert.Equal(t, []byte("new binary"), content, "should not roll back after install") +} + +func TestUpgrader_Run_CancelledDuringVerification(t *testing.T) { + tmpDir := t.TempDir() + + installPath := filepath.Join(tmpDir, "hostlink") + selfPath := filepath.Join(tmpDir, "staging", "hostlink") + createTestBinary(t, installPath, []byte("old binary")) + createTestBinary(t, selfPath, []byte("new binary")) + + ctx, cancel := context.WithCancel(context.Background()) + + // Health server that blocks until context is cancelled + healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-ctx.Done() + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer healthServer.Close() + + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: selfPath, + BackupDir: filepath.Join(tmpDir, "backup"), + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: filepath.Join(tmpDir, "state.json"), + HealthURL: healthServer.URL, + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + HealthCheckRetries: 1, + HealthCheckInterval: 10 * time.Millisecond, + HealthInitialWait: 1 * time.Millisecond, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, + }) + require.NoError(t, err) + u.serviceController = &mockServiceController{} + + // Cancel during verification + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + err = u.Run(ctx) + + // Should return context.Canceled, NOT trigger rollback + assert.ErrorIs(t, err, context.Canceled) + + // Binary should NOT be rolled back (service is running with new binary) + content, err := os.ReadFile(installPath) + require.NoError(t, err) + assert.Equal(t, []byte("new binary"), content) +} + +func TestUpgrader_Rollback_RestoresAndStartsService(t *testing.T) { + tmpDir := t.TempDir() + + installPath := filepath.Join(tmpDir, "hostlink") + backupDir := filepath.Join(tmpDir, "backup") + statePath := filepath.Join(tmpDir, "state.json") + + // Create backup + backupContent := []byte("backup binary v1.0.0") + require.NoError(t, os.MkdirAll(backupDir, 0755)) + require.NoError(t, os.WriteFile(filepath.Join(backupDir, "hostlink"), backupContent, 0755)) + + // Create "broken" current binary + require.NoError(t, os.WriteFile(installPath, []byte("broken"), 0755)) + + var callOrder []string + mockSvc := &mockServiceController{ + onStop: func() { callOrder = append(callOrder, "stop") }, + onStart: func() { callOrder = append(callOrder, "start") }, + } + + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: filepath.Join(tmpDir, "staging", "hostlink"), + BackupDir: backupDir, + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: statePath, + HealthURL: "http://localhost:8080/health", + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + }) + require.NoError(t, err) + u.serviceController = mockSvc + + err = u.rollbackFrom(PhaseVerifying) + require.NoError(t, err) + + // Verify binary was restored + content, err := os.ReadFile(installPath) + require.NoError(t, err) + assert.Equal(t, backupContent, content) + + // Verify service was stopped then started + assert.Equal(t, []string{"stop", "start"}, callOrder) +} + +func TestUpgrader_Rollback_WritesRolledBackState(t *testing.T) { + tmpDir := t.TempDir() + + installPath := filepath.Join(tmpDir, "hostlink") + backupDir := filepath.Join(tmpDir, "backup") + statePath := filepath.Join(tmpDir, "state.json") + + require.NoError(t, os.MkdirAll(backupDir, 0755)) + require.NoError(t, os.WriteFile(filepath.Join(backupDir, "hostlink"), []byte("backup"), 0755)) + require.NoError(t, os.WriteFile(installPath, []byte("current"), 0755)) + + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: filepath.Join(tmpDir, "staging", "hostlink"), + BackupDir: backupDir, + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: statePath, + HealthURL: "http://localhost:8080/health", + TargetVersion: "v2.0.0", + UpdateID: "rollback-update-id", + SourceVersion: "v1.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + }) + require.NoError(t, err) + u.serviceController = &mockServiceController{} + u.startedAt = time.Now() + + err = u.rollbackFrom(PhaseVerifying) + require.NoError(t, err) + + // Verify state was updated with full context + stateWriter := update.NewStateWriter(update.StateConfig{StatePath: statePath}) + state, err := stateWriter.Read() + require.NoError(t, err) + assert.Equal(t, update.StateRolledBack, state.State) + assert.Equal(t, "rollback-update-id", state.UpdateID) + assert.Equal(t, "v1.0.0", state.SourceVersion) + assert.Equal(t, "v2.0.0", state.TargetVersion) + assert.False(t, state.StartedAt.IsZero(), "StartedAt should be set") + require.NotNil(t, state.CompletedAt, "CompletedAt should be set") +} + +func TestUpgrader_Run_WritesCompletedState(t *testing.T) { + tmpDir := t.TempDir() + + installPath := filepath.Join(tmpDir, "hostlink") + selfPath := filepath.Join(tmpDir, "staging", "hostlink") + statePath := filepath.Join(tmpDir, "state.json") + + createTestBinary(t, installPath, []byte("old binary")) + createTestBinary(t, selfPath, []byte("new binary v2.0.0")) + + healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(update.HealthResponse{Ok: true, Version: "v2.0.0"}) + })) + defer healthServer.Close() + + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: selfPath, + BackupDir: filepath.Join(tmpDir, "backup"), + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: statePath, + HealthURL: healthServer.URL, + TargetVersion: "v2.0.0", + UpdateID: "test-update-id", + SourceVersion: "v1.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + HealthCheckRetries: 1, + HealthCheckInterval: 10 * time.Millisecond, + HealthInitialWait: 1 * time.Millisecond, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, + }) + require.NoError(t, err) + u.serviceController = &mockServiceController{} + + beforeRun := time.Now() + err = u.Run(context.Background()) + require.NoError(t, err) + + // Verify state was written with full context + stateWriter := update.NewStateWriter(update.StateConfig{StatePath: statePath}) + state, err := stateWriter.Read() + require.NoError(t, err) + assert.Equal(t, update.StateCompleted, state.State) + assert.Equal(t, "v2.0.0", state.TargetVersion) + assert.Equal(t, "test-update-id", state.UpdateID) + assert.Equal(t, "v1.0.0", state.SourceVersion) + assert.False(t, state.StartedAt.IsZero(), "StartedAt should be set") + assert.True(t, !state.StartedAt.Before(beforeRun), "StartedAt should be >= test start") + require.NotNil(t, state.CompletedAt, "CompletedAt should be set") +} + +func TestUpgrader_Rollback_RetriesStop(t *testing.T) { + tmpDir := t.TempDir() + + installPath := filepath.Join(tmpDir, "hostlink") + backupDir := filepath.Join(tmpDir, "backup") + statePath := filepath.Join(tmpDir, "state.json") + + // Create backup and current binary + backupContent := []byte("backup binary v1.0.0") + require.NoError(t, os.MkdirAll(backupDir, 0755)) + require.NoError(t, os.WriteFile(filepath.Join(backupDir, "hostlink"), backupContent, 0755)) + require.NoError(t, os.WriteFile(installPath, []byte("broken"), 0755)) + + // Stop fails first 2 times, succeeds on 3rd + mockSvc := &mockServiceController{ + stopErrs: []error{ + fmt.Errorf("stop failed attempt 1"), + fmt.Errorf("stop failed attempt 2"), + nil, // succeeds on 3rd attempt + }, + } + + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: filepath.Join(tmpDir, "staging", "hostlink"), + BackupDir: backupDir, + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: statePath, + HealthURL: "http://localhost:8080/health", + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + RollbackStopRetries: 3, + RollbackStopRetryInterval: 10 * time.Millisecond, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, + }) + require.NoError(t, err) + u.serviceController = mockSvc + + err = u.rollbackFrom(PhaseVerifying) + require.NoError(t, err) + + // Verify Stop was called 3 times (retried) + assert.Equal(t, 3, mockSvc.stopCallCount) + + // Verify binary was restored + content, err := os.ReadFile(installPath) + require.NoError(t, err) + assert.Equal(t, backupContent, content) +} + +func TestUpgrader_Rollback_ProceedsAfterStopExhausted(t *testing.T) { + tmpDir := t.TempDir() + + installPath := filepath.Join(tmpDir, "hostlink") + backupDir := filepath.Join(tmpDir, "backup") + statePath := filepath.Join(tmpDir, "state.json") + + // Create backup and current binary + backupContent := []byte("backup binary v1.0.0") + require.NoError(t, os.MkdirAll(backupDir, 0755)) + require.NoError(t, os.WriteFile(filepath.Join(backupDir, "hostlink"), backupContent, 0755)) + require.NoError(t, os.WriteFile(installPath, []byte("broken"), 0755)) + + // Stop always fails + mockSvc := &mockServiceController{ + stopErr: fmt.Errorf("service stop permanently failing"), + } + + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: filepath.Join(tmpDir, "staging", "hostlink"), + BackupDir: backupDir, + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: statePath, + HealthURL: "http://localhost:8080/health", + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + RollbackStopRetries: 3, + RollbackStopRetryInterval: 10 * time.Millisecond, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, + }) + require.NoError(t, err) + u.serviceController = mockSvc + + err = u.rollbackFrom(PhaseVerifying) + require.NoError(t, err) + + // Verify Stop was retried the configured number of times + assert.Equal(t, 3, mockSvc.stopCallCount) + + // Verify binary was still restored (proceeded despite stop failure) + content, err := os.ReadFile(installPath) + require.NoError(t, err) + assert.Equal(t, backupContent, content) + + // Verify service start was attempted + assert.True(t, mockSvc.startCalled) +} + +func TestUpgrader_Rollback_HealthCheckAfterRestart(t *testing.T) { + tmpDir := t.TempDir() + + installPath := filepath.Join(tmpDir, "hostlink") + backupDir := filepath.Join(tmpDir, "backup") + statePath := filepath.Join(tmpDir, "state.json") + + // Create backup and current binary + require.NoError(t, os.MkdirAll(backupDir, 0755)) + require.NoError(t, os.WriteFile(filepath.Join(backupDir, "hostlink"), []byte("old binary"), 0755)) + require.NoError(t, os.WriteFile(installPath, []byte("broken"), 0755)) + + healthCheckCalled := false + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: filepath.Join(tmpDir, "staging", "hostlink"), + BackupDir: backupDir, + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: statePath, + HealthURL: "http://localhost:8080/health", + TargetVersion: "v2.0.0", + SourceVersion: "v1.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + RollbackStopRetries: 1, + RollbackStopRetryInterval: 10 * time.Millisecond, + RollbackHealthCheckFunc: func(ctx context.Context) error { + healthCheckCalled = true + return nil // healthy + }, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, + }) + require.NoError(t, err) + u.serviceController = &mockServiceController{} + u.startedAt = time.Now() + + err = u.rollbackFrom(PhaseVerifying) + require.NoError(t, err) + + // Verify health check was called after restart + assert.True(t, healthCheckCalled, "health check should be called after rollback restart") +} + +func TestUpgrader_Rollback_HealthCheckFailure(t *testing.T) { + tmpDir := t.TempDir() + + installPath := filepath.Join(tmpDir, "hostlink") + backupDir := filepath.Join(tmpDir, "backup") + statePath := filepath.Join(tmpDir, "state.json") + + // Create backup and current binary + require.NoError(t, os.MkdirAll(backupDir, 0755)) + require.NoError(t, os.WriteFile(filepath.Join(backupDir, "hostlink"), []byte("old binary"), 0755)) + require.NoError(t, os.WriteFile(installPath, []byte("broken"), 0755)) + + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: filepath.Join(tmpDir, "staging", "hostlink"), + BackupDir: backupDir, + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: statePath, + HealthURL: "http://localhost:8080/health", + TargetVersion: "v2.0.0", + SourceVersion: "v1.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + RollbackStopRetries: 1, + RollbackStopRetryInterval: 10 * time.Millisecond, + RollbackHealthCheckFunc: func(ctx context.Context) error { + return fmt.Errorf("old binary unhealthy") + }, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, + }) + require.NoError(t, err) + u.serviceController = &mockServiceController{} + u.startedAt = time.Now() + + err = u.rollbackFrom(PhaseVerifying) + + // Rollback should return an error when health check fails + require.Error(t, err) + assert.Contains(t, err.Error(), "unhealthy") + + // But state should still be written as RolledBack (binary was restored, service started) + stateWriter := update.NewStateWriter(update.StateConfig{StatePath: statePath}) + state, stateErr := stateWriter.Read() + require.NoError(t, stateErr) + assert.Equal(t, update.StateRolledBack, state.State) +} + +func TestNewUpgrader_RejectsEmptyInstallPath(t *testing.T) { + tmpDir := t.TempDir() + + cfg := &Config{ + InstallPath: "", // Empty - should be rejected + SelfPath: filepath.Join(tmpDir, "self"), + BackupDir: filepath.Join(tmpDir, "backup"), + LockPath: filepath.Join(tmpDir, "lock"), + StatePath: filepath.Join(tmpDir, "state.json"), + } + + _, err := NewUpgrader(cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "install-path") +} + +// Mock service controller for testing +type mockServiceController struct { + stopCalled bool + startCalled bool + stopErr error + startErr error + existsVal bool + existsErr error + onStop func() + onStart func() + stopCallCount int + stopErrs []error // If set, returns errors sequentially (overrides stopErr) +} + +func (m *mockServiceController) Stop(ctx context.Context) error { + m.stopCalled = true + m.stopCallCount++ + if m.onStop != nil { + m.onStop() + } + if m.stopErrs != nil { + idx := m.stopCallCount - 1 + if idx < len(m.stopErrs) { + return m.stopErrs[idx] + } + return m.stopErrs[len(m.stopErrs)-1] + } + return m.stopErr +} + +func (m *mockServiceController) Start(ctx context.Context) error { + m.startCalled = true + if m.onStart != nil { + m.onStart() + } + return m.startErr +} + +func (m *mockServiceController) Exists(ctx context.Context) (bool, error) { + return m.existsVal, m.existsErr +} + +// Mock state writer for testing +type mockStateWriter struct { + writeErr error + readErr error + lastState update.StateData + writeCount int +} + +func (m *mockStateWriter) Write(data update.StateData) error { + m.writeCount++ + m.lastState = data + return m.writeErr +} + +func (m *mockStateWriter) Read() (update.StateData, error) { + return m.lastState, m.readErr +} + +func TestUpgrader_Run_StateWriteFailure_LogsWarningAndContinues(t *testing.T) { + tmpDir := t.TempDir() + + installPath := filepath.Join(tmpDir, "usr", "bin", "hostlink") + selfPath := filepath.Join(tmpDir, "staging", "hostlink") + backupDir := filepath.Join(tmpDir, "backup") + lockPath := filepath.Join(tmpDir, "update.lock") + statePath := filepath.Join(tmpDir, "state.json") + + createTestBinary(t, installPath, []byte("old binary v1.0.0")) + createTestBinary(t, selfPath, []byte("new binary v2.0.0")) + + // Mock health server that returns healthy + healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(update.HealthResponse{Ok: true, Version: "v2.0.0"}) + })) + defer healthServer.Close() + + // Capture logs + var logBuf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&logBuf, nil)) + + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: selfPath, + BackupDir: backupDir, + LockPath: lockPath, + StatePath: statePath, + HealthURL: healthServer.URL, + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + HealthCheckRetries: 1, + HealthCheckInterval: 10 * time.Millisecond, + HealthInitialWait: 1 * time.Millisecond, + Logger: logger, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, + }) + require.NoError(t, err) + + // Inject mock state writer that fails + mockState := &mockStateWriter{writeErr: fmt.Errorf("disk full")} + u.state = mockState + u.serviceController = &mockServiceController{} + + // Run should succeed despite state write failure + err = u.Run(context.Background()) + require.NoError(t, err) + + // Verify state write was attempted + assert.GreaterOrEqual(t, mockState.writeCount, 1, "state.Write should have been called") + + // Verify warning was logged + logOutput := logBuf.String() + assert.Contains(t, logOutput, "disk full", "should log the state write error") + assert.Contains(t, logOutput, "WARN", "should log at WARN level") +} + +func TestUpgrader_Run_InstallFailure_RollbackFails_ReturnsBothErrors(t *testing.T) { + tmpDir := t.TempDir() + + installPath := filepath.Join(tmpDir, "hostlink") + selfPath := filepath.Join(tmpDir, "staging", "hostlink") + backupDir := filepath.Join(tmpDir, "backup") + statePath := filepath.Join(tmpDir, "state.json") + + // Create current binary + createTestBinary(t, installPath, []byte("old binary")) + // Create staging binary + createTestBinary(t, selfPath, []byte("new binary")) + + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: selfPath, + BackupDir: backupDir, + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: statePath, + HealthURL: "http://localhost:8080/health", + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + }) + require.NoError(t, err) + + // Mock service controller - start fails (rollback will fail when trying to restart) + mockSvc := &mockServiceController{ + startErr: fmt.Errorf("start failed: systemd error"), + } + u.serviceController = mockSvc + + // Make install fail by making InstallFunc return an error + u.config.InstallFunc = func(src, dst string) error { + return fmt.Errorf("install failed: permission denied") + } + + err = u.Run(context.Background()) + require.Error(t, err) + + // Should contain both install error and rollback error + errStr := err.Error() + assert.Contains(t, errStr, "install", "error should mention install failure") + assert.Contains(t, errStr, "rollback", "error should mention rollback failure") +} + +func TestUpgrader_Run_StartFailure_RollbackFails_ReturnsBothErrors(t *testing.T) { + tmpDir := t.TempDir() + + installPath := filepath.Join(tmpDir, "hostlink") + selfPath := filepath.Join(tmpDir, "staging", "hostlink") + backupDir := filepath.Join(tmpDir, "backup") + statePath := filepath.Join(tmpDir, "state.json") + + // Create current binary + createTestBinary(t, installPath, []byte("old binary")) + // Create staging binary + createTestBinary(t, selfPath, []byte("new binary")) + + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: selfPath, + BackupDir: backupDir, + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: statePath, + HealthURL: "http://localhost:8080/health", + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + }) + require.NoError(t, err) + + // Mock service controller - start always fails (both in Run and in rollback) + mockSvc := &mockServiceController{ + startErr: fmt.Errorf("start failed: systemd error"), + } + u.serviceController = mockSvc + + err = u.Run(context.Background()) + require.Error(t, err) + + // Should contain both start error and rollback error + errStr := err.Error() + assert.Contains(t, errStr, "failed to start service:", "error should mention original start failure") + assert.Contains(t, errStr, "failed to start service after rollback", "error should mention rollback start failure") +} + +func TestUpgrader_Run_HealthCheckFailure_RollbackFails_ReturnsBothErrors(t *testing.T) { + tmpDir := t.TempDir() + + installPath := filepath.Join(tmpDir, "hostlink") + selfPath := filepath.Join(tmpDir, "staging", "hostlink") + backupDir := filepath.Join(tmpDir, "backup") + statePath := filepath.Join(tmpDir, "state.json") + + // Create current binary + createTestBinary(t, installPath, []byte("old binary")) + // Create staging binary + createTestBinary(t, selfPath, []byte("new binary")) + + // Health server that always returns unhealthy + healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(update.HealthResponse{Ok: false, Version: ""}) + })) + defer healthServer.Close() + + u, err := NewUpgrader(&Config{ + InstallPath: installPath, + SelfPath: selfPath, + BackupDir: backupDir, + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: statePath, + HealthURL: healthServer.URL, + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + HealthCheckRetries: 1, + HealthCheckInterval: 10 * time.Millisecond, + HealthInitialWait: 1 * time.Millisecond, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, + }) + require.NoError(t, err) + + // Mock service controller - start succeeds first time (in Run), fails second time (in rollback) + mockSvc := &startFailsOnSecondCallController{} + u.serviceController = mockSvc + + err = u.Run(context.Background()) + require.Error(t, err) + + // Should contain both health check error and rollback error + errStr := err.Error() + assert.Contains(t, errStr, "health check failed", "error should mention health check failure") + assert.Contains(t, errStr, "start failed during rollback", "error should mention rollback failure") +} + +// startFailsOnSecondCallController is a mock that succeeds on first Start, fails on second +type startFailsOnSecondCallController struct { + startCallCount int +} + +func (m *startFailsOnSecondCallController) Stop(ctx context.Context) error { + return nil +} + +func (m *startFailsOnSecondCallController) Start(ctx context.Context) error { + m.startCallCount++ + if m.startCallCount > 1 { + return fmt.Errorf("start failed during rollback: systemd error") + } + return nil +} + +func (m *startFailsOnSecondCallController) Exists(ctx context.Context) (bool, error) { + return true, nil +} diff --git a/config/appconf/appconf.go b/config/appconf/appconf.go index ac6b061..fb164c1 100644 --- a/config/appconf/appconf.go +++ b/config/appconf/appconf.go @@ -56,6 +56,15 @@ func AgentStatePath() string { return "/var/lib/hostlink" } +// InstallPath returns the target install path for the hostlink binary. +// Controlled by HOSTLINK_INSTALL_PATH (default: /usr/bin/hostlink). +func InstallPath() string { + if path := os.Getenv("HOSTLINK_INSTALL_PATH"); path != "" { + return path + } + return "/usr/bin/hostlink" +} + // SelfUpdateEnabled returns whether the self-update feature is enabled. // Controlled by HOSTLINK_SELF_UPDATE_ENABLED (default: true). func SelfUpdateEnabled() bool { diff --git a/config/appconf/appconf_test.go b/config/appconf/appconf_test.go index c6d3279..ef5d8e9 100644 --- a/config/appconf/appconf_test.go +++ b/config/appconf/appconf_test.go @@ -76,3 +76,13 @@ func TestUpdateLockTimeout_InvalidFallsToDefault(t *testing.T) { t.Setenv("HOSTLINK_UPDATE_LOCK_TIMEOUT", "garbage") assert.Equal(t, 5*time.Minute, UpdateLockTimeout()) } + +func TestInstallPath_Default(t *testing.T) { + t.Setenv("HOSTLINK_INSTALL_PATH", "") + assert.Equal(t, "/usr/bin/hostlink", InstallPath()) +} + +func TestInstallPath_CustomValue(t *testing.T) { + t.Setenv("HOSTLINK_INSTALL_PATH", "/opt/hostlink/bin/hostlink") + assert.Equal(t, "/opt/hostlink/bin/hostlink", InstallPath()) +} diff --git a/internal/update/binary.go b/internal/update/binary.go index 7728f17..51af7e2 100644 --- a/internal/update/binary.go +++ b/internal/update/binary.go @@ -18,15 +18,14 @@ const ( BackupFilename = "hostlink" // AgentBinaryName is the binary name inside the agent tarball. AgentBinaryName = "hostlink" - // UpdaterBinaryName is the binary name inside the updater tarball. - UpdaterBinaryName = "hostlink-updater" // MaxBinarySize is the maximum allowed size for an extracted binary (100MB). MaxBinarySize = 100 * 1024 * 1024 ) // BackupBinary copies the binary at srcPath to the backup directory. // It creates the backup directory if it doesn't exist. -// It overwrites any existing backup. +// It overwrites any existing backup using atomic rename to ensure +// the backup is never corrupted even if the process crashes mid-write. func BackupBinary(srcPath, backupDir string) error { // Open source file src, err := os.Open(srcPath) @@ -46,19 +45,45 @@ func BackupBinary(srcPath, backupDir string) error { return fmt.Errorf("failed to create backup directory: %w", err) } - // Create backup file + // Generate temp file path for atomic write backupPath := filepath.Join(backupDir, BackupFilename) - dst, err := os.OpenFile(backupPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, srcInfo.Mode().Perm()) + randSuffix, err := randomHex(8) if err != nil { - return fmt.Errorf("failed to create backup file: %w", err) + return fmt.Errorf("failed to generate random suffix: %w", err) + } + tmpPath := backupPath + ".tmp." + randSuffix + + // Clean up temp file on error + defer func() { + if tmpPath != "" { + os.Remove(tmpPath) + } + }() + + // Create temp file + dst, err := os.OpenFile(tmpPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, srcInfo.Mode().Perm()) + if err != nil { + return fmt.Errorf("failed to create temp backup file: %w", err) } - defer dst.Close() // Copy content if _, err := io.Copy(dst, src); err != nil { + dst.Close() return fmt.Errorf("failed to copy to backup: %w", err) } + // Close before rename + if err := dst.Close(); err != nil { + return fmt.Errorf("failed to close temp backup file: %w", err) + } + + // Atomic rename - replaces existing backup atomically + if err := os.Rename(tmpPath, backupPath); err != nil { + return fmt.Errorf("failed to finalize backup: %w", err) + } + + // Success - don't clean up the temp file (it's been renamed) + tmpPath = "" return nil } @@ -69,13 +94,6 @@ func InstallBinary(tarPath, destPath string) error { return installBinaryFromTarGz(tarPath, AgentBinaryName, destPath) } -// InstallUpdaterBinary extracts the updater binary from a tar.gz file and installs it atomically. -// It extracts the file named "hostlink-updater" from the tarball to destPath. -// Uses atomic rename to ensure the install is atomic. -func InstallUpdaterBinary(tarPath, destPath string) error { - return installBinaryFromTarGz(tarPath, UpdaterBinaryName, destPath) -} - // installBinaryFromTarGz extracts a named binary from a tar.gz and installs it atomically. func installBinaryFromTarGz(tarPath, binaryName, destPath string) error { // Create destination directory @@ -118,6 +136,69 @@ func installBinaryFromTarGz(tarPath, binaryName, destPath string) error { return nil } +// InstallSelf copies the binary at srcPath to destPath atomically. +// srcPath is typically os.Executable() — the staged binary that is currently running. +// It writes to a temp file first, sets permissions to 0755, then does an atomic rename. +func InstallSelf(srcPath, destPath string) error { + // Create destination directory + destDir := filepath.Dir(destPath) + if err := os.MkdirAll(destDir, 0755); err != nil { + return fmt.Errorf("failed to create destination directory: %w", err) + } + + // Open source + src, err := os.Open(srcPath) + if err != nil { + return fmt.Errorf("failed to open source binary: %w", err) + } + defer src.Close() + + // Generate temp file path + randSuffix, err := randomHex(8) + if err != nil { + return fmt.Errorf("failed to generate random suffix: %w", err) + } + tmpPath := destPath + ".tmp." + randSuffix + + // Clean up temp file on error + defer func() { + if tmpPath != "" { + os.Remove(tmpPath) + } + }() + + // Create temp file + dst, err := os.OpenFile(tmpPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, BinaryPermissions) + if err != nil { + return fmt.Errorf("failed to create temp file: %w", err) + } + + // Copy content + if _, err := io.Copy(dst, src); err != nil { + dst.Close() + return fmt.Errorf("failed to copy binary: %w", err) + } + + // Close before rename + if err := dst.Close(); err != nil { + return fmt.Errorf("failed to close temp file: %w", err) + } + + // Set permissions + if err := os.Chmod(tmpPath, BinaryPermissions); err != nil { + return fmt.Errorf("failed to set permissions: %w", err) + } + + // Atomic rename + if err := os.Rename(tmpPath, destPath); err != nil { + return fmt.Errorf("failed to install binary: %w", err) + } + + // Success - don't clean up the temp file (it's been renamed) + tmpPath = "" + return nil +} + // RestoreBackup restores the binary from backup to destPath. // Uses atomic rename for safe restoration. func RestoreBackup(backupDir, destPath string) error { diff --git a/internal/update/binary_test.go b/internal/update/binary_test.go index 212e1c9..9a4d0b2 100644 --- a/internal/update/binary_test.go +++ b/internal/update/binary_test.go @@ -5,6 +5,7 @@ import ( "compress/gzip" "os" "path/filepath" + "syscall" "testing" "github.com/stretchr/testify/assert" @@ -108,6 +109,164 @@ func TestBackupBinary_ReturnsErrorIfSourceMissing(t *testing.T) { assert.True(t, os.IsNotExist(err) || os.IsNotExist(unwrapErr(err))) } +func TestBackupBinary_Atomic_PreservesExistingOnError(t *testing.T) { + // Skip on systems where we can't test permission-based failures (e.g., root) + if os.Getuid() == 0 { + t.Skip("test requires non-root user to test permission failures") + } + + tmpDir := t.TempDir() + backupDir := filepath.Join(tmpDir, "backup") + err := os.MkdirAll(backupDir, 0755) + require.NoError(t, err) + + // Create existing backup with known content + existingBackupContent := []byte("existing backup - must survive") + backupPath := filepath.Join(backupDir, "hostlink") + err = os.WriteFile(backupPath, existingBackupContent, 0755) + require.NoError(t, err) + + // Create a source file, then make the SOURCE DIRECTORY unreadable after opening + // This simulates a failure that happens AFTER the backup file is created but DURING copy + // Actually, we need a different approach: use a FIFO or make backup dir unwritable + + // Better approach: make the backup directory read-only AFTER it exists + // This will cause the backup file creation to succeed (if overwriting) but... + // Actually no, O_TRUNC on existing file works even if dir is read-only. + + // Simplest reliable test: make source file readable but remove it mid-operation + // That's racy. Let's use a different approach: + // Create source as readable, existing backup, then make BACKUP FILE immutable + // Actually on most systems we can't easily test this. + + // Pragmatic test: verify that with atomic implementation, the temp file pattern is used. + // We'll test that when copy fails (unreadable source), no partial backup is left. + + // Create a source file that can be opened but fails on read + // We can do this with a named pipe (FIFO) that we don't write to, causing read to block/fail + // But that's complex. Let's use a simpler approach: + + // Create source file, make it unreadable AFTER we know BackupBinary will try to read it + // This is inherently racy, so instead let's verify the behavior we care about: + // After atomic implementation, the function should use temp files. + + // For now, test with unreadable source - this at least verifies existing backup survives + // when error happens early (before backup file is touched). + srcPath := filepath.Join(tmpDir, "unreadable") + err = os.WriteFile(srcPath, []byte("content"), 0000) // No read permission + require.NoError(t, err) + + // Attempt backup - should fail + err = BackupBinary(srcPath, backupDir) + assert.Error(t, err) + + // Verify existing backup is UNCHANGED + content, err := os.ReadFile(backupPath) + require.NoError(t, err) + assert.Equal(t, existingBackupContent, content, "existing backup should be preserved on error") +} + +func TestBackupBinary_Atomic_NoTempFilesOnSuccess(t *testing.T) { + tmpDir := t.TempDir() + + srcPath := filepath.Join(tmpDir, "hostlink") + err := os.WriteFile(srcPath, []byte("binary content"), 0755) + require.NoError(t, err) + + backupDir := filepath.Join(tmpDir, "backup") + + err = BackupBinary(srcPath, backupDir) + require.NoError(t, err) + + // Verify no temp files left behind in backup directory + entries, err := os.ReadDir(backupDir) + require.NoError(t, err) + for _, entry := range entries { + assert.NotContains(t, entry.Name(), ".tmp.", "temp file should be cleaned up on success") + } + // Should only have the backup file + assert.Len(t, entries, 1) + assert.Equal(t, "hostlink", entries[0].Name()) +} + +func TestBackupBinary_Atomic_UsesAtomicRename(t *testing.T) { + // This test verifies the atomic write pattern is used by checking that + // the inode changes after backup (atomic rename creates a new file). + // + // With non-atomic code (O_TRUNC), the same file is overwritten, inode unchanged. + // With atomic code (temp + rename), a new file replaces the old, inode changes. + + tmpDir := t.TempDir() + backupDir := filepath.Join(tmpDir, "backup") + err := os.MkdirAll(backupDir, 0755) + require.NoError(t, err) + + // Create existing backup + existingContent := []byte("old backup content - v1.0.0") + backupPath := filepath.Join(backupDir, "hostlink") + err = os.WriteFile(backupPath, existingContent, 0755) + require.NoError(t, err) + + // Get inode of existing backup + existingInfo, err := os.Stat(backupPath) + require.NoError(t, err) + existingStat, ok := existingInfo.Sys().(*syscall.Stat_t) + if !ok { + t.Skip("cannot get inode on this platform") + } + existingInode := existingStat.Ino + + // Create new source + newContent := []byte("new binary content - v2.0.0 - this is longer than the old one") + srcPath := filepath.Join(tmpDir, "hostlink") + err = os.WriteFile(srcPath, newContent, 0755) + require.NoError(t, err) + + // Perform backup + err = BackupBinary(srcPath, backupDir) + require.NoError(t, err) + + // Verify backup has new content + backupContent, err := os.ReadFile(backupPath) + require.NoError(t, err) + assert.Equal(t, newContent, backupContent) + + // With atomic rename, the inode should have changed (new file) + // With O_TRUNC, the inode would be the same (same file, overwritten) + newInfo, err := os.Stat(backupPath) + require.NoError(t, err) + newStat, ok := newInfo.Sys().(*syscall.Stat_t) + require.True(t, ok) + newInode := newStat.Ino + + // Inode should change with atomic rename (new file replaces old) + assert.NotEqual(t, existingInode, newInode, + "inode should change with atomic rename; got same inode %d (non-atomic O_TRUNC used)", existingInode) +} + +func TestBackupBinary_Atomic_CleansTempOnError(t *testing.T) { + tmpDir := t.TempDir() + backupDir := filepath.Join(tmpDir, "backup") + err := os.MkdirAll(backupDir, 0755) + require.NoError(t, err) + + // Create a source file that will fail during read + srcPath := filepath.Join(tmpDir, "unreadable") + err = os.WriteFile(srcPath, []byte("content"), 0000) // No read permission + require.NoError(t, err) + + // Attempt backup - should fail + err = BackupBinary(srcPath, backupDir) + assert.Error(t, err) + + // Verify no temp files left behind + entries, err := os.ReadDir(backupDir) + require.NoError(t, err) + for _, entry := range entries { + assert.NotContains(t, entry.Name(), ".tmp.", "temp file should be cleaned up on error") + } +} + func TestInstallBinary_ExtractsAndInstalls(t *testing.T) { tmpDir := t.TempDir() @@ -312,58 +471,121 @@ func TestInstallBinary_RejectsBinaryExceedingMaxSize(t *testing.T) { assert.Contains(t, err.Error(), "exceeds maximum allowed size") } -func TestInstallUpdaterBinary_ExtractsHostlinkUpdater(t *testing.T) { +// Helper function to unwrap errors +func unwrapErr(err error) error { + for { + unwrapped := err + if u, ok := err.(interface{ Unwrap() error }); ok { + unwrapped = u.Unwrap() + } + if unwrapped == err { + return err + } + err = unwrapped + } +} + +func TestInstallSelf_CopiesSelfToDest(t *testing.T) { tmpDir := t.TempDir() - // Create a tarball containing "hostlink-updater" - tarPath := filepath.Join(tmpDir, "updater.tar.gz") - binaryContent := []byte("updater binary v2.0.0") - createTestTarGz(t, tarPath, "hostlink-updater", binaryContent, 0755) + // Create a fake "self" binary (simulating os.Executable()) + selfPath := filepath.Join(tmpDir, "staged-hostlink") + selfContent := []byte("new binary v3.0.0") + err := os.WriteFile(selfPath, selfContent, 0755) + require.NoError(t, err) - destPath := filepath.Join(tmpDir, "installed", "hostlink-updater") + destPath := filepath.Join(tmpDir, "installed", "hostlink") - err := InstallUpdaterBinary(tarPath, destPath) + err = InstallSelf(selfPath, destPath) require.NoError(t, err) - // Verify installed binary content + // Verify installed content installedContent, err := os.ReadFile(destPath) require.NoError(t, err) - assert.Equal(t, binaryContent, installedContent) + assert.Equal(t, selfContent, installedContent) +} + +func TestInstallSelf_SetsPermissions(t *testing.T) { + tmpDir := t.TempDir() + + selfPath := filepath.Join(tmpDir, "staged-hostlink") + err := os.WriteFile(selfPath, []byte("binary"), 0755) + require.NoError(t, err) + + destPath := filepath.Join(tmpDir, "hostlink") + + err = InstallSelf(selfPath, destPath) + require.NoError(t, err) - // Verify permissions info, err := os.Stat(destPath) require.NoError(t, err) assert.Equal(t, os.FileMode(0755), info.Mode().Perm()) } -func TestInstallUpdaterBinary_IgnoresHostlinkBinary(t *testing.T) { +func TestInstallSelf_AtomicRename(t *testing.T) { tmpDir := t.TempDir() - // Create a tarball containing "hostlink" (NOT "hostlink-updater") - tarPath := filepath.Join(tmpDir, "updater.tar.gz") - createTestTarGz(t, tarPath, "hostlink", []byte("wrong binary"), 0755) + // Create existing binary at destination + destPath := filepath.Join(tmpDir, "hostlink") + err := os.WriteFile(destPath, []byte("old binary"), 0755) + require.NoError(t, err) - destPath := filepath.Join(tmpDir, "hostlink-updater") + selfPath := filepath.Join(tmpDir, "staged-hostlink") + newContent := []byte("new binary") + err = os.WriteFile(selfPath, newContent, 0755) + require.NoError(t, err) - err := InstallUpdaterBinary(tarPath, destPath) - assert.Error(t, err) - assert.Contains(t, err.Error(), "not found in tarball") + err = InstallSelf(selfPath, destPath) + require.NoError(t, err) + + // Verify new binary is in place + installedContent, err := os.ReadFile(destPath) + require.NoError(t, err) + assert.Equal(t, newContent, installedContent) + + // Verify no temp files left behind + entries, err := os.ReadDir(tmpDir) + require.NoError(t, err) + for _, entry := range entries { + assert.NotContains(t, entry.Name(), ".tmp.", "temp file should be cleaned up") + } } -// Helper function to unwrap errors -func unwrapErr(err error) error { - for { - unwrapped := err - if u, ok := err.(interface{ Unwrap() error }); ok { - unwrapped = u.Unwrap() - } - if unwrapped == err { - return err - } - err = unwrapped +func TestInstallSelf_CleansUpTempOnError(t *testing.T) { + tmpDir := t.TempDir() + + // Source doesn't exist + selfPath := filepath.Join(tmpDir, "nonexistent") + destPath := filepath.Join(tmpDir, "hostlink") + + err := InstallSelf(selfPath, destPath) + assert.Error(t, err) + + // Verify no temp files left behind + entries, err := os.ReadDir(tmpDir) + require.NoError(t, err) + for _, entry := range entries { + assert.NotContains(t, entry.Name(), ".tmp.", "temp file should be cleaned up on error") } } +func TestInstallSelf_CreatesDestinationDirectory(t *testing.T) { + tmpDir := t.TempDir() + + selfPath := filepath.Join(tmpDir, "staged-hostlink") + err := os.WriteFile(selfPath, []byte("binary"), 0755) + require.NoError(t, err) + + // Nested destination that doesn't exist + destPath := filepath.Join(tmpDir, "usr", "bin", "hostlink") + + err = InstallSelf(selfPath, destPath) + require.NoError(t, err) + + _, err = os.Stat(destPath) + require.NoError(t, err) +} + // createTestTarGz creates a tar.gz file containing a single file func createTestTarGz(t *testing.T, tarPath, filename string, content []byte, mode os.FileMode) { t.Helper() diff --git a/internal/update/dirs.go b/internal/update/dirs.go index db8129f..a64e55a 100644 --- a/internal/update/dirs.go +++ b/internal/update/dirs.go @@ -19,7 +19,6 @@ type Paths struct { BaseDir string // /var/lib/hostlink/updates BackupDir string // /var/lib/hostlink/updates/backup StagingDir string // /var/lib/hostlink/updates/staging - UpdaterDir string // /var/lib/hostlink/updates/updater LockFile string // /var/lib/hostlink/updates/update.lock StateFile string // /var/lib/hostlink/updates/state.json } @@ -35,7 +34,6 @@ func NewPaths(baseDir string) Paths { BaseDir: baseDir, BackupDir: filepath.Join(baseDir, "backup"), StagingDir: filepath.Join(baseDir, "staging"), - UpdaterDir: filepath.Join(baseDir, "updater"), LockFile: filepath.Join(baseDir, "update.lock"), StateFile: filepath.Join(baseDir, "state.json"), } @@ -50,7 +48,6 @@ func InitDirectories(baseDir string) error { paths.BaseDir, paths.BackupDir, paths.StagingDir, - paths.UpdaterDir, } for _, dir := range dirs { diff --git a/internal/update/dirs_test.go b/internal/update/dirs_test.go index c834a27..9fae28b 100644 --- a/internal/update/dirs_test.go +++ b/internal/update/dirs_test.go @@ -21,7 +21,6 @@ func TestInitDirectories_CreatesAllDirs(t *testing.T) { basePath, filepath.Join(basePath, "backup"), filepath.Join(basePath, "staging"), - filepath.Join(basePath, "updater"), } for _, dir := range expectedDirs { @@ -43,7 +42,6 @@ func TestInitDirectories_CorrectPermissions(t *testing.T) { basePath, filepath.Join(basePath, "backup"), filepath.Join(basePath, "staging"), - filepath.Join(basePath, "updater"), } for _, dir := range dirs { @@ -110,7 +108,6 @@ func TestDefaultPaths(t *testing.T) { assert.Equal(t, "/var/lib/hostlink/updates", paths.BaseDir) assert.Equal(t, "/var/lib/hostlink/updates/backup", paths.BackupDir) assert.Equal(t, "/var/lib/hostlink/updates/staging", paths.StagingDir) - assert.Equal(t, "/var/lib/hostlink/updates/updater", paths.UpdaterDir) assert.Equal(t, "/var/lib/hostlink/updates/update.lock", paths.LockFile) assert.Equal(t, "/var/lib/hostlink/updates/state.json", paths.StateFile) } @@ -121,7 +118,6 @@ func TestNewPaths(t *testing.T) { assert.Equal(t, "/custom/path", paths.BaseDir) assert.Equal(t, "/custom/path/backup", paths.BackupDir) assert.Equal(t, "/custom/path/staging", paths.StagingDir) - assert.Equal(t, "/custom/path/updater", paths.UpdaterDir) assert.Equal(t, "/custom/path/update.lock", paths.LockFile) assert.Equal(t, "/custom/path/state.json", paths.StateFile) } diff --git a/internal/update/health.go b/internal/update/health.go index 982d05e..94c2daf 100644 --- a/internal/update/health.go +++ b/internal/update/health.go @@ -25,6 +25,17 @@ const ( DefaultInitialWait = 5 * time.Second ) +// sleepWithContext sleeps for the given duration or until context is cancelled. +// Returns nil on normal completion, or ctx.Err() if cancelled. +func sleepWithContext(ctx context.Context, d time.Duration) error { + select { + case <-time.After(d): + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + // HealthResponse represents the response from the health endpoint. type HealthResponse struct { Ok bool `json:"ok"` @@ -33,13 +44,13 @@ type HealthResponse struct { // HealthConfig configures the HealthChecker. type HealthConfig struct { - URL string // Health check URL (e.g., http://localhost:8080/health) - TargetVersion string // Expected version after update - MaxRetries int // Max number of retries (default: 5) - RetryInterval time.Duration // Time between retries (default: 5s) - InitialWait time.Duration // Initial wait before first check (default: 5s) - SleepFunc func(time.Duration) // For testing - HTTPClient *http.Client // Optional custom HTTP client + URL string // Health check URL (e.g., http://localhost:8080/health) + TargetVersion string // Expected version after update + MaxRetries int // Max number of retries (default: 5) + RetryInterval time.Duration // Time between retries (default: 5s) + InitialWait time.Duration // Initial wait before first check (default: 5s) + SleepFunc func(context.Context, time.Duration) error // For testing; returns ctx.Err() if cancelled + HTTPClient *http.Client // Optional custom HTTP client } // HealthChecker verifies that the service is healthy after an update. @@ -61,7 +72,7 @@ func NewHealthChecker(cfg HealthConfig) *HealthChecker { cfg.InitialWait = DefaultInitialWait } if cfg.SleepFunc == nil { - cfg.SleepFunc = time.Sleep + cfg.SleepFunc = sleepWithContext } client := cfg.HTTPClient @@ -82,9 +93,8 @@ func NewHealthChecker(cfg HealthConfig) *HealthChecker { func (h *HealthChecker) WaitForHealth(ctx context.Context) error { // Initial wait before first check if h.config.InitialWait > 0 { - h.config.SleepFunc(h.config.InitialWait) - if ctx.Err() != nil { - return ctx.Err() + if err := h.config.SleepFunc(ctx, h.config.InitialWait); err != nil { + return err } } @@ -110,9 +120,8 @@ func (h *HealthChecker) WaitForHealth(ctx context.Context) error { } if attempt < totalAttempts-1 { - h.config.SleepFunc(h.config.RetryInterval) - if ctx.Err() != nil { - return ctx.Err() + if err := h.config.SleepFunc(ctx, h.config.RetryInterval); err != nil { + return err } } } diff --git a/internal/update/health_test.go b/internal/update/health_test.go index e8bb412..99500ee 100644 --- a/internal/update/health_test.go +++ b/internal/update/health_test.go @@ -26,7 +26,7 @@ func TestHealthChecker_WaitForHealth_Success(t *testing.T) { MaxRetries: 5, RetryInterval: 10 * time.Millisecond, InitialWait: 0, - SleepFunc: func(d time.Duration) {}, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, }) err := hc.WaitForHealth(context.Background()) @@ -54,7 +54,7 @@ func TestHealthChecker_WaitForHealth_RetriesOnHttpError(t *testing.T) { MaxRetries: 5, RetryInterval: 10 * time.Millisecond, InitialWait: 0, - SleepFunc: func(d time.Duration) {}, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, }) err := hc.WaitForHealth(context.Background()) @@ -82,7 +82,7 @@ func TestHealthChecker_WaitForHealth_RetriesOnOkFalse(t *testing.T) { MaxRetries: 5, RetryInterval: 10 * time.Millisecond, InitialWait: 0, - SleepFunc: func(d time.Duration) {}, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, }) err := hc.WaitForHealth(context.Background()) @@ -106,7 +106,7 @@ func TestHealthChecker_WaitForHealth_FailsAfterMaxRetries(t *testing.T) { MaxRetries: 3, RetryInterval: 10 * time.Millisecond, InitialWait: 0, - SleepFunc: func(d time.Duration) {}, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, }) err := hc.WaitForHealth(context.Background()) @@ -137,7 +137,7 @@ func TestHealthChecker_WaitForHealth_RetriesOnVersionMismatch(t *testing.T) { MaxRetries: 5, RetryInterval: 10 * time.Millisecond, InitialWait: 0, - SleepFunc: func(d time.Duration) {}, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, }) err := hc.WaitForHealth(context.Background()) @@ -161,7 +161,7 @@ func TestHealthChecker_WaitForHealth_FailsOnVersionMismatch(t *testing.T) { MaxRetries: 3, RetryInterval: 10 * time.Millisecond, InitialWait: 0, - SleepFunc: func(d time.Duration) {}, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, }) err := hc.WaitForHealth(context.Background()) @@ -186,8 +186,9 @@ func TestHealthChecker_WaitForHealth_RespectsContext(t *testing.T) { MaxRetries: 10, RetryInterval: 10 * time.Millisecond, InitialWait: 0, - SleepFunc: func(d time.Duration) { + SleepFunc: func(_ context.Context, _ time.Duration) error { cancel() // Cancel context during sleep + return nil }, }) @@ -210,8 +211,9 @@ func TestHealthChecker_WaitForHealth_InitialWait(t *testing.T) { MaxRetries: 5, RetryInterval: 100 * time.Millisecond, InitialWait: 500 * time.Millisecond, - SleepFunc: func(d time.Duration) { + SleepFunc: func(_ context.Context, d time.Duration) error { sleepDurations = append(sleepDurations, d) + return nil }, }) @@ -255,10 +257,46 @@ func TestHealthChecker_WaitForHealth_HandlesInvalidJSON(t *testing.T) { MaxRetries: 5, RetryInterval: 10 * time.Millisecond, InitialWait: 0, - SleepFunc: func(d time.Duration) {}, + SleepFunc: func(_ context.Context, _ time.Duration) error { return nil }, }) err := hc.WaitForHealth(context.Background()) require.NoError(t, err) assert.Equal(t, int32(2), attempts.Load()) } + +func TestHealthChecker_WaitForHealth_ContextCancelledDuringSleep_ReturnsImmediately(t *testing.T) { + // This test verifies that when context is cancelled during a sleep, + // WaitForHealth returns immediately without waiting for the full sleep duration. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(HealthResponse{Ok: false, Version: "v1.0.0"}) // Always fail to trigger retry sleep + })) + defer server.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + // Use the real sleepWithContext (default) to test actual behavior + hc := NewHealthChecker(HealthConfig{ + URL: server.URL, + TargetVersion: "v1.0.0", + MaxRetries: 10, + RetryInterval: 10 * time.Second, // Very long sleep - would timeout test if not cancelled + InitialWait: 0, + // SleepFunc not set - uses default sleepWithContext + }) + + // Cancel context after a short delay + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + start := time.Now() + err := hc.WaitForHealth(ctx) + elapsed := time.Since(start) + + assert.ErrorIs(t, err, context.Canceled) + // Should return quickly (< 1s), not wait for the 10s retry interval + assert.Less(t, elapsed, 1*time.Second, "should return immediately on context cancel, not wait for sleep") +} diff --git a/internal/update/service.go b/internal/update/service.go index 098b4f0..1a719f4 100644 --- a/internal/update/service.go +++ b/internal/update/service.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os/exec" + "strings" "time" ) @@ -72,6 +73,23 @@ func (s *ServiceController) Stop(ctx context.Context) error { return nil } +// Exists checks whether the systemd service unit is loaded. +// It runs "systemctl show --property=LoadState " and returns true if +// the LoadState is "loaded". +func (s *ServiceController) Exists(ctx context.Context) (bool, error) { + output, err := s.config.ExecFunc(ctx, "systemctl", "show", "--property=LoadState", s.config.ServiceName) + if err != nil { + if ctx.Err() != nil { + return false, ctx.Err() + } + return false, fmt.Errorf("failed to check service %s: %w (output: %s)", s.config.ServiceName, err, string(output)) + } + + // Parse output: "LoadState=loaded\n" or "LoadState=not-found\n" + line := strings.TrimSpace(string(output)) + return line == "LoadState=loaded", nil +} + // Start starts the systemd service. // It respects the configured timeout and the parent context. func (s *ServiceController) Start(ctx context.Context) error { diff --git a/internal/update/service_test.go b/internal/update/service_test.go index a32c74a..bf3114c 100644 --- a/internal/update/service_test.go +++ b/internal/update/service_test.go @@ -211,6 +211,48 @@ func TestServiceController_Start_UsesConfiguredTimeout(t *testing.T) { assert.WithinDuration(t, time.Now().Add(20*time.Second), deadline, 2*time.Second) } +func TestServiceController_Exists_ReturnsTrueWhenLoaded(t *testing.T) { + recorder := newRecordingExec(mockExecResult{output: "LoadState=loaded\n", err: nil}) + sc := NewServiceController(ServiceConfig{ + ServiceName: "hostlink", + ExecFunc: recorder.exec, + }) + + exists, err := sc.Exists(context.Background()) + + require.NoError(t, err) + assert.True(t, exists) + require.Len(t, recorder.calls, 1) + assert.Equal(t, "systemctl", recorder.calls[0].name) + assert.Equal(t, []string{"show", "--property=LoadState", "hostlink"}, recorder.calls[0].args) +} + +func TestServiceController_Exists_ReturnsFalseWhenNotFound(t *testing.T) { + recorder := newRecordingExec(mockExecResult{output: "LoadState=not-found\n", err: nil}) + sc := NewServiceController(ServiceConfig{ + ServiceName: "hostlink", + ExecFunc: recorder.exec, + }) + + exists, err := sc.Exists(context.Background()) + + require.NoError(t, err) + assert.False(t, exists) +} + +func TestServiceController_Exists_ReturnsErrorOnExecFailure(t *testing.T) { + recorder := newRecordingExec(mockExecResult{output: "", err: errors.New("exec failed")}) + sc := NewServiceController(ServiceConfig{ + ServiceName: "hostlink", + ExecFunc: recorder.exec, + }) + + _, err := sc.Exists(context.Background()) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to check service") +} + func TestServiceController_DefaultTimeouts(t *testing.T) { sc := NewServiceController(ServiceConfig{ ServiceName: "hostlink", diff --git a/internal/update/spawn.go b/internal/update/spawn.go index 9a9ed04..b499843 100644 --- a/internal/update/spawn.go +++ b/internal/update/spawn.go @@ -5,12 +5,12 @@ import ( "syscall" ) -// SpawnUpdater starts the updater binary in its own process group. -// The updater survives the agent's shutdown because Setpgid: true +// SpawnUpgrade starts the staged binary in its own process group. +// The upgrade process survives the agent's shutdown because Setpgid: true // places it in a new process group that systemd won't kill. // This is fire-and-forget: the caller does not wait for the process to exit. -func SpawnUpdater(updaterPath string, args []string) error { - cmd, err := spawnWithCmd(updaterPath, args) +func SpawnUpgrade(binaryPath string, args []string) error { + cmd, err := spawnWithCmd(binaryPath, args) if err != nil { return err } @@ -20,8 +20,8 @@ func SpawnUpdater(updaterPath string, args []string) error { // spawnWithCmd is the internal implementation that returns the exec.Cmd // for testing purposes (to inspect the child PID/PGID). -func spawnWithCmd(updaterPath string, args []string) (*exec.Cmd, error) { - cmd := exec.Command(updaterPath, args...) +func spawnWithCmd(binaryPath string, args []string) (*exec.Cmd, error) { + cmd := exec.Command(binaryPath, args...) cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} if err := cmd.Start(); err != nil { return nil, err diff --git a/internal/update/spawn_test.go b/internal/update/spawn_test.go index fa59f79..51eb7e3 100644 --- a/internal/update/spawn_test.go +++ b/internal/update/spawn_test.go @@ -9,12 +9,12 @@ import ( "github.com/stretchr/testify/require" ) -func TestSpawnUpdater_StartsProcess(t *testing.T) { - err := SpawnUpdater("/bin/sleep", []string{"0.1"}) +func TestSpawnUpgrade_StartsProcess(t *testing.T) { + err := SpawnUpgrade("/bin/sleep", []string{"0.1"}) require.NoError(t, err) } -func TestSpawnUpdater_SetpgidTrue(t *testing.T) { +func TestSpawnUpgrade_SetpgidTrue(t *testing.T) { // Use a process that prints its own PGID so we can verify // We spawn "sleep 2" and check its PGID differs from ours cmd, err := spawnWithCmd("/bin/sleep", []string{"0.5"}) @@ -38,7 +38,7 @@ func TestSpawnUpdater_SetpgidTrue(t *testing.T) { "child should be its own process group leader (PGID == PID)") } -func TestSpawnUpdater_ReturnsErrorForInvalidBinary(t *testing.T) { - err := SpawnUpdater("/nonexistent/binary", []string{}) +func TestSpawnUpgrade_ReturnsErrorForInvalidBinary(t *testing.T) { + err := SpawnUpgrade("/nonexistent/binary", []string{}) assert.Error(t, err) } diff --git a/internal/update/state.go b/internal/update/state.go index abd96cc..ca2d413 100644 --- a/internal/update/state.go +++ b/internal/update/state.go @@ -32,7 +32,15 @@ type StateData struct { Error *string `json:"error,omitempty"` } +// StateWriterInterface defines the interface for state persistence. +// This interface allows mocking in tests. +type StateWriterInterface interface { + Write(data StateData) error + Read() (StateData, error) +} + // StateWriter manages the update state file for observability. +// Implements StateWriterInterface. type StateWriter struct { statePath string } diff --git a/main.go b/main.go index e0a9644..5f1596b 100644 --- a/main.go +++ b/main.go @@ -7,23 +7,175 @@ import ( "hostlink/app/jobs/heartbeatjob" "hostlink/app/jobs/metricsjob" "hostlink/app/jobs/registrationjob" + "hostlink/app/jobs/selfupdatejob" "hostlink/app/jobs/taskjob" + "hostlink/app/services/agentstate" "hostlink/app/services/heartbeat" "hostlink/app/services/metrics" + "hostlink/app/services/requestsigner" "hostlink/app/services/taskfetcher" "hostlink/app/services/taskreporter" + "hostlink/app/services/updatecheck" + "hostlink/app/services/updatedownload" + "hostlink/app/services/updatepreflight" + "hostlink/cmd/upgrade" "hostlink/config" "hostlink/config/appconf" "hostlink/internal/dbconn" + "hostlink/internal/update" "hostlink/internal/validator" + "hostlink/version" "log" + "net/http" + "os" + "syscall" + "time" "github.com/joho/godotenv" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" + "github.com/urfave/cli/v3" ) -func main() { +func init() { + _ = godotenv.Load() +} + +func newApp() *cli.Command { + return &cli.Command{ + Name: "hostlink", + Usage: "Hostlink agent", + Version: version.Version, + Action: runServer, + Commands: []*cli.Command{ + { + Name: "version", + Usage: "Print the version", + Action: func(ctx context.Context, cmd *cli.Command) error { + fmt.Println(version.Version) + return nil + }, + }, + { + Name: "upgrade", + Usage: "Upgrade the hostlink binary in-place", + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "install-path", + Usage: "Target path to install the binary", + Value: "/usr/bin/hostlink", + Sources: cli.EnvVars("HOSTLINK_INSTALL_PATH"), + }, + &cli.BoolFlag{ + Name: "dry-run", + Usage: "Validate preconditions without performing the upgrade", + }, + &cli.StringFlag{ + Name: "base-dir", + Usage: "Override update base directory (for testing)", + Hidden: true, + }, + &cli.StringFlag{ + Name: "update-id", + Usage: "Unique ID for this update operation", + Hidden: true, + }, + &cli.StringFlag{ + Name: "source-version", + Usage: "Version being upgraded from", + Hidden: true, + }, + }, + Action: runUpgrade, + }, + }, + } +} + +const upgradeTimeout = 90 * time.Second + +func runUpgrade(ctx context.Context, cmd *cli.Command) error { + installPath := cmd.String("install-path") + if installPath == "" { + return fmt.Errorf("--install-path cannot be empty") + } + + dryRun := cmd.Bool("dry-run") + baseDir := cmd.String("base-dir") + updateID := cmd.String("update-id") + sourceVersion := cmd.String("source-version") + + // Resolve self path (the staged binary that was executed) + selfPath, err := os.Executable() + if err != nil { + return fmt.Errorf("cannot determine self path: %w", err) + } + + // Set up paths (use custom base-dir if provided, otherwise defaults) + var paths update.Paths + if baseDir != "" { + paths = update.NewPaths(baseDir) + } else { + paths = update.DefaultPaths() + } + + // Set up logger + logger, cleanup, err := upgrade.NewLogger(upgrade.DefaultLogPath) + if err != nil { + // Fall back to stderr-only logging if we can't write the log file + fmt.Fprintf(os.Stderr, "warning: cannot open log file: %v\n", err) + logger = nil // Upgrader will use discard logger + } else { + defer cleanup() + } + + // Build config + cfg := &upgrade.Config{ + InstallPath: installPath, + SelfPath: selfPath, + BackupDir: paths.BackupDir, + LockPath: paths.LockFile, + StatePath: paths.StateFile, + HealthURL: "http://127.0.0.1:" + appconf.Port() + "/health", + TargetVersion: version.Version, + UpdateID: updateID, + SourceVersion: sourceVersion, + Logger: logger, + } + + u, err := upgrade.NewUpgrader(cfg) + if err != nil { + return err + } + + if dryRun { + results := u.DryRun(ctx) + allPassed := true + for _, r := range results { + status := "PASS" + if !r.Passed { + status = "FAIL" + allPassed = false + } + fmt.Fprintf(os.Stderr, "[%s] %s: %s\n", status, r.Name, r.Detail) + } + if !allPassed { + return fmt.Errorf("dry-run: one or more checks failed") + } + return nil + } + + // Set up timeout and signal handling + ctx, cancel := context.WithTimeout(ctx, upgradeTimeout) + defer cancel() + + stop := upgrade.WatchSignals(cancel) + defer stop() + + return u.Run(ctx) +} + +func runServer(ctx context.Context, cmd *cli.Command) error { db, err := dbconn.GetConn( dbconn.WithURL(appconf.DBURL()), ) @@ -47,7 +199,6 @@ func main() { config.AddRoutesV2(e, container) - // TODO(iAziz786): check if we can move this cron in app // Agent-related jobs run in goroutine after registration go func() { ctx := context.Background() @@ -88,11 +239,120 @@ func main() { } heartbeatJob := heartbeatjob.New() heartbeatJob.Register(ctx, heartbeatSvc) + + // Self-update job (gated by config) + if appconf.SelfUpdateEnabled() { + startSelfUpdateJob(ctx) + } }() - log.Fatal(e.Start(fmt.Sprintf(":%s", appconf.Port()))) + return e.Start(fmt.Sprintf(":%s", appconf.Port())) } -func init() { - _ = godotenv.Load() +func startSelfUpdateJob(ctx context.Context) { + paths := update.DefaultPaths() + + // Ensure update directories exist with correct permissions + if err := update.InitDirectories(paths.BaseDir); err != nil { + log.Printf("failed to initialize update directories: %v", err) + return + } + + // Clean staging dir on boot and ensure it's ready for use + stagingMgr := updatedownload.NewStagingManager(paths.StagingDir, nil) + if err := stagingMgr.Cleanup(); err != nil { + log.Printf("failed to clean staging dir on boot: %v", err) + } + if err := stagingMgr.Prepare(); err != nil { + log.Printf("failed to prepare staging dir: %v", err) + return + } + + // Load agent state for ID and signer + state := agentstate.New(appconf.AgentStatePath()) + if err := state.Load(); err != nil { + log.Printf("failed to load agent state for self-update: %v", err) + return + } + agentID := state.GetAgentID() + if agentID == "" { + log.Printf("self-update: agent ID not available, skipping") + return + } + + // Create request signer + signer, err := requestsigner.New(appconf.AgentPrivateKeyPath(), agentID) + if err != nil { + log.Printf("failed to create request signer for self-update: %v", err) + return + } + + // Create update checker + checker, err := updatecheck.New( + &http.Client{Timeout: 30 * time.Second}, + appconf.ControlPlaneURL(), + agentID, + signer, + ) + if err != nil { + log.Printf("failed to create update checker: %v", err) + return + } + + // Create downloader + downloader := updatedownload.NewDownloader(updatedownload.DefaultDownloadConfig()) + + // Create preflight checker + preflight := updatepreflight.New(updatepreflight.PreflightConfig{ + AgentBinaryPath: appconf.InstallPath(), + UpdatesDir: paths.BaseDir, + StatFunc: func(path string) (uint64, error) { + var stat syscall.Statfs_t + if err := syscall.Statfs(path, &stat); err != nil { + return 0, err + } + return stat.Bavail * uint64(stat.Bsize), nil + }, + }) + + // Create lock manager + lockMgr := update.NewLockManager(update.LockConfig{ + LockPath: paths.LockFile, + }) + + // Create state writer + stateWriter := update.NewStateWriter(update.StateConfig{ + StatePath: paths.StateFile, + }) + + // Configure trigger with update check interval + triggerCfg := selfupdatejob.TriggerConfig{ + Interval: appconf.UpdateCheckInterval(), + } + + job := selfupdatejob.NewWithConfig(selfupdatejob.SelfUpdateJobConfig{ + Trigger: func(ctx context.Context, fn func() error) { + selfupdatejob.TriggerWithConfig(ctx, fn, triggerCfg) + }, + UpdateChecker: checker, + Downloader: downloader, + PreflightChecker: preflight, + LockManager: lockMgr, + StateWriter: stateWriter, + Spawn: update.SpawnUpgrade, + InstallBinary: update.InstallBinary, + CurrentVersion: version.Version, + InstallPath: appconf.InstallPath(), + StagingDir: paths.StagingDir, + }) + + job.Register(ctx) + log.Printf("self-update job started (interval: %s)", appconf.UpdateCheckInterval()) +} + +func main() { + app := newApp() + if err := app.Run(context.Background(), os.Args); err != nil { + log.Fatal(err) + } } diff --git a/test/integration/selfupdate_test.go b/test/integration/selfupdate_test.go index 4e29d0b..25091f1 100644 --- a/test/integration/selfupdate_test.go +++ b/test/integration/selfupdate_test.go @@ -1,5 +1,4 @@ -//go:build integration -// +build integration +//go:build integration && linux package integration @@ -19,15 +18,15 @@ import ( "github.com/stretchr/testify/require" ) -// buildUpdaterBinary compiles the hostlink-updater binary into the given directory. +// buildHostlinkBinary compiles the hostlink binary into the given directory. // Returns the path to the compiled binary. -func buildUpdaterBinary(t *testing.T, outputDir string) string { +func buildHostlinkBinary(t *testing.T, outputDir string) string { t.Helper() - binaryPath := filepath.Join(outputDir, "hostlink-updater") - cmd := exec.Command("go", "build", "-o", binaryPath, "./cmd/updater") + binaryPath := filepath.Join(outputDir, "hostlink") + cmd := exec.Command("go", "build", "-o", binaryPath, ".") cmd.Dir = findProjectRoot(t) output, err := cmd.CombinedOutput() - require.NoError(t, err, "failed to build hostlink-updater: %s", string(output)) + require.NoError(t, err, "failed to build hostlink: %s", string(output)) return binaryPath } @@ -58,15 +57,6 @@ func setupUpdateDirs(t *testing.T) string { return baseDir } -// writeStateFile writes a state.json file into the base directory. -func writeStateFile(t *testing.T, baseDir string, data update.StateData) { - t.Helper() - paths := update.NewPaths(baseDir) - sw := update.NewStateWriter(update.StateConfig{StatePath: paths.StateFile}) - err := sw.Write(data) - require.NoError(t, err) -} - // readStateFile reads and returns the state.json contents from the base directory. func readStateFile(t *testing.T, baseDir string) update.StateData { t.Helper() @@ -77,7 +67,17 @@ func readStateFile(t *testing.T, baseDir string) update.StateData { return data } -func TestSelfUpdate_LockPreventsConcurrentUpdates(t *testing.T) { +// createDummyBinary creates a dummy executable file at the given path. +// Returns the path to the created file. +func createDummyBinary(t *testing.T, dir string) string { + t.Helper() + path := filepath.Join(dir, "hostlink") + err := os.WriteFile(path, []byte("#!/bin/sh\nexit 0\n"), 0755) + require.NoError(t, err) + return path +} + +func TestUpgrade_LockPreventsConcurrent(t *testing.T) { baseDir := setupUpdateDirs(t) paths := update.NewPaths(baseDir) @@ -87,174 +87,174 @@ func TestSelfUpdate_LockPreventsConcurrentUpdates(t *testing.T) { require.NoError(t, err) defer lock.Unlock() - // Build the updater binary + // Build the hostlink binary binDir := t.TempDir() - updaterBin := buildUpdaterBinary(t, binDir) - - // Write state file so updater can read target version - writeStateFile(t, baseDir, update.StateData{ - State: update.StateStaged, - TargetVersion: "1.2.3", - }) + hostlinkBin := buildHostlinkBinary(t, binDir) - // Attempt to run the updater — it should fail because the lock is held + // Attempt to run upgrade — it should fail because the lock is held ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - cmd := exec.CommandContext(ctx, updaterBin, - "-base-dir", baseDir, - "-version", "1.2.3", + cmd := exec.CommandContext(ctx, hostlinkBin, "upgrade", + "--base-dir", baseDir, + "--install-path", "/tmp/fake-hostlink", ) output, err := cmd.CombinedOutput() - require.Error(t, err, "updater should fail when lock is held") + require.Error(t, err, "upgrade should fail when lock is held") assert.Contains(t, string(output), "lock", "error should mention lock") + + // The lock file should still be held by us (not stolen) + lockContent, err := os.ReadFile(paths.LockFile) + require.NoError(t, err) + + var lockData struct { + PID int `json:"pid"` + } + err = json.Unmarshal(lockContent, &lockData) + require.NoError(t, err) + assert.Equal(t, os.Getpid(), lockData.PID, "lock should still be held by test process") } -func TestSelfUpdate_SignalHandlingDuringUpdate(t *testing.T) { +func TestUpgrade_SignalHandling(t *testing.T) { baseDir := setupUpdateDirs(t) - // Build the updater binary + // Build the hostlink binary binDir := t.TempDir() - updaterBin := buildUpdaterBinary(t, binDir) - - // Write state file with target version - writeStateFile(t, baseDir, update.StateData{ - State: update.StateStaged, - TargetVersion: "1.2.3", - }) - - // Start the updater process — it will try to stop the service via systemctl - // which will take time / fail, giving us time to send a signal - cmd := exec.Command(updaterBin, - "-base-dir", baseDir, - "-version", "1.2.3", - "-binary", "/nonexistent/hostlink", // non-existent binary, will fail at stop + hostlinkBin := buildHostlinkBinary(t, binDir) + + // Create a dummy binary at install-path so backup succeeds, + // then systemctl stop will stall/fail giving us time to send a signal + installDir := t.TempDir() + installPath := createDummyBinary(t, installDir) + + // Start the upgrade process + cmd := exec.Command(hostlinkBin, "upgrade", + "--base-dir", baseDir, + "--install-path", installPath, ) cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} err := cmd.Start() - require.NoError(t, err, "should start updater process") + require.NoError(t, err, "should start upgrade process") - // Give it a moment to start - time.Sleep(100 * time.Millisecond) + // Give it time to start and reach the systemctl stop phase + time.Sleep(500 * time.Millisecond) // Send SIGTERM err = cmd.Process.Signal(syscall.SIGTERM) require.NoError(t, err, "should send SIGTERM") - // Wait for exit — should exit within a few seconds + // Wait for exit — should exit within 10 seconds done := make(chan error, 1) go func() { done <- cmd.Wait() }() select { case err := <-done: - // Process exited (may be non-zero exit code due to cancellation, that's ok) + // Process exited (non-zero exit due to cancellation is expected) if err != nil { - // Verify it's an exit error, not something unexpected _, ok := err.(*exec.ExitError) assert.True(t, ok, "expected ExitError after signal, got: %v", err) } case <-time.After(10 * time.Second): cmd.Process.Kill() - t.Fatal("updater did not exit within 10 seconds after SIGTERM") + t.Fatal("upgrade did not exit within 10 seconds after SIGTERM") } } -func TestSelfUpdate_UpdaterWritesStateOnLockFailure(t *testing.T) { +func TestUpgrade_MissingInstallPathFails(t *testing.T) { baseDir := setupUpdateDirs(t) - paths := update.NewPaths(baseDir) - - // Acquire lock from this process to block the updater - lock := update.NewLockManager(update.LockConfig{LockPath: paths.LockFile}) - err := lock.TryLock(5 * time.Minute) - require.NoError(t, err) - defer lock.Unlock() - // Build the updater + // Build the hostlink binary binDir := t.TempDir() - updaterBin := buildUpdaterBinary(t, binDir) + hostlinkBin := buildHostlinkBinary(t, binDir) - // Run the updater — should fail on lock acquisition + // Run upgrade with a non-existent install-path — backup phase should fail ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - cmd := exec.CommandContext(ctx, updaterBin, - "-base-dir", baseDir, - "-version", "2.0.0", + cmd := exec.CommandContext(ctx, hostlinkBin, "upgrade", + "--base-dir", baseDir, + "--install-path", "/nonexistent/path/hostlink", ) - _, err = cmd.CombinedOutput() - require.Error(t, err, "updater should fail when lock is held") - - // The lock file should still be held by us (not stolen) - lockContent, err := os.ReadFile(paths.LockFile) - require.NoError(t, err) - - var lockData struct { - PID int `json:"pid"` - } - err = json.Unmarshal(lockContent, &lockData) - require.NoError(t, err) - assert.Equal(t, os.Getpid(), lockData.PID, "lock should still be held by test process") + output, err := cmd.CombinedOutput() + require.Error(t, err, "upgrade should fail with non-existent install-path") + assert.Contains(t, string(output), "no such file", + "error should indicate file not found") } -func TestSelfUpdate_UpdaterExitsWithErrorForMissingVersion(t *testing.T) { +func TestUpgrade_DryRun(t *testing.T) { baseDir := setupUpdateDirs(t) - // Build the updater + // Build the hostlink binary binDir := t.TempDir() - updaterBin := buildUpdaterBinary(t, binDir) + hostlinkBin := buildHostlinkBinary(t, binDir) - // Run without -version and without state file — should fail + // Run dry-run with a non-existent install-path — some checks should fail ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - cmd := exec.CommandContext(ctx, updaterBin, - "-base-dir", baseDir, + cmd := exec.CommandContext(ctx, hostlinkBin, "upgrade", + "--dry-run", + "--base-dir", baseDir, + "--install-path", "/nonexistent/path/hostlink", ) output, err := cmd.CombinedOutput() - require.Error(t, err, "updater should fail without version") - assert.Contains(t, string(output), "version", "error should mention version") + outputStr := string(output) + + // Should exit with error because checks fail + require.Error(t, err, "dry-run should fail when checks don't pass") + + // Should contain check result output format + assert.Contains(t, outputStr, "[FAIL]", "should report failed checks") + assert.Contains(t, outputStr, "binary_writable", "should check binary_writable") } -func TestSelfUpdate_UpdaterReadsVersionFromState(t *testing.T) { +func TestUpgrade_DryRunPassesWithValidPath(t *testing.T) { baseDir := setupUpdateDirs(t) - // Write state file with target version - writeStateFile(t, baseDir, update.StateData{ - State: update.StateStaged, - TargetVersion: "3.0.0", - }) - - // Build the updater + // Build the hostlink binary binDir := t.TempDir() - updaterBin := buildUpdaterBinary(t, binDir) + hostlinkBin := buildHostlinkBinary(t, binDir) - // Run without -version flag but with state file containing version - // It should read the version from state and proceed (then fail at systemctl stop) + // Create a writable dummy binary at a temp path + installDir := t.TempDir() + installPath := createDummyBinary(t, installDir) + + // Run dry-run with a valid install-path ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - cmd := exec.CommandContext(ctx, updaterBin, - "-base-dir", baseDir, - "-binary", "/nonexistent/hostlink", + cmd := exec.CommandContext(ctx, hostlinkBin, "upgrade", + "--dry-run", + "--base-dir", baseDir, + "--install-path", installPath, ) output, err := cmd.CombinedOutput() - // Should fail (no systemctl), but NOT because of missing version - require.Error(t, err) - assert.NotContains(t, string(output), "target version is required", - "should have read version from state file") + outputStr := string(output) + + // Some checks should pass (lock, binary_writable, backup_dir) + assert.Contains(t, outputStr, "[PASS]", "should report passing checks") + assert.Contains(t, outputStr, "lock_acquirable", "should check lock_acquirable") + assert.Contains(t, outputStr, "binary_writable", "should check binary_writable") + + // service_exists will fail on non-hostlink machines, so overall exits with error + // but the path-related checks should pass + if err != nil { + assert.Contains(t, outputStr, "[FAIL]", + "if error, should be because service_exists or similar check failed") + } } -func TestSelfUpdate_UpdaterPrintVersion(t *testing.T) { - // Build the updater +func TestUpgrade_VersionSubcommand(t *testing.T) { + // Build the hostlink binary binDir := t.TempDir() - updaterBin := buildUpdaterBinary(t, binDir) + hostlinkBin := buildHostlinkBinary(t, binDir) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - cmd := exec.CommandContext(ctx, updaterBin, "-v") + cmd := exec.CommandContext(ctx, hostlinkBin, "version") output, err := cmd.CombinedOutput() - require.NoError(t, err, "version flag should not fail") - assert.Contains(t, string(output), "hostlink-updater", "should print version info") + require.NoError(t, err, "version subcommand should not fail") + assert.Contains(t, string(output), "dev", "should print version info") }