From c50764b65b743d8db0c6b6482fbc5a79c8cfb48f Mon Sep 17 00:00:00 2001 From: Mohammad Aziz Date: Thu, 15 Jan 2026 23:36:53 +0530 Subject: [PATCH] feat: add agent self-update mechanism --- .goreleaser.yaml | 38 +- AGENTS.md | 4 + app/jobs/selfupdatejob/selfupdatejob.go | 224 ++++ app/jobs/selfupdatejob/selfupdatejob_test.go | 987 ++++++++++++++++++ app/jobs/selfupdatejob/trigger.go | 39 + app/services/updatecheck/updatecheck.go | 81 ++ app/services/updatecheck/updatecheck_test.go | 248 +++++ app/services/updatedownload/download.go | 281 +++++ app/services/updatedownload/download_test.go | 410 ++++++++ app/services/updatedownload/extract.go | 214 ++++ app/services/updatedownload/extract_test.go | 339 ++++++ app/services/updatedownload/staging.go | 86 ++ app/services/updatedownload/staging_test.go | 190 ++++ app/services/updatepreflight/preflight.go | 98 ++ .../updatepreflight/preflight_test.go | 183 ++++ cmd/updater/main.go | 102 ++ cmd/updater/signals.go | 27 + cmd/updater/signals_test.go | 37 + cmd/updater/updater.go | 306 ++++++ cmd/updater/updater_test.go | 644 ++++++++++++ config/appconf/appconf.go | 66 +- config/appconf/appconf_test.go | 78 ++ internal/update/binary.go | 246 +++++ internal/update/binary_test.go | 390 +++++++ internal/update/dirs.go | 69 ++ internal/update/dirs_test.go | 127 +++ internal/update/health.go | 147 +++ internal/update/health_test.go | 264 +++++ internal/update/lock.go | 267 +++++ internal/update/lock_test.go | 377 +++++++ internal/update/procutil_darwin.go | 52 + internal/update/procutil_linux.go | 73 ++ internal/update/procutil_test.go | 66 ++ internal/update/service.go | 92 ++ internal/update/service_test.go | 223 ++++ internal/update/spawn.go | 30 + internal/update/spawn_test.go | 44 + internal/update/state.go | 104 ++ internal/update/state_test.go | 258 +++++ internal/versionutil/version.go | 95 ++ internal/versionutil/version_test.go | 226 ++++ scripts/linux/hostlink.service | 3 +- test/integration/selfupdate_test.go | 260 +++++ 43 files changed, 8088 insertions(+), 7 deletions(-) create mode 100644 AGENTS.md create mode 100644 app/jobs/selfupdatejob/selfupdatejob.go create mode 100644 app/jobs/selfupdatejob/selfupdatejob_test.go create mode 100644 app/jobs/selfupdatejob/trigger.go create mode 100644 app/services/updatecheck/updatecheck.go create mode 100644 app/services/updatecheck/updatecheck_test.go create mode 100644 app/services/updatedownload/download.go create mode 100644 app/services/updatedownload/download_test.go create mode 100644 app/services/updatedownload/extract.go create mode 100644 app/services/updatedownload/extract_test.go create mode 100644 app/services/updatedownload/staging.go create mode 100644 app/services/updatedownload/staging_test.go create mode 100644 app/services/updatepreflight/preflight.go create mode 100644 app/services/updatepreflight/preflight_test.go create mode 100644 cmd/updater/main.go create mode 100644 cmd/updater/signals.go create mode 100644 cmd/updater/signals_test.go create mode 100644 cmd/updater/updater.go create mode 100644 cmd/updater/updater_test.go create mode 100644 config/appconf/appconf_test.go create mode 100644 internal/update/binary.go create mode 100644 internal/update/binary_test.go create mode 100644 internal/update/dirs.go create mode 100644 internal/update/dirs_test.go create mode 100644 internal/update/health.go create mode 100644 internal/update/health_test.go create mode 100644 internal/update/lock.go create mode 100644 internal/update/lock_test.go create mode 100644 internal/update/procutil_darwin.go create mode 100644 internal/update/procutil_linux.go create mode 100644 internal/update/procutil_test.go create mode 100644 internal/update/service.go create mode 100644 internal/update/service_test.go create mode 100644 internal/update/spawn.go create mode 100644 internal/update/spawn_test.go create mode 100644 internal/update/state.go create mode 100644 internal/update/state_test.go create mode 100644 internal/versionutil/version.go create mode 100644 internal/versionutil/version_test.go create mode 100644 test/integration/selfupdate_test.go diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 2911ea1..a37c0e4 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -19,24 +19,40 @@ before: - make test-it builds: - - env: + - id: hostlink + env: - CGO_ENABLED=0 goos: - linux + goarch: + - amd64 + - arm64 + ldflags: + - -s -w -X hostlink/version.Version={{.Version}} + + - id: hostlink-updater + main: ./cmd/updater + binary: hostlink-updater + env: + - CGO_ENABLED=0 + goos: + - linux + goarch: + - amd64 + - arm64 ldflags: - -s -w -X hostlink/version.Version={{.Version}} archives: - - formats: [tar.gz] - # this name template makes the OS and Arch compatible with the results of `uname`. + - id: hostlink-archive + builds: [hostlink] + formats: [tar.gz] name_template: >- {{ .ProjectName }}_ {{- title .Os }}_ {{- if eq .Arch "amd64" }}x86_64 - {{- else if eq .Arch "386" }}i386 {{- else }}{{ .Arch }}{{ end }} {{- if .Arm }}v{{ .Arm }}{{ end }} - # use zip for windows archives format_overrides: - goos: windows formats: [zip] @@ -45,6 +61,18 @@ archives: dst: scripts strip_parent: true + - id: updater-archive + builds: [hostlink-updater] + formats: [tar.gz] + name_template: >- + hostlink-updater_ + {{- title .Os }}_ + {{- if eq .Arch "amd64" }}x86_64 + {{- else }}{{ .Arch }}{{ end }} + {{- if .Arm }}v{{ .Arm }}{{ end }} + files: + - none* + changelog: sort: asc filters: diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..40b5085 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,4 @@ +- ALWAYS USE PARALLEL TASKS SUBAGENTS FOR CODE EXPLORATION, DEEP DIVES, AND SO ON +- I use jj instead of git +- ALWAYS FOLLOW TDD, red phase to green phase +- Use ripgrep instead of grep, use fd instead of find diff --git a/app/jobs/selfupdatejob/selfupdatejob.go b/app/jobs/selfupdatejob/selfupdatejob.go new file mode 100644 index 0000000..c3c6fcf --- /dev/null +++ b/app/jobs/selfupdatejob/selfupdatejob.go @@ -0,0 +1,224 @@ +package selfupdatejob + +import ( + "context" + "fmt" + "path/filepath" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "hostlink/app/services/updatecheck" + "hostlink/app/services/updatedownload" + "hostlink/app/services/updatepreflight" + "hostlink/internal/update" +) + +const ( + // defaultRequiredSpace is the fallback disk space requirement (50MB) when the + // control plane does not provide download sizes. + defaultRequiredSpace = 50 * 1024 * 1024 +) + +// TriggerFunc is the function type for the job's scheduling strategy. +type TriggerFunc func(context.Context, func() error) + +// UpdateCheckerInterface abstracts the update check client. +type UpdateCheckerInterface interface { + Check(currentVersion string) (*updatecheck.UpdateInfo, error) +} + +// DownloaderInterface abstracts the download and verify functionality. +type DownloaderInterface interface { + DownloadAndVerify(ctx context.Context, url, destPath, sha256 string) (*updatedownload.DownloadResult, error) +} + +// PreflightCheckerInterface abstracts pre-flight checks. +type PreflightCheckerInterface interface { + Check(requiredSpace int64) *updatepreflight.PreflightResult +} + +// LockManagerInterface abstracts the lock manager. +type LockManagerInterface interface { + TryLockWithRetry(expiration time.Duration, retries int, interval time.Duration) error + Unlock() error +} + +// StateWriterInterface abstracts the state writer. +type StateWriterInterface interface { + Write(data update.StateData) error +} + +// SpawnFunc is a function that spawns the updater binary. +type SpawnFunc func(updaterPath string, args []string) error + +// InstallUpdaterFunc is a function that extracts and installs the updater binary from a tarball. +type InstallUpdaterFunc func(tarPath, destPath string) error + +// SelfUpdateJobConfig holds the configuration for the SelfUpdateJob. +type SelfUpdateJobConfig struct { + Trigger TriggerFunc + UpdateChecker UpdateCheckerInterface + Downloader DownloaderInterface + PreflightChecker PreflightCheckerInterface + LockManager LockManagerInterface + StateWriter StateWriterInterface + Spawn SpawnFunc + InstallUpdater InstallUpdaterFunc + CurrentVersion string + UpdaterPath string // Where to install the extracted updater binary + StagingDir string // Where to download tarballs + BaseDir string // Base update directory (for -base-dir flag to updater) +} + +// SelfUpdateJob periodically checks for and applies updates. +type SelfUpdateJob struct { + config SelfUpdateJobConfig + cancel context.CancelFunc + wg sync.WaitGroup +} + +// New creates a SelfUpdateJob with default configuration. +func New() *SelfUpdateJob { + return &SelfUpdateJob{ + config: SelfUpdateJobConfig{ + Trigger: Trigger, + }, + } +} + +// NewWithConfig creates a SelfUpdateJob with the given configuration. +func NewWithConfig(cfg SelfUpdateJobConfig) *SelfUpdateJob { + if cfg.Trigger == nil { + cfg.Trigger = Trigger + } + return &SelfUpdateJob{ + config: cfg, + } +} + +// Register starts the job goroutine and returns a cancel function. +func (j *SelfUpdateJob) Register(ctx context.Context) context.CancelFunc { + ctx, cancel := context.WithCancel(ctx) + j.cancel = cancel + + j.wg.Add(1) + go func() { + defer j.wg.Done() + j.config.Trigger(ctx, func() error { + return j.runUpdate(ctx) + }) + }() + + return cancel +} + +// Shutdown cancels the job and waits for the goroutine to exit. +func (j *SelfUpdateJob) Shutdown() { + if j.cancel != nil { + j.cancel() + } + j.wg.Wait() +} + +// runUpdate performs a single update check and apply cycle. +func (j *SelfUpdateJob) runUpdate(ctx context.Context) error { + // Step 1: Check for updates + info, err := j.config.UpdateChecker.Check(j.config.CurrentVersion) + if err != nil { + return fmt.Errorf("update check failed: %w", err) + } + if !info.UpdateAvailable { + return nil + } + + log.Infof("update available: %s -> %s", j.config.CurrentVersion, info.TargetVersion) + + // Step 2: Pre-flight checks + requiredSpace := info.AgentSize + info.UpdaterSize + if requiredSpace == 0 { + requiredSpace = defaultRequiredSpace + } + result := j.config.PreflightChecker.Check(requiredSpace) + if !result.Passed { + return fmt.Errorf("preflight checks failed: %v", result.Errors) + } + + // Step 3: Acquire lock + if err := j.config.LockManager.TryLockWithRetry(5*time.Minute, 3, 5*time.Second); err != nil { + return fmt.Errorf("failed to acquire update lock: %w", err) + } + + // From here on, we must release the lock on any failure + locked := true + defer func() { + if locked { + j.config.LockManager.Unlock() + } + }() + + // Step 4: Write initialized state + if err := j.config.StateWriter.Write(update.StateData{ + State: update.StateInitialized, + SourceVersion: j.config.CurrentVersion, + TargetVersion: info.TargetVersion, + }); err != nil { + return fmt.Errorf("failed to write initialized state: %w", err) + } + + if err := ctx.Err(); err != nil { + return err + } + + // Step 5: Download agent tarball + agentDest := filepath.Join(j.config.StagingDir, updatedownload.AgentTarballName) + if _, err := j.config.Downloader.DownloadAndVerify(ctx, info.AgentURL, agentDest, info.AgentSHA256); err != nil { + return fmt.Errorf("failed to download agent: %w", err) + } + + if err := ctx.Err(); err != nil { + return err + } + + // Step 6: Download updater tarball + updaterDest := filepath.Join(j.config.StagingDir, updatedownload.UpdaterTarballName) + if _, err := j.config.Downloader.DownloadAndVerify(ctx, info.UpdaterURL, updaterDest, info.UpdaterSHA256); err != nil { + return fmt.Errorf("failed to download updater: %w", err) + } + + if err := ctx.Err(); err != nil { + return err + } + + // Step 7: Write staged state + if err := j.config.StateWriter.Write(update.StateData{ + State: update.StateStaged, + SourceVersion: j.config.CurrentVersion, + TargetVersion: info.TargetVersion, + }); err != nil { + return fmt.Errorf("failed to write staged state: %w", err) + } + + if err := ctx.Err(); err != nil { + return err + } + + // Step 8: Extract updater binary from tarball + if err := j.config.InstallUpdater(updaterDest, j.config.UpdaterPath); err != nil { + return fmt.Errorf("failed to install updater binary: %w", err) + } + + // Step 9: Release lock before spawning updater + j.config.LockManager.Unlock() + locked = false + + // Step 10: Spawn updater in its own process group + args := []string{"-version", info.TargetVersion, "-base-dir", j.config.BaseDir} + if err := j.config.Spawn(j.config.UpdaterPath, args); err != nil { + return fmt.Errorf("failed to spawn updater: %w", err) + } + + log.Infof("updater spawned for version %s", info.TargetVersion) + return nil +} diff --git a/app/jobs/selfupdatejob/selfupdatejob_test.go b/app/jobs/selfupdatejob/selfupdatejob_test.go new file mode 100644 index 0000000..efaeb3a --- /dev/null +++ b/app/jobs/selfupdatejob/selfupdatejob_test.go @@ -0,0 +1,987 @@ +package selfupdatejob + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "hostlink/app/services/updatecheck" + "hostlink/app/services/updatedownload" + "hostlink/app/services/updatepreflight" + "hostlink/internal/update" +) + +// --- Registration Tests --- + +func TestRegister_StartsGoroutineWithTrigger(t *testing.T) { + triggered := make(chan struct{}) + job := NewWithConfig(SelfUpdateJobConfig{ + Trigger: func(ctx context.Context, fn func() error) { + close(triggered) + <-ctx.Done() + }, + }) + + ctx := context.Background() + cancel := job.Register(ctx) + defer cancel() + + select { + case <-triggered: + // success + case <-time.After(time.Second): + t.Fatal("trigger was not called within timeout") + } +} + +func TestRegister_ReturnsCancelFunc(t *testing.T) { + var ctxCancelled atomic.Bool + job := NewWithConfig(SelfUpdateJobConfig{ + Trigger: func(ctx context.Context, fn func() error) { + <-ctx.Done() + ctxCancelled.Store(true) + }, + }) + + ctx := context.Background() + cancel := job.Register(ctx) + + cancel() + time.Sleep(50 * time.Millisecond) + + if !ctxCancelled.Load() { + t.Error("expected context to be cancelled after cancel() called") + } +} + +func TestShutdown_WaitsForGoroutine(t *testing.T) { + var goroutineExited atomic.Bool + job := NewWithConfig(SelfUpdateJobConfig{ + Trigger: func(ctx context.Context, fn func() error) { + <-ctx.Done() + time.Sleep(50 * time.Millisecond) // simulate cleanup + goroutineExited.Store(true) + }, + }) + + ctx := context.Background() + job.Register(ctx) + job.Shutdown() + + if !goroutineExited.Load() { + t.Error("Shutdown() returned before goroutine exited") + } +} + +func TestRegister_RespectsParentContextCancellation(t *testing.T) { + var ctxCancelled atomic.Bool + job := NewWithConfig(SelfUpdateJobConfig{ + Trigger: func(ctx context.Context, fn func() error) { + <-ctx.Done() + ctxCancelled.Store(true) + }, + }) + + ctx, parentCancel := context.WithCancel(context.Background()) + job.Register(ctx) + + parentCancel() + time.Sleep(50 * time.Millisecond) + + if !ctxCancelled.Load() { + t.Error("expected trigger context to be cancelled when parent context is cancelled") + } +} + +// --- Trigger Tests --- + +func TestDefaultTriggerConfig_OneHourInterval(t *testing.T) { + cfg := DefaultTriggerConfig() + if cfg.Interval != 1*time.Hour { + t.Errorf("expected default interval 1h, got %v", cfg.Interval) + } +} + +func TestTriggerWithConfig_CallsFnOnInterval(t *testing.T) { + var callCount atomic.Int32 + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan struct{}) + go func() { + TriggerWithConfig(ctx, func() error { + callCount.Add(1) + if callCount.Load() >= 3 { + cancel() + } + return nil + }, TriggerConfig{Interval: 10 * time.Millisecond}) + close(done) + }() + + select { + case <-done: + // success + case <-time.After(time.Second): + cancel() + t.Fatal("trigger did not call fn 3 times within timeout") + } + + if callCount.Load() < 3 { + t.Errorf("expected at least 3 calls, got %d", callCount.Load()) + } +} + +func TestTriggerWithConfig_StopsOnContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan struct{}) + go func() { + TriggerWithConfig(ctx, func() error { + return nil + }, TriggerConfig{Interval: 10 * time.Millisecond}) + close(done) + }() + + cancel() + + select { + case <-done: + // success - trigger exited + case <-time.After(time.Second): + t.Fatal("trigger did not stop after context cancel") + } +} + +func TestTriggerWithConfig_ContinuesOnError(t *testing.T) { + var callCount atomic.Int32 + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan struct{}) + go func() { + TriggerWithConfig(ctx, func() error { + callCount.Add(1) + if callCount.Load() >= 3 { + cancel() + } + return errors.New("some error") + }, TriggerConfig{Interval: 10 * time.Millisecond}) + close(done) + }() + + select { + case <-done: + // success + case <-time.After(time.Second): + cancel() + t.Fatal("trigger did not continue after errors") + } + + if callCount.Load() < 3 { + t.Errorf("expected at least 3 calls despite errors, got %d", callCount.Load()) + } +} + +// --- Update Flow Tests --- + +func TestUpdateFlow_SkipsWhenNoUpdate(t *testing.T) { + checker := &mockUpdateChecker{ + result: &updatecheck.UpdateInfo{UpdateAvailable: false}, + } + downloader := &mockDownloader{} + + job := NewWithConfig(SelfUpdateJobConfig{ + UpdateChecker: checker, + Downloader: downloader, + CurrentVersion: "1.0.0", + }) + + job.runUpdate(context.Background()) + + if downloader.callCount.Load() > 0 { + t.Error("downloader should not be called when no update available") + } +} + +func TestUpdateFlow_SkipsWhenPreflightFails(t *testing.T) { + checker := &mockUpdateChecker{ + result: &updatecheck.UpdateInfo{ + UpdateAvailable: true, + TargetVersion: "2.0.0", + AgentURL: "https://example.com/agent.tar.gz", + AgentSHA256: "abc", + }, + } + preflight := &mockPreflight{ + result: &updatepreflight.PreflightResult{ + Passed: false, + Errors: []string{"disk full"}, + }, + } + downloader := &mockDownloader{} + + job := NewWithConfig(SelfUpdateJobConfig{ + UpdateChecker: checker, + Downloader: downloader, + PreflightChecker: preflight, + CurrentVersion: "1.0.0", + }) + + job.runUpdate(context.Background()) + + if downloader.callCount.Load() > 0 { + t.Error("downloader should not be called when preflight fails") + } +} + +func TestUpdateFlow_FullFlow(t *testing.T) { + checker := &mockUpdateChecker{ + result: &updatecheck.UpdateInfo{ + UpdateAvailable: true, + TargetVersion: "2.0.0", + AgentURL: "https://example.com/agent.tar.gz", + AgentSHA256: "abc123", + UpdaterURL: "https://example.com/updater.tar.gz", + UpdaterSHA256: "def456", + }, + } + preflight := &mockPreflight{ + result: &updatepreflight.PreflightResult{Passed: true}, + } + lock := &mockLockManager{} + state := &mockStateWriter{} + downloader := &mockDownloader{} + spawner := &mockSpawner{} + + job := NewWithConfig(SelfUpdateJobConfig{ + UpdateChecker: checker, + Downloader: downloader, + PreflightChecker: preflight, + LockManager: lock, + StateWriter: state, + Spawn: spawner.spawn, + InstallUpdater: noopInstaller, + CurrentVersion: "1.0.0", + UpdaterPath: "/tmp/updater", + StagingDir: "/tmp/staging", + BaseDir: "/var/lib/hostlink/updates", + }) + + err := job.runUpdate(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify full flow: check → preflight → lock → state(init) → download agent → download updater → state(staged) → unlock → spawn + if checker.callCount.Load() != 1 { + t.Errorf("expected 1 check call, got %d", checker.callCount.Load()) + } + if preflight.callCount.Load() != 1 { + t.Errorf("expected 1 preflight call, got %d", preflight.callCount.Load()) + } + if lock.lockCount.Load() != 1 { + t.Errorf("expected 1 lock call, got %d", lock.lockCount.Load()) + } + if downloader.callCount.Load() != 2 { + t.Errorf("expected 2 download calls (agent + updater), got %d", downloader.callCount.Load()) + } + if lock.unlockCount.Load() != 1 { + t.Errorf("expected 1 unlock call, got %d", lock.unlockCount.Load()) + } + if spawner.callCount.Load() != 1 { + t.Errorf("expected 1 spawn call, got %d", spawner.callCount.Load()) + } + + // Verify state transitions + states := state.getStates() + if len(states) < 2 { + t.Fatalf("expected at least 2 state writes, got %d", len(states)) + } + if states[0] != update.StateInitialized { + t.Errorf("expected first state to be Initialized, got %s", states[0]) + } + if states[1] != update.StateStaged { + t.Errorf("expected second state to be Staged, got %s", states[1]) + } +} + +func TestUpdateFlow_UnlocksBeforeSpawn(t *testing.T) { + var sequence []string + var mu sync.Mutex + + checker := &mockUpdateChecker{ + result: &updatecheck.UpdateInfo{ + UpdateAvailable: true, + TargetVersion: "2.0.0", + AgentURL: "https://example.com/agent.tar.gz", + AgentSHA256: "abc", + UpdaterURL: "https://example.com/updater.tar.gz", + UpdaterSHA256: "def", + }, + } + preflight := &mockPreflight{ + result: &updatepreflight.PreflightResult{Passed: true}, + } + lock := &mockLockManager{ + onUnlock: func() { + mu.Lock() + sequence = append(sequence, "unlock") + mu.Unlock() + }, + } + state := &mockStateWriter{} + downloader := &mockDownloader{} + spawner := &mockSpawner{ + onSpawn: func() { + mu.Lock() + sequence = append(sequence, "spawn") + mu.Unlock() + }, + } + + job := NewWithConfig(SelfUpdateJobConfig{ + UpdateChecker: checker, + Downloader: downloader, + PreflightChecker: preflight, + LockManager: lock, + StateWriter: state, + Spawn: spawner.spawn, + InstallUpdater: noopInstaller, + CurrentVersion: "1.0.0", + UpdaterPath: "/tmp/updater", + StagingDir: "/tmp/staging", + BaseDir: "/var/lib/hostlink/updates", + }) + + err := job.runUpdate(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + if len(sequence) < 2 { + t.Fatalf("expected at least 2 sequence entries, got %d: %v", len(sequence), sequence) + } + unlockIdx := -1 + spawnIdx := -1 + for i, s := range sequence { + if s == "unlock" { + unlockIdx = i + } + if s == "spawn" { + spawnIdx = i + } + } + if unlockIdx == -1 || spawnIdx == -1 { + t.Fatalf("missing unlock or spawn in sequence: %v", sequence) + } + if unlockIdx > spawnIdx { + t.Error("unlock must happen before spawn") + } +} + +func TestUpdateFlow_DownloadFailure(t *testing.T) { + checker := &mockUpdateChecker{ + result: &updatecheck.UpdateInfo{ + UpdateAvailable: true, + TargetVersion: "2.0.0", + AgentURL: "https://example.com/agent.tar.gz", + AgentSHA256: "abc", + UpdaterURL: "https://example.com/updater.tar.gz", + UpdaterSHA256: "def", + }, + } + preflight := &mockPreflight{ + result: &updatepreflight.PreflightResult{Passed: true}, + } + lock := &mockLockManager{} + state := &mockStateWriter{} + downloader := &mockDownloader{err: errors.New("download failed")} + spawner := &mockSpawner{} + + job := NewWithConfig(SelfUpdateJobConfig{ + UpdateChecker: checker, + Downloader: downloader, + PreflightChecker: preflight, + LockManager: lock, + StateWriter: state, + Spawn: spawner.spawn, + InstallUpdater: noopInstaller, + CurrentVersion: "1.0.0", + UpdaterPath: "/tmp/updater", + StagingDir: "/tmp/staging", + BaseDir: "/var/lib/hostlink/updates", + }) + + job.runUpdate(context.Background()) + + if spawner.callCount.Load() > 0 { + t.Error("spawner should not be called when download fails") + } + // Lock should still be released on failure + if lock.unlockCount.Load() != 1 { + t.Errorf("expected lock to be released on failure, unlock count: %d", lock.unlockCount.Load()) + } +} + +func TestUpdateFlow_ChecksumMismatch(t *testing.T) { + checker := &mockUpdateChecker{ + result: &updatecheck.UpdateInfo{ + UpdateAvailable: true, + TargetVersion: "2.0.0", + AgentURL: "https://example.com/agent.tar.gz", + AgentSHA256: "abc", + UpdaterURL: "https://example.com/updater.tar.gz", + UpdaterSHA256: "def", + }, + } + preflight := &mockPreflight{ + result: &updatepreflight.PreflightResult{Passed: true}, + } + lock := &mockLockManager{} + state := &mockStateWriter{} + // First call (agent) succeeds, second (updater) fails + downloader := &mockDownloader{failOnCall: 2, err: errors.New("checksum mismatch")} + spawner := &mockSpawner{} + + job := NewWithConfig(SelfUpdateJobConfig{ + UpdateChecker: checker, + Downloader: downloader, + PreflightChecker: preflight, + LockManager: lock, + StateWriter: state, + Spawn: spawner.spawn, + InstallUpdater: noopInstaller, + CurrentVersion: "1.0.0", + UpdaterPath: "/tmp/updater", + StagingDir: "/tmp/staging", + BaseDir: "/var/lib/hostlink/updates", + }) + + job.runUpdate(context.Background()) + + if spawner.callCount.Load() > 0 { + t.Error("spawner should not be called when checksum fails") + } + if lock.unlockCount.Load() != 1 { + t.Errorf("expected lock to be released on failure, unlock count: %d", lock.unlockCount.Load()) + } +} + +func TestUpdateFlow_SpawnArgs(t *testing.T) { + checker := &mockUpdateChecker{ + result: &updatecheck.UpdateInfo{ + UpdateAvailable: true, + TargetVersion: "2.0.0", + AgentURL: "https://example.com/agent.tar.gz", + AgentSHA256: "abc", + UpdaterURL: "https://example.com/updater.tar.gz", + UpdaterSHA256: "def", + }, + } + preflight := &mockPreflight{ + result: &updatepreflight.PreflightResult{Passed: true}, + } + lock := &mockLockManager{} + state := &mockStateWriter{} + downloader := &mockDownloader{} + spawner := &mockSpawner{} + + job := NewWithConfig(SelfUpdateJobConfig{ + UpdateChecker: checker, + Downloader: downloader, + PreflightChecker: preflight, + LockManager: lock, + StateWriter: state, + Spawn: spawner.spawn, + InstallUpdater: noopInstaller, + CurrentVersion: "1.0.0", + UpdaterPath: "/opt/updater/hostlink-updater", + StagingDir: "/tmp/staging", + BaseDir: "/var/lib/hostlink/updates", + }) + + err := job.runUpdate(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if spawner.lastPath != "/opt/updater/hostlink-updater" { + t.Errorf("expected updater path /opt/updater/hostlink-updater, got %s", spawner.lastPath) + } + // Verify args contain -version and -base-dir + foundVersion := false + foundDir := false + for i, arg := range spawner.lastArgs { + if arg == "-version" && i+1 < len(spawner.lastArgs) && spawner.lastArgs[i+1] == "2.0.0" { + foundVersion = true + } + if arg == "-base-dir" && i+1 < len(spawner.lastArgs) && spawner.lastArgs[i+1] == "/var/lib/hostlink/updates" { + foundDir = true + } + } + if !foundVersion { + t.Errorf("expected -version 2.0.0 in spawn args, got %v", spawner.lastArgs) + } + if !foundDir { + t.Errorf("expected -base-dir /var/lib/hostlink/updates in spawn args, got %v", spawner.lastArgs) + } +} + +func TestUpdateFlow_PassesDownloadSizeToPreflight(t *testing.T) { + checker := &mockUpdateChecker{ + result: &updatecheck.UpdateInfo{ + UpdateAvailable: true, + TargetVersion: "2.0.0", + AgentURL: "https://example.com/agent.tar.gz", + AgentSHA256: "abc", + AgentSize: 30 * 1024 * 1024, // 30MB + UpdaterURL: "https://example.com/updater.tar.gz", + UpdaterSHA256: "def", + UpdaterSize: 5 * 1024 * 1024, // 5MB + }, + } + preflight := &mockPreflight{ + result: &updatepreflight.PreflightResult{Passed: true}, + } + lock := &mockLockManager{} + state := &mockStateWriter{} + downloader := &mockDownloader{} + spawner := &mockSpawner{} + + job := NewWithConfig(SelfUpdateJobConfig{ + UpdateChecker: checker, + Downloader: downloader, + PreflightChecker: preflight, + LockManager: lock, + StateWriter: state, + Spawn: spawner.spawn, + InstallUpdater: noopInstaller, + CurrentVersion: "1.0.0", + UpdaterPath: "/tmp/updater", + StagingDir: "/tmp/staging", + BaseDir: "/var/lib/hostlink/updates", + }) + + err := job.runUpdate(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := int64(35 * 1024 * 1024) // 30MB + 5MB + if preflight.getLastRequiredSpace() != expected { + t.Errorf("expected preflight requiredSpace %d, got %d", expected, preflight.getLastRequiredSpace()) + } +} + +func TestUpdateFlow_FallsBackTo50MB_WhenSizesZero(t *testing.T) { + checker := &mockUpdateChecker{ + result: &updatecheck.UpdateInfo{ + UpdateAvailable: true, + TargetVersion: "2.0.0", + AgentURL: "https://example.com/agent.tar.gz", + AgentSHA256: "abc", + UpdaterURL: "https://example.com/updater.tar.gz", + UpdaterSHA256: "def", + // AgentSize and UpdaterSize are zero (not provided by control plane) + }, + } + preflight := &mockPreflight{ + result: &updatepreflight.PreflightResult{Passed: true}, + } + lock := &mockLockManager{} + state := &mockStateWriter{} + downloader := &mockDownloader{} + spawner := &mockSpawner{} + + job := NewWithConfig(SelfUpdateJobConfig{ + UpdateChecker: checker, + Downloader: downloader, + PreflightChecker: preflight, + LockManager: lock, + StateWriter: state, + Spawn: spawner.spawn, + InstallUpdater: noopInstaller, + CurrentVersion: "1.0.0", + UpdaterPath: "/tmp/updater", + StagingDir: "/tmp/staging", + BaseDir: "/var/lib/hostlink/updates", + }) + + err := job.runUpdate(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := int64(50 * 1024 * 1024) // 50MB fallback + if preflight.getLastRequiredSpace() != expected { + t.Errorf("expected preflight requiredSpace %d (50MB fallback), got %d", expected, preflight.getLastRequiredSpace()) + } +} + +func TestUpdateFlow_AgentDestUsesCanonicalTarballName(t *testing.T) { + checker := &mockUpdateChecker{ + result: &updatecheck.UpdateInfo{ + UpdateAvailable: true, + TargetVersion: "2.0.0", + AgentURL: "https://example.com/agent.tar.gz", + AgentSHA256: "abc", + UpdaterURL: "https://example.com/updater.tar.gz", + UpdaterSHA256: "def", + }, + } + preflight := &mockPreflight{ + result: &updatepreflight.PreflightResult{Passed: true}, + } + lock := &mockLockManager{} + state := &mockStateWriter{} + downloader := &mockDownloader{} + spawner := &mockSpawner{} + + job := NewWithConfig(SelfUpdateJobConfig{ + UpdateChecker: checker, + Downloader: downloader, + PreflightChecker: preflight, + LockManager: lock, + StateWriter: state, + Spawn: spawner.spawn, + InstallUpdater: noopInstaller, + CurrentVersion: "1.0.0", + UpdaterPath: "/tmp/updater", + StagingDir: "/tmp/staging", + BaseDir: "/var/lib/hostlink/updates", + }) + + err := job.runUpdate(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Agent tarball must use the canonical name expected by the updater + expectedAgentDest := "/tmp/staging/" + updatedownload.AgentTarballName + if len(downloader.destPaths) < 1 || downloader.destPaths[0] != expectedAgentDest { + t.Errorf("expected agent dest path %q, got %v", expectedAgentDest, downloader.destPaths) + } +} + +func TestUpdateFlow_ExtractsUpdaterBeforeSpawn(t *testing.T) { + checker := &mockUpdateChecker{ + result: &updatecheck.UpdateInfo{ + UpdateAvailable: true, + TargetVersion: "2.0.0", + AgentURL: "https://example.com/agent.tar.gz", + AgentSHA256: "abc", + UpdaterURL: "https://example.com/updater.tar.gz", + UpdaterSHA256: "def", + }, + } + preflight := &mockPreflight{ + result: &updatepreflight.PreflightResult{Passed: true}, + } + lock := &mockLockManager{} + state := &mockStateWriter{} + downloader := &mockDownloader{} + spawner := &mockSpawner{} + installer := &mockUpdaterInstaller{} + + job := NewWithConfig(SelfUpdateJobConfig{ + UpdateChecker: checker, + Downloader: downloader, + PreflightChecker: preflight, + LockManager: lock, + StateWriter: state, + Spawn: spawner.spawn, + InstallUpdater: installer.install, + CurrentVersion: "1.0.0", + UpdaterPath: "/opt/updater/hostlink-updater", + StagingDir: "/tmp/staging", + BaseDir: "/var/lib/hostlink/updates", + }) + + err := job.runUpdate(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify InstallUpdater was called with correct args + if installer.callCount.Load() != 1 { + t.Fatalf("expected 1 install call, got %d", installer.callCount.Load()) + } + expectedTarPath := "/tmp/staging/" + updatedownload.UpdaterTarballName + if installer.lastTarPath != expectedTarPath { + t.Errorf("expected tarPath %q, got %q", expectedTarPath, installer.lastTarPath) + } + if installer.lastDestPath != "/opt/updater/hostlink-updater" { + t.Errorf("expected destPath %q, got %q", "/opt/updater/hostlink-updater", installer.lastDestPath) + } + // Verify spawn was still called (extraction didn't block it) + if spawner.callCount.Load() != 1 { + t.Errorf("expected 1 spawn call, got %d", spawner.callCount.Load()) + } +} + +func TestUpdateFlow_ExtractFailure_PreventsSpawn(t *testing.T) { + checker := &mockUpdateChecker{ + result: &updatecheck.UpdateInfo{ + UpdateAvailable: true, + TargetVersion: "2.0.0", + AgentURL: "https://example.com/agent.tar.gz", + AgentSHA256: "abc", + UpdaterURL: "https://example.com/updater.tar.gz", + UpdaterSHA256: "def", + }, + } + preflight := &mockPreflight{ + result: &updatepreflight.PreflightResult{Passed: true}, + } + lock := &mockLockManager{} + state := &mockStateWriter{} + downloader := &mockDownloader{} + spawner := &mockSpawner{} + installer := &mockUpdaterInstaller{err: errors.New("extraction failed")} + + job := NewWithConfig(SelfUpdateJobConfig{ + UpdateChecker: checker, + Downloader: downloader, + PreflightChecker: preflight, + LockManager: lock, + StateWriter: state, + Spawn: spawner.spawn, + InstallUpdater: installer.install, + CurrentVersion: "1.0.0", + UpdaterPath: "/opt/updater/hostlink-updater", + StagingDir: "/tmp/staging", + BaseDir: "/var/lib/hostlink/updates", + }) + + job.runUpdate(context.Background()) + + // Extraction failed, so spawn should NOT be called + if spawner.callCount.Load() != 0 { + t.Error("spawner should not be called when extraction fails") + } + // Lock should still be released + if lock.unlockCount.Load() != 1 { + t.Errorf("expected lock to be released on failure, unlock count: %d", lock.unlockCount.Load()) + } +} + +func TestUpdateFlow_ContextCancelledBetweenDownloads(t *testing.T) { + var cancelCtx context.CancelFunc + + checker := &mockUpdateChecker{ + result: &updatecheck.UpdateInfo{ + UpdateAvailable: true, + TargetVersion: "2.0.0", + AgentURL: "https://example.com/agent.tar.gz", + AgentSHA256: "abc", + UpdaterURL: "https://example.com/updater.tar.gz", + UpdaterSHA256: "def", + }, + } + preflight := &mockPreflight{ + result: &updatepreflight.PreflightResult{Passed: true}, + } + lock := &mockLockManager{} + state := &mockStateWriter{} + downloader := &mockDownloader{ + onCall: func(count int32) { + if count == 1 { + // Cancel context after first download (agent) completes + cancelCtx() + } + }, + } + spawner := &mockSpawner{} + + job := NewWithConfig(SelfUpdateJobConfig{ + UpdateChecker: checker, + Downloader: downloader, + PreflightChecker: preflight, + LockManager: lock, + StateWriter: state, + Spawn: spawner.spawn, + InstallUpdater: noopInstaller, + CurrentVersion: "1.0.0", + UpdaterPath: "/tmp/updater", + StagingDir: "/tmp/staging", + BaseDir: "/var/lib/hostlink/updates", + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancelCtx = cancel + + err := job.runUpdate(ctx) + + // Should return context.Canceled error + if err == nil { + t.Fatal("expected error from cancelled context, got nil") + } + if !errors.Is(err, context.Canceled) { + t.Errorf("expected context.Canceled error, got: %v", err) + } + // Second download (updater) should NOT have been called + if downloader.callCount.Load() != 1 { + t.Errorf("expected only 1 download call (agent), got %d", downloader.callCount.Load()) + } + // Spawn should NOT have been called + if spawner.callCount.Load() != 0 { + t.Error("spawner should not be called when context is cancelled") + } +} + +// --- Helpers --- + +// noopInstaller is a no-op InstallUpdaterFunc for tests that don't care about extraction. +func noopInstaller(tarPath, destPath string) error { return nil } + +// immediateTrigger calls fn exactly n times synchronously then returns. +func immediateTrigger(n int) TriggerFunc { + return func(ctx context.Context, fn func() error) { + for i := 0; i < n; i++ { + fn() + } + } +} + +// --- Mocks --- + +type mockUpdateChecker struct { + result *updatecheck.UpdateInfo + err error + callCount atomic.Int32 +} + +func (m *mockUpdateChecker) Check(currentVersion string) (*updatecheck.UpdateInfo, error) { + m.callCount.Add(1) + return m.result, m.err +} + +type mockPreflight struct { + result *updatepreflight.PreflightResult + callCount atomic.Int32 + lastRequiredSpace int64 + mu sync.Mutex +} + +func (m *mockPreflight) Check(requiredSpace int64) *updatepreflight.PreflightResult { + m.callCount.Add(1) + m.mu.Lock() + m.lastRequiredSpace = requiredSpace + m.mu.Unlock() + return m.result +} + +func (m *mockPreflight) getLastRequiredSpace() int64 { + m.mu.Lock() + defer m.mu.Unlock() + return m.lastRequiredSpace +} + +type mockLockManager struct { + lockErr error + unlockErr error + lockCount atomic.Int32 + unlockCount atomic.Int32 + onUnlock func() +} + +func (m *mockLockManager) TryLockWithRetry(expiration time.Duration, retries int, interval time.Duration) error { + m.lockCount.Add(1) + return m.lockErr +} + +func (m *mockLockManager) Unlock() error { + m.unlockCount.Add(1) + if m.onUnlock != nil { + m.onUnlock() + } + return m.unlockErr +} + +type mockStateWriter struct { + mu sync.Mutex + states []update.State + err error +} + +func (m *mockStateWriter) Write(data update.StateData) error { + m.mu.Lock() + defer m.mu.Unlock() + m.states = append(m.states, data.State) + return m.err +} + +func (m *mockStateWriter) getStates() []update.State { + m.mu.Lock() + defer m.mu.Unlock() + result := make([]update.State, len(m.states)) + copy(result, m.states) + return result +} + +type mockDownloader struct { + err error + failOnCall int32 // fail on this call number (1-indexed), 0 means all fail if err is set + callCount atomic.Int32 + destPaths []string + onCall func(count int32) // called after each download with 1-indexed call number + mu sync.Mutex +} + +func (m *mockDownloader) DownloadAndVerify(ctx context.Context, url, destPath, sha256 string) (*updatedownload.DownloadResult, error) { + count := m.callCount.Add(1) + m.mu.Lock() + m.destPaths = append(m.destPaths, destPath) + m.mu.Unlock() + if m.onCall != nil { + m.onCall(count) + } + if m.err != nil { + if m.failOnCall == 0 || count == m.failOnCall { + return nil, m.err + } + } + return nil, nil +} + +type mockSpawner struct { + err error + callCount atomic.Int32 + lastPath string + lastArgs []string + onSpawn func() + mu sync.Mutex +} + +func (m *mockSpawner) spawn(updaterPath string, args []string) error { + m.callCount.Add(1) + m.mu.Lock() + m.lastPath = updaterPath + m.lastArgs = args + m.mu.Unlock() + if m.onSpawn != nil { + m.onSpawn() + } + return m.err +} + +type mockUpdaterInstaller struct { + err error + callCount atomic.Int32 + lastTarPath string + lastDestPath string + mu sync.Mutex +} + +func (m *mockUpdaterInstaller) install(tarPath, destPath string) error { + m.callCount.Add(1) + m.mu.Lock() + m.lastTarPath = tarPath + m.lastDestPath = destPath + m.mu.Unlock() + return m.err +} diff --git a/app/jobs/selfupdatejob/trigger.go b/app/jobs/selfupdatejob/trigger.go new file mode 100644 index 0000000..f91a237 --- /dev/null +++ b/app/jobs/selfupdatejob/trigger.go @@ -0,0 +1,39 @@ +package selfupdatejob + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" +) + +// TriggerConfig holds configuration for the trigger. +type TriggerConfig struct { + Interval time.Duration +} + +// DefaultTriggerConfig returns the default trigger configuration (1 hour interval). +func DefaultTriggerConfig() TriggerConfig { + return TriggerConfig{ + Interval: 1 * time.Hour, + } +} + +// TriggerWithConfig runs fn on the configured interval until ctx is cancelled. +func TriggerWithConfig(ctx context.Context, fn func() error, config TriggerConfig) { + for { + select { + case <-ctx.Done(): + return + case <-time.After(config.Interval): + if err := fn(); err != nil { + log.Errorf("self-update check failed: %s", err) + } + } + } +} + +// Trigger runs fn with the default configuration. +func Trigger(ctx context.Context, fn func() error) { + TriggerWithConfig(ctx, fn, DefaultTriggerConfig()) +} diff --git a/app/services/updatecheck/updatecheck.go b/app/services/updatecheck/updatecheck.go new file mode 100644 index 0000000..a509777 --- /dev/null +++ b/app/services/updatecheck/updatecheck.go @@ -0,0 +1,81 @@ +package updatecheck + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" +) + +// UpdateInfo represents the response from the update check endpoint. +type UpdateInfo struct { + UpdateAvailable bool `json:"update_available"` + TargetVersion string `json:"target_version"` + AgentURL string `json:"agent_url"` + AgentSHA256 string `json:"agent_sha256"` + AgentSize int64 `json:"agent_size"` + UpdaterURL string `json:"updater_url"` + UpdaterSHA256 string `json:"updater_sha256"` + UpdaterSize int64 `json:"updater_size"` +} + +// RequestSignerInterface abstracts request signing for testability. +type RequestSignerInterface interface { + SignRequest(req *http.Request) error +} + +// UpdateChecker checks for available updates from the control plane. +type UpdateChecker struct { + client *http.Client + controlPlaneURL string + agentID string + signer RequestSignerInterface +} + +// New creates a new UpdateChecker. Returns an error if agentID is empty. +func New(client *http.Client, controlPlaneURL, agentID string, signer RequestSignerInterface) (*UpdateChecker, error) { + if agentID == "" { + return nil, errors.New("agentID must not be empty") + } + return &UpdateChecker{ + client: client, + controlPlaneURL: controlPlaneURL, + agentID: agentID, + signer: signer, + }, nil +} + +// Check queries the control plane for available updates. +func (c *UpdateChecker) Check(currentVersion string) (*UpdateInfo, error) { + url := fmt.Sprintf("%s/api/v1/agents/%s/update", c.controlPlaneURL, c.agentID) + + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("X-Agent-Version", currentVersion) + + if c.signer != nil { + if err := c.signer.SignRequest(req); err != nil { + return nil, fmt.Errorf("failed to sign request: %w", err) + } + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("update check request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("update check returned status %d", resp.StatusCode) + } + + var info UpdateInfo + if err := json.NewDecoder(resp.Body).Decode(&info); err != nil { + return nil, fmt.Errorf("failed to decode update check response: %w", err) + } + + return &info, nil +} diff --git a/app/services/updatecheck/updatecheck_test.go b/app/services/updatecheck/updatecheck_test.go new file mode 100644 index 0000000..35eab14 --- /dev/null +++ b/app/services/updatecheck/updatecheck_test.go @@ -0,0 +1,248 @@ +package updatecheck + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "hostlink/app/services/requestsigner" +) + +func TestCheck_UpdateAvailable(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := UpdateInfo{ + UpdateAvailable: true, + TargetVersion: "2.0.0", + AgentURL: "https://example.com/agent.tar.gz", + AgentSHA256: "abc123", + AgentSize: 52428800, + UpdaterURL: "https://example.com/updater.tar.gz", + UpdaterSHA256: "def456", + UpdaterSize: 10485760, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + checker := newTestChecker(t, server.Client(), server.URL, nil) + info, err := checker.Check("1.0.0") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !info.UpdateAvailable { + t.Error("expected UpdateAvailable to be true") + } + if info.TargetVersion != "2.0.0" { + t.Errorf("expected TargetVersion 2.0.0, got %s", info.TargetVersion) + } + if info.AgentURL != "https://example.com/agent.tar.gz" { + t.Errorf("expected AgentURL https://example.com/agent.tar.gz, got %s", info.AgentURL) + } + if info.AgentSHA256 != "abc123" { + t.Errorf("expected AgentSHA256 abc123, got %s", info.AgentSHA256) + } + if info.UpdaterURL != "https://example.com/updater.tar.gz" { + t.Errorf("expected UpdaterURL https://example.com/updater.tar.gz, got %s", info.UpdaterURL) + } + if info.AgentSize != 52428800 { + t.Errorf("expected AgentSize 52428800, got %d", info.AgentSize) + } + if info.UpdaterSHA256 != "def456" { + t.Errorf("expected UpdaterSHA256 def456, got %s", info.UpdaterSHA256) + } + if info.UpdaterSize != 10485760 { + t.Errorf("expected UpdaterSize 10485760, got %d", info.UpdaterSize) + } +} + +func TestCheck_NoUpdateAvailable(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := UpdateInfo{ + UpdateAvailable: false, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + checker := newTestChecker(t, server.Client(), server.URL, nil) + info, err := checker.Check("2.0.0") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if info.UpdateAvailable { + t.Error("expected UpdateAvailable to be false") + } +} + +func TestCheck_NetworkError(t *testing.T) { + checker := newTestChecker(t, http.DefaultClient, "http://localhost:1", nil) + _, err := checker.Check("1.0.0") + if err == nil { + t.Fatal("expected error for bad URL, got nil") + } +} + +func TestCheck_SignsRequest(t *testing.T) { + var receivedHeaders http.Header + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header + resp := UpdateInfo{UpdateAvailable: false} + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + signer := &mockSigner{agentID: "agent-123"} + checker := newTestChecker(t, server.Client(), server.URL, signer) + _, err := checker.Check("1.0.0") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if receivedHeaders.Get("X-Agent-ID") != "agent-123" { + t.Errorf("expected X-Agent-ID header agent-123, got %s", receivedHeaders.Get("X-Agent-ID")) + } + if receivedHeaders.Get("X-Timestamp") == "" { + t.Error("expected X-Timestamp header to be set") + } + if receivedHeaders.Get("X-Nonce") == "" { + t.Error("expected X-Nonce header to be set") + } + if receivedHeaders.Get("X-Signature") == "" { + t.Error("expected X-Signature header to be set") + } +} + +func TestCheck_SendsCurrentVersionAsHeader(t *testing.T) { + var receivedVersion string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedVersion = r.Header.Get("X-Agent-Version") + resp := UpdateInfo{UpdateAvailable: false} + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + checker := newTestChecker(t, server.Client(), server.URL, nil) + _, err := checker.Check("1.5.3") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if receivedVersion != "1.5.3" { + t.Errorf("expected X-Agent-Version header 1.5.3, got %s", receivedVersion) + } +} + +func TestCheck_NoQueryParams(t *testing.T) { + var receivedRawQuery string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedRawQuery = r.URL.RawQuery + resp := UpdateInfo{UpdateAvailable: false} + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + checker := newTestChecker(t, server.Client(), server.URL, nil) + checker.Check("1.5.3") + + if receivedRawQuery != "" { + t.Errorf("expected no query params, got %s", receivedRawQuery) + } +} + +func TestCheck_HTTPErrorStatus(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + checker := newTestChecker(t, server.Client(), server.URL, nil) + _, err := checker.Check("1.0.0") + if err == nil { + t.Fatal("expected error for 500 status, got nil") + } +} + +func TestCheck_InvalidJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("not json")) + })) + defer server.Close() + + checker := newTestChecker(t, server.Client(), server.URL, nil) + _, err := checker.Check("1.0.0") + if err == nil { + t.Fatal("expected error for invalid JSON, got nil") + } +} + +func TestCheck_UsesGETMethod(t *testing.T) { + var receivedMethod string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedMethod = r.Method + resp := UpdateInfo{UpdateAvailable: false} + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + checker := newTestChecker(t, server.Client(), server.URL, nil) + checker.Check("1.0.0") + + if receivedMethod != http.MethodGet { + t.Errorf("expected GET method, got %s", receivedMethod) + } +} + +func TestCheck_UsesCorrectPath(t *testing.T) { + var receivedPath string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedPath = r.URL.Path + resp := UpdateInfo{UpdateAvailable: false} + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + checker, err := New(server.Client(), server.URL, "agent-123", nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + checker.Check("1.0.0") + + if receivedPath != "/api/v1/agents/agent-123/update" { + t.Errorf("expected path /api/v1/agents/agent-123/update, got %s", receivedPath) + } +} + +func TestNew_ReturnsErrorForEmptyAgentID(t *testing.T) { + _, err := New(http.DefaultClient, "http://example.com", "", nil) + if err == nil { + t.Fatal("expected error for empty agentID, got nil") + } +} + +// newTestChecker is a test helper that creates an UpdateChecker with a default agent ID. +func newTestChecker(t *testing.T, client *http.Client, url string, signer RequestSignerInterface) *UpdateChecker { + t.Helper() + checker, err := New(client, url, "agent-123", signer) + if err != nil { + t.Fatalf("failed to create UpdateChecker: %v", err) + } + return checker +} + +// mockSigner implements the RequestSigner interface for testing +type mockSigner struct { + agentID string +} + +func (m *mockSigner) SignRequest(req *http.Request) error { + req.Header.Set("X-Agent-ID", m.agentID) + req.Header.Set("X-Timestamp", "1234567890") + req.Header.Set("X-Nonce", "testnonce") + req.Header.Set("X-Signature", "testsignature") + return nil +} + +// Ensure mockSigner satisfies the interface +var _ RequestSignerInterface = (*mockSigner)(nil) +var _ RequestSignerInterface = (*requestsigner.RequestSigner)(nil) diff --git a/app/services/updatedownload/download.go b/app/services/updatedownload/download.go new file mode 100644 index 0000000..8169222 --- /dev/null +++ b/app/services/updatedownload/download.go @@ -0,0 +1,281 @@ +// Package updatedownload provides functionality for downloading and verifying update artifacts. +package updatedownload + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "time" +) + +var ( + // ErrDownloadFailed is returned when download fails after all retries. + ErrDownloadFailed = errors.New("download failed after retries") + // ErrChecksumMismatch is returned when SHA256 verification fails. + ErrChecksumMismatch = errors.New("SHA256 checksum mismatch") +) + +// DownloadConfig configures the download behavior. +type DownloadConfig struct { + MaxRetries int + InitialBackoff time.Duration + MaxBackoff time.Duration + BackoffMultiplier int + Timeout time.Duration +} + +// DefaultDownloadConfig returns the default download configuration. +// Retries: 5, Backoff: 5s -> 10s -> 20s -> 40s -> 60s (capped) +func DefaultDownloadConfig() *DownloadConfig { + return &DownloadConfig{ + MaxRetries: 5, + InitialBackoff: 5 * time.Second, + MaxBackoff: 60 * time.Second, + BackoffMultiplier: 2, + Timeout: 5 * time.Minute, + } +} + +// Downloader downloads files with retry and exponential backoff. +type Downloader struct { + client *http.Client + config *DownloadConfig + sleepFunc func(time.Duration) +} + +// DownloadResult contains the result of a successful download. +type DownloadResult struct { + FilePath string + SHA256 string +} + +// NewDownloader creates a new Downloader with the given configuration. +func NewDownloader(config *DownloadConfig) *Downloader { + if config == nil { + config = DefaultDownloadConfig() + } + return &Downloader{ + client: &http.Client{Timeout: config.Timeout}, + config: config, + sleepFunc: time.Sleep, + } +} + +// NewDownloaderWithSleep creates a Downloader with a custom sleep function for testing. +func NewDownloaderWithSleep(config *DownloadConfig, sleepFunc func(time.Duration)) *Downloader { + d := NewDownloader(config) + d.sleepFunc = sleepFunc + return d +} + +// Download downloads a file from url to destPath with retry logic. +// Retries on network errors and 5xx responses. +// Does NOT retry on 4xx errors (returns immediately). +func (d *Downloader) Download(ctx context.Context, url, destPath string) error { + var lastErr error + backoff := d.config.InitialBackoff + + for attempt := 0; attempt <= d.config.MaxRetries; attempt++ { + // Check context before each attempt + if ctx.Err() != nil { + return ctx.Err() + } + + err := d.downloadOnce(ctx, url, destPath) + if err == nil { + return nil + } + + lastErr = err + + // Check if error is retryable + if !d.isRetryable(err) { + return err + } + + // Don't sleep after the last attempt + if attempt < d.config.MaxRetries { + d.sleepFunc(backoff) + backoff = d.nextBackoff(backoff) + + // Check context after sleep + if ctx.Err() != nil { + return ctx.Err() + } + } + } + + return fmt.Errorf("%w: %v", ErrDownloadFailed, lastErr) +} + +// downloadOnce performs a single download attempt. +func (d *Downloader) downloadOnce(ctx context.Context, url, destPath string) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + resp, err := d.client.Do(req) + if err != nil { + return &networkError{err: err} + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return &httpError{statusCode: resp.StatusCode} + } + + // Ensure parent directory exists + if err := os.MkdirAll(filepath.Dir(destPath), 0755); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + // Write to temp file first (atomic write) + tmpFile, err := os.CreateTemp(filepath.Dir(destPath), ".download-*") + if err != nil { + return fmt.Errorf("failed to create temp file: %w", err) + } + tmpPath := tmpFile.Name() + + // Clean up temp file on error + defer func() { + if tmpPath != "" { + os.Remove(tmpPath) + } + }() + + // Copy response body to file + // Use a network-error-tagging reader so we can distinguish network vs disk errors + _, err = io.Copy(tmpFile, &networkReader{r: resp.Body}) + if err != nil { + tmpFile.Close() + // Check if it's a network error (tagged by networkReader) or a disk write error + var netErr *networkError + if errors.As(err, &netErr) { + return err // Already tagged as network error, retryable + } + // Disk write error - not retryable + return fmt.Errorf("failed to write to disk: %w", err) + } + + if err := tmpFile.Close(); err != nil { + return fmt.Errorf("failed to close temp file: %w", err) + } + + // Atomic rename + if err := os.Rename(tmpPath, destPath); err != nil { + return fmt.Errorf("failed to rename temp file: %w", err) + } + + tmpPath = "" // Prevent cleanup since rename succeeded + return nil +} + +// isRetryable returns true if the error should trigger a retry. +func (d *Downloader) isRetryable(err error) bool { + // Network errors are retryable + var netErr *networkError + if errors.As(err, &netErr) { + return true + } + + // 5xx errors are retryable, 4xx are not + var httpErr *httpError + if errors.As(err, &httpErr) { + return httpErr.statusCode >= 500 + } + + return false +} + +// nextBackoff calculates the next backoff duration with exponential increase and cap. +func (d *Downloader) nextBackoff(current time.Duration) time.Duration { + next := current * time.Duration(d.config.BackoffMultiplier) + if next > d.config.MaxBackoff { + return d.config.MaxBackoff + } + return next +} + +// networkError wraps network-related errors for retry detection. +type networkError struct { + err error +} + +func (e *networkError) Error() string { + return e.err.Error() +} + +func (e *networkError) Unwrap() error { + return e.err +} + +// networkReader wraps an io.Reader and tags any read errors as network errors. +// This allows distinguishing network errors from disk write errors during io.Copy. +type networkReader struct { + r io.Reader +} + +func (nr *networkReader) Read(p []byte) (n int, err error) { + n, err = nr.r.Read(p) + if err != nil && err != io.EOF { + return n, &networkError{err: err} + } + return n, err +} + +// httpError represents an HTTP error response. +type httpError struct { + statusCode int +} + +func (e *httpError) Error() string { + return fmt.Sprintf("HTTP %d", e.statusCode) +} + +// VerifySHA256 verifies that the file at filePath has the expected SHA256 checksum. +// Returns ErrChecksumMismatch if the checksum doesn't match. +func VerifySHA256(filePath, expectedSHA256 string) error { + f, err := os.Open(filePath) + if err != nil { + return fmt.Errorf("failed to open file: %w", err) + } + defer f.Close() + + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return fmt.Errorf("failed to hash file: %w", err) + } + + actualSHA256 := hex.EncodeToString(h.Sum(nil)) + if actualSHA256 != expectedSHA256 { + return fmt.Errorf("%w: expected %s, got %s", ErrChecksumMismatch, expectedSHA256, actualSHA256) + } + + return nil +} + +// DownloadAndVerify downloads a file and verifies its SHA256 checksum. +// If verification fails, the downloaded file is deleted. +func (d *Downloader) DownloadAndVerify(ctx context.Context, url, destPath, expectedSHA256 string) (*DownloadResult, error) { + if err := d.Download(ctx, url, destPath); err != nil { + return nil, err + } + + if err := VerifySHA256(destPath, expectedSHA256); err != nil { + // Delete file on checksum failure + os.Remove(destPath) + return nil, err + } + + return &DownloadResult{ + FilePath: destPath, + SHA256: expectedSHA256, + }, nil +} diff --git a/app/services/updatedownload/download_test.go b/app/services/updatedownload/download_test.go new file mode 100644 index 0000000..be50852 --- /dev/null +++ b/app/services/updatedownload/download_test.go @@ -0,0 +1,410 @@ +package updatedownload + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Helper to create a test server that returns given content +func createTestServer(t *testing.T, content []byte) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write(content) + })) +} + +// Helper to create a server that returns a specific status code +func createStatusServer(t *testing.T, statusCode int) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(statusCode) + })) +} + +// Helper to compute SHA256 of content +func computeSHA256(content []byte) string { + hash := sha256.Sum256(content) + return hex.EncodeToString(hash[:]) +} + +// noopSleep is a no-op sleep function for fast tests +func noopSleep(d time.Duration) {} + +// ============================================================================ +// Download Tests +// ============================================================================ + +func TestDownload_SuccessOnFirstAttempt(t *testing.T) { + content := []byte("test file content") + server := createTestServer(t, content) + defer server.Close() + + tmpDir := t.TempDir() + destPath := filepath.Join(tmpDir, "downloaded.bin") + + d := NewDownloaderWithSleep(DefaultDownloadConfig(), noopSleep) + err := d.Download(context.Background(), server.URL, destPath) + + require.NoError(t, err) + + // Verify file was created with correct content + data, err := os.ReadFile(destPath) + require.NoError(t, err) + assert.Equal(t, content, data) +} + +func TestDownload_RetriesOn5xxErrors(t *testing.T) { + var attemptCount atomic.Int32 + content := []byte("success after retries") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count := attemptCount.Add(1) + if count < 3 { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + w.Write(content) + })) + defer server.Close() + + tmpDir := t.TempDir() + destPath := filepath.Join(tmpDir, "downloaded.bin") + + d := NewDownloaderWithSleep(DefaultDownloadConfig(), noopSleep) + err := d.Download(context.Background(), server.URL, destPath) + + require.NoError(t, err) + assert.Equal(t, int32(3), attemptCount.Load(), "should have made 3 attempts") + + // Verify file content + data, err := os.ReadFile(destPath) + require.NoError(t, err) + assert.Equal(t, content, data) +} + +func TestDownload_RetriesOnNetworkTimeout(t *testing.T) { + var attemptCount atomic.Int32 + content := []byte("success after timeout") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count := attemptCount.Add(1) + if count < 3 { + // Simulate timeout by not responding (connection will timeout) + time.Sleep(200 * time.Millisecond) + return + } + w.WriteHeader(http.StatusOK) + w.Write(content) + })) + defer server.Close() + + tmpDir := t.TempDir() + destPath := filepath.Join(tmpDir, "downloaded.bin") + + config := &DownloadConfig{ + MaxRetries: 5, + InitialBackoff: 1 * time.Millisecond, + MaxBackoff: 10 * time.Millisecond, + BackoffMultiplier: 2, + Timeout: 50 * time.Millisecond, // Short timeout to trigger retries + } + d := NewDownloaderWithSleep(config, noopSleep) + err := d.Download(context.Background(), server.URL, destPath) + + require.NoError(t, err) + assert.GreaterOrEqual(t, attemptCount.Load(), int32(3), "should have retried on timeout") +} + +func TestDownload_NoRetryOn4xxErrors(t *testing.T) { + var attemptCount atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount.Add(1) + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + tmpDir := t.TempDir() + destPath := filepath.Join(tmpDir, "downloaded.bin") + + d := NewDownloaderWithSleep(DefaultDownloadConfig(), noopSleep) + err := d.Download(context.Background(), server.URL, destPath) + + require.Error(t, err) + assert.Equal(t, int32(1), attemptCount.Load(), "should NOT retry on 4xx errors") +} + +func TestDownload_ExponentialBackoff(t *testing.T) { + var sleepDurations []time.Duration + + server := createStatusServer(t, http.StatusInternalServerError) + defer server.Close() + + tmpDir := t.TempDir() + destPath := filepath.Join(tmpDir, "downloaded.bin") + + config := &DownloadConfig{ + MaxRetries: 5, + InitialBackoff: 5 * time.Second, + MaxBackoff: 60 * time.Second, + BackoffMultiplier: 2, + Timeout: 1 * time.Second, + } + + trackingSleep := func(d time.Duration) { + sleepDurations = append(sleepDurations, d) + } + + d := NewDownloaderWithSleep(config, trackingSleep) + _ = d.Download(context.Background(), server.URL, destPath) + + // Expected backoff: 5s, 10s, 20s, 40s, 60s (capped) + expected := []time.Duration{ + 5 * time.Second, + 10 * time.Second, + 20 * time.Second, + 40 * time.Second, + 60 * time.Second, + } + + require.Equal(t, len(expected), len(sleepDurations), "should sleep between each retry") + for i, exp := range expected { + assert.Equal(t, exp, sleepDurations[i], "backoff at retry %d", i+1) + } +} + +func TestDownload_MaxRetriesExceeded(t *testing.T) { + var attemptCount atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount.Add(1) + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + tmpDir := t.TempDir() + destPath := filepath.Join(tmpDir, "downloaded.bin") + + config := &DownloadConfig{ + MaxRetries: 3, + InitialBackoff: 1 * time.Millisecond, + MaxBackoff: 10 * time.Millisecond, + BackoffMultiplier: 2, + Timeout: 1 * time.Second, + } + + d := NewDownloaderWithSleep(config, noopSleep) + err := d.Download(context.Background(), server.URL, destPath) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrDownloadFailed) + assert.Equal(t, int32(4), attemptCount.Load(), "should make 1 initial + 3 retries = 4 attempts") +} + +func TestDownload_ContextCancellation(t *testing.T) { + var attemptCount atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount.Add(1) + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + tmpDir := t.TempDir() + destPath := filepath.Join(tmpDir, "downloaded.bin") + + ctx, cancel := context.WithCancel(context.Background()) + + // Cancel after first sleep + sleepWithCancel := func(d time.Duration) { + cancel() + } + + d := NewDownloaderWithSleep(DefaultDownloadConfig(), sleepWithCancel) + err := d.Download(ctx, server.URL, destPath) + + require.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) + assert.LessOrEqual(t, attemptCount.Load(), int32(2), "should stop retrying after context cancelled") +} + +func TestDownload_AtomicWrite(t *testing.T) { + content := []byte("atomic write test content") + server := createTestServer(t, content) + defer server.Close() + + tmpDir := t.TempDir() + destPath := filepath.Join(tmpDir, "downloaded.bin") + + d := NewDownloaderWithSleep(DefaultDownloadConfig(), noopSleep) + err := d.Download(context.Background(), server.URL, destPath) + + require.NoError(t, err) + + // Verify no temp files left behind + entries, err := os.ReadDir(tmpDir) + require.NoError(t, err) + + for _, entry := range entries { + assert.Equal(t, "downloaded.bin", entry.Name(), "only destination file should exist") + } +} + +func TestDownload_NoRetryOnDiskWriteError(t *testing.T) { + var attemptCount atomic.Int32 + content := []byte("test content") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount.Add(1) + w.WriteHeader(http.StatusOK) + w.Write(content) + })) + defer server.Close() + + // Try to download to a read-only directory (will fail on write) + tmpDir := t.TempDir() + readOnlyDir := filepath.Join(tmpDir, "readonly") + err := os.MkdirAll(readOnlyDir, 0555) // read + execute only, no write + require.NoError(t, err) + + destPath := filepath.Join(readOnlyDir, "downloaded.bin") + + d := NewDownloaderWithSleep(DefaultDownloadConfig(), noopSleep) + err = d.Download(context.Background(), server.URL, destPath) + + require.Error(t, err) + assert.NotErrorIs(t, err, ErrDownloadFailed, "disk errors should not result in ErrDownloadFailed") + assert.Equal(t, int32(1), attemptCount.Load(), "should NOT retry on disk write errors") +} + +// ============================================================================ +// VerifySHA256 Tests +// ============================================================================ + +func TestVerifySHA256_Match(t *testing.T) { + content := []byte("test content for hashing") + expectedSHA := computeSHA256(content) + + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "test.bin") + err := os.WriteFile(filePath, content, 0644) + require.NoError(t, err) + + err = VerifySHA256(filePath, expectedSHA) + assert.NoError(t, err) +} + +func TestVerifySHA256_Mismatch(t *testing.T) { + content := []byte("test content for hashing") + wrongSHA := "0000000000000000000000000000000000000000000000000000000000000000" + + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "test.bin") + err := os.WriteFile(filePath, content, 0644) + require.NoError(t, err) + + err = VerifySHA256(filePath, wrongSHA) + assert.ErrorIs(t, err, ErrChecksumMismatch) +} + +func TestVerifySHA256_StreamingLargeFile(t *testing.T) { + // Create a "large" file (1MB) to ensure streaming works + size := 1024 * 1024 + content := make([]byte, size) + for i := range content { + content[i] = byte(i % 256) + } + expectedSHA := computeSHA256(content) + + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "large.bin") + err := os.WriteFile(filePath, content, 0644) + require.NoError(t, err) + + err = VerifySHA256(filePath, expectedSHA) + assert.NoError(t, err) +} + +// ============================================================================ +// DownloadAndVerify Tests +// ============================================================================ + +func TestDownloadAndVerify_Success(t *testing.T) { + content := []byte("verified content") + expectedSHA := computeSHA256(content) + + server := createTestServer(t, content) + defer server.Close() + + tmpDir := t.TempDir() + destPath := filepath.Join(tmpDir, "verified.bin") + + d := NewDownloaderWithSleep(DefaultDownloadConfig(), noopSleep) + result, err := d.DownloadAndVerify(context.Background(), server.URL, destPath, expectedSHA) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, destPath, result.FilePath) + assert.Equal(t, expectedSHA, result.SHA256) + + // Verify file exists with correct content + data, err := os.ReadFile(destPath) + require.NoError(t, err) + assert.Equal(t, content, data) +} + +func TestDownloadAndVerify_DeletesOnMismatch(t *testing.T) { + content := []byte("content with wrong checksum") + wrongSHA := "0000000000000000000000000000000000000000000000000000000000000000" + + server := createTestServer(t, content) + defer server.Close() + + tmpDir := t.TempDir() + destPath := filepath.Join(tmpDir, "bad.bin") + + d := NewDownloaderWithSleep(DefaultDownloadConfig(), noopSleep) + result, err := d.DownloadAndVerify(context.Background(), server.URL, destPath, wrongSHA) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrChecksumMismatch) + assert.Nil(t, result) + + // File should be deleted on checksum failure + _, err = os.Stat(destPath) + assert.True(t, os.IsNotExist(err), "file should be deleted on checksum mismatch") +} + +func TestDownloadAndVerify_ReturnsResult(t *testing.T) { + content := []byte("result test content") + expectedSHA := computeSHA256(content) + + server := createTestServer(t, content) + defer server.Close() + + tmpDir := t.TempDir() + destPath := filepath.Join(tmpDir, "result.bin") + + d := NewDownloaderWithSleep(DefaultDownloadConfig(), noopSleep) + result, err := d.DownloadAndVerify(context.Background(), server.URL, destPath, expectedSHA) + + require.NoError(t, err) + require.NotNil(t, result) + + // Verify result struct is populated correctly + assert.Equal(t, destPath, result.FilePath) + assert.Equal(t, expectedSHA, result.SHA256) +} diff --git a/app/services/updatedownload/extract.go b/app/services/updatedownload/extract.go new file mode 100644 index 0000000..e2d9145 --- /dev/null +++ b/app/services/updatedownload/extract.go @@ -0,0 +1,214 @@ +package updatedownload + +import ( + "archive/tar" + "compress/gzip" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strings" +) + +var ( + // ErrPathTraversal is returned when a tar entry contains path traversal. + ErrPathTraversal = errors.New("path traversal detected in archive") + // ErrFileNotFound is returned when the requested file is not in the archive. + ErrFileNotFound = errors.New("file not found in archive") +) + +const ( + // ExtractDirPermissions is the permission mode for extraction directories. + // Uses 0755 as a sensible default for general-purpose extraction (e.g., binaries + // to /usr/bin/ need to be world-executable). Callers needing different permissions + // should create the destination directory beforehand with desired permissions. + ExtractDirPermissions = 0755 +) + +// ExtractTarGz extracts all files from a .tar.gz archive to destDir. +// Creates destDir if it doesn't exist with 0755 permissions. +// Returns error on invalid archive or path traversal attempt. +func ExtractTarGz(tarPath, destDir string) error { + // Create destDir with correct permissions before extraction + if err := os.MkdirAll(destDir, ExtractDirPermissions); err != nil { + return fmt.Errorf("failed to create destination directory: %w", err) + } + // Ensure permissions are correct even if directory already exists + if err := os.Chmod(destDir, ExtractDirPermissions); err != nil { + return fmt.Errorf("failed to set destination directory permissions: %w", err) + } + + f, err := os.Open(tarPath) + if err != nil { + return fmt.Errorf("failed to open archive: %w", err) + } + defer f.Close() + + gr, err := gzip.NewReader(f) + if err != nil { + return fmt.Errorf("failed to create gzip reader: %w", err) + } + defer gr.Close() + + tr := tar.NewReader(gr) + + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("failed to read tar entry: %w", err) + } + + // Security: check for path traversal + if err := validatePath(hdr.Name); err != nil { + return err + } + + targetPath := filepath.Join(destDir, hdr.Name) + + // Ensure the target is within destDir (defense in depth) + if !strings.HasPrefix(filepath.Clean(targetPath), filepath.Clean(destDir)) { + return fmt.Errorf("%w: %s", ErrPathTraversal, hdr.Name) + } + + switch hdr.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(targetPath, os.FileMode(hdr.Mode)); err != nil { + return fmt.Errorf("failed to create directory %s: %w", targetPath, err) + } + case tar.TypeReg: + if err := extractFile(tr, targetPath, os.FileMode(hdr.Mode)); err != nil { + return err + } + } + } + + return nil +} + +// validatePath checks if a path is safe (no traversal). +// Note: On Linux, backslashes are valid filename characters, not separators. +// The defense-in-depth check in ExtractTarGz (filepath.Join + HasPrefix) handles +// any edge cases. This function provides an early rejection for obvious attacks. +func validatePath(path string) error { + // Reject absolute paths (checks both / on Unix and drive letters on Windows) + if filepath.IsAbs(path) { + return fmt.Errorf("%w: absolute path %s", ErrPathTraversal, path) + } + + // Reject paths with .. components + // Using filepath.Clean normalizes the path and resolves .. where possible + cleanPath := filepath.Clean(path) + if strings.HasPrefix(cleanPath, "..") { + return fmt.Errorf("%w: %s", ErrPathTraversal, path) + } + + // Also check for .. anywhere in the cleaned path (e.g., "foo/../..") + if strings.Contains(cleanPath, string(filepath.Separator)+"..") { + return fmt.Errorf("%w: %s", ErrPathTraversal, path) + } + + return nil +} + +// extractFile extracts a single file from the tar reader to the target path. +func extractFile(tr *tar.Reader, targetPath string, mode os.FileMode) error { + // Create parent directory if needed + if err := os.MkdirAll(filepath.Dir(targetPath), ExtractDirPermissions); err != nil { + return fmt.Errorf("failed to create parent directory: %w", err) + } + + f, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode) + if err != nil { + return fmt.Errorf("failed to create file %s: %w", targetPath, err) + } + defer f.Close() + + if _, err := io.Copy(f, tr); err != nil { + return fmt.Errorf("failed to write file %s: %w", targetPath, err) + } + + return nil +} + +// ExtractFile extracts a single file from a .tar.gz archive to destPath. +// Uses atomic write (temp file + rename). +// Returns error if the file is not found in the archive. +func ExtractFile(tarPath, fileName, destPath string) error { + f, err := os.Open(tarPath) + if err != nil { + return fmt.Errorf("failed to open archive: %w", err) + } + defer f.Close() + + gr, err := gzip.NewReader(f) + if err != nil { + return fmt.Errorf("failed to create gzip reader: %w", err) + } + defer gr.Close() + + tr := tar.NewReader(gr) + + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("failed to read tar entry: %w", err) + } + + if hdr.Name == fileName || filepath.Base(hdr.Name) == fileName { + return extractFileAtomic(tr, destPath, os.FileMode(hdr.Mode)) + } + } + + return fmt.Errorf("%w: %s", ErrFileNotFound, fileName) +} + +// extractFileAtomic extracts content to a temp file then renames atomically. +func extractFileAtomic(r io.Reader, destPath string, mode os.FileMode) error { + // Ensure parent directory exists + if err := os.MkdirAll(filepath.Dir(destPath), ExtractDirPermissions); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + // Write to temp file + tmpFile, err := os.CreateTemp(filepath.Dir(destPath), ".extract-*") + if err != nil { + return fmt.Errorf("failed to create temp file: %w", err) + } + tmpPath := tmpFile.Name() + + // Clean up temp file on error + defer func() { + if tmpPath != "" { + os.Remove(tmpPath) + } + }() + + if _, err := io.Copy(tmpFile, r); err != nil { + tmpFile.Close() + return fmt.Errorf("failed to write file: %w", err) + } + + if err := tmpFile.Chmod(mode); err != nil { + tmpFile.Close() + return fmt.Errorf("failed to set permissions: %w", err) + } + + if err := tmpFile.Close(); err != nil { + return fmt.Errorf("failed to close temp file: %w", err) + } + + // Atomic rename + if err := os.Rename(tmpPath, destPath); err != nil { + return fmt.Errorf("failed to rename temp file: %w", err) + } + + tmpPath = "" // Prevent cleanup since rename succeeded + return nil +} diff --git a/app/services/updatedownload/extract_test.go b/app/services/updatedownload/extract_test.go new file mode 100644 index 0000000..b76dcc7 --- /dev/null +++ b/app/services/updatedownload/extract_test.go @@ -0,0 +1,339 @@ +package updatedownload + +import ( + "archive/tar" + "compress/gzip" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Helper to create a .tar.gz archive for testing +func createTestTarGz(t *testing.T, destPath string, files map[string]testFile) { + t.Helper() + + f, err := os.Create(destPath) + require.NoError(t, err) + defer f.Close() + + gw := gzip.NewWriter(f) + defer gw.Close() + + tw := tar.NewWriter(gw) + defer tw.Close() + + for name, tf := range files { + hdr := &tar.Header{ + Name: name, + Mode: tf.mode, + Size: int64(len(tf.content)), + } + err := tw.WriteHeader(hdr) + require.NoError(t, err) + + _, err = tw.Write([]byte(tf.content)) + require.NoError(t, err) + } +} + +type testFile struct { + content string + mode int64 +} + +// ============================================================================ +// ExtractTarGz Tests +// ============================================================================ + +func TestExtractTarGz_ValidArchive(t *testing.T) { + tmpDir := t.TempDir() + tarPath := filepath.Join(tmpDir, "test.tar.gz") + destDir := filepath.Join(tmpDir, "extracted") + + files := map[string]testFile{ + "file1.txt": {content: "content of file 1", mode: 0644}, + "file2.txt": {content: "content of file 2", mode: 0644}, + "subdir/file3.txt": {content: "content in subdir", mode: 0644}, + } + createTestTarGz(t, tarPath, files) + + err := ExtractTarGz(tarPath, destDir) + require.NoError(t, err) + + // Verify all files were extracted + for name, tf := range files { + path := filepath.Join(destDir, name) + content, err := os.ReadFile(path) + require.NoError(t, err, "file %s should exist", name) + assert.Equal(t, tf.content, string(content), "content of %s", name) + } +} + +func TestExtractTarGz_PreservesPermissions(t *testing.T) { + tmpDir := t.TempDir() + tarPath := filepath.Join(tmpDir, "test.tar.gz") + destDir := filepath.Join(tmpDir, "extracted") + + files := map[string]testFile{ + "readonly.txt": {content: "readonly", mode: 0444}, + "executable.sh": {content: "#!/bin/bash", mode: 0755}, + "normal.txt": {content: "normal", mode: 0644}, + } + createTestTarGz(t, tarPath, files) + + err := ExtractTarGz(tarPath, destDir) + require.NoError(t, err) + + // Verify permissions + for name, tf := range files { + path := filepath.Join(destDir, name) + info, err := os.Stat(path) + require.NoError(t, err) + assert.Equal(t, os.FileMode(tf.mode), info.Mode().Perm(), "permissions of %s", name) + } +} + +func TestExtractTarGz_CreatesDestDir(t *testing.T) { + tmpDir := t.TempDir() + tarPath := filepath.Join(tmpDir, "test.tar.gz") + destDir := filepath.Join(tmpDir, "nested", "dest", "dir") + + files := map[string]testFile{ + "file.txt": {content: "test", mode: 0644}, + } + createTestTarGz(t, tarPath, files) + + // destDir doesn't exist yet + _, err := os.Stat(destDir) + require.True(t, os.IsNotExist(err)) + + err = ExtractTarGz(tarPath, destDir) + require.NoError(t, err) + + // destDir should now exist + info, err := os.Stat(destDir) + require.NoError(t, err) + assert.True(t, info.IsDir()) +} + +func TestExtractTarGz_DestDirPermissions(t *testing.T) { + tmpDir := t.TempDir() + tarPath := filepath.Join(tmpDir, "test.tar.gz") + destDir := filepath.Join(tmpDir, "extracted") + + files := map[string]testFile{ + "file.txt": {content: "test", mode: 0644}, + } + createTestTarGz(t, tarPath, files) + + err := ExtractTarGz(tarPath, destDir) + require.NoError(t, err) + + // destDir should have 0755 permissions + info, err := os.Stat(destDir) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0755), info.Mode().Perm(), "destDir should have 0755 permissions") +} + +func TestExtractTarGz_DestDirCreatedBeforeExtraction(t *testing.T) { + tmpDir := t.TempDir() + tarPath := filepath.Join(tmpDir, "test.tar.gz") + destDir := filepath.Join(tmpDir, "extracted") + + // Create archive with files in subdirectory only (not directly in destDir) + files := map[string]testFile{ + "subdir/file.txt": {content: "test", mode: 0644}, + } + createTestTarGz(t, tarPath, files) + + err := ExtractTarGz(tarPath, destDir) + require.NoError(t, err) + + // destDir should exist and have correct permissions + info, err := os.Stat(destDir) + require.NoError(t, err) + assert.True(t, info.IsDir()) + assert.Equal(t, os.FileMode(0755), info.Mode().Perm(), "destDir should have 0755 permissions") +} + +func TestExtractTarGz_FixesExistingDestDirPermissions(t *testing.T) { + tmpDir := t.TempDir() + tarPath := filepath.Join(tmpDir, "test.tar.gz") + destDir := filepath.Join(tmpDir, "extracted") + + files := map[string]testFile{ + "file.txt": {content: "test", mode: 0644}, + } + createTestTarGz(t, tarPath, files) + + // Create destDir with restrictive permissions + err := os.MkdirAll(destDir, 0700) + require.NoError(t, err) + + err = ExtractTarGz(tarPath, destDir) + require.NoError(t, err) + + // destDir should now have 0755 permissions (fixed) + info, err := os.Stat(destDir) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0755), info.Mode().Perm(), "destDir permissions should be fixed to 0755") +} + +func TestExtractTarGz_InvalidArchive(t *testing.T) { + tmpDir := t.TempDir() + tarPath := filepath.Join(tmpDir, "invalid.tar.gz") + destDir := filepath.Join(tmpDir, "extracted") + + // Write invalid content + err := os.WriteFile(tarPath, []byte("not a valid tar.gz"), 0644) + require.NoError(t, err) + + err = ExtractTarGz(tarPath, destDir) + assert.Error(t, err) +} + +func TestExtractTarGz_PathTraversal(t *testing.T) { + testCases := []struct { + name string + fileName string + }{ + {"unix style", "../../../etc/passwd"}, + {"windows style backslash", "..\\..\\..\\etc\\passwd"}, + {"mixed slashes", "..\\../..\\etc/passwd"}, + {"hidden in middle", "foo/../../../etc/passwd"}, + {"absolute unix", "/etc/passwd"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tmpDir := t.TempDir() + tarPath := filepath.Join(tmpDir, "malicious.tar.gz") + destDir := filepath.Join(tmpDir, "extracted") + + // Create archive with path traversal attempt + f, err := os.Create(tarPath) + require.NoError(t, err) + + gw := gzip.NewWriter(f) + tw := tar.NewWriter(gw) + + // Malicious file with path traversal + hdr := &tar.Header{ + Name: tc.fileName, + Mode: 0644, + Size: int64(len("malicious")), + } + err = tw.WriteHeader(hdr) + require.NoError(t, err) + _, err = tw.Write([]byte("malicious")) + require.NoError(t, err) + + tw.Close() + gw.Close() + f.Close() + + err = ExtractTarGz(tarPath, destDir) + assert.Error(t, err, "should reject path traversal: %s", tc.fileName) + assert.ErrorIs(t, err, ErrPathTraversal) + }) + } +} + +// ============================================================================ +// ExtractFile Tests +// ============================================================================ + +func TestExtractFile_SingleFile(t *testing.T) { + tmpDir := t.TempDir() + tarPath := filepath.Join(tmpDir, "test.tar.gz") + destPath := filepath.Join(tmpDir, "extracted.txt") + + files := map[string]testFile{ + "file1.txt": {content: "content 1", mode: 0644}, + "file2.txt": {content: "content 2", mode: 0644}, + "file3.txt": {content: "content 3", mode: 0644}, + } + createTestTarGz(t, tarPath, files) + + err := ExtractFile(tarPath, "file2.txt", destPath) + require.NoError(t, err) + + content, err := os.ReadFile(destPath) + require.NoError(t, err) + assert.Equal(t, "content 2", string(content)) +} + +func TestExtractFile_PreservesPermissions(t *testing.T) { + testCases := []struct { + name string + mode int64 + }{ + {"executable", 0755}, + {"readonly", 0444}, + {"normal", 0644}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tmpDir := t.TempDir() + tarPath := filepath.Join(tmpDir, "test.tar.gz") + destPath := filepath.Join(tmpDir, "extracted") + + files := map[string]testFile{ + "file.bin": {content: "binary content", mode: tc.mode}, + } + createTestTarGz(t, tarPath, files) + + err := ExtractFile(tarPath, "file.bin", destPath) + require.NoError(t, err) + + info, err := os.Stat(destPath) + require.NoError(t, err) + assert.Equal(t, os.FileMode(tc.mode), info.Mode().Perm(), "should preserve %s permissions", tc.name) + }) + } +} + +func TestExtractFile_NotFound(t *testing.T) { + tmpDir := t.TempDir() + tarPath := filepath.Join(tmpDir, "test.tar.gz") + destPath := filepath.Join(tmpDir, "extracted.txt") + + files := map[string]testFile{ + "file1.txt": {content: "content 1", mode: 0644}, + } + createTestTarGz(t, tarPath, files) + + err := ExtractFile(tarPath, "nonexistent.txt", destPath) + assert.Error(t, err, "should error when file not found in archive") +} + +func TestExtractFile_AtomicWrite(t *testing.T) { + tmpDir := t.TempDir() + tarPath := filepath.Join(tmpDir, "test.tar.gz") + destPath := filepath.Join(tmpDir, "extracted.txt") + + files := map[string]testFile{ + "file.txt": {content: "atomic content", mode: 0644}, + } + createTestTarGz(t, tarPath, files) + + err := ExtractFile(tarPath, "file.txt", destPath) + require.NoError(t, err) + + // Verify no temp files left behind + entries, err := os.ReadDir(tmpDir) + require.NoError(t, err) + + fileNames := make([]string, 0, len(entries)) + for _, entry := range entries { + fileNames = append(fileNames, entry.Name()) + } + + assert.Contains(t, fileNames, "test.tar.gz") + assert.Contains(t, fileNames, "extracted.txt") + assert.Len(t, fileNames, 2, "should only have tar and extracted file, no temp files") +} diff --git a/app/services/updatedownload/staging.go b/app/services/updatedownload/staging.go new file mode 100644 index 0000000..480b815 --- /dev/null +++ b/app/services/updatedownload/staging.go @@ -0,0 +1,86 @@ +package updatedownload + +import ( + "context" + "fmt" + "os" + "path/filepath" +) + +const ( + // AgentTarballName is the filename for the staged agent tarball. + AgentTarballName = "hostlink.tar.gz" + // UpdaterTarballName is the filename for the staged updater tarball. + UpdaterTarballName = "updater.tar.gz" + // StagingDirPermissions is the permission mode for the staging directory. + StagingDirPermissions = 0700 +) + +// StagingManager manages the staging area for update artifacts. +type StagingManager struct { + basePath string + downloader *Downloader +} + +// NewStagingManager creates a new StagingManager. +func NewStagingManager(basePath string, downloader *Downloader) *StagingManager { + return &StagingManager{ + basePath: basePath, + downloader: downloader, + } +} + +// Prepare creates the staging directory with correct permissions (0700). +// This function is idempotent. +func (s *StagingManager) Prepare() error { + if err := os.MkdirAll(s.basePath, StagingDirPermissions); err != nil { + return fmt.Errorf("failed to create staging directory: %w", err) + } + + // Ensure permissions are correct even if directory already exists + if err := os.Chmod(s.basePath, StagingDirPermissions); err != nil { + return fmt.Errorf("failed to set staging directory permissions: %w", err) + } + + return nil +} + +// StageAgent downloads and verifies the agent tarball to the staging area. +func (s *StagingManager) StageAgent(ctx context.Context, url, sha256 string) error { + destPath := s.GetAgentPath() + _, err := s.downloader.DownloadAndVerify(ctx, url, destPath, sha256) + return err +} + +// StageUpdater downloads and verifies the updater tarball to the staging area. +func (s *StagingManager) StageUpdater(ctx context.Context, url, sha256 string) error { + destPath := s.GetUpdaterPath() + _, err := s.downloader.DownloadAndVerify(ctx, url, destPath, sha256) + return err +} + +// GetAgentPath returns the path to the staged agent tarball. +// Note: Returns tarball path, not extracted binary. Extraction happens in updater phase. +func (s *StagingManager) GetAgentPath() string { + return filepath.Join(s.basePath, AgentTarballName) +} + +// GetUpdaterPath returns the path to the staged updater tarball. +// Note: Returns tarball path, not extracted binary. Extraction happens in updater phase. +func (s *StagingManager) GetUpdaterPath() string { + return filepath.Join(s.basePath, UpdaterTarballName) +} + +// Cleanup removes the entire staging directory. +func (s *StagingManager) Cleanup() error { + // Check if directory exists + if _, err := os.Stat(s.basePath); os.IsNotExist(err) { + return nil // Nothing to clean up + } + + if err := os.RemoveAll(s.basePath); err != nil { + return fmt.Errorf("failed to remove staging directory: %w", err) + } + + return nil +} diff --git a/app/services/updatedownload/staging_test.go b/app/services/updatedownload/staging_test.go new file mode 100644 index 0000000..3e41496 --- /dev/null +++ b/app/services/updatedownload/staging_test.go @@ -0,0 +1,190 @@ +package updatedownload + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStagingManager_Prepare(t *testing.T) { + tmpDir := t.TempDir() + stagingPath := filepath.Join(tmpDir, "staging") + + // Staging dir doesn't exist yet + _, err := os.Stat(stagingPath) + require.True(t, os.IsNotExist(err)) + + sm := NewStagingManager(stagingPath, nil) + err = sm.Prepare() + require.NoError(t, err) + + // Staging dir should now exist with 0700 permissions + info, err := os.Stat(stagingPath) + require.NoError(t, err) + assert.True(t, info.IsDir()) + assert.Equal(t, os.FileMode(0700), info.Mode().Perm()) +} + +func TestStagingManager_Prepare_Idempotent(t *testing.T) { + tmpDir := t.TempDir() + stagingPath := filepath.Join(tmpDir, "staging") + + sm := NewStagingManager(stagingPath, nil) + + // Call Prepare twice - should not error + err := sm.Prepare() + require.NoError(t, err) + + err = sm.Prepare() + require.NoError(t, err) + + // Should still have correct permissions + info, err := os.Stat(stagingPath) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0700), info.Mode().Perm()) +} + +func TestStagingManager_StageAgent(t *testing.T) { + content := []byte("fake agent tarball content") + contentSHA := computeStagingSHA256(content) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write(content) + })) + defer server.Close() + + tmpDir := t.TempDir() + stagingPath := filepath.Join(tmpDir, "staging") + + downloader := NewDownloaderWithSleep(DefaultDownloadConfig(), noopSleep) + sm := NewStagingManager(stagingPath, downloader) + + err := sm.Prepare() + require.NoError(t, err) + + err = sm.StageAgent(context.Background(), server.URL, contentSHA) + require.NoError(t, err) + + // Verify file was downloaded + agentPath := sm.GetAgentPath() + data, err := os.ReadFile(agentPath) + require.NoError(t, err) + assert.Equal(t, content, data) +} + +func TestStagingManager_StageUpdater(t *testing.T) { + content := []byte("fake updater tarball content") + contentSHA := computeStagingSHA256(content) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write(content) + })) + defer server.Close() + + tmpDir := t.TempDir() + stagingPath := filepath.Join(tmpDir, "staging") + + downloader := NewDownloaderWithSleep(DefaultDownloadConfig(), noopSleep) + sm := NewStagingManager(stagingPath, downloader) + + err := sm.Prepare() + require.NoError(t, err) + + err = sm.StageUpdater(context.Background(), server.URL, contentSHA) + require.NoError(t, err) + + // Verify file was downloaded + updaterPath := sm.GetUpdaterPath() + data, err := os.ReadFile(updaterPath) + require.NoError(t, err) + assert.Equal(t, content, data) +} + +func TestStagingManager_GetPaths(t *testing.T) { + stagingPath := "/var/lib/hostlink/updates/staging" + sm := NewStagingManager(stagingPath, nil) + + assert.Equal(t, filepath.Join(stagingPath, "hostlink.tar.gz"), sm.GetAgentPath()) + assert.Equal(t, filepath.Join(stagingPath, UpdaterTarballName), sm.GetUpdaterPath()) +} + +func TestStagingManager_Cleanup(t *testing.T) { + tmpDir := t.TempDir() + stagingPath := filepath.Join(tmpDir, "staging") + + sm := NewStagingManager(stagingPath, nil) + + err := sm.Prepare() + require.NoError(t, err) + + // Create some files in staging + file1 := filepath.Join(stagingPath, "file1.txt") + file2 := filepath.Join(stagingPath, "file2.txt") + err = os.WriteFile(file1, []byte("content1"), 0644) + require.NoError(t, err) + err = os.WriteFile(file2, []byte("content2"), 0644) + require.NoError(t, err) + + // Cleanup + err = sm.Cleanup() + require.NoError(t, err) + + // Staging directory should be completely removed + _, err = os.Stat(stagingPath) + assert.True(t, os.IsNotExist(err), "staging directory should be removed after cleanup") +} + +func TestStagingManager_Cleanup_NonExistentDir(t *testing.T) { + tmpDir := t.TempDir() + stagingPath := filepath.Join(tmpDir, "nonexistent") + + sm := NewStagingManager(stagingPath, nil) + + // Cleanup on non-existent dir should not error + err := sm.Cleanup() + assert.NoError(t, err) +} + +func TestStagingManager_StageAgent_ChecksumMismatch(t *testing.T) { + content := []byte("agent content") + wrongSHA := "0000000000000000000000000000000000000000000000000000000000000000" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write(content) + })) + defer server.Close() + + tmpDir := t.TempDir() + stagingPath := filepath.Join(tmpDir, "staging") + + downloader := NewDownloaderWithSleep(DefaultDownloadConfig(), noopSleep) + sm := NewStagingManager(stagingPath, downloader) + + err := sm.Prepare() + require.NoError(t, err) + + err = sm.StageAgent(context.Background(), server.URL, wrongSHA) + assert.ErrorIs(t, err, ErrChecksumMismatch) + + // File should not exist after checksum failure + agentPath := sm.GetAgentPath() + _, err = os.Stat(agentPath) + assert.True(t, os.IsNotExist(err), "file should be deleted on checksum mismatch") +} + +// Helper for staging tests +func computeStagingSHA256(content []byte) string { + hash := sha256.Sum256(content) + return hex.EncodeToString(hash[:]) +} diff --git a/app/services/updatepreflight/preflight.go b/app/services/updatepreflight/preflight.go new file mode 100644 index 0000000..3b7939d --- /dev/null +++ b/app/services/updatepreflight/preflight.go @@ -0,0 +1,98 @@ +package updatepreflight + +import ( + "fmt" + "os" +) + +const ( + // diskSpaceBuffer is an additional 10MB buffer required beyond the stated requiredSpace. + diskSpaceBuffer = 10 * 1024 * 1024 +) + +// StatFunc returns available bytes for a given path. +type StatFunc func(path string) (uint64, error) + +// PreflightResult holds the results of pre-flight checks. +type PreflightResult struct { + Passed bool + Errors []string +} + +// PreflightChecker runs pre-flight checks before an update. +type PreflightChecker struct { + agentBinaryPath string + updatesDir string + statFunc StatFunc +} + +// PreflightConfig holds configuration for the PreflightChecker. +type PreflightConfig struct { + AgentBinaryPath string + UpdatesDir string + StatFunc StatFunc +} + +// New creates a new PreflightChecker. +func New(cfg PreflightConfig) *PreflightChecker { + return &PreflightChecker{ + agentBinaryPath: cfg.AgentBinaryPath, + updatesDir: cfg.UpdatesDir, + statFunc: cfg.StatFunc, + } +} + +// Check runs all pre-flight checks. requiredSpace is in bytes. +func (p *PreflightChecker) Check(requiredSpace int64) *PreflightResult { + var errs []string + + if err := p.checkBinaryWritable(); err != nil { + errs = append(errs, err.Error()) + } + + if err := p.checkDirWritable(); err != nil { + errs = append(errs, err.Error()) + } + + if err := p.checkDiskSpace(requiredSpace); err != nil { + errs = append(errs, err.Error()) + } + + return &PreflightResult{ + Passed: len(errs) == 0, + Errors: errs, + } +} + +func (p *PreflightChecker) checkBinaryWritable() error { + f, err := os.OpenFile(p.agentBinaryPath, os.O_WRONLY, 0) + if err != nil { + return fmt.Errorf("agent binary %s is not writable: %w", p.agentBinaryPath, err) + } + f.Close() + return nil +} + +func (p *PreflightChecker) checkDirWritable() error { + tmpFile, err := os.CreateTemp(p.updatesDir, "preflight-*") + if err != nil { + return fmt.Errorf("updates directory %s is not writable: %w", p.updatesDir, err) + } + name := tmpFile.Name() + tmpFile.Close() + os.Remove(name) + return nil +} + +func (p *PreflightChecker) checkDiskSpace(requiredSpace int64) error { + available, err := p.statFunc(p.updatesDir) + if err != nil { + return fmt.Errorf("failed to check disk space: %w", err) + } + + needed := uint64(requiredSpace) + diskSpaceBuffer + if available < needed { + return fmt.Errorf("insufficient disk space: need %d bytes, have %d bytes", needed, available) + } + return nil +} diff --git a/app/services/updatepreflight/preflight_test.go b/app/services/updatepreflight/preflight_test.go new file mode 100644 index 0000000..e4a13da --- /dev/null +++ b/app/services/updatepreflight/preflight_test.go @@ -0,0 +1,183 @@ +package updatepreflight + +import ( + "errors" + "os" + "path/filepath" + "testing" +) + +func TestCheck_BinaryWritable(t *testing.T) { + dir := t.TempDir() + binaryPath := filepath.Join(dir, "hostlink") + os.WriteFile(binaryPath, []byte("binary"), 0444) + + checker := New(PreflightConfig{ + AgentBinaryPath: binaryPath, + UpdatesDir: dir, + StatFunc: func(path string) (uint64, error) { return 1 << 30, nil }, + }) + + result := checker.Check(10 * 1024 * 1024) // 10MB required + if result.Passed { + t.Error("expected Passed to be false when binary is not writable") + } + assertContainsError(t, result.Errors, "not writable") +} + +func TestCheck_BinaryWritable_Passes(t *testing.T) { + dir := t.TempDir() + binaryPath := filepath.Join(dir, "hostlink") + os.WriteFile(binaryPath, []byte("binary"), 0755) + + checker := New(PreflightConfig{ + AgentBinaryPath: binaryPath, + UpdatesDir: dir, + StatFunc: func(path string) (uint64, error) { return 1 << 30, nil }, + }) + + result := checker.Check(10 * 1024 * 1024) + if !result.Passed { + t.Errorf("expected Passed to be true, got errors: %v", result.Errors) + } +} + +func TestCheck_UpdatesDirWritable(t *testing.T) { + dir := t.TempDir() + readOnlyDir := filepath.Join(dir, "updates") + os.MkdirAll(readOnlyDir, 0555) + + binaryPath := filepath.Join(dir, "hostlink") + os.WriteFile(binaryPath, []byte("binary"), 0755) + + checker := New(PreflightConfig{ + AgentBinaryPath: binaryPath, + UpdatesDir: readOnlyDir, + StatFunc: func(path string) (uint64, error) { return 1 << 30, nil }, + }) + + result := checker.Check(10 * 1024 * 1024) + if result.Passed { + t.Error("expected Passed to be false when updates dir is not writable") + } + assertContainsError(t, result.Errors, "not writable") +} + +func TestCheck_DiskSpaceInsufficient(t *testing.T) { + dir := t.TempDir() + binaryPath := filepath.Join(dir, "hostlink") + os.WriteFile(binaryPath, []byte("binary"), 0755) + + checker := New(PreflightConfig{ + AgentBinaryPath: binaryPath, + UpdatesDir: dir, + StatFunc: func(path string) (uint64, error) { return 5 * 1024 * 1024, nil }, // 5MB available + }) + + // Require 10MB + 10MB buffer = 20MB, but only 5MB available + result := checker.Check(10 * 1024 * 1024) + if result.Passed { + t.Error("expected Passed to be false when disk space is insufficient") + } + assertContainsError(t, result.Errors, "disk space") +} + +func TestCheck_DiskSpaceSufficient(t *testing.T) { + dir := t.TempDir() + binaryPath := filepath.Join(dir, "hostlink") + os.WriteFile(binaryPath, []byte("binary"), 0755) + + checker := New(PreflightConfig{ + AgentBinaryPath: binaryPath, + UpdatesDir: dir, + StatFunc: func(path string) (uint64, error) { return 100 * 1024 * 1024, nil }, // 100MB + }) + + result := checker.Check(10 * 1024 * 1024) + if !result.Passed { + t.Errorf("expected Passed to be true, got errors: %v", result.Errors) + } +} + +func TestCheck_AllErrorsReported(t *testing.T) { + dir := t.TempDir() + binaryPath := filepath.Join(dir, "hostlink") + os.WriteFile(binaryPath, []byte("binary"), 0444) // not writable + + readOnlyDir := filepath.Join(dir, "updates") + os.MkdirAll(readOnlyDir, 0555) // not writable + + checker := New(PreflightConfig{ + AgentBinaryPath: binaryPath, + UpdatesDir: readOnlyDir, + StatFunc: func(path string) (uint64, error) { return 1024, nil }, // tiny + }) + + result := checker.Check(10 * 1024 * 1024) + if result.Passed { + t.Error("expected Passed to be false") + } + // Should have 3 errors: binary not writable, dir not writable, disk space + if len(result.Errors) < 3 { + t.Errorf("expected at least 3 errors, got %d: %v", len(result.Errors), result.Errors) + } +} + +func TestCheck_StatFuncError(t *testing.T) { + dir := t.TempDir() + binaryPath := filepath.Join(dir, "hostlink") + os.WriteFile(binaryPath, []byte("binary"), 0755) + + checker := New(PreflightConfig{ + AgentBinaryPath: binaryPath, + UpdatesDir: dir, + StatFunc: func(path string) (uint64, error) { return 0, errors.New("statfs failed") }, + }) + + result := checker.Check(10 * 1024 * 1024) + if result.Passed { + t.Error("expected Passed to be false when stat fails") + } + assertContainsError(t, result.Errors, "disk space") +} + +func TestCheck_BinaryNotExists(t *testing.T) { + dir := t.TempDir() + binaryPath := filepath.Join(dir, "nonexistent") + + checker := New(PreflightConfig{ + AgentBinaryPath: binaryPath, + UpdatesDir: dir, + StatFunc: func(path string) (uint64, error) { return 1 << 30, nil }, + }) + + result := checker.Check(10 * 1024 * 1024) + if result.Passed { + t.Error("expected Passed to be false when binary does not exist") + } + assertContainsError(t, result.Errors, "not writable") +} + +// assertContainsError checks that at least one error string contains the substring. +func assertContainsError(t *testing.T, errs []string, substr string) { + t.Helper() + for _, e := range errs { + if contains(e, substr) { + return + } + } + t.Errorf("expected an error containing %q, got: %v", substr, errs) +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsSubstring(s, substr)) +} + +func containsSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/cmd/updater/main.go b/cmd/updater/main.go new file mode 100644 index 0000000..8d2ac8e --- /dev/null +++ b/cmd/updater/main.go @@ -0,0 +1,102 @@ +package main + +import ( + "context" + "flag" + "fmt" + "log" + "os" + "time" + + "hostlink/internal/update" +) + +const ( + // Default paths + DefaultBinaryPath = "/usr/bin/hostlink" + DefaultBaseDir = "/var/lib/hostlink/updates" + DefaultHealthURL = "http://localhost:8080/health" + DefaultServiceName = "hostlink" +) + +func main() { + // Parse command line flags + var ( + binaryPath = flag.String("binary", DefaultBinaryPath, "Path to the agent binary") + baseDir = flag.String("base-dir", DefaultBaseDir, "Base directory for update files") + healthURL = flag.String("health-url", DefaultHealthURL, "Health check URL") + targetVersion = flag.String("version", "", "Target version to verify after update (required)") + showVersion = flag.Bool("v", false, "Print version and exit") + ) + flag.Parse() + + if *showVersion { + printVersion() + os.Exit(0) + } + + if *targetVersion == "" { + // Try to read from state file + paths := update.NewPaths(*baseDir) + stateWriter := update.NewStateWriter(update.StateConfig{StatePath: paths.StateFile}) + state, err := stateWriter.Read() + if err != nil || state.TargetVersion == "" { + log.Fatal("target version is required: use -version flag or ensure state.json has target version") + } + *targetVersion = state.TargetVersion + } + + // Build paths + paths := update.NewPaths(*baseDir) + + // Create configuration + cfg := &UpdaterConfig{ + AgentBinaryPath: *binaryPath, + BackupDir: paths.BackupDir, + StagingDir: paths.StagingDir, + LockPath: paths.LockFile, + StatePath: paths.StateFile, + HealthURL: *healthURL, + TargetVersion: *targetVersion, + ServiceStopTimeout: 30 * time.Second, + ServiceStartTimeout: 30 * time.Second, + HealthCheckRetries: 5, + HealthCheckInterval: 5 * time.Second, + HealthInitialWait: 5 * time.Second, + LockRetries: 5, + LockRetryInterval: 1 * time.Second, + } + + // Create updater + updater := NewUpdater(cfg) + updater.onPhaseChange = func(phase Phase) { + log.Printf("Phase: %s", phase) + } + + // Create context with overall timeout (90s budget) + ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second) + defer cancel() + + // Set up signal handler: SIGTERM/SIGINT simply cancel the context. + // Run() checks ctx.Err() between phases and does its own cleanup. + stopSignals := WatchSignals(cancel) + defer stopSignals() + + // Run the update - Run() owns all cleanup/rollback. + if err := updater.Run(ctx); err != nil { + log.Printf("Update failed: %v", err) + os.Exit(1) + } + + log.Println("Update completed successfully") +} + +// Version information (set via ldflags) +var ( + version = "dev" + commit = "unknown" +) + +func printVersion() { + fmt.Printf("hostlink-updater %s (%s)\n", version, commit) +} diff --git a/cmd/updater/signals.go b/cmd/updater/signals.go new file mode 100644 index 0000000..a58a053 --- /dev/null +++ b/cmd/updater/signals.go @@ -0,0 +1,27 @@ +package main + +import ( + "context" + "os" + "os/signal" + "syscall" +) + +// WatchSignals listens for SIGTERM/SIGINT and cancels the given context. +// Returns a cleanup function that stops signal watching. +func WatchSignals(cancel context.CancelFunc) (stop func()) { + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT) + + go func() { + _, ok := <-sigChan + if ok { + cancel() + } + }() + + return func() { + signal.Stop(sigChan) + close(sigChan) + } +} diff --git a/cmd/updater/signals_test.go b/cmd/updater/signals_test.go new file mode 100644 index 0000000..b7d60ff --- /dev/null +++ b/cmd/updater/signals_test.go @@ -0,0 +1,37 @@ +package main + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestWatchSignals_CancelsContextOnSignal(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + stop := WatchSignals(cancel) + defer stop() + + // Cancel directly (simulating signal effect) to verify wiring + cancel() + + assert.Error(t, ctx.Err()) + assert.ErrorIs(t, ctx.Err(), context.Canceled) +} + +func TestWatchSignals_StopPreventsCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + stop := WatchSignals(cancel) + stop() // Stop before any signal + + // Give a moment for any goroutine to fire + time.Sleep(10 * time.Millisecond) + + // Context should not be cancelled + assert.NoError(t, ctx.Err()) +} diff --git a/cmd/updater/updater.go b/cmd/updater/updater.go new file mode 100644 index 0000000..9599aa1 --- /dev/null +++ b/cmd/updater/updater.go @@ -0,0 +1,306 @@ +package main + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "hostlink/app/services/updatedownload" + "hostlink/internal/update" +) + +// Phase represents the current phase of the update process. +type Phase string + +const ( + PhaseAcquireLock Phase = "acquire_lock" + PhaseStopping Phase = "stopping" + PhaseBackup Phase = "backup" + PhaseInstalling Phase = "installing" + PhaseStarting Phase = "starting" + PhaseVerifying Phase = "verifying" + PhaseCompleted Phase = "completed" + PhaseRollback Phase = "rollback" +) + +// Default configuration values +const ( + DefaultLockRetries = 5 + DefaultLockRetryInterval = 1 * time.Second + DefaultLockExpiration = 5 * time.Minute +) + +// UpdaterConfig holds the configuration for the Updater. +type UpdaterConfig struct { + AgentBinaryPath string // /usr/bin/hostlink + BackupDir string // /var/lib/hostlink/updates/backup/ + StagingDir string // /var/lib/hostlink/updates/staging/ + LockPath string // /var/lib/hostlink/updates/update.lock + StatePath string // /var/lib/hostlink/updates/state.json + HealthURL string // http://localhost:8080/health + TargetVersion string // Version to verify after update + ServiceStopTimeout time.Duration // 30s + ServiceStartTimeout time.Duration // 30s + HealthCheckRetries int // 5 + HealthCheckInterval time.Duration // 5s + HealthInitialWait time.Duration // 5s + LockRetries int // 5 + LockRetryInterval time.Duration // 1s + SleepFunc func(time.Duration) // For testing +} + +// ServiceController interface for mocking in tests +type ServiceController interface { + Stop(ctx context.Context) error + Start(ctx context.Context) error +} + +// Updater orchestrates the update process. +type Updater struct { + config *UpdaterConfig + lock *update.LockManager + state *update.StateWriter + serviceController ServiceController + healthChecker *update.HealthChecker + currentPhase Phase + onPhaseChange func(Phase) // For testing +} + +// NewUpdater creates a new Updater with the given configuration. +func NewUpdater(cfg *UpdaterConfig) *Updater { + // Apply defaults + if cfg.LockRetries == 0 { + cfg.LockRetries = DefaultLockRetries + } + if cfg.LockRetryInterval == 0 { + cfg.LockRetryInterval = DefaultLockRetryInterval + } + if cfg.ServiceStopTimeout == 0 { + cfg.ServiceStopTimeout = update.DefaultStopTimeout + } + if cfg.ServiceStartTimeout == 0 { + cfg.ServiceStartTimeout = update.DefaultStartTimeout + } + if cfg.HealthCheckRetries == 0 { + cfg.HealthCheckRetries = update.DefaultHealthRetries + } + if cfg.HealthCheckInterval == 0 { + cfg.HealthCheckInterval = update.DefaultHealthInterval + } + if cfg.HealthInitialWait == 0 { + cfg.HealthInitialWait = update.DefaultInitialWait + } + + return &Updater{ + config: cfg, + lock: update.NewLockManager(update.LockConfig{ + LockPath: cfg.LockPath, + }), + state: update.NewStateWriter(update.StateConfig{ + StatePath: cfg.StatePath, + }), + serviceController: update.NewServiceController(update.ServiceConfig{ + ServiceName: "hostlink", + StopTimeout: cfg.ServiceStopTimeout, + StartTimeout: cfg.ServiceStartTimeout, + }), + healthChecker: update.NewHealthChecker(update.HealthConfig{ + URL: cfg.HealthURL, + TargetVersion: cfg.TargetVersion, + MaxRetries: cfg.HealthCheckRetries, + RetryInterval: cfg.HealthCheckInterval, + InitialWait: cfg.HealthInitialWait, + SleepFunc: cfg.SleepFunc, + }), + } +} + +// setPhase updates the current phase and calls the callback if set. +func (u *Updater) setPhase(phase Phase) { + u.currentPhase = phase + if u.onPhaseChange != nil { + u.onPhaseChange(phase) + } +} + +// Run executes the full update process: +// lock → stop → backup → install → start → verify → cleanup → unlock +// +// Run owns all cleanup. If ctx is cancelled (e.g. by a signal), Run aborts +// between phases and ensures the service is left running. +func (u *Updater) Run(ctx context.Context) error { + // Clean up any leftover temp files first + u.cleanupTempFiles() + + // Phase 1: Acquire lock + u.setPhase(PhaseAcquireLock) + if err := u.lock.TryLockWithRetry(DefaultLockExpiration, u.config.LockRetries, u.config.LockRetryInterval); err != nil { + return fmt.Errorf("failed to acquire lock: %w", err) + } + defer u.lock.Unlock() + + // serviceStopped tracks whether we've stopped the service. + // If true, any abort path must restart it. + serviceStopped := false + + // abort restarts the service (if stopped) using a background context + // since the original ctx may be cancelled. + abort := func(reason error) error { + if serviceStopped { + u.serviceController.Start(context.Background()) + } + return reason + } + + // Check for cancellation before stopping + if ctx.Err() != nil { + return abort(ctx.Err()) + } + + // Phase 2: Stop service + u.setPhase(PhaseStopping) + if err := u.serviceController.Stop(ctx); err != nil { + if ctx.Err() != nil { + return abort(ctx.Err()) + } + return fmt.Errorf("failed to stop service: %w", err) + } + serviceStopped = true + + // Check for cancellation after stop + if ctx.Err() != nil { + return abort(ctx.Err()) + } + + // Phase 3: Backup current binary + u.setPhase(PhaseBackup) + if err := update.BackupBinary(u.config.AgentBinaryPath, u.config.BackupDir); err != nil { + return abort(fmt.Errorf("failed to backup binary: %w", err)) + } + + // Check for cancellation after backup + if ctx.Err() != nil { + return abort(ctx.Err()) + } + + // Phase 4: Install new binary + u.setPhase(PhaseInstalling) + tarballPath := filepath.Join(u.config.StagingDir, updatedownload.AgentTarballName) + if err := update.InstallBinary(tarballPath, u.config.AgentBinaryPath); err != nil { + u.rollbackFrom(PhaseInstalling) + return fmt.Errorf("failed to install binary: %w", err) + } + + // Check for cancellation after install - rollback needed + if ctx.Err() != nil { + u.rollbackFrom(PhaseInstalling) + return ctx.Err() + } + + // Phase 5: Start service + // Use background context: even if cancelled, we must start the service + // since the binary is already installed. + u.setPhase(PhaseStarting) + if err := u.serviceController.Start(context.Background()); err != nil { + u.rollbackFrom(PhaseStarting) + return fmt.Errorf("failed to start service: %w", err) + } + serviceStopped = false + + // Check for cancellation after start - service is running, + // skip verification and exit cleanly. + if ctx.Err() != nil { + return ctx.Err() + } + + // Phase 6: Verify health + u.setPhase(PhaseVerifying) + if err := u.healthChecker.WaitForHealth(ctx); err != nil { + if ctx.Err() != nil { + // Cancelled during verification - service is running, just exit. + return ctx.Err() + } + // Health check failed (not due to cancellation) - rollback + u.rollbackFrom(PhaseVerifying) + return fmt.Errorf("health check failed: %w", err) + } + + // Phase 7: Success! + u.setPhase(PhaseCompleted) + + // Update state to completed + u.state.Write(update.StateData{ + State: update.StateCompleted, + TargetVersion: u.config.TargetVersion, + CompletedAt: timePtr(time.Now()), + }) + + return nil +} + +// Rollback restores the backup and starts the service. +// Uses a background context for all operations since this may be called +// after the original context is cancelled. +func (u *Updater) Rollback(ctx context.Context) error { + return u.rollbackFrom(PhaseVerifying) +} + +// rollbackFrom restores the backup and starts the service. +// Uses context.Background() for all operations since this is cleanup +// that must complete regardless of cancellation state. +func (u *Updater) rollbackFrom(failedPhase Phase) error { + u.setPhase(PhaseRollback) + + // Update state to rollback in progress + u.state.Write(update.StateData{ + State: update.StateRollback, + TargetVersion: u.config.TargetVersion, + }) + + // Stop the service first (best-effort) - it may still be running the bad binary + u.serviceController.Stop(context.Background()) + + // Restore backup + if err := update.RestoreBackup(u.config.BackupDir, u.config.AgentBinaryPath); err != nil { + return fmt.Errorf("failed to restore backup: %w", err) + } + + // Start service with old binary (use background context - must complete) + if err := u.serviceController.Start(context.Background()); err != nil { + return fmt.Errorf("failed to start service after rollback: %w", err) + } + + // Update state to rolled back + u.state.Write(update.StateData{ + State: update.StateRolledBack, + TargetVersion: u.config.TargetVersion, + CompletedAt: timePtr(time.Now()), + }) + + // Clean up temp files + u.cleanupTempFiles() + + return nil +} + +// cleanupTempFiles removes any leftover hostlink.tmp.* files. +func (u *Updater) cleanupTempFiles() { + dir := filepath.Dir(u.config.AgentBinaryPath) + entries, err := os.ReadDir(dir) + if err != nil { + return + } + + for _, entry := range entries { + if strings.HasPrefix(entry.Name(), "hostlink.tmp.") { + os.Remove(filepath.Join(dir, entry.Name())) + } + } +} + +func timePtr(t time.Time) *time.Time { + return &t +} diff --git a/cmd/updater/updater_test.go b/cmd/updater/updater_test.go new file mode 100644 index 0000000..6c8aafc --- /dev/null +++ b/cmd/updater/updater_test.go @@ -0,0 +1,644 @@ +package main + +import ( + "archive/tar" + "compress/gzip" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "hostlink/internal/update" +) + +// Test helpers +func createTestBinary(t *testing.T, path string, content []byte) { + t.Helper() + dir := filepath.Dir(path) + require.NoError(t, os.MkdirAll(dir, 0755)) + require.NoError(t, os.WriteFile(path, content, 0755)) +} + +func createTestTarball(t *testing.T, path string, binaryContent []byte) { + t.Helper() + dir := filepath.Dir(path) + require.NoError(t, os.MkdirAll(dir, 0755)) + + // Create a tar.gz file using archive/tar and compress/gzip + f, err := os.Create(path) + require.NoError(t, err) + defer f.Close() + + gw := newGzipWriter(f) + defer gw.Close() + + tw := newTarWriter(gw) + defer tw.Close() + + require.NoError(t, tw.WriteHeader(&tarHeader{ + Name: "hostlink", + Mode: 0755, + Size: int64(len(binaryContent)), + })) + _, err = tw.Write(binaryContent) + require.NoError(t, err) +} + +func TestUpdater_Run_HappyPath(t *testing.T) { + tmpDir := t.TempDir() + + // Setup paths + binaryPath := filepath.Join(tmpDir, "usr", "bin", "hostlink") + backupDir := filepath.Join(tmpDir, "backup") + stagingDir := filepath.Join(tmpDir, "staging") + lockPath := filepath.Join(tmpDir, "update.lock") + statePath := filepath.Join(tmpDir, "state.json") + + // Create current binary + createTestBinary(t, binaryPath, []byte("old binary v1.0.0")) + + // Create staged tarball with new binary + tarballPath := filepath.Join(stagingDir, "hostlink.tar.gz") + createTestTarball(t, tarballPath, []byte("new binary v2.0.0")) + + // Mock health server + healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(update.HealthResponse{Ok: true, Version: "v2.0.0"}) + })) + defer healthServer.Close() + + // Create updater + u := NewUpdater(&UpdaterConfig{ + AgentBinaryPath: binaryPath, + BackupDir: backupDir, + StagingDir: stagingDir, + LockPath: lockPath, + StatePath: statePath, + HealthURL: healthServer.URL, + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + HealthCheckRetries: 1, + HealthCheckInterval: 10 * time.Millisecond, + HealthInitialWait: 1 * time.Millisecond, + SleepFunc: func(d time.Duration) {}, // No-op for tests + }) + + // Mock service controller (no real systemctl) + u.serviceController = &mockServiceController{} + + err := u.Run(context.Background()) + require.NoError(t, err) + + // Verify new binary is installed + content, err := os.ReadFile(binaryPath) + require.NoError(t, err) + assert.Equal(t, []byte("new binary v2.0.0"), content) + + // Verify backup exists + backupContent, err := os.ReadFile(filepath.Join(backupDir, "hostlink")) + require.NoError(t, err) + assert.Equal(t, []byte("old binary v1.0.0"), backupContent) +} + +func TestUpdater_Run_RollbackOnHealthCheckFailure(t *testing.T) { + tmpDir := t.TempDir() + + // Setup paths + binaryPath := filepath.Join(tmpDir, "usr", "bin", "hostlink") + backupDir := filepath.Join(tmpDir, "backup") + stagingDir := filepath.Join(tmpDir, "staging") + lockPath := filepath.Join(tmpDir, "update.lock") + statePath := filepath.Join(tmpDir, "state.json") + + // Create current binary + oldContent := []byte("old binary v1.0.0") + createTestBinary(t, binaryPath, oldContent) + + // Create staged tarball with new binary + tarballPath := filepath.Join(stagingDir, "hostlink.tar.gz") + createTestTarball(t, tarballPath, []byte("new binary v2.0.0")) + + // Mock health server that always returns unhealthy + healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(update.HealthResponse{Ok: false, Version: "v2.0.0"}) + })) + defer healthServer.Close() + + u := NewUpdater(&UpdaterConfig{ + AgentBinaryPath: binaryPath, + BackupDir: backupDir, + StagingDir: stagingDir, + LockPath: lockPath, + StatePath: statePath, + HealthURL: healthServer.URL, + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + HealthCheckRetries: 1, + HealthCheckInterval: 10 * time.Millisecond, + HealthInitialWait: 1 * time.Millisecond, + SleepFunc: func(d time.Duration) {}, + }) + + u.serviceController = &mockServiceController{} + + err := u.Run(context.Background()) + assert.Error(t, err) + + // Verify rollback occurred - old binary should be restored + content, err := os.ReadFile(binaryPath) + require.NoError(t, err) + assert.Equal(t, oldContent, content) +} + +func TestUpdater_Run_LockAcquisitionFailure(t *testing.T) { + tmpDir := t.TempDir() + + lockPath := filepath.Join(tmpDir, "update.lock") + + // Acquire lock first with another lock manager to simulate contention + otherLock := update.NewLockManager(update.LockConfig{LockPath: lockPath}) + require.NoError(t, otherLock.TryLock(1*time.Hour)) + defer otherLock.Unlock() + + u := NewUpdater(&UpdaterConfig{ + AgentBinaryPath: filepath.Join(tmpDir, "hostlink"), + BackupDir: filepath.Join(tmpDir, "backup"), + StagingDir: filepath.Join(tmpDir, "staging"), + LockPath: lockPath, + StatePath: filepath.Join(tmpDir, "state.json"), + HealthURL: "http://localhost:8080/health", + TargetVersion: "v2.0.0", + LockRetries: 1, + LockRetryInterval: 10 * time.Millisecond, + }) + + err := u.Run(context.Background()) + assert.Error(t, err) + assert.ErrorIs(t, err, update.ErrLockAcquireFailed) +} + +func TestUpdater_Run_CleansUpTempFiles(t *testing.T) { + tmpDir := t.TempDir() + + binaryPath := filepath.Join(tmpDir, "usr", "bin", "hostlink") + backupDir := filepath.Join(tmpDir, "backup") + stagingDir := filepath.Join(tmpDir, "staging") + lockPath := filepath.Join(tmpDir, "update.lock") + statePath := filepath.Join(tmpDir, "state.json") + + // Create current binary + createTestBinary(t, binaryPath, []byte("old binary")) + + // Create leftover temp files + tempFile1 := filepath.Join(tmpDir, "usr", "bin", "hostlink.tmp.abc123") + tempFile2 := filepath.Join(tmpDir, "usr", "bin", "hostlink.tmp.def456") + require.NoError(t, os.WriteFile(tempFile1, []byte("temp"), 0755)) + require.NoError(t, os.WriteFile(tempFile2, []byte("temp"), 0755)) + + // Create staged tarball + tarballPath := filepath.Join(stagingDir, "hostlink.tar.gz") + createTestTarball(t, tarballPath, []byte("new binary v2.0.0")) + + // Mock health server + healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(update.HealthResponse{Ok: true, Version: "v2.0.0"}) + })) + defer healthServer.Close() + + u := NewUpdater(&UpdaterConfig{ + AgentBinaryPath: binaryPath, + BackupDir: backupDir, + StagingDir: stagingDir, + LockPath: lockPath, + StatePath: statePath, + HealthURL: healthServer.URL, + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + HealthCheckRetries: 1, + HealthCheckInterval: 10 * time.Millisecond, + HealthInitialWait: 1 * time.Millisecond, + SleepFunc: func(d time.Duration) {}, + }) + + u.serviceController = &mockServiceController{} + + err := u.Run(context.Background()) + require.NoError(t, err) + + // Verify temp files were cleaned up + entries, err := os.ReadDir(filepath.Dir(binaryPath)) + require.NoError(t, err) + for _, entry := range entries { + assert.NotContains(t, entry.Name(), ".tmp.", "temp files should be cleaned up") + } +} + +func TestUpdater_Rollback_RestoresAndStartsService(t *testing.T) { + tmpDir := t.TempDir() + + binaryPath := filepath.Join(tmpDir, "hostlink") + backupDir := filepath.Join(tmpDir, "backup") + statePath := filepath.Join(tmpDir, "state.json") + + // Create backup + backupContent := []byte("backup binary v1.0.0") + require.NoError(t, os.MkdirAll(backupDir, 0755)) + require.NoError(t, os.WriteFile(filepath.Join(backupDir, "hostlink"), backupContent, 0755)) + + // Create "broken" current binary + require.NoError(t, os.WriteFile(binaryPath, []byte("broken"), 0755)) + + // Track service call order + var callOrder []string + mockSvc := &mockServiceController{ + onStop: func() { callOrder = append(callOrder, "stop") }, + onStart: func() { callOrder = append(callOrder, "start") }, + } + + u := NewUpdater(&UpdaterConfig{ + AgentBinaryPath: binaryPath, + BackupDir: backupDir, + StagingDir: filepath.Join(tmpDir, "staging"), + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: statePath, + HealthURL: "http://localhost:8080/health", + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + }) + u.serviceController = mockSvc + + err := u.Rollback(context.Background()) + require.NoError(t, err) + + // Verify binary was restored + content, err := os.ReadFile(binaryPath) + require.NoError(t, err) + assert.Equal(t, backupContent, content) + + // Verify service was stopped then started (in that order) + assert.True(t, mockSvc.stopCalled) + assert.True(t, mockSvc.startCalled) + assert.Equal(t, []string{"stop", "start"}, callOrder) +} + +func TestUpdater_Rollback_UpdatesStateToRolledBack(t *testing.T) { + tmpDir := t.TempDir() + + binaryPath := filepath.Join(tmpDir, "hostlink") + backupDir := filepath.Join(tmpDir, "backup") + statePath := filepath.Join(tmpDir, "state.json") + + // Create backup + require.NoError(t, os.MkdirAll(backupDir, 0755)) + require.NoError(t, os.WriteFile(filepath.Join(backupDir, "hostlink"), []byte("backup"), 0755)) + + // Create current binary + require.NoError(t, os.WriteFile(binaryPath, []byte("current"), 0755)) + + u := NewUpdater(&UpdaterConfig{ + AgentBinaryPath: binaryPath, + BackupDir: backupDir, + StagingDir: filepath.Join(tmpDir, "staging"), + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: statePath, + HealthURL: "http://localhost:8080/health", + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + }) + u.serviceController = &mockServiceController{} + + err := u.Rollback(context.Background()) + require.NoError(t, err) + + // Verify state was updated + stateWriter := update.NewStateWriter(update.StateConfig{StatePath: statePath}) + state, err := stateWriter.Read() + require.NoError(t, err) + assert.Equal(t, update.StateRolledBack, state.State) +} + +func TestUpdater_UpdatePhases(t *testing.T) { + tmpDir := t.TempDir() + + binaryPath := filepath.Join(tmpDir, "hostlink") + backupDir := filepath.Join(tmpDir, "backup") + stagingDir := filepath.Join(tmpDir, "staging") + statePath := filepath.Join(tmpDir, "state.json") + + // Create current binary + createTestBinary(t, binaryPath, []byte("old binary")) + + // Create staged tarball + tarballPath := filepath.Join(stagingDir, "hostlink.tar.gz") + createTestTarball(t, tarballPath, []byte("new binary v2.0.0")) + + // Track phase transitions + var phases []string + + healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(update.HealthResponse{Ok: true, Version: "v2.0.0"}) + })) + defer healthServer.Close() + + mockSvc := &mockServiceController{ + onStop: func() { phases = append(phases, "stop") }, + onStart: func() { phases = append(phases, "start") }, + } + + u := NewUpdater(&UpdaterConfig{ + AgentBinaryPath: binaryPath, + BackupDir: backupDir, + StagingDir: stagingDir, + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: statePath, + HealthURL: healthServer.URL, + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + HealthCheckRetries: 1, + HealthCheckInterval: 10 * time.Millisecond, + HealthInitialWait: 1 * time.Millisecond, + SleepFunc: func(d time.Duration) {}, + }) + u.serviceController = mockSvc + u.onPhaseChange = func(phase Phase) { + phases = append(phases, string(phase)) + } + + err := u.Run(context.Background()) + require.NoError(t, err) + + // Verify phases executed in order + expectedPhases := []string{ + string(PhaseAcquireLock), + string(PhaseStopping), + "stop", + string(PhaseBackup), + string(PhaseInstalling), + string(PhaseStarting), + "start", + string(PhaseVerifying), + string(PhaseCompleted), + } + assert.Equal(t, expectedPhases, phases) +} + +func TestUpdater_Run_CancelledBeforeStop(t *testing.T) { + tmpDir := t.TempDir() + + binaryPath := filepath.Join(tmpDir, "hostlink") + createTestBinary(t, binaryPath, []byte("binary")) + + mockSvc := &mockServiceController{} + + u := NewUpdater(&UpdaterConfig{ + AgentBinaryPath: binaryPath, + BackupDir: filepath.Join(tmpDir, "backup"), + StagingDir: filepath.Join(tmpDir, "staging"), + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: filepath.Join(tmpDir, "state.json"), + HealthURL: "http://localhost:8080/health", + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + HealthInitialWait: 1 * time.Millisecond, + SleepFunc: func(d time.Duration) {}, + }) + u.serviceController = mockSvc + + // Cancel context before calling Run + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := u.Run(ctx) + + assert.ErrorIs(t, err, context.Canceled) + assert.False(t, mockSvc.stopCalled, "service should not have been stopped") + assert.False(t, mockSvc.startCalled, "service should not have been started") +} + +func TestUpdater_Run_CancelledAfterStop(t *testing.T) { + tmpDir := t.TempDir() + + binaryPath := filepath.Join(tmpDir, "hostlink") + createTestBinary(t, binaryPath, []byte("binary")) + stagingDir := filepath.Join(tmpDir, "staging") + createTestTarball(t, filepath.Join(stagingDir, "hostlink.tar.gz"), []byte("new")) + + ctx, cancel := context.WithCancel(context.Background()) + + mockSvc := &mockServiceController{ + onStop: func() { + cancel() // Cancel after stop completes + }, + } + + u := NewUpdater(&UpdaterConfig{ + AgentBinaryPath: binaryPath, + BackupDir: filepath.Join(tmpDir, "backup"), + StagingDir: stagingDir, + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: filepath.Join(tmpDir, "state.json"), + HealthURL: "http://localhost:8080/health", + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + HealthInitialWait: 1 * time.Millisecond, + SleepFunc: func(d time.Duration) {}, + }) + u.serviceController = mockSvc + + err := u.Run(ctx) + + assert.ErrorIs(t, err, context.Canceled) + assert.True(t, mockSvc.stopCalled) + assert.True(t, mockSvc.startCalled, "service must be restarted after being stopped") +} + +func TestUpdater_Run_CancelledAfterInstall(t *testing.T) { + tmpDir := t.TempDir() + + binaryPath := filepath.Join(tmpDir, "hostlink") + oldContent := []byte("old binary v1.0.0") + createTestBinary(t, binaryPath, oldContent) + stagingDir := filepath.Join(tmpDir, "staging") + createTestTarball(t, filepath.Join(stagingDir, "hostlink.tar.gz"), []byte("new binary")) + + ctx, cancel := context.WithCancel(context.Background()) + + u := NewUpdater(&UpdaterConfig{ + AgentBinaryPath: binaryPath, + BackupDir: filepath.Join(tmpDir, "backup"), + StagingDir: stagingDir, + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: filepath.Join(tmpDir, "state.json"), + HealthURL: "http://localhost:8080/health", + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + HealthInitialWait: 1 * time.Millisecond, + SleepFunc: func(d time.Duration) {}, + }) + + // Cancel right after install phase begins + u.serviceController = &mockServiceController{} + u.onPhaseChange = func(phase Phase) { + if phase == PhaseInstalling { + // Let install complete, cancel right after + // We use a goroutine to cancel after a tiny delay + } + if phase == PhaseStarting { + // Cancel before start completes its inner work + cancel() + } + } + + err := u.Run(ctx) + + // Context was cancelled after install, during start. + // Start uses context.Background() so it should still succeed. + // But since ctx is cancelled after start returns, verification is skipped. + assert.ErrorIs(t, err, context.Canceled) + + // Service should have been started (start uses Background ctx) + svc := u.serviceController.(*mockServiceController) + assert.True(t, svc.startCalled) +} + +func TestUpdater_Run_CancelledDuringVerification(t *testing.T) { + tmpDir := t.TempDir() + + binaryPath := filepath.Join(tmpDir, "hostlink") + createTestBinary(t, binaryPath, []byte("old binary")) + stagingDir := filepath.Join(tmpDir, "staging") + createTestTarball(t, filepath.Join(stagingDir, "hostlink.tar.gz"), []byte("new binary")) + + ctx, cancel := context.WithCancel(context.Background()) + + // Health server that blocks until context is cancelled + healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Wait for context cancellation to simulate slow health check + <-ctx.Done() + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer healthServer.Close() + + u := NewUpdater(&UpdaterConfig{ + AgentBinaryPath: binaryPath, + BackupDir: filepath.Join(tmpDir, "backup"), + StagingDir: stagingDir, + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: filepath.Join(tmpDir, "state.json"), + HealthURL: healthServer.URL, + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + HealthCheckRetries: 1, + HealthCheckInterval: 10 * time.Millisecond, + HealthInitialWait: 1 * time.Millisecond, + SleepFunc: func(d time.Duration) {}, + }) + u.serviceController = &mockServiceController{} + + // Cancel during verification + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + err := u.Run(ctx) + + // Should return context.Canceled, NOT trigger rollback + assert.ErrorIs(t, err, context.Canceled) + + // Binary should NOT be rolled back (service is running with new binary) + content, err := os.ReadFile(binaryPath) + require.NoError(t, err) + assert.Equal(t, []byte("new binary"), content) +} + +func TestUpdater_Run_NoDoubleUnlock(t *testing.T) { + tmpDir := t.TempDir() + + binaryPath := filepath.Join(tmpDir, "hostlink") + createTestBinary(t, binaryPath, []byte("binary")) + + // Pre-cancelled context + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + u := NewUpdater(&UpdaterConfig{ + AgentBinaryPath: binaryPath, + BackupDir: filepath.Join(tmpDir, "backup"), + StagingDir: filepath.Join(tmpDir, "staging"), + LockPath: filepath.Join(tmpDir, "update.lock"), + StatePath: filepath.Join(tmpDir, "state.json"), + HealthURL: "http://localhost:8080/health", + TargetVersion: "v2.0.0", + ServiceStopTimeout: 100 * time.Millisecond, + ServiceStartTimeout: 100 * time.Millisecond, + HealthInitialWait: 1 * time.Millisecond, + SleepFunc: func(d time.Duration) {}, + }) + u.serviceController = &mockServiceController{} + + // This should not panic - only Run() unlocks via defer + err := u.Run(ctx) + assert.ErrorIs(t, err, context.Canceled) + + // Lock file should not exist (properly unlocked once) + _, statErr := os.Stat(filepath.Join(tmpDir, "update.lock")) + assert.True(t, os.IsNotExist(statErr), "lock should be released") +} + +// Mock service controller for testing +type mockServiceController struct { + stopCalled bool + startCalled bool + stopErr error + startErr error + onStop func() + onStart func() +} + +func (m *mockServiceController) Stop(ctx context.Context) error { + m.stopCalled = true + if m.onStop != nil { + m.onStop() + } + return m.stopErr +} + +func (m *mockServiceController) Start(ctx context.Context) error { + m.startCalled = true + if m.onStart != nil { + m.onStart() + } + return m.startErr +} + +func newGzipWriter(w *os.File) *gzip.Writer { + return gzip.NewWriter(w) +} + +func newTarWriter(gw *gzip.Writer) *tar.Writer { + return tar.NewWriter(gw) +} + +type tarHeader = tar.Header diff --git a/config/appconf/appconf.go b/config/appconf/appconf.go index 902c9a2..ac6b061 100644 --- a/config/appconf/appconf.go +++ b/config/appconf/appconf.go @@ -2,10 +2,15 @@ package appconf import ( + "os" + "strings" + "time" + + log "github.com/sirupsen/logrus" + "hostlink/config" devconf "hostlink/config/environments/development" prodconf "hostlink/config/environments/production" - "os" ) var appconf config.AppConfiger @@ -51,6 +56,65 @@ func AgentStatePath() string { return "/var/lib/hostlink" } +// SelfUpdateEnabled returns whether the self-update feature is enabled. +// Controlled by HOSTLINK_SELF_UPDATE_ENABLED (default: true). +func SelfUpdateEnabled() bool { + v := strings.TrimSpace(os.Getenv("HOSTLINK_SELF_UPDATE_ENABLED")) + if v == "" { + return true + } + switch strings.ToLower(v) { + case "false", "0", "no": + return false + default: + return true + } +} + +// UpdateCheckInterval returns the interval between update checks. +// Controlled by HOSTLINK_UPDATE_CHECK_INTERVAL (default: 1h, clamped to [1m, 24h]). +func UpdateCheckInterval() time.Duration { + const ( + defaultInterval = 1 * time.Hour + minInterval = 1 * time.Minute + maxInterval = 24 * time.Hour + ) + return parseDurationClamped("HOSTLINK_UPDATE_CHECK_INTERVAL", defaultInterval, minInterval, maxInterval) +} + +// UpdateLockTimeout returns the lock expiration duration for self-updates. +// Controlled by HOSTLINK_UPDATE_LOCK_TIMEOUT (default: 5m, clamped to [1m, 30m]). +func UpdateLockTimeout() time.Duration { + const ( + defaultTimeout = 5 * time.Minute + minTimeout = 1 * time.Minute + maxTimeout = 30 * time.Minute + ) + return parseDurationClamped("HOSTLINK_UPDATE_LOCK_TIMEOUT", defaultTimeout, minTimeout, maxTimeout) +} + +// parseDurationClamped reads a duration from an environment variable, clamping +// it to [min, max]. Returns defaultVal if the env var is empty or unparseable. +func parseDurationClamped(envVar string, defaultVal, min, max time.Duration) time.Duration { + v := strings.TrimSpace(os.Getenv(envVar)) + if v == "" { + return defaultVal + } + d, err := time.ParseDuration(v) + if err != nil { + log.Warnf("invalid %s value %q, using default %s", envVar, v, defaultVal) + return defaultVal + } + if d < min { + log.Warnf("%s value %s below minimum %s, clamping to %s", envVar, d, min, min) + return min + } + if d > max { + log.Warnf("%s value %s above maximum %s, clamping to %s", envVar, d, max, max) + return max + } + return d +} func init() { env := os.Getenv("APP_ENV") diff --git a/config/appconf/appconf_test.go b/config/appconf/appconf_test.go new file mode 100644 index 0000000..c6d3279 --- /dev/null +++ b/config/appconf/appconf_test.go @@ -0,0 +1,78 @@ +package appconf + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestSelfUpdateEnabled_DefaultTrue(t *testing.T) { + t.Setenv("HOSTLINK_SELF_UPDATE_ENABLED", "") + assert.True(t, SelfUpdateEnabled()) +} + +func TestSelfUpdateEnabled_ExplicitTrue(t *testing.T) { + t.Setenv("HOSTLINK_SELF_UPDATE_ENABLED", "true") + assert.True(t, SelfUpdateEnabled()) +} + +func TestSelfUpdateEnabled_ExplicitFalse(t *testing.T) { + t.Setenv("HOSTLINK_SELF_UPDATE_ENABLED", "false") + assert.False(t, SelfUpdateEnabled()) +} + +func TestSelfUpdateEnabled_InvalidFallsToDefault(t *testing.T) { + t.Setenv("HOSTLINK_SELF_UPDATE_ENABLED", "garbage") + assert.True(t, SelfUpdateEnabled()) +} + +func TestUpdateCheckInterval_Default1h(t *testing.T) { + t.Setenv("HOSTLINK_UPDATE_CHECK_INTERVAL", "") + assert.Equal(t, 1*time.Hour, UpdateCheckInterval()) +} + +func TestUpdateCheckInterval_CustomValue(t *testing.T) { + t.Setenv("HOSTLINK_UPDATE_CHECK_INTERVAL", "30m") + assert.Equal(t, 30*time.Minute, UpdateCheckInterval()) +} + +func TestUpdateCheckInterval_ClampedToMin(t *testing.T) { + t.Setenv("HOSTLINK_UPDATE_CHECK_INTERVAL", "10s") + assert.Equal(t, 1*time.Minute, UpdateCheckInterval()) +} + +func TestUpdateCheckInterval_ClampedToMax(t *testing.T) { + t.Setenv("HOSTLINK_UPDATE_CHECK_INTERVAL", "48h") + assert.Equal(t, 24*time.Hour, UpdateCheckInterval()) +} + +func TestUpdateCheckInterval_InvalidFallsToDefault(t *testing.T) { + t.Setenv("HOSTLINK_UPDATE_CHECK_INTERVAL", "garbage") + assert.Equal(t, 1*time.Hour, UpdateCheckInterval()) +} + +func TestUpdateLockTimeout_Default5m(t *testing.T) { + t.Setenv("HOSTLINK_UPDATE_LOCK_TIMEOUT", "") + assert.Equal(t, 5*time.Minute, UpdateLockTimeout()) +} + +func TestUpdateLockTimeout_CustomValue(t *testing.T) { + t.Setenv("HOSTLINK_UPDATE_LOCK_TIMEOUT", "10m") + assert.Equal(t, 10*time.Minute, UpdateLockTimeout()) +} + +func TestUpdateLockTimeout_ClampedToMin(t *testing.T) { + t.Setenv("HOSTLINK_UPDATE_LOCK_TIMEOUT", "10s") + assert.Equal(t, 1*time.Minute, UpdateLockTimeout()) +} + +func TestUpdateLockTimeout_ClampedToMax(t *testing.T) { + t.Setenv("HOSTLINK_UPDATE_LOCK_TIMEOUT", "2h") + assert.Equal(t, 30*time.Minute, UpdateLockTimeout()) +} + +func TestUpdateLockTimeout_InvalidFallsToDefault(t *testing.T) { + t.Setenv("HOSTLINK_UPDATE_LOCK_TIMEOUT", "garbage") + assert.Equal(t, 5*time.Minute, UpdateLockTimeout()) +} diff --git a/internal/update/binary.go b/internal/update/binary.go new file mode 100644 index 0000000..7728f17 --- /dev/null +++ b/internal/update/binary.go @@ -0,0 +1,246 @@ +package update + +import ( + "archive/tar" + "compress/gzip" + "crypto/rand" + "encoding/hex" + "fmt" + "io" + "os" + "path/filepath" +) + +const ( + // BinaryPermissions is the file permission for installed binaries. + BinaryPermissions = 0755 + // BackupFilename is the name of the backup file in the backup directory. + BackupFilename = "hostlink" + // AgentBinaryName is the binary name inside the agent tarball. + AgentBinaryName = "hostlink" + // UpdaterBinaryName is the binary name inside the updater tarball. + UpdaterBinaryName = "hostlink-updater" + // MaxBinarySize is the maximum allowed size for an extracted binary (100MB). + MaxBinarySize = 100 * 1024 * 1024 +) + +// BackupBinary copies the binary at srcPath to the backup directory. +// It creates the backup directory if it doesn't exist. +// It overwrites any existing backup. +func BackupBinary(srcPath, backupDir string) error { + // Open source file + src, err := os.Open(srcPath) + if err != nil { + return err + } + defer src.Close() + + // Get source file info for permissions + srcInfo, err := src.Stat() + if err != nil { + return fmt.Errorf("failed to stat source file: %w", err) + } + + // Create backup directory + if err := os.MkdirAll(backupDir, DirPermissions); err != nil { + return fmt.Errorf("failed to create backup directory: %w", err) + } + + // Create backup file + backupPath := filepath.Join(backupDir, BackupFilename) + dst, err := os.OpenFile(backupPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, srcInfo.Mode().Perm()) + if err != nil { + return fmt.Errorf("failed to create backup file: %w", err) + } + defer dst.Close() + + // Copy content + if _, err := io.Copy(dst, src); err != nil { + return fmt.Errorf("failed to copy to backup: %w", err) + } + + return nil +} + +// InstallBinary extracts the binary from a tar.gz file and installs it atomically. +// It extracts the file named "hostlink" from the tarball to destPath. +// Uses atomic rename to ensure the install is atomic. +func InstallBinary(tarPath, destPath string) error { + return installBinaryFromTarGz(tarPath, AgentBinaryName, destPath) +} + +// InstallUpdaterBinary extracts the updater binary from a tar.gz file and installs it atomically. +// It extracts the file named "hostlink-updater" from the tarball to destPath. +// Uses atomic rename to ensure the install is atomic. +func InstallUpdaterBinary(tarPath, destPath string) error { + return installBinaryFromTarGz(tarPath, UpdaterBinaryName, destPath) +} + +// installBinaryFromTarGz extracts a named binary from a tar.gz and installs it atomically. +func installBinaryFromTarGz(tarPath, binaryName, destPath string) error { + // Create destination directory + destDir := filepath.Dir(destPath) + if err := os.MkdirAll(destDir, 0755); err != nil { + return fmt.Errorf("failed to create destination directory: %w", err) + } + + // Generate temp file path + randSuffix, err := randomHex(8) + if err != nil { + return fmt.Errorf("failed to generate random suffix: %w", err) + } + tmpPath := destPath + ".tmp." + randSuffix + + // Clean up temp file on error + defer func() { + if tmpPath != "" { + os.Remove(tmpPath) + } + }() + + // Extract binary to temp path + if err := extractBinaryFromTarGz(tarPath, binaryName, tmpPath); err != nil { + return fmt.Errorf("failed to extract binary: %w", err) + } + + // Set permissions + if err := os.Chmod(tmpPath, BinaryPermissions); err != nil { + return fmt.Errorf("failed to set permissions: %w", err) + } + + // Atomic rename + if err := os.Rename(tmpPath, destPath); err != nil { + return fmt.Errorf("failed to install binary: %w", err) + } + + // Success - don't clean up the temp file (it's been renamed) + tmpPath = "" + return nil +} + +// RestoreBackup restores the binary from backup to destPath. +// Uses atomic rename for safe restoration. +func RestoreBackup(backupDir, destPath string) error { + backupPath := filepath.Join(backupDir, BackupFilename) + + // Check backup exists + srcInfo, err := os.Stat(backupPath) + if err != nil { + return fmt.Errorf("failed to stat backup: %w", err) + } + + // Open backup file + src, err := os.Open(backupPath) + if err != nil { + return fmt.Errorf("failed to open backup: %w", err) + } + defer src.Close() + + // Generate temp file path + randSuffix, err := randomHex(8) + if err != nil { + return fmt.Errorf("failed to generate random suffix: %w", err) + } + tmpPath := destPath + ".tmp." + randSuffix + + // Clean up temp file on error + defer func() { + if tmpPath != "" { + os.Remove(tmpPath) + } + }() + + // Create temp file + dst, err := os.OpenFile(tmpPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, srcInfo.Mode().Perm()) + if err != nil { + return fmt.Errorf("failed to create temp file: %w", err) + } + + // Copy content + if _, err := io.Copy(dst, src); err != nil { + dst.Close() + return fmt.Errorf("failed to copy backup: %w", err) + } + + // Close before rename + if err := dst.Close(); err != nil { + return fmt.Errorf("failed to close temp file: %w", err) + } + + // Atomic rename + if err := os.Rename(tmpPath, destPath); err != nil { + return fmt.Errorf("failed to restore backup: %w", err) + } + + // Success + tmpPath = "" + return nil +} + +// extractBinaryFromTarGz extracts the named binary from a tar.gz file. +func extractBinaryFromTarGz(tarPath, binaryName, destPath string) error { + // Open tarball + f, err := os.Open(tarPath) + if err != nil { + return err + } + defer f.Close() + + // Create gzip reader + gr, err := gzip.NewReader(f) + if err != nil { + return fmt.Errorf("failed to create gzip reader: %w", err) + } + defer gr.Close() + + // Create tar reader + tr := tar.NewReader(gr) + + // Find and extract the named binary + for { + header, err := tr.Next() + if err == io.EOF { + return fmt.Errorf("%s binary not found in tarball", binaryName) + } + if err != nil { + return fmt.Errorf("failed to read tar header: %w", err) + } + + // Look for the named file (might be at root or in a subdirectory) + baseName := filepath.Base(header.Name) + if baseName == binaryName && header.Typeflag == tar.TypeReg { + // Reject if declared size exceeds maximum + if header.Size > MaxBinarySize { + return fmt.Errorf("binary size %d exceeds maximum allowed size %d", header.Size, MaxBinarySize) + } + + // Create destination file + dst, err := os.OpenFile(destPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode)) + if err != nil { + return fmt.Errorf("failed to create destination file: %w", err) + } + defer dst.Close() + + // Copy content with size limit as safety net (header.Size could lie) + limited := io.LimitReader(tr, MaxBinarySize+1) + n, err := io.Copy(dst, limited) + if err != nil { + return fmt.Errorf("failed to extract binary: %w", err) + } + if n > MaxBinarySize { + return fmt.Errorf("binary size %d exceeds maximum allowed size %d", n, MaxBinarySize) + } + + return nil + } + } +} + +// randomHex generates a random hex string of the given length. +func randomHex(n int) (string, error) { + bytes := make([]byte, n) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + return hex.EncodeToString(bytes), nil +} diff --git a/internal/update/binary_test.go b/internal/update/binary_test.go new file mode 100644 index 0000000..212e1c9 --- /dev/null +++ b/internal/update/binary_test.go @@ -0,0 +1,390 @@ +package update + +import ( + "archive/tar" + "compress/gzip" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBackupBinary_CopiesFile(t *testing.T) { + tmpDir := t.TempDir() + + // Create a source binary + srcPath := filepath.Join(tmpDir, "hostlink") + srcContent := []byte("binary content v1.0.0") + err := os.WriteFile(srcPath, srcContent, 0755) + require.NoError(t, err) + + // Backup to a subdirectory + backupDir := filepath.Join(tmpDir, "backup") + + err = BackupBinary(srcPath, backupDir) + require.NoError(t, err) + + // Verify backup exists with same content + backupPath := filepath.Join(backupDir, "hostlink") + backupContent, err := os.ReadFile(backupPath) + require.NoError(t, err) + assert.Equal(t, srcContent, backupContent) +} + +func TestBackupBinary_CreatesBackupDirectory(t *testing.T) { + tmpDir := t.TempDir() + + srcPath := filepath.Join(tmpDir, "hostlink") + err := os.WriteFile(srcPath, []byte("content"), 0755) + require.NoError(t, err) + + // Nested backup directory that doesn't exist + backupDir := filepath.Join(tmpDir, "deep", "nested", "backup") + + err = BackupBinary(srcPath, backupDir) + require.NoError(t, err) + + // Verify directory was created + info, err := os.Stat(backupDir) + require.NoError(t, err) + assert.True(t, info.IsDir()) +} + +func TestBackupBinary_OverwritesExistingBackup(t *testing.T) { + tmpDir := t.TempDir() + backupDir := filepath.Join(tmpDir, "backup") + err := os.MkdirAll(backupDir, 0755) + require.NoError(t, err) + + // Create existing backup + backupPath := filepath.Join(backupDir, "hostlink") + err = os.WriteFile(backupPath, []byte("old backup"), 0755) + require.NoError(t, err) + + // Create new source + srcPath := filepath.Join(tmpDir, "hostlink") + newContent := []byte("new binary content") + err = os.WriteFile(srcPath, newContent, 0755) + require.NoError(t, err) + + err = BackupBinary(srcPath, backupDir) + require.NoError(t, err) + + // Verify backup was overwritten + backupContent, err := os.ReadFile(backupPath) + require.NoError(t, err) + assert.Equal(t, newContent, backupContent) +} + +func TestBackupBinary_PreservesPermissions(t *testing.T) { + tmpDir := t.TempDir() + + srcPath := filepath.Join(tmpDir, "hostlink") + err := os.WriteFile(srcPath, []byte("content"), 0755) + require.NoError(t, err) + + backupDir := filepath.Join(tmpDir, "backup") + + err = BackupBinary(srcPath, backupDir) + require.NoError(t, err) + + // Verify permissions + backupPath := filepath.Join(backupDir, "hostlink") + info, err := os.Stat(backupPath) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0755), info.Mode().Perm()) +} + +func TestBackupBinary_ReturnsErrorIfSourceMissing(t *testing.T) { + tmpDir := t.TempDir() + + srcPath := filepath.Join(tmpDir, "nonexistent") + backupDir := filepath.Join(tmpDir, "backup") + + err := BackupBinary(srcPath, backupDir) + assert.Error(t, err) + assert.True(t, os.IsNotExist(err) || os.IsNotExist(unwrapErr(err))) +} + +func TestInstallBinary_ExtractsAndInstalls(t *testing.T) { + tmpDir := t.TempDir() + + // Create a tarball with a binary + tarPath := filepath.Join(tmpDir, "hostlink.tar.gz") + binaryContent := []byte("new binary v2.0.0") + createTestTarGz(t, tarPath, "hostlink", binaryContent, 0755) + + destPath := filepath.Join(tmpDir, "installed", "hostlink") + + err := InstallBinary(tarPath, destPath) + require.NoError(t, err) + + // Verify installed binary + installedContent, err := os.ReadFile(destPath) + require.NoError(t, err) + assert.Equal(t, binaryContent, installedContent) +} + +func TestInstallBinary_SetsPermissions(t *testing.T) { + tmpDir := t.TempDir() + + tarPath := filepath.Join(tmpDir, "hostlink.tar.gz") + createTestTarGz(t, tarPath, "hostlink", []byte("binary"), 0755) + + destPath := filepath.Join(tmpDir, "hostlink") + + err := InstallBinary(tarPath, destPath) + require.NoError(t, err) + + info, err := os.Stat(destPath) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0755), info.Mode().Perm()) +} + +func TestInstallBinary_AtomicRename(t *testing.T) { + tmpDir := t.TempDir() + + // Create existing binary + destPath := filepath.Join(tmpDir, "hostlink") + err := os.WriteFile(destPath, []byte("old binary"), 0755) + require.NoError(t, err) + + // Create tarball with new binary + tarPath := filepath.Join(tmpDir, "hostlink.tar.gz") + newContent := []byte("new binary") + createTestTarGz(t, tarPath, "hostlink", newContent, 0755) + + err = InstallBinary(tarPath, destPath) + require.NoError(t, err) + + // Verify new binary is in place + installedContent, err := os.ReadFile(destPath) + require.NoError(t, err) + assert.Equal(t, newContent, installedContent) + + // Verify no temp files left behind + entries, err := os.ReadDir(tmpDir) + require.NoError(t, err) + for _, entry := range entries { + assert.NotContains(t, entry.Name(), ".tmp.", "temp file should be cleaned up") + } +} + +func TestInstallBinary_CleansUpTempOnError(t *testing.T) { + tmpDir := t.TempDir() + + // Create an invalid tarball + tarPath := filepath.Join(tmpDir, "invalid.tar.gz") + err := os.WriteFile(tarPath, []byte("not a tarball"), 0644) + require.NoError(t, err) + + destPath := filepath.Join(tmpDir, "hostlink") + + err = InstallBinary(tarPath, destPath) + assert.Error(t, err) + + // Verify no temp files left behind + entries, err := os.ReadDir(tmpDir) + require.NoError(t, err) + for _, entry := range entries { + assert.NotContains(t, entry.Name(), ".tmp.", "temp file should be cleaned up on error") + } +} + +func TestInstallBinary_CreatesDestinationDirectory(t *testing.T) { + tmpDir := t.TempDir() + + tarPath := filepath.Join(tmpDir, "hostlink.tar.gz") + createTestTarGz(t, tarPath, "hostlink", []byte("binary"), 0755) + + // Nested destination that doesn't exist + destPath := filepath.Join(tmpDir, "usr", "bin", "hostlink") + + err := InstallBinary(tarPath, destPath) + require.NoError(t, err) + + // Verify file exists + _, err = os.Stat(destPath) + require.NoError(t, err) +} + +func TestRestoreBackup_RestoresFile(t *testing.T) { + tmpDir := t.TempDir() + backupDir := filepath.Join(tmpDir, "backup") + err := os.MkdirAll(backupDir, 0755) + require.NoError(t, err) + + // Create backup + backupContent := []byte("backup binary v1.0.0") + backupPath := filepath.Join(backupDir, "hostlink") + err = os.WriteFile(backupPath, backupContent, 0755) + require.NoError(t, err) + + // Create destination with different content + destPath := filepath.Join(tmpDir, "hostlink") + err = os.WriteFile(destPath, []byte("broken binary"), 0755) + require.NoError(t, err) + + err = RestoreBackup(backupDir, destPath) + require.NoError(t, err) + + // Verify restored content + restoredContent, err := os.ReadFile(destPath) + require.NoError(t, err) + assert.Equal(t, backupContent, restoredContent) +} + +func TestRestoreBackup_ReturnsErrorIfBackupMissing(t *testing.T) { + tmpDir := t.TempDir() + + backupDir := filepath.Join(tmpDir, "backup") // Doesn't exist + destPath := filepath.Join(tmpDir, "hostlink") + + err := RestoreBackup(backupDir, destPath) + assert.Error(t, err) +} + +func TestRestoreBackup_AtomicReplace(t *testing.T) { + tmpDir := t.TempDir() + backupDir := filepath.Join(tmpDir, "backup") + err := os.MkdirAll(backupDir, 0755) + require.NoError(t, err) + + backupPath := filepath.Join(backupDir, "hostlink") + err = os.WriteFile(backupPath, []byte("backup"), 0755) + require.NoError(t, err) + + destPath := filepath.Join(tmpDir, "hostlink") + err = os.WriteFile(destPath, []byte("current"), 0755) + require.NoError(t, err) + + err = RestoreBackup(backupDir, destPath) + require.NoError(t, err) + + // Verify no temp files left + entries, err := os.ReadDir(tmpDir) + require.NoError(t, err) + for _, entry := range entries { + if entry.Name() != "backup" && entry.Name() != "hostlink" { + t.Errorf("unexpected file left behind: %s", entry.Name()) + } + } +} + +func TestInstallBinary_RejectsBinaryExceedingMaxSize(t *testing.T) { + tmpDir := t.TempDir() + + // Create a tarball with header.Size just over MaxBinarySize (100MB + 1 byte) + tarPath := filepath.Join(tmpDir, "hostlink.tar.gz") + var oversizeBytes int64 = MaxBinarySize + 1 + + f, err := os.Create(tarPath) + require.NoError(t, err) + + gw := gzip.NewWriter(f) + tw := tar.NewWriter(gw) + + // Write header claiming a size of 100MB+1 + err = tw.WriteHeader(&tar.Header{ + Name: "hostlink", + Mode: 0755, + Size: oversizeBytes, + Typeflag: tar.TypeReg, + }) + require.NoError(t, err) + + // Write just a small amount of actual data (header lies about size, but + // the check should reject based on header.Size before copying) + smallData := make([]byte, 1024) + _, err = tw.Write(smallData) + // tar writer may error because we declared more bytes than we wrote - that's fine + // Close writers regardless + tw.Close() + gw.Close() + f.Close() + + destPath := filepath.Join(tmpDir, "hostlink") + + err = InstallBinary(tarPath, destPath) + assert.Error(t, err) + assert.Contains(t, err.Error(), "exceeds maximum allowed size") +} + +func TestInstallUpdaterBinary_ExtractsHostlinkUpdater(t *testing.T) { + tmpDir := t.TempDir() + + // Create a tarball containing "hostlink-updater" + tarPath := filepath.Join(tmpDir, "updater.tar.gz") + binaryContent := []byte("updater binary v2.0.0") + createTestTarGz(t, tarPath, "hostlink-updater", binaryContent, 0755) + + destPath := filepath.Join(tmpDir, "installed", "hostlink-updater") + + err := InstallUpdaterBinary(tarPath, destPath) + require.NoError(t, err) + + // Verify installed binary content + installedContent, err := os.ReadFile(destPath) + require.NoError(t, err) + assert.Equal(t, binaryContent, installedContent) + + // Verify permissions + info, err := os.Stat(destPath) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0755), info.Mode().Perm()) +} + +func TestInstallUpdaterBinary_IgnoresHostlinkBinary(t *testing.T) { + tmpDir := t.TempDir() + + // Create a tarball containing "hostlink" (NOT "hostlink-updater") + tarPath := filepath.Join(tmpDir, "updater.tar.gz") + createTestTarGz(t, tarPath, "hostlink", []byte("wrong binary"), 0755) + + destPath := filepath.Join(tmpDir, "hostlink-updater") + + err := InstallUpdaterBinary(tarPath, destPath) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found in tarball") +} + +// Helper function to unwrap errors +func unwrapErr(err error) error { + for { + unwrapped := err + if u, ok := err.(interface{ Unwrap() error }); ok { + unwrapped = u.Unwrap() + } + if unwrapped == err { + return err + } + err = unwrapped + } +} + +// createTestTarGz creates a tar.gz file containing a single file +func createTestTarGz(t *testing.T, tarPath, filename string, content []byte, mode os.FileMode) { + t.Helper() + + f, err := os.Create(tarPath) + require.NoError(t, err) + defer f.Close() + + gw := gzip.NewWriter(f) + defer gw.Close() + + tw := tar.NewWriter(gw) + defer tw.Close() + + err = tw.WriteHeader(&tar.Header{ + Name: filename, + Mode: int64(mode), + Size: int64(len(content)), + }) + require.NoError(t, err) + + _, err = tw.Write(content) + require.NoError(t, err) +} diff --git a/internal/update/dirs.go b/internal/update/dirs.go new file mode 100644 index 0000000..db8129f --- /dev/null +++ b/internal/update/dirs.go @@ -0,0 +1,69 @@ +package update + +import ( + "fmt" + "os" + "path/filepath" +) + +const ( + // DefaultBaseDir is the default base directory for update files. + DefaultBaseDir = "/var/lib/hostlink/updates" + + // DirPermissions is the permission mode for update directories (owner rwx only). + DirPermissions = 0700 +) + +// Paths holds all the paths used by the update system. +type Paths struct { + BaseDir string // /var/lib/hostlink/updates + BackupDir string // /var/lib/hostlink/updates/backup + StagingDir string // /var/lib/hostlink/updates/staging + UpdaterDir string // /var/lib/hostlink/updates/updater + LockFile string // /var/lib/hostlink/updates/update.lock + StateFile string // /var/lib/hostlink/updates/state.json +} + +// DefaultPaths returns the default paths for the update system. +func DefaultPaths() Paths { + return NewPaths(DefaultBaseDir) +} + +// NewPaths creates a Paths struct with the given base directory. +func NewPaths(baseDir string) Paths { + return Paths{ + BaseDir: baseDir, + BackupDir: filepath.Join(baseDir, "backup"), + StagingDir: filepath.Join(baseDir, "staging"), + UpdaterDir: filepath.Join(baseDir, "updater"), + LockFile: filepath.Join(baseDir, "update.lock"), + StateFile: filepath.Join(baseDir, "state.json"), + } +} + +// InitDirectories creates all required directories for the update system +// with correct permissions (0700). This function is idempotent. +func InitDirectories(baseDir string) error { + paths := NewPaths(baseDir) + + dirs := []string{ + paths.BaseDir, + paths.BackupDir, + paths.StagingDir, + paths.UpdaterDir, + } + + for _, dir := range dirs { + if err := os.MkdirAll(dir, DirPermissions); err != nil { + return fmt.Errorf("failed to create directory %s: %w", dir, err) + } + + // Ensure permissions are correct even if directory already exists + // (MkdirAll doesn't change permissions of existing directories) + if err := os.Chmod(dir, DirPermissions); err != nil { + return fmt.Errorf("failed to set permissions on %s: %w", dir, err) + } + } + + return nil +} diff --git a/internal/update/dirs_test.go b/internal/update/dirs_test.go new file mode 100644 index 0000000..c834a27 --- /dev/null +++ b/internal/update/dirs_test.go @@ -0,0 +1,127 @@ +package update + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInitDirectories_CreatesAllDirs(t *testing.T) { + tmpDir := t.TempDir() + basePath := filepath.Join(tmpDir, "updates") + + err := InitDirectories(basePath) + require.NoError(t, err) + + // Check all required directories exist + expectedDirs := []string{ + basePath, + filepath.Join(basePath, "backup"), + filepath.Join(basePath, "staging"), + filepath.Join(basePath, "updater"), + } + + for _, dir := range expectedDirs { + info, err := os.Stat(dir) + require.NoError(t, err, "directory should exist: %s", dir) + assert.True(t, info.IsDir(), "%s should be a directory", dir) + } +} + +func TestInitDirectories_CorrectPermissions(t *testing.T) { + tmpDir := t.TempDir() + basePath := filepath.Join(tmpDir, "updates") + + err := InitDirectories(basePath) + require.NoError(t, err) + + // Check permissions are 0700 + dirs := []string{ + basePath, + filepath.Join(basePath, "backup"), + filepath.Join(basePath, "staging"), + filepath.Join(basePath, "updater"), + } + + for _, dir := range dirs { + info, err := os.Stat(dir) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0700), info.Mode().Perm(), "directory %s should have 0700 permissions", dir) + } +} + +func TestInitDirectories_Idempotent(t *testing.T) { + tmpDir := t.TempDir() + basePath := filepath.Join(tmpDir, "updates") + + // Call twice - should not error + err := InitDirectories(basePath) + require.NoError(t, err) + + err = InitDirectories(basePath) + require.NoError(t, err) + + // Directories should still exist with correct permissions + info, err := os.Stat(basePath) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0700), info.Mode().Perm()) +} + +func TestInitDirectories_CreatesNestedPath(t *testing.T) { + tmpDir := t.TempDir() + basePath := filepath.Join(tmpDir, "var", "lib", "hostlink", "updates") + + err := InitDirectories(basePath) + require.NoError(t, err) + + // All directories should exist + _, err = os.Stat(filepath.Join(basePath, "backup")) + assert.NoError(t, err) +} + +func TestInitDirectories_PreservesExistingFiles(t *testing.T) { + tmpDir := t.TempDir() + basePath := filepath.Join(tmpDir, "updates") + + // Create base dir and a file inside + err := os.MkdirAll(basePath, 0700) + require.NoError(t, err) + + testFile := filepath.Join(basePath, "state.json") + err = os.WriteFile(testFile, []byte(`{"test": true}`), 0600) + require.NoError(t, err) + + // Init should not remove existing files + err = InitDirectories(basePath) + require.NoError(t, err) + + // File should still exist + content, err := os.ReadFile(testFile) + require.NoError(t, err) + assert.Equal(t, `{"test": true}`, string(content)) +} + +func TestDefaultPaths(t *testing.T) { + paths := DefaultPaths() + + assert.Equal(t, "/var/lib/hostlink/updates", paths.BaseDir) + assert.Equal(t, "/var/lib/hostlink/updates/backup", paths.BackupDir) + assert.Equal(t, "/var/lib/hostlink/updates/staging", paths.StagingDir) + assert.Equal(t, "/var/lib/hostlink/updates/updater", paths.UpdaterDir) + assert.Equal(t, "/var/lib/hostlink/updates/update.lock", paths.LockFile) + assert.Equal(t, "/var/lib/hostlink/updates/state.json", paths.StateFile) +} + +func TestNewPaths(t *testing.T) { + paths := NewPaths("/custom/path") + + assert.Equal(t, "/custom/path", paths.BaseDir) + assert.Equal(t, "/custom/path/backup", paths.BackupDir) + assert.Equal(t, "/custom/path/staging", paths.StagingDir) + assert.Equal(t, "/custom/path/updater", paths.UpdaterDir) + assert.Equal(t, "/custom/path/update.lock", paths.LockFile) + assert.Equal(t, "/custom/path/state.json", paths.StateFile) +} diff --git a/internal/update/health.go b/internal/update/health.go new file mode 100644 index 0000000..982d05e --- /dev/null +++ b/internal/update/health.go @@ -0,0 +1,147 @@ +package update + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "time" +) + +var ( + // ErrHealthCheckFailed is returned when health check fails after all retries. + ErrHealthCheckFailed = errors.New("health check failed after retries") + // ErrVersionMismatch is returned when the running version doesn't match expected. + ErrVersionMismatch = errors.New("version mismatch") +) + +const ( + // DefaultHealthRetries is the default number of health check retries. + DefaultHealthRetries = 5 + // DefaultHealthInterval is the default interval between health checks. + DefaultHealthInterval = 5 * time.Second + // DefaultInitialWait is the default wait before first health check. + DefaultInitialWait = 5 * time.Second +) + +// HealthResponse represents the response from the health endpoint. +type HealthResponse struct { + Ok bool `json:"ok"` + Version string `json:"version"` +} + +// HealthConfig configures the HealthChecker. +type HealthConfig struct { + URL string // Health check URL (e.g., http://localhost:8080/health) + TargetVersion string // Expected version after update + MaxRetries int // Max number of retries (default: 5) + RetryInterval time.Duration // Time between retries (default: 5s) + InitialWait time.Duration // Initial wait before first check (default: 5s) + SleepFunc func(time.Duration) // For testing + HTTPClient *http.Client // Optional custom HTTP client +} + +// HealthChecker verifies that the service is healthy after an update. +type HealthChecker struct { + config HealthConfig + client *http.Client +} + +// NewHealthChecker creates a new HealthChecker with the given configuration. +func NewHealthChecker(cfg HealthConfig) *HealthChecker { + // Apply defaults + if cfg.MaxRetries == 0 { + cfg.MaxRetries = DefaultHealthRetries + } + if cfg.RetryInterval == 0 { + cfg.RetryInterval = DefaultHealthInterval + } + if cfg.InitialWait == 0 { + cfg.InitialWait = DefaultInitialWait + } + if cfg.SleepFunc == nil { + cfg.SleepFunc = time.Sleep + } + + client := cfg.HTTPClient + if client == nil { + client = &http.Client{ + Timeout: 10 * time.Second, + } + } + + return &HealthChecker{ + config: cfg, + client: client, + } +} + +// WaitForHealth waits for the service to be healthy with the expected version. +// It performs an initial wait, then retries up to MaxRetries times. +func (h *HealthChecker) WaitForHealth(ctx context.Context) error { + // Initial wait before first check + if h.config.InitialWait > 0 { + h.config.SleepFunc(h.config.InitialWait) + if ctx.Err() != nil { + return ctx.Err() + } + } + + var lastErr error + // Initial attempt + retries + totalAttempts := h.config.MaxRetries + 1 + + for attempt := 0; attempt < totalAttempts; attempt++ { + if ctx.Err() != nil { + return ctx.Err() + } + + healthy, version, err := h.checkHealth(ctx) + switch { + case err != nil: + lastErr = err + case !healthy: + lastErr = errors.New("health check returned ok: false") + case version != h.config.TargetVersion: + lastErr = fmt.Errorf("%w: expected %s, got %s", ErrVersionMismatch, h.config.TargetVersion, version) + default: + return nil + } + + if attempt < totalAttempts-1 { + h.config.SleepFunc(h.config.RetryInterval) + if ctx.Err() != nil { + return ctx.Err() + } + } + } + + return fmt.Errorf("%w: %v", ErrHealthCheckFailed, lastErr) +} + +// checkHealth performs a single health check request. +// Returns (healthy, version, error). +func (h *HealthChecker) checkHealth(ctx context.Context) (bool, string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, h.config.URL, nil) + if err != nil { + return false, "", fmt.Errorf("failed to create request: %w", err) + } + + resp, err := h.client.Do(req) + if err != nil { + return false, "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return false, "", fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + var healthResp HealthResponse + if err := json.NewDecoder(resp.Body).Decode(&healthResp); err != nil { + return false, "", fmt.Errorf("failed to decode response: %w", err) + } + + return healthResp.Ok, healthResp.Version, nil +} diff --git a/internal/update/health_test.go b/internal/update/health_test.go new file mode 100644 index 0000000..e8bb412 --- /dev/null +++ b/internal/update/health_test.go @@ -0,0 +1,264 @@ +package update + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHealthChecker_WaitForHealth_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(HealthResponse{Ok: true, Version: "v1.0.0"}) + })) + defer server.Close() + + hc := NewHealthChecker(HealthConfig{ + URL: server.URL, + TargetVersion: "v1.0.0", + MaxRetries: 5, + RetryInterval: 10 * time.Millisecond, + InitialWait: 0, + SleepFunc: func(d time.Duration) {}, + }) + + err := hc.WaitForHealth(context.Background()) + require.NoError(t, err) +} + +func TestHealthChecker_WaitForHealth_RetriesOnHttpError(t *testing.T) { + var attempts atomic.Int32 + + // Server returns 503 initially, then succeeds + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count := attempts.Add(1) + if count < 3 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(HealthResponse{Ok: true, Version: "v1.0.0"}) + })) + defer server.Close() + + hc := NewHealthChecker(HealthConfig{ + URL: server.URL, + TargetVersion: "v1.0.0", + MaxRetries: 5, + RetryInterval: 10 * time.Millisecond, + InitialWait: 0, + SleepFunc: func(d time.Duration) {}, + }) + + err := hc.WaitForHealth(context.Background()) + require.NoError(t, err) + assert.Equal(t, int32(3), attempts.Load()) +} + +func TestHealthChecker_WaitForHealth_RetriesOnOkFalse(t *testing.T) { + var attempts atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count := attempts.Add(1) + w.Header().Set("Content-Type", "application/json") + if count < 3 { + json.NewEncoder(w).Encode(HealthResponse{Ok: false, Version: "v1.0.0"}) + } else { + json.NewEncoder(w).Encode(HealthResponse{Ok: true, Version: "v1.0.0"}) + } + })) + defer server.Close() + + hc := NewHealthChecker(HealthConfig{ + URL: server.URL, + TargetVersion: "v1.0.0", + MaxRetries: 5, + RetryInterval: 10 * time.Millisecond, + InitialWait: 0, + SleepFunc: func(d time.Duration) {}, + }) + + err := hc.WaitForHealth(context.Background()) + require.NoError(t, err) + assert.Equal(t, int32(3), attempts.Load()) +} + +func TestHealthChecker_WaitForHealth_FailsAfterMaxRetries(t *testing.T) { + var attempts atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts.Add(1) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(HealthResponse{Ok: false, Version: "v1.0.0"}) + })) + defer server.Close() + + hc := NewHealthChecker(HealthConfig{ + URL: server.URL, + TargetVersion: "v1.0.0", + MaxRetries: 3, + RetryInterval: 10 * time.Millisecond, + InitialWait: 0, + SleepFunc: func(d time.Duration) {}, + }) + + err := hc.WaitForHealth(context.Background()) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrHealthCheckFailed) + // Should have tried 4 times (initial + 3 retries) + assert.Equal(t, int32(4), attempts.Load()) +} + +func TestHealthChecker_WaitForHealth_RetriesOnVersionMismatch(t *testing.T) { + var attempts atomic.Int32 + + // Server returns old version for 2 attempts, then correct version + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count := attempts.Add(1) + w.Header().Set("Content-Type", "application/json") + if count < 3 { + json.NewEncoder(w).Encode(HealthResponse{Ok: true, Version: "v1.0.0"}) + } else { + json.NewEncoder(w).Encode(HealthResponse{Ok: true, Version: "v2.0.0"}) + } + })) + defer server.Close() + + hc := NewHealthChecker(HealthConfig{ + URL: server.URL, + TargetVersion: "v2.0.0", + MaxRetries: 5, + RetryInterval: 10 * time.Millisecond, + InitialWait: 0, + SleepFunc: func(d time.Duration) {}, + }) + + err := hc.WaitForHealth(context.Background()) + require.NoError(t, err) + assert.Equal(t, int32(3), attempts.Load()) +} + +func TestHealthChecker_WaitForHealth_FailsOnVersionMismatch(t *testing.T) { + var attempts atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts.Add(1) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(HealthResponse{Ok: true, Version: "v1.0.0"}) + })) + defer server.Close() + + hc := NewHealthChecker(HealthConfig{ + URL: server.URL, + TargetVersion: "v2.0.0", // Different from what server returns + MaxRetries: 3, + RetryInterval: 10 * time.Millisecond, + InitialWait: 0, + SleepFunc: func(d time.Duration) {}, + }) + + err := hc.WaitForHealth(context.Background()) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrHealthCheckFailed) + // Should have exhausted all retries (initial + 3 retries = 4 attempts) + assert.Equal(t, int32(4), attempts.Load()) +} + +func TestHealthChecker_WaitForHealth_RespectsContext(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(HealthResponse{Ok: false, Version: "v1.0.0"}) + })) + defer server.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + hc := NewHealthChecker(HealthConfig{ + URL: server.URL, + TargetVersion: "v1.0.0", + MaxRetries: 10, + RetryInterval: 10 * time.Millisecond, + InitialWait: 0, + SleepFunc: func(d time.Duration) { + cancel() // Cancel context during sleep + }, + }) + + err := hc.WaitForHealth(ctx) + assert.ErrorIs(t, err, context.Canceled) +} + +func TestHealthChecker_WaitForHealth_InitialWait(t *testing.T) { + var sleepDurations []time.Duration + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(HealthResponse{Ok: true, Version: "v1.0.0"}) + })) + defer server.Close() + + hc := NewHealthChecker(HealthConfig{ + URL: server.URL, + TargetVersion: "v1.0.0", + MaxRetries: 5, + RetryInterval: 100 * time.Millisecond, + InitialWait: 500 * time.Millisecond, + SleepFunc: func(d time.Duration) { + sleepDurations = append(sleepDurations, d) + }, + }) + + err := hc.WaitForHealth(context.Background()) + require.NoError(t, err) + + // First sleep should be the initial wait + require.GreaterOrEqual(t, len(sleepDurations), 1) + assert.Equal(t, 500*time.Millisecond, sleepDurations[0]) +} + +func TestHealthChecker_DefaultConfig(t *testing.T) { + hc := NewHealthChecker(HealthConfig{ + URL: "http://localhost:8080/health", + TargetVersion: "v1.0.0", + }) + + // Verify defaults + assert.Equal(t, 5, hc.config.MaxRetries) + assert.Equal(t, 5*time.Second, hc.config.RetryInterval) + assert.Equal(t, 5*time.Second, hc.config.InitialWait) +} + +func TestHealthChecker_WaitForHealth_HandlesInvalidJSON(t *testing.T) { + var attempts atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count := attempts.Add(1) + w.Header().Set("Content-Type", "application/json") + if count < 2 { + w.Write([]byte("not json")) + } else { + json.NewEncoder(w).Encode(HealthResponse{Ok: true, Version: "v1.0.0"}) + } + })) + defer server.Close() + + hc := NewHealthChecker(HealthConfig{ + URL: server.URL, + TargetVersion: "v1.0.0", + MaxRetries: 5, + RetryInterval: 10 * time.Millisecond, + InitialWait: 0, + SleepFunc: func(d time.Duration) {}, + }) + + err := hc.WaitForHealth(context.Background()) + require.NoError(t, err) + assert.Equal(t, int32(2), attempts.Load()) +} diff --git a/internal/update/lock.go b/internal/update/lock.go new file mode 100644 index 0000000..0364db9 --- /dev/null +++ b/internal/update/lock.go @@ -0,0 +1,267 @@ +package update + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "time" +) + +var ( + // ErrLockBusy is returned when the lock is held by another live process. + ErrLockBusy = errors.New("lock held by another process") + + // ErrLockInvalid is returned when the lock file is corrupted or has invalid format. + ErrLockInvalid = errors.New("lock file corrupted or invalid format") + + // ErrLockNotOwned is returned when trying to release a lock not owned by this process. + ErrLockNotOwned = errors.New("cannot release lock not owned by this process") + + // ErrLockAcquireFailed is returned when lock acquisition fails after all retries. + ErrLockAcquireFailed = errors.New("could not acquire lock after retries") +) + +// LockData represents the JSON structure stored in the lock file. +type LockData struct { + PID int `json:"pid"` + ExpireAt int64 `json:"expire_at"` // Unix timestamp when lock expires + OwnerStartTime int64 `json:"owner_start_time"` // Process start time in clock ticks +} + +// LockManager manages a file-based lock for coordinating updates. +type LockManager struct { + lockPath string + sleepFunc func(time.Duration) +} + +// LockConfig holds configuration for creating a LockManager. +type LockConfig struct { + LockPath string + SleepFunc func(time.Duration) // Optional: for testing +} + +// NewLockManager creates a new LockManager with the given configuration. +func NewLockManager(cfg LockConfig) *LockManager { + sleepFunc := cfg.SleepFunc + if sleepFunc == nil { + sleepFunc = time.Sleep + } + return &LockManager{ + lockPath: cfg.LockPath, + sleepFunc: sleepFunc, + } +} + +// TryLock attempts to acquire the lock with the given expiration duration. +// Returns nil on success, ErrLockBusy if held by another live process, +// or another error if something else fails. +func (l *LockManager) TryLock(expiration time.Duration) error { + // Ensure parent directory exists + dir := filepath.Dir(l.lockPath) + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("failed to create lock directory: %w", err) + } + + // Get current process info + pid := os.Getpid() + startTime, err := getCurrentProcessStartTime() + if err != nil { + return fmt.Errorf("failed to get process start time: %w", err) + } + + // Create lock data + lockData := LockData{ + PID: pid, + ExpireAt: time.Now().Add(expiration).Unix(), + OwnerStartTime: startTime, + } + + content, err := json.Marshal(lockData) + if err != nil { + return fmt.Errorf("failed to marshal lock data: %w", err) + } + + // Generate unique temp filename + randSuffix, err := randomString(8) + if err != nil { + return err + } + tmpFile := l.lockPath + "." + randSuffix + + // Write to temp file + if err := os.WriteFile(tmpFile, content, 0600); err != nil { + return fmt.Errorf("failed to write temp lock file: %w", err) + } + + // Try to hard link temp file to lock file (atomic operation) + err = os.Link(tmpFile, l.lockPath) + + // Always clean up temp file + os.Remove(tmpFile) + + if err == nil { + // Lock acquired successfully + return nil + } + + // Link failed - check if existing lock is stale + if !os.IsExist(err) { + return fmt.Errorf("failed to create lock: %w", err) + } + + // Read existing lock to check if it's stale + isStale, err := l.isLockStale() + if err != nil { + // Corrupted or unreadable lock - treat as stale + isStale = true + } + + if !isStale { + return ErrLockBusy + } + + // Lock is stale - remove it and try again + if err := os.Remove(l.lockPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove stale lock: %w", err) + } + + // Retry acquisition with new temp file + randSuffix, err = randomString(8) + if err != nil { + return err + } + tmpFile = l.lockPath + "." + randSuffix + if err := os.WriteFile(tmpFile, content, 0600); err != nil { + return fmt.Errorf("failed to write temp lock file: %w", err) + } + + err = os.Link(tmpFile, l.lockPath) + os.Remove(tmpFile) + + if err != nil { + if os.IsExist(err) { + // Another process grabbed it between our delete and link + return ErrLockBusy + } + return fmt.Errorf("failed to create lock: %w", err) + } + + return nil +} + +// TryLockWithRetry attempts to acquire the lock with retries. +// It will try up to 'retries' additional times (total attempts = retries + 1), +// waiting 'interval' between each attempt. +func (l *LockManager) TryLockWithRetry(expiration time.Duration, retries int, interval time.Duration) error { + var lastErr error + + for attempt := 0; attempt <= retries; attempt++ { + if attempt > 0 { + l.sleepFunc(interval) + } + + err := l.TryLock(expiration) + if err == nil { + return nil + } + + lastErr = err + if !errors.Is(err, ErrLockBusy) { + // Non-retryable error + return err + } + } + + return fmt.Errorf("%w: %v", ErrLockAcquireFailed, lastErr) +} + +// Unlock releases the lock if owned by the current process. +// Returns ErrLockNotOwned if the lock is held by another process. +// Returns nil if the lock file doesn't exist (idempotent). +func (l *LockManager) Unlock() error { + content, err := os.ReadFile(l.lockPath) + if err != nil { + if os.IsNotExist(err) { + // No lock to release - that's fine + return nil + } + return fmt.Errorf("failed to read lock file: %w", err) + } + + var lockData LockData + if err := json.Unmarshal(content, &lockData); err != nil { + // Corrupted lock file - just delete it + if err := os.Remove(l.lockPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove corrupted lock: %w", err) + } + return nil + } + + // Verify we own the lock + if lockData.PID != os.Getpid() { + return ErrLockNotOwned + } + + // Delete the lock file + if err := os.Remove(l.lockPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove lock: %w", err) + } + + return nil +} + +// isLockStale checks if the current lock file represents a stale lock. +// A lock is stale if: +// - It has expired (expire_at < now) +// - The owner process is dead +// - The owner PID was reused (start_time doesn't match) +func (l *LockManager) isLockStale() (bool, error) { + content, err := os.ReadFile(l.lockPath) + if err != nil { + return false, err + } + + var lockData LockData + if err := json.Unmarshal(content, &lockData); err != nil { + // Corrupted lock is considered stale + return true, ErrLockInvalid + } + + // Check if expired + if lockData.ExpireAt < time.Now().Unix() { + return true, nil + } + + // Check if owner process is alive + if !isProcessAlive(lockData.PID) { + return true, nil + } + + // Check for PID reuse by comparing start times + ownerStartTime, err := getProcessStartTime(lockData.PID) + if err != nil { + // Can't determine start time - process might have just died + return true, nil + } + + if ownerStartTime != lockData.OwnerStartTime { + // PID was reused by a different process + return true, nil + } + + return false, nil +} + +// randomString generates a random hex string of the given length. +// Returns an error if crypto/rand fails (rare but possible in constrained environments). +func randomString(length int) (string, error) { + bytes := make([]byte, length/2+1) + if _, err := rand.Read(bytes); err != nil { + return "", fmt.Errorf("failed to generate random string: %w", err) + } + return hex.EncodeToString(bytes)[:length], nil +} diff --git a/internal/update/lock_test.go b/internal/update/lock_test.go new file mode 100644 index 0000000..f58f0bf --- /dev/null +++ b/internal/update/lock_test.go @@ -0,0 +1,377 @@ +package update + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTryLock_Success_NoExistingLock(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "update.lock") + + lm := NewLockManager(LockConfig{LockPath: lockPath}) + + err := lm.TryLock(time.Hour) + require.NoError(t, err) + + // Verify lock file exists + _, err = os.Stat(lockPath) + assert.NoError(t, err, "lock file should exist") +} + +func TestTryLock_Success_CorrectFormat(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "update.lock") + + lm := NewLockManager(LockConfig{LockPath: lockPath}) + + err := lm.TryLock(time.Hour) + require.NoError(t, err) + + // Read and verify lock file content + content, err := os.ReadFile(lockPath) + require.NoError(t, err) + + var lockData LockData + err = json.Unmarshal(content, &lockData) + require.NoError(t, err) + + assert.Equal(t, os.Getpid(), lockData.PID, "PID should match current process") + assert.Greater(t, lockData.ExpireAt, time.Now().Unix(), "expire_at should be in the future") + assert.Greater(t, lockData.OwnerStartTime, int64(0), "owner_start_time should be positive") +} + +func TestTryLock_Fail_LiveProcess(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "update.lock") + + // Create a lock held by current process (simulating another instance) + startTime, err := getCurrentProcessStartTime() + require.NoError(t, err) + + lockData := LockData{ + PID: os.Getpid(), + ExpireAt: time.Now().Add(time.Hour).Unix(), + OwnerStartTime: startTime, + } + content, err := json.Marshal(lockData) + require.NoError(t, err) + err = os.WriteFile(lockPath, content, 0600) + require.NoError(t, err) + + lm := NewLockManager(LockConfig{LockPath: lockPath}) + + err = lm.TryLock(time.Hour) + assert.ErrorIs(t, err, ErrLockBusy) +} + +func TestTryLock_Success_ExpiredLock(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "update.lock") + + // Create an expired lock + startTime, err := getCurrentProcessStartTime() + require.NoError(t, err) + + lockData := LockData{ + PID: os.Getpid(), + ExpireAt: time.Now().Add(-time.Hour).Unix(), // Expired 1 hour ago + OwnerStartTime: startTime, + } + content, err := json.Marshal(lockData) + require.NoError(t, err) + err = os.WriteFile(lockPath, content, 0600) + require.NoError(t, err) + + lm := NewLockManager(LockConfig{LockPath: lockPath}) + + err = lm.TryLock(time.Hour) + assert.NoError(t, err, "should acquire lock when existing lock is expired") +} + +func TestTryLock_Success_DeadPID(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "update.lock") + + // Create a lock held by a non-existent process + lockData := LockData{ + PID: 99999999, // Very unlikely to exist + ExpireAt: time.Now().Add(time.Hour).Unix(), + OwnerStartTime: 12345, + } + content, err := json.Marshal(lockData) + require.NoError(t, err) + err = os.WriteFile(lockPath, content, 0600) + require.NoError(t, err) + + lm := NewLockManager(LockConfig{LockPath: lockPath}) + + err = lm.TryLock(time.Hour) + assert.NoError(t, err, "should acquire lock when owner PID is dead") +} + +func TestTryLock_Success_PIDReuse(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "update.lock") + + // Create a lock with current PID but wrong start time (simulating PID reuse) + lockData := LockData{ + PID: os.Getpid(), + ExpireAt: time.Now().Add(time.Hour).Unix(), + OwnerStartTime: 1, // Wrong start time - PID was reused + } + content, err := json.Marshal(lockData) + require.NoError(t, err) + err = os.WriteFile(lockPath, content, 0600) + require.NoError(t, err) + + lm := NewLockManager(LockConfig{LockPath: lockPath}) + + err = lm.TryLock(time.Hour) + assert.NoError(t, err, "should acquire lock when PID was reused (start_time mismatch)") +} + +func TestTryLock_Success_CorruptedLock(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "update.lock") + + // Create a corrupted lock file + err := os.WriteFile(lockPath, []byte("not valid json"), 0600) + require.NoError(t, err) + + lm := NewLockManager(LockConfig{LockPath: lockPath}) + + err = lm.TryLock(time.Hour) + assert.NoError(t, err, "should acquire lock when existing lock is corrupted") +} + +func TestTryLock_TempFileCleanup(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "update.lock") + + lm := NewLockManager(LockConfig{LockPath: lockPath}) + + err := lm.TryLock(time.Hour) + require.NoError(t, err) + + // Check for any leftover temp files + entries, err := os.ReadDir(tmpDir) + require.NoError(t, err) + + for _, entry := range entries { + assert.Equal(t, "update.lock", entry.Name(), "only update.lock should remain, found: %s", entry.Name()) + } +} + +func TestTryLock_CreateParentDir(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "nested", "dir", "update.lock") + + lm := NewLockManager(LockConfig{LockPath: lockPath}) + + err := lm.TryLock(time.Hour) + require.NoError(t, err) + + // Verify lock file exists + _, err = os.Stat(lockPath) + assert.NoError(t, err, "lock file should exist in nested directory") +} + +// TryLockWithRetry tests + +func TestTryLockWithRetry_Success_FirstAttempt(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "update.lock") + + sleepCalled := 0 + mockSleep := func(d time.Duration) { + sleepCalled++ + } + + lm := NewLockManager(LockConfig{ + LockPath: lockPath, + SleepFunc: mockSleep, + }) + + err := lm.TryLockWithRetry(time.Hour, 3, time.Second) + require.NoError(t, err) + assert.Equal(t, 0, sleepCalled, "should not sleep when lock acquired on first attempt") +} + +func TestTryLockWithRetry_Success_ThirdAttempt(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "update.lock") + + attempts := 0 + sleepCalled := 0 + mockSleep := func(d time.Duration) { + sleepCalled++ + // After 2 failed attempts, release the lock + if sleepCalled == 2 { + os.Remove(lockPath) + } + } + + // Create initial lock held by "another" process (dead PID won't work, use current PID) + startTime, _ := getCurrentProcessStartTime() + lockData := LockData{ + PID: os.Getpid(), + ExpireAt: time.Now().Add(time.Hour).Unix(), + OwnerStartTime: startTime, + } + content, _ := json.Marshal(lockData) + os.WriteFile(lockPath, content, 0600) + + lm := NewLockManager(LockConfig{ + LockPath: lockPath, + SleepFunc: mockSleep, + }) + + // Track actual attempts via the mock + _ = attempts + + err := lm.TryLockWithRetry(time.Hour, 5, time.Second) + require.NoError(t, err) + assert.Equal(t, 2, sleepCalled, "should sleep between retries until lock acquired") +} + +func TestTryLockWithRetry_Fail_MaxRetries(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "update.lock") + + sleepCalled := 0 + mockSleep := func(d time.Duration) { + sleepCalled++ + } + + // Create a lock that will never be released (held by current process) + startTime, _ := getCurrentProcessStartTime() + lockData := LockData{ + PID: os.Getpid(), + ExpireAt: time.Now().Add(time.Hour).Unix(), + OwnerStartTime: startTime, + } + content, _ := json.Marshal(lockData) + os.WriteFile(lockPath, content, 0600) + + lm := NewLockManager(LockConfig{ + LockPath: lockPath, + SleepFunc: mockSleep, + }) + + err := lm.TryLockWithRetry(time.Hour, 3, time.Second) + assert.ErrorIs(t, err, ErrLockAcquireFailed) + assert.Equal(t, 3, sleepCalled, "should sleep between each retry") +} + +func TestTryLockWithRetry_RetryInterval(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "update.lock") + + var sleepDurations []time.Duration + mockSleep := func(d time.Duration) { + sleepDurations = append(sleepDurations, d) + } + + // Create a lock that will never be released + startTime, _ := getCurrentProcessStartTime() + lockData := LockData{ + PID: os.Getpid(), + ExpireAt: time.Now().Add(time.Hour).Unix(), + OwnerStartTime: startTime, + } + content, _ := json.Marshal(lockData) + os.WriteFile(lockPath, content, 0600) + + lm := NewLockManager(LockConfig{ + LockPath: lockPath, + SleepFunc: mockSleep, + }) + + expectedInterval := 500 * time.Millisecond + _ = lm.TryLockWithRetry(time.Hour, 2, expectedInterval) + + for i, d := range sleepDurations { + assert.Equal(t, expectedInterval, d, "sleep duration %d should match interval", i) + } +} + +// Unlock tests + +func TestUnlock_Success_Owner(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "update.lock") + + lm := NewLockManager(LockConfig{LockPath: lockPath}) + + // Acquire lock first + err := lm.TryLock(time.Hour) + require.NoError(t, err) + + // Release it + err = lm.Unlock() + require.NoError(t, err) + + // Verify lock file is gone + _, err = os.Stat(lockPath) + assert.True(t, os.IsNotExist(err), "lock file should be deleted") +} + +func TestUnlock_Fail_NotOwner(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "update.lock") + + // Create a lock held by another PID + lockData := LockData{ + PID: 99999999, + ExpireAt: time.Now().Add(time.Hour).Unix(), + OwnerStartTime: 12345, + } + content, _ := json.Marshal(lockData) + os.WriteFile(lockPath, content, 0600) + + lm := NewLockManager(LockConfig{LockPath: lockPath}) + + err := lm.Unlock() + assert.ErrorIs(t, err, ErrLockNotOwned) + + // Verify lock file still exists + _, err = os.Stat(lockPath) + assert.NoError(t, err, "lock file should still exist") +} + +func TestUnlock_Success_NoLock(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "update.lock") + + lm := NewLockManager(LockConfig{LockPath: lockPath}) + + // Unlock when no lock exists should be idempotent + err := lm.Unlock() + assert.NoError(t, err, "unlocking non-existent lock should succeed") +} + +func TestUnlock_Success_CorruptedLock(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "update.lock") + + // Create a corrupted lock file + err := os.WriteFile(lockPath, []byte("corrupted"), 0600) + require.NoError(t, err) + + lm := NewLockManager(LockConfig{LockPath: lockPath}) + + // Should delete corrupted lock file + err = lm.Unlock() + assert.NoError(t, err, "should handle corrupted lock gracefully") + + // Verify lock file is gone + _, err = os.Stat(lockPath) + assert.True(t, os.IsNotExist(err), "corrupted lock file should be deleted") +} diff --git a/internal/update/procutil_darwin.go b/internal/update/procutil_darwin.go new file mode 100644 index 0000000..abcc0fa --- /dev/null +++ b/internal/update/procutil_darwin.go @@ -0,0 +1,52 @@ +//go:build darwin + +package update + +import ( + "os" + "syscall" +) + +// isProcessAlive checks if a process with the given PID is alive. +// It sends signal 0 to the process - this doesn't actually send a signal +// but performs error checking to determine if the process exists. +func isProcessAlive(pid int) bool { + if pid <= 0 { + return false + } + + process, err := os.FindProcess(pid) + if err != nil { + return false + } + + // Signal 0 is a special signal that performs error checking without + // actually sending a signal. If the process exists and we have + // permission to signal it, err will be nil. + err = process.Signal(syscall.Signal(0)) + return err == nil +} + +// getProcessStartTime returns the process start time. +// On Darwin, we use a fixed value based on PID for testing purposes. +// The actual implementation will only run on Linux where /proc is available. +// This stub allows tests that don't depend on real start times to pass on macOS. +func getProcessStartTime(pid int) (int64, error) { + if pid <= 0 { + return 0, syscall.ESRCH + } + + // Check if process exists first + if !isProcessAlive(pid) { + return 0, syscall.ESRCH + } + + // Return a deterministic value based on PID for testing + // In production, this only runs on Linux + return int64(pid) * 1000, nil +} + +// getCurrentProcessStartTime returns the current process start time. +func getCurrentProcessStartTime() (int64, error) { + return getProcessStartTime(os.Getpid()) +} diff --git a/internal/update/procutil_linux.go b/internal/update/procutil_linux.go new file mode 100644 index 0000000..25fa0ba --- /dev/null +++ b/internal/update/procutil_linux.go @@ -0,0 +1,73 @@ +//go:build linux + +package update + +import ( + "fmt" + "os" + "strconv" + "strings" + "syscall" +) + +// isProcessAlive checks if a process with the given PID is alive. +// It sends signal 0 to the process - this doesn't actually send a signal +// but performs error checking to determine if the process exists. +func isProcessAlive(pid int) bool { + if pid <= 0 { + return false + } + + process, err := os.FindProcess(pid) + if err != nil { + return false + } + + // Signal 0 is a special signal that performs error checking without + // actually sending a signal. If the process exists and we have + // permission to signal it, err will be nil. + err = process.Signal(syscall.Signal(0)) + return err == nil +} + +// getProcessStartTime returns the process start time in clock ticks since boot. +// Reads from /proc/[pid]/stat field 22 (starttime). +func getProcessStartTime(pid int) (int64, error) { + if pid <= 0 { + return 0, fmt.Errorf("invalid PID: %d", pid) + } + + statPath := fmt.Sprintf("/proc/%d/stat", pid) + content, err := os.ReadFile(statPath) + if err != nil { + return 0, fmt.Errorf("failed to read %s: %w", statPath, err) + } + + // /proc/[pid]/stat format has comm (field 2) in parentheses which may contain spaces + // Find the last ')' to reliably parse fields after comm + data := string(content) + closeParen := strings.LastIndex(data, ")") + if closeParen == -1 { + return 0, fmt.Errorf("invalid format in %s: no closing parenthesis", statPath) + } + + // Fields after ')' are space-separated, starting at field 3 + // Field 22 is starttime, so we need field index 19 (22 - 3 = 19) after the ')' + fieldsAfterComm := strings.Fields(data[closeParen+1:]) + if len(fieldsAfterComm) < 20 { + return 0, fmt.Errorf("invalid format in %s: not enough fields (got %d)", statPath, len(fieldsAfterComm)) + } + + // starttime is field 22, which is index 19 in fieldsAfterComm (0-indexed, starting from field 3) + starttime, err := strconv.ParseInt(fieldsAfterComm[19], 10, 64) + if err != nil { + return 0, fmt.Errorf("failed to parse starttime: %w", err) + } + + return starttime, nil +} + +// getCurrentProcessStartTime returns the current process start time. +func getCurrentProcessStartTime() (int64, error) { + return getProcessStartTime(os.Getpid()) +} diff --git a/internal/update/procutil_test.go b/internal/update/procutil_test.go new file mode 100644 index 0000000..89c1e72 --- /dev/null +++ b/internal/update/procutil_test.go @@ -0,0 +1,66 @@ +package update + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIsProcessAlive_CurrentProcess(t *testing.T) { + // Current process should always be alive + alive := isProcessAlive(os.Getpid()) + assert.True(t, alive, "current process should be alive") +} + +func TestIsProcessAlive_InvalidPID(t *testing.T) { + // Negative PID should not be alive + alive := isProcessAlive(-1) + assert.False(t, alive, "negative PID should not be alive") +} + +func TestIsProcessAlive_NonExistentPID(t *testing.T) { + // Very large PID that almost certainly doesn't exist + alive := isProcessAlive(99999999) + assert.False(t, alive, "non-existent PID should not be alive") +} + +func TestGetProcessStartTime_CurrentProcess(t *testing.T) { + startTime, err := getProcessStartTime(os.Getpid()) + require.NoError(t, err) + assert.Greater(t, startTime, int64(0), "start time should be positive") +} + +func TestGetProcessStartTime_InvalidPID(t *testing.T) { + _, err := getProcessStartTime(-1) + assert.Error(t, err, "should error for invalid PID") +} + +func TestGetProcessStartTime_NonExistentPID(t *testing.T) { + _, err := getProcessStartTime(99999999) + assert.Error(t, err, "should error for non-existent PID") +} + +func TestGetProcessStartTime_Consistency(t *testing.T) { + // Calling twice should return the same value + pid := os.Getpid() + startTime1, err := getProcessStartTime(pid) + require.NoError(t, err) + + startTime2, err := getProcessStartTime(pid) + require.NoError(t, err) + + assert.Equal(t, startTime1, startTime2, "start time should be consistent") +} + +func TestGetCurrentProcessStartTime(t *testing.T) { + startTime, err := getCurrentProcessStartTime() + require.NoError(t, err) + assert.Greater(t, startTime, int64(0), "start time should be positive") + + // Should match getProcessStartTime for current PID + expected, err := getProcessStartTime(os.Getpid()) + require.NoError(t, err) + assert.Equal(t, expected, startTime) +} diff --git a/internal/update/service.go b/internal/update/service.go new file mode 100644 index 0000000..098b4f0 --- /dev/null +++ b/internal/update/service.go @@ -0,0 +1,92 @@ +package update + +import ( + "context" + "fmt" + "os/exec" + "time" +) + +const ( + // DefaultStopTimeout is the default timeout for stopping the service. + DefaultStopTimeout = 30 * time.Second + // DefaultStartTimeout is the default timeout for starting the service. + DefaultStartTimeout = 30 * time.Second +) + +// ExecFunc is a function type for executing external commands. +// It allows injecting mock implementations for testing. +type ExecFunc func(ctx context.Context, name string, args ...string) ([]byte, error) + +// DefaultExecFunc executes commands using os/exec. +func DefaultExecFunc(ctx context.Context, name string, args ...string) ([]byte, error) { + cmd := exec.CommandContext(ctx, name, args...) + return cmd.CombinedOutput() +} + +// ServiceConfig configures the ServiceController. +type ServiceConfig struct { + ServiceName string // Name of the systemd service (e.g., "hostlink") + StopTimeout time.Duration // Timeout for stop operation (default: 30s) + StartTimeout time.Duration // Timeout for start operation (default: 30s) + ExecFunc ExecFunc // Function to execute commands (for testing) +} + +// ServiceController manages systemd service operations. +type ServiceController struct { + config ServiceConfig +} + +// NewServiceController creates a new ServiceController with the given configuration. +func NewServiceController(cfg ServiceConfig) *ServiceController { + // Apply defaults + if cfg.StopTimeout == 0 { + cfg.StopTimeout = DefaultStopTimeout + } + if cfg.StartTimeout == 0 { + cfg.StartTimeout = DefaultStartTimeout + } + if cfg.ExecFunc == nil { + cfg.ExecFunc = DefaultExecFunc + } + + return &ServiceController{config: cfg} +} + +// Stop stops the systemd service. +// It respects the configured timeout and the parent context. +func (s *ServiceController) Stop(ctx context.Context) error { + // Create context with timeout + ctx, cancel := context.WithTimeout(ctx, s.config.StopTimeout) + defer cancel() + + output, err := s.config.ExecFunc(ctx, "systemctl", "stop", s.config.ServiceName) + if err != nil { + // Check if context was cancelled/timed out + if ctx.Err() != nil { + return ctx.Err() + } + return fmt.Errorf("failed to stop service %s: %w (output: %s)", s.config.ServiceName, err, string(output)) + } + + return nil +} + +// Start starts the systemd service. +// It respects the configured timeout and the parent context. +func (s *ServiceController) Start(ctx context.Context) error { + // Create context with timeout + ctx, cancel := context.WithTimeout(ctx, s.config.StartTimeout) + defer cancel() + + output, err := s.config.ExecFunc(ctx, "systemctl", "start", s.config.ServiceName) + if err != nil { + // Check if context was cancelled/timed out + if ctx.Err() != nil { + return ctx.Err() + } + return fmt.Errorf("failed to start service %s: %w (output: %s)", s.config.ServiceName, err, string(output)) + } + + return nil +} diff --git a/internal/update/service_test.go b/internal/update/service_test.go new file mode 100644 index 0000000..a32c74a --- /dev/null +++ b/internal/update/service_test.go @@ -0,0 +1,223 @@ +package update + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockExecResult holds the result for a single command execution +type mockExecResult struct { + output string + err error +} + +// mockExecFunc creates an exec function that returns predefined results +func mockExecFunc(results ...mockExecResult) func(ctx context.Context, name string, args ...string) ([]byte, error) { + idx := 0 + return func(ctx context.Context, name string, args ...string) ([]byte, error) { + if idx >= len(results) { + return nil, errors.New("unexpected call to exec") + } + result := results[idx] + idx++ + return []byte(result.output), result.err + } +} + +// recordingExecFunc records all calls made to the exec function +type recordingExec struct { + calls []execCall + results []mockExecResult + idx int +} + +type execCall struct { + name string + args []string +} + +func newRecordingExec(results ...mockExecResult) *recordingExec { + return &recordingExec{results: results} +} + +func (r *recordingExec) exec(ctx context.Context, name string, args ...string) ([]byte, error) { + r.calls = append(r.calls, execCall{name: name, args: args}) + if r.idx >= len(r.results) { + return nil, errors.New("unexpected call to exec") + } + result := r.results[r.idx] + r.idx++ + return []byte(result.output), result.err +} + +func TestServiceController_Stop_CallsSystemctl(t *testing.T) { + recorder := newRecordingExec(mockExecResult{output: "", err: nil}) + sc := NewServiceController(ServiceConfig{ + ServiceName: "hostlink", + ExecFunc: recorder.exec, + }) + + err := sc.Stop(context.Background()) + + require.NoError(t, err) + require.Len(t, recorder.calls, 1) + assert.Equal(t, "systemctl", recorder.calls[0].name) + assert.Equal(t, []string{"stop", "hostlink"}, recorder.calls[0].args) +} + +func TestServiceController_Stop_ReturnsErrorOnFailure(t *testing.T) { + recorder := newRecordingExec(mockExecResult{ + output: "Failed to stop hostlink.service: Unit hostlink.service not loaded.", + err: errors.New("exit status 5"), + }) + sc := NewServiceController(ServiceConfig{ + ServiceName: "hostlink", + ExecFunc: recorder.exec, + }) + + err := sc.Stop(context.Background()) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to stop service") +} + +func TestServiceController_Stop_RespectsTimeout(t *testing.T) { + // Create a context that's already cancelled + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + execCalled := false + sc := NewServiceController(ServiceConfig{ + ServiceName: "hostlink", + ExecFunc: func(ctx context.Context, name string, args ...string) ([]byte, error) { + execCalled = true + // Check if context is done + if ctx.Err() != nil { + return nil, ctx.Err() + } + return nil, nil + }, + }) + + err := sc.Stop(ctx) + + assert.True(t, execCalled) + assert.ErrorIs(t, err, context.Canceled) +} + +func TestServiceController_Stop_HandlesAlreadyStopped(t *testing.T) { + // When service is already stopped, systemctl stop still returns success + recorder := newRecordingExec(mockExecResult{output: "", err: nil}) + sc := NewServiceController(ServiceConfig{ + ServiceName: "hostlink", + ExecFunc: recorder.exec, + }) + + err := sc.Stop(context.Background()) + + require.NoError(t, err) +} + +func TestServiceController_Stop_UsesConfiguredTimeout(t *testing.T) { + var capturedCtx context.Context + sc := NewServiceController(ServiceConfig{ + ServiceName: "hostlink", + StopTimeout: 15 * time.Second, + ExecFunc: func(ctx context.Context, name string, args ...string) ([]byte, error) { + capturedCtx = ctx + return nil, nil + }, + }) + + _ = sc.Stop(context.Background()) + + // Verify context has a deadline + deadline, ok := capturedCtx.Deadline() + require.True(t, ok, "context should have a deadline") + // Deadline should be roughly 15 seconds from now (give some margin) + assert.WithinDuration(t, time.Now().Add(15*time.Second), deadline, 2*time.Second) +} + +func TestServiceController_Start_CallsSystemctl(t *testing.T) { + recorder := newRecordingExec(mockExecResult{output: "", err: nil}) + sc := NewServiceController(ServiceConfig{ + ServiceName: "hostlink", + ExecFunc: recorder.exec, + }) + + err := sc.Start(context.Background()) + + require.NoError(t, err) + require.Len(t, recorder.calls, 1) + assert.Equal(t, "systemctl", recorder.calls[0].name) + assert.Equal(t, []string{"start", "hostlink"}, recorder.calls[0].args) +} + +func TestServiceController_Start_ReturnsErrorOnFailure(t *testing.T) { + recorder := newRecordingExec(mockExecResult{ + output: "Failed to start hostlink.service: Unit hostlink.service not found.", + err: errors.New("exit status 5"), + }) + sc := NewServiceController(ServiceConfig{ + ServiceName: "hostlink", + ExecFunc: recorder.exec, + }) + + err := sc.Start(context.Background()) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to start service") +} + +func TestServiceController_Start_RespectsTimeout(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + sc := NewServiceController(ServiceConfig{ + ServiceName: "hostlink", + ExecFunc: func(ctx context.Context, name string, args ...string) ([]byte, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + return nil, nil + }, + }) + + err := sc.Start(ctx) + + assert.ErrorIs(t, err, context.Canceled) +} + +func TestServiceController_Start_UsesConfiguredTimeout(t *testing.T) { + var capturedCtx context.Context + sc := NewServiceController(ServiceConfig{ + ServiceName: "hostlink", + StartTimeout: 20 * time.Second, + ExecFunc: func(ctx context.Context, name string, args ...string) ([]byte, error) { + capturedCtx = ctx + return nil, nil + }, + }) + + _ = sc.Start(context.Background()) + + deadline, ok := capturedCtx.Deadline() + require.True(t, ok, "context should have a deadline") + assert.WithinDuration(t, time.Now().Add(20*time.Second), deadline, 2*time.Second) +} + +func TestServiceController_DefaultTimeouts(t *testing.T) { + sc := NewServiceController(ServiceConfig{ + ServiceName: "hostlink", + ExecFunc: func(ctx context.Context, name string, args ...string) ([]byte, error) { return nil, nil }, + }) + + // Verify defaults are applied + assert.Equal(t, 30*time.Second, sc.config.StopTimeout) + assert.Equal(t, 30*time.Second, sc.config.StartTimeout) +} diff --git a/internal/update/spawn.go b/internal/update/spawn.go new file mode 100644 index 0000000..9a9ed04 --- /dev/null +++ b/internal/update/spawn.go @@ -0,0 +1,30 @@ +package update + +import ( + "os/exec" + "syscall" +) + +// SpawnUpdater starts the updater binary in its own process group. +// The updater survives the agent's shutdown because Setpgid: true +// places it in a new process group that systemd won't kill. +// This is fire-and-forget: the caller does not wait for the process to exit. +func SpawnUpdater(updaterPath string, args []string) error { + cmd, err := spawnWithCmd(updaterPath, args) + if err != nil { + return err + } + _ = cmd // Process started; caller does not manage it. + return nil +} + +// spawnWithCmd is the internal implementation that returns the exec.Cmd +// for testing purposes (to inspect the child PID/PGID). +func spawnWithCmd(updaterPath string, args []string) (*exec.Cmd, error) { + cmd := exec.Command(updaterPath, args...) + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + if err := cmd.Start(); err != nil { + return nil, err + } + return cmd, nil +} diff --git a/internal/update/spawn_test.go b/internal/update/spawn_test.go new file mode 100644 index 0000000..fa59f79 --- /dev/null +++ b/internal/update/spawn_test.go @@ -0,0 +1,44 @@ +package update + +import ( + "os" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSpawnUpdater_StartsProcess(t *testing.T) { + err := SpawnUpdater("/bin/sleep", []string{"0.1"}) + require.NoError(t, err) +} + +func TestSpawnUpdater_SetpgidTrue(t *testing.T) { + // Use a process that prints its own PGID so we can verify + // We spawn "sleep 2" and check its PGID differs from ours + cmd, err := spawnWithCmd("/bin/sleep", []string{"0.5"}) + require.NoError(t, err) + require.NotNil(t, cmd) + require.NotNil(t, cmd.Process) + + defer cmd.Process.Kill() + defer cmd.Wait() + + childPID := cmd.Process.Pid + childPGID, err := syscall.Getpgid(childPID) + require.NoError(t, err) + + parentPGID, err := syscall.Getpgid(os.Getpid()) + require.NoError(t, err) + + assert.NotEqual(t, parentPGID, childPGID, + "child PGID should differ from parent PGID (Setpgid: true)") + assert.Equal(t, childPID, childPGID, + "child should be its own process group leader (PGID == PID)") +} + +func TestSpawnUpdater_ReturnsErrorForInvalidBinary(t *testing.T) { + err := SpawnUpdater("/nonexistent/binary", []string{}) + assert.Error(t, err) +} diff --git a/internal/update/state.go b/internal/update/state.go new file mode 100644 index 0000000..abd96cc --- /dev/null +++ b/internal/update/state.go @@ -0,0 +1,104 @@ +package update + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "time" +) + +// State represents the current state of an update operation. +type State string + +const ( + StateNotStarted State = "NotStarted" + StateInitialized State = "Initialized" + StateStaged State = "Staged" + StateInstalled State = "Installed" + StateCompleted State = "Completed" + StateRollback State = "Rollback" + StateRolledBack State = "RolledBack" +) + +// StateData represents the update state persisted to disk. +type StateData struct { + UpdateID string `json:"update_id"` + State State `json:"state"` + SourceVersion string `json:"source_version"` + TargetVersion string `json:"target_version"` + StartedAt time.Time `json:"started_at"` + CompletedAt *time.Time `json:"completed_at,omitempty"` + Error *string `json:"error,omitempty"` +} + +// StateWriter manages the update state file for observability. +type StateWriter struct { + statePath string +} + +// StateConfig holds configuration for creating a StateWriter. +type StateConfig struct { + StatePath string // e.g., /var/lib/hostlink/updates/state.json +} + +// NewStateWriter creates a new StateWriter with the given configuration. +func NewStateWriter(cfg StateConfig) *StateWriter { + return &StateWriter{ + statePath: cfg.StatePath, + } +} + +// Write persists the state data to disk atomically. +// Uses temp file + rename for atomic write. +func (s *StateWriter) Write(data StateData) error { + // Ensure parent directory exists + dir := filepath.Dir(s.statePath) + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("failed to create state directory: %w", err) + } + + // Marshal to human-readable JSON + content, err := json.MarshalIndent(data, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal state data: %w", err) + } + + // Write to temp file first + randSuffix, err := randomString(8) + if err != nil { + return err + } + tmpFile := s.statePath + ".tmp." + randSuffix + if err := os.WriteFile(tmpFile, content, 0600); err != nil { + return fmt.Errorf("failed to write temp state file: %w", err) + } + + // Atomic rename + if err := os.Rename(tmpFile, s.statePath); err != nil { + os.Remove(tmpFile) // Clean up on failure + return fmt.Errorf("failed to rename state file: %w", err) + } + + return nil +} + +// Read loads the state data from disk. +// Returns zero-value StateData and no error if file doesn't exist. +// Returns error if file exists but is corrupted. +func (s *StateWriter) Read() (StateData, error) { + content, err := os.ReadFile(s.statePath) + if err != nil { + if os.IsNotExist(err) { + return StateData{}, nil + } + return StateData{}, fmt.Errorf("failed to read state file: %w", err) + } + + var data StateData + if err := json.Unmarshal(content, &data); err != nil { + return StateData{}, fmt.Errorf("failed to parse state file: %w", err) + } + + return data, nil +} diff --git a/internal/update/state_test.go b/internal/update/state_test.go new file mode 100644 index 0000000..23fcf4d --- /dev/null +++ b/internal/update/state_test.go @@ -0,0 +1,258 @@ +package update + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStateWriter_Write_ValidState(t *testing.T) { + tmpDir := t.TempDir() + statePath := filepath.Join(tmpDir, "state.json") + + sw := NewStateWriter(StateConfig{StatePath: statePath}) + + now := time.Now().Truncate(time.Second) + data := StateData{ + UpdateID: "test-update-123", + State: StateInitialized, + SourceVersion: "v0.5.5", + TargetVersion: "v0.6.0", + StartedAt: now, + } + + err := sw.Write(data) + require.NoError(t, err) + + // Verify file exists and has correct content + content, err := os.ReadFile(statePath) + require.NoError(t, err) + + var readData StateData + err = json.Unmarshal(content, &readData) + require.NoError(t, err) + + assert.Equal(t, data.UpdateID, readData.UpdateID) + assert.Equal(t, data.State, readData.State) + assert.Equal(t, data.SourceVersion, readData.SourceVersion) + assert.Equal(t, data.TargetVersion, readData.TargetVersion) +} + +func TestStateWriter_Write_AtomicWrite(t *testing.T) { + tmpDir := t.TempDir() + statePath := filepath.Join(tmpDir, "state.json") + + sw := NewStateWriter(StateConfig{StatePath: statePath}) + + data := StateData{ + UpdateID: "test-123", + State: StateInitialized, + SourceVersion: "v0.5.5", + TargetVersion: "v0.6.0", + StartedAt: time.Now(), + } + + err := sw.Write(data) + require.NoError(t, err) + + // Check no temp files left behind + entries, err := os.ReadDir(tmpDir) + require.NoError(t, err) + + for _, entry := range entries { + assert.Equal(t, "state.json", entry.Name(), "only state.json should exist") + } +} + +func TestStateWriter_Write_OverwritesExisting(t *testing.T) { + tmpDir := t.TempDir() + statePath := filepath.Join(tmpDir, "state.json") + + sw := NewStateWriter(StateConfig{StatePath: statePath}) + + // Write first state + data1 := StateData{ + UpdateID: "first-update", + State: StateInitialized, + SourceVersion: "v0.5.5", + TargetVersion: "v0.6.0", + StartedAt: time.Now(), + } + err := sw.Write(data1) + require.NoError(t, err) + + // Write second state + completedAt := time.Now() + data2 := StateData{ + UpdateID: "first-update", + State: StateCompleted, + SourceVersion: "v0.5.5", + TargetVersion: "v0.6.0", + StartedAt: data1.StartedAt, + CompletedAt: &completedAt, + } + err = sw.Write(data2) + require.NoError(t, err) + + // Read and verify second state + readData, err := sw.Read() + require.NoError(t, err) + assert.Equal(t, StateCompleted, readData.State) + assert.NotNil(t, readData.CompletedAt) +} + +func TestStateWriter_Write_CreatesParentDir(t *testing.T) { + tmpDir := t.TempDir() + statePath := filepath.Join(tmpDir, "nested", "dir", "state.json") + + sw := NewStateWriter(StateConfig{StatePath: statePath}) + + data := StateData{ + UpdateID: "test-123", + State: StateInitialized, + SourceVersion: "v0.5.5", + TargetVersion: "v0.6.0", + StartedAt: time.Now(), + } + + err := sw.Write(data) + require.NoError(t, err) + + // Verify file exists + _, err = os.Stat(statePath) + assert.NoError(t, err) +} + +func TestStateWriter_Write_CorrectPermissions(t *testing.T) { + tmpDir := t.TempDir() + statePath := filepath.Join(tmpDir, "state.json") + + sw := NewStateWriter(StateConfig{StatePath: statePath}) + + data := StateData{ + UpdateID: "test-123", + State: StateInitialized, + SourceVersion: "v0.5.5", + TargetVersion: "v0.6.0", + StartedAt: time.Now(), + } + + err := sw.Write(data) + require.NoError(t, err) + + info, err := os.Stat(statePath) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0600), info.Mode().Perm(), "file should have 0600 permissions") +} + +func TestStateWriter_Write_WithError(t *testing.T) { + tmpDir := t.TempDir() + statePath := filepath.Join(tmpDir, "state.json") + + sw := NewStateWriter(StateConfig{StatePath: statePath}) + + errMsg := "health check failed: version mismatch" + data := StateData{ + UpdateID: "test-123", + State: StateRolledBack, + SourceVersion: "v0.5.5", + TargetVersion: "v0.6.0", + StartedAt: time.Now(), + Error: &errMsg, + } + + err := sw.Write(data) + require.NoError(t, err) + + readData, err := sw.Read() + require.NoError(t, err) + require.NotNil(t, readData.Error) + assert.Equal(t, errMsg, *readData.Error) +} + +func TestStateWriter_Read_ExistingFile(t *testing.T) { + tmpDir := t.TempDir() + statePath := filepath.Join(tmpDir, "state.json") + + sw := NewStateWriter(StateConfig{StatePath: statePath}) + + // Write state first + now := time.Now().Truncate(time.Second) + data := StateData{ + UpdateID: "test-123", + State: StateCompleted, + SourceVersion: "v0.5.5", + TargetVersion: "v0.6.0", + StartedAt: now, + } + err := sw.Write(data) + require.NoError(t, err) + + // Read it back + readData, err := sw.Read() + require.NoError(t, err) + + assert.Equal(t, data.UpdateID, readData.UpdateID) + assert.Equal(t, data.State, readData.State) + assert.Equal(t, data.SourceVersion, readData.SourceVersion) + assert.Equal(t, data.TargetVersion, readData.TargetVersion) +} + +func TestStateWriter_Read_NonExistentFile(t *testing.T) { + tmpDir := t.TempDir() + statePath := filepath.Join(tmpDir, "state.json") + + sw := NewStateWriter(StateConfig{StatePath: statePath}) + + // Read non-existent file should return zero value + data, err := sw.Read() + require.NoError(t, err) + + assert.Equal(t, StateData{}, data) +} + +func TestStateWriter_Read_CorruptedFile(t *testing.T) { + tmpDir := t.TempDir() + statePath := filepath.Join(tmpDir, "state.json") + + // Write corrupted content + err := os.WriteFile(statePath, []byte("not valid json"), 0600) + require.NoError(t, err) + + sw := NewStateWriter(StateConfig{StatePath: statePath}) + + // Should return error for corrupted file + _, err = sw.Read() + assert.Error(t, err) +} + +func TestStateWriter_HumanReadableJSON(t *testing.T) { + tmpDir := t.TempDir() + statePath := filepath.Join(tmpDir, "state.json") + + sw := NewStateWriter(StateConfig{StatePath: statePath}) + + data := StateData{ + UpdateID: "test-123", + State: StateInitialized, + SourceVersion: "v0.5.5", + TargetVersion: "v0.6.0", + StartedAt: time.Now(), + } + + err := sw.Write(data) + require.NoError(t, err) + + // Read raw content and check it's indented (human-readable) + content, err := os.ReadFile(statePath) + require.NoError(t, err) + + // MarshalIndent produces newlines and spaces + assert.Contains(t, string(content), "\n", "JSON should be formatted with newlines") + assert.Contains(t, string(content), " ", "JSON should be indented") +} diff --git a/internal/versionutil/version.go b/internal/versionutil/version.go new file mode 100644 index 0000000..e18372e --- /dev/null +++ b/internal/versionutil/version.go @@ -0,0 +1,95 @@ +// Package versionutil provides utilities for parsing and comparing semantic versions. +package versionutil + +import ( + "fmt" + "strconv" + "strings" +) + +// Parse parses a semantic version string (e.g., "v1.2.3" or "1.2.3") +// and returns major, minor, patch components. +func Parse(version string) (major, minor, patch int, err error) { + if version == "" { + return 0, 0, 0, fmt.Errorf("empty version string") + } + + // Strip optional 'v' prefix + v := strings.TrimPrefix(version, "v") + if v == "" { + return 0, 0, 0, fmt.Errorf("invalid version format: %q", version) + } + + parts := strings.Split(v, ".") + if len(parts) != 3 { + return 0, 0, 0, fmt.Errorf("invalid version format: %q (expected major.minor.patch)", version) + } + + major, err = strconv.Atoi(parts[0]) + if err != nil { + return 0, 0, 0, fmt.Errorf("invalid major version in %q: %w", version, err) + } + + minor, err = strconv.Atoi(parts[1]) + if err != nil { + return 0, 0, 0, fmt.Errorf("invalid minor version in %q: %w", version, err) + } + + patch, err = strconv.Atoi(parts[2]) + if err != nil { + return 0, 0, 0, fmt.Errorf("invalid patch version in %q: %w", version, err) + } + + return major, minor, patch, nil +} + +// Compare compares two semantic version strings. +// Returns -1 if v1 < v2, 0 if v1 == v2, 1 if v1 > v2. +// Returns an error if either version string is invalid. +func Compare(v1, v2 string) (int, error) { + major1, minor1, patch1, err := Parse(v1) + if err != nil { + return 0, fmt.Errorf("invalid version v1: %w", err) + } + + major2, minor2, patch2, err := Parse(v2) + if err != nil { + return 0, fmt.Errorf("invalid version v2: %w", err) + } + + // Compare major + if major1 < major2 { + return -1, nil + } + if major1 > major2 { + return 1, nil + } + + // Compare minor + if minor1 < minor2 { + return -1, nil + } + if minor1 > minor2 { + return 1, nil + } + + // Compare patch + if patch1 < patch2 { + return -1, nil + } + if patch1 > patch2 { + return 1, nil + } + + return 0, nil +} + +// IsNewer returns true if candidate version is newer than current version. +// Returns an error if either version string is invalid. +func IsNewer(candidate, current string) (bool, error) { + cmp, err := Compare(candidate, current) + if err != nil { + return false, err + } + return cmp > 0, nil +} diff --git a/internal/versionutil/version_test.go b/internal/versionutil/version_test.go new file mode 100644 index 0000000..93bfece --- /dev/null +++ b/internal/versionutil/version_test.go @@ -0,0 +1,226 @@ +package versionutil + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParse_WithVPrefix(t *testing.T) { + major, minor, patch, err := Parse("v1.2.3") + require.NoError(t, err) + assert.Equal(t, 1, major) + assert.Equal(t, 2, minor) + assert.Equal(t, 3, patch) +} + +func TestParse_WithoutVPrefix(t *testing.T) { + major, minor, patch, err := Parse("1.2.3") + require.NoError(t, err) + assert.Equal(t, 1, major) + assert.Equal(t, 2, minor) + assert.Equal(t, 3, patch) +} + +func TestParse_ZeroVersion(t *testing.T) { + major, minor, patch, err := Parse("v0.6.0") + require.NoError(t, err) + assert.Equal(t, 0, major) + assert.Equal(t, 6, minor) + assert.Equal(t, 0, patch) +} + +func TestParse_LargeNumbers(t *testing.T) { + major, minor, patch, err := Parse("v10.20.30") + require.NoError(t, err) + assert.Equal(t, 10, major) + assert.Equal(t, 20, minor) + assert.Equal(t, 30, patch) +} + +func TestParse_InvalidFormat_TwoParts(t *testing.T) { + _, _, _, err := Parse("v1.2") + assert.Error(t, err) +} + +func TestParse_InvalidFormat_FourParts(t *testing.T) { + _, _, _, err := Parse("v1.2.3.4") + assert.Error(t, err) +} + +func TestParse_InvalidFormat_NonNumeric(t *testing.T) { + _, _, _, err := Parse("v1.2.abc") + assert.Error(t, err) +} + +func TestParse_InvalidFormat_Empty(t *testing.T) { + _, _, _, err := Parse("") + assert.Error(t, err) +} + +func TestParse_InvalidFormat_JustV(t *testing.T) { + _, _, _, err := Parse("v") + assert.Error(t, err) +} + +func TestCompare_FirstLessThanSecond(t *testing.T) { + tests := []struct { + v1, v2 string + }{ + {"v1.0.0", "v2.0.0"}, + {"v1.1.0", "v1.2.0"}, + {"v1.1.1", "v1.1.2"}, + {"v0.5.5", "v0.6.0"}, + {"v0.0.1", "v0.0.2"}, + } + + for _, tt := range tests { + t.Run(tt.v1+"_vs_"+tt.v2, func(t *testing.T) { + result, err := Compare(tt.v1, tt.v2) + require.NoError(t, err) + assert.Equal(t, -1, result, "%s should be less than %s", tt.v1, tt.v2) + }) + } +} + +func TestCompare_FirstGreaterThanSecond(t *testing.T) { + tests := []struct { + v1, v2 string + }{ + {"v2.0.0", "v1.0.0"}, + {"v1.2.0", "v1.1.0"}, + {"v1.1.2", "v1.1.1"}, + {"v0.6.0", "v0.5.5"}, + {"v1.0.0", "v0.99.99"}, + } + + for _, tt := range tests { + t.Run(tt.v1+"_vs_"+tt.v2, func(t *testing.T) { + result, err := Compare(tt.v1, tt.v2) + require.NoError(t, err) + assert.Equal(t, 1, result, "%s should be greater than %s", tt.v1, tt.v2) + }) + } +} + +func TestCompare_Equal(t *testing.T) { + tests := []struct { + v1, v2 string + }{ + {"v1.0.0", "v1.0.0"}, + {"v0.6.0", "v0.6.0"}, + {"1.2.3", "v1.2.3"}, // With and without v prefix + {"v1.2.3", "1.2.3"}, + } + + for _, tt := range tests { + t.Run(tt.v1+"_vs_"+tt.v2, func(t *testing.T) { + result, err := Compare(tt.v1, tt.v2) + require.NoError(t, err) + assert.Equal(t, 0, result, "%s should equal %s", tt.v1, tt.v2) + }) + } +} + +func TestCompare_HandlesMissingVPrefix(t *testing.T) { + result, err := Compare("1.2.3", "1.2.4") + require.NoError(t, err) + assert.Equal(t, -1, result) + + result, err = Compare("1.2.4", "1.2.3") + require.NoError(t, err) + assert.Equal(t, 1, result) +} + +func TestCompare_ErrorOnInvalidV1(t *testing.T) { + _, err := Compare("invalid", "v1.0.0") + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid version v1") +} + +func TestCompare_ErrorOnInvalidV2(t *testing.T) { + _, err := Compare("v1.0.0", "invalid") + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid version v2") +} + +func TestCompare_ErrorOnMalformedVersions(t *testing.T) { + tests := []struct { + name string + v1, v2 string + }{ + {"two parts", "v1.2", "v1.0.0"}, + {"prerelease suffix", "v1.2.3-beta", "v1.0.0"}, + {"build metadata", "v1.2.3+build", "v1.0.0"}, + {"empty string", "", "v1.0.0"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := Compare(tt.v1, tt.v2) + assert.Error(t, err, "Compare(%q, %q) should return error", tt.v1, tt.v2) + }) + } +} + +func TestIsNewer_True(t *testing.T) { + tests := []struct { + candidate, current string + }{ + {"v0.6.0", "v0.5.5"}, + {"v1.0.0", "v0.9.9"}, + {"v2.0.0", "v1.99.99"}, + {"v1.1.0", "v1.0.0"}, + {"v1.0.1", "v1.0.0"}, + } + + for _, tt := range tests { + t.Run(tt.candidate+"_newer_than_"+tt.current, func(t *testing.T) { + result, err := IsNewer(tt.candidate, tt.current) + require.NoError(t, err) + assert.True(t, result, "%s should be newer than %s", tt.candidate, tt.current) + }) + } +} + +func TestIsNewer_False_Older(t *testing.T) { + tests := []struct { + candidate, current string + }{ + {"v0.5.5", "v0.6.0"}, + {"v0.9.9", "v1.0.0"}, + {"v1.0.0", "v1.0.1"}, + } + + for _, tt := range tests { + t.Run(tt.candidate+"_not_newer_than_"+tt.current, func(t *testing.T) { + result, err := IsNewer(tt.candidate, tt.current) + require.NoError(t, err) + assert.False(t, result, "%s should not be newer than %s", tt.candidate, tt.current) + }) + } +} + +func TestIsNewer_False_Equal(t *testing.T) { + result, err := IsNewer("v0.6.0", "v0.6.0") + require.NoError(t, err) + assert.False(t, result, "same version should not be newer") +} + +func TestIsNewer_ErrorOnInvalidCandidate(t *testing.T) { + _, err := IsNewer("v1.2", "v1.0.0") + assert.Error(t, err) +} + +func TestIsNewer_ErrorOnInvalidCurrent(t *testing.T) { + _, err := IsNewer("v1.0.0", "v1.2") + assert.Error(t, err) +} + +func TestIsNewer_ErrorOnPrereleaseVersion(t *testing.T) { + // Prerelease versions like "v1.2.3-beta" should return error + // because our simple parser doesn't support them + _, err := IsNewer("v1.2.3-beta", "v1.0.0") + assert.Error(t, err, "prerelease versions should return error") +} diff --git a/scripts/linux/hostlink.service b/scripts/linux/hostlink.service index 5dcfefd..ab70a7a 100644 --- a/scripts/linux/hostlink.service +++ b/scripts/linux/hostlink.service @@ -4,11 +4,12 @@ After=network-online.target [Service] Type=simple +KillMode=process WorkingDirectory=/usr/bin/ EnvironmentFile=/etc/hostlink/hostlink.env ExecStart=/usr/bin/hostlink Restart=always -RestartSec=60 +RestartSec=90 [Install] WantedBy=multi-user.target diff --git a/test/integration/selfupdate_test.go b/test/integration/selfupdate_test.go new file mode 100644 index 0000000..4e29d0b --- /dev/null +++ b/test/integration/selfupdate_test.go @@ -0,0 +1,260 @@ +//go:build integration +// +build integration + +package integration + +import ( + "context" + "encoding/json" + "os" + "os/exec" + "path/filepath" + "syscall" + "testing" + "time" + + "hostlink/internal/update" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// buildUpdaterBinary compiles the hostlink-updater binary into the given directory. +// Returns the path to the compiled binary. +func buildUpdaterBinary(t *testing.T, outputDir string) string { + t.Helper() + binaryPath := filepath.Join(outputDir, "hostlink-updater") + cmd := exec.Command("go", "build", "-o", binaryPath, "./cmd/updater") + cmd.Dir = findProjectRoot(t) + output, err := cmd.CombinedOutput() + require.NoError(t, err, "failed to build hostlink-updater: %s", string(output)) + return binaryPath +} + +// findProjectRoot returns the project root directory by looking for go.mod. +func findProjectRoot(t *testing.T) string { + t.Helper() + dir, err := os.Getwd() + require.NoError(t, err) + for { + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + return dir + } + parent := filepath.Dir(dir) + if parent == dir { + t.Fatal("could not find project root (no go.mod found)") + } + dir = parent + } +} + +// setupUpdateDirs creates the update directory structure in a temp dir. +// Returns the base directory path. +func setupUpdateDirs(t *testing.T) string { + t.Helper() + baseDir := t.TempDir() + err := update.InitDirectories(baseDir) + require.NoError(t, err) + return baseDir +} + +// writeStateFile writes a state.json file into the base directory. +func writeStateFile(t *testing.T, baseDir string, data update.StateData) { + t.Helper() + paths := update.NewPaths(baseDir) + sw := update.NewStateWriter(update.StateConfig{StatePath: paths.StateFile}) + err := sw.Write(data) + require.NoError(t, err) +} + +// readStateFile reads and returns the state.json contents from the base directory. +func readStateFile(t *testing.T, baseDir string) update.StateData { + t.Helper() + paths := update.NewPaths(baseDir) + sw := update.NewStateWriter(update.StateConfig{StatePath: paths.StateFile}) + data, err := sw.Read() + require.NoError(t, err) + return data +} + +func TestSelfUpdate_LockPreventsConcurrentUpdates(t *testing.T) { + baseDir := setupUpdateDirs(t) + paths := update.NewPaths(baseDir) + + // Acquire lock from this process + lock := update.NewLockManager(update.LockConfig{LockPath: paths.LockFile}) + err := lock.TryLock(5 * time.Minute) + require.NoError(t, err) + defer lock.Unlock() + + // Build the updater binary + binDir := t.TempDir() + updaterBin := buildUpdaterBinary(t, binDir) + + // Write state file so updater can read target version + writeStateFile(t, baseDir, update.StateData{ + State: update.StateStaged, + TargetVersion: "1.2.3", + }) + + // Attempt to run the updater — it should fail because the lock is held + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, updaterBin, + "-base-dir", baseDir, + "-version", "1.2.3", + ) + output, err := cmd.CombinedOutput() + require.Error(t, err, "updater should fail when lock is held") + assert.Contains(t, string(output), "lock", "error should mention lock") +} + +func TestSelfUpdate_SignalHandlingDuringUpdate(t *testing.T) { + baseDir := setupUpdateDirs(t) + + // Build the updater binary + binDir := t.TempDir() + updaterBin := buildUpdaterBinary(t, binDir) + + // Write state file with target version + writeStateFile(t, baseDir, update.StateData{ + State: update.StateStaged, + TargetVersion: "1.2.3", + }) + + // Start the updater process — it will try to stop the service via systemctl + // which will take time / fail, giving us time to send a signal + cmd := exec.Command(updaterBin, + "-base-dir", baseDir, + "-version", "1.2.3", + "-binary", "/nonexistent/hostlink", // non-existent binary, will fail at stop + ) + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + + err := cmd.Start() + require.NoError(t, err, "should start updater process") + + // Give it a moment to start + time.Sleep(100 * time.Millisecond) + + // Send SIGTERM + err = cmd.Process.Signal(syscall.SIGTERM) + require.NoError(t, err, "should send SIGTERM") + + // Wait for exit — should exit within a few seconds + done := make(chan error, 1) + go func() { done <- cmd.Wait() }() + + select { + case err := <-done: + // Process exited (may be non-zero exit code due to cancellation, that's ok) + if err != nil { + // Verify it's an exit error, not something unexpected + _, ok := err.(*exec.ExitError) + assert.True(t, ok, "expected ExitError after signal, got: %v", err) + } + case <-time.After(10 * time.Second): + cmd.Process.Kill() + t.Fatal("updater did not exit within 10 seconds after SIGTERM") + } +} + +func TestSelfUpdate_UpdaterWritesStateOnLockFailure(t *testing.T) { + baseDir := setupUpdateDirs(t) + paths := update.NewPaths(baseDir) + + // Acquire lock from this process to block the updater + lock := update.NewLockManager(update.LockConfig{LockPath: paths.LockFile}) + err := lock.TryLock(5 * time.Minute) + require.NoError(t, err) + defer lock.Unlock() + + // Build the updater + binDir := t.TempDir() + updaterBin := buildUpdaterBinary(t, binDir) + + // Run the updater — should fail on lock acquisition + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, updaterBin, + "-base-dir", baseDir, + "-version", "2.0.0", + ) + _, err = cmd.CombinedOutput() + require.Error(t, err, "updater should fail when lock is held") + + // The lock file should still be held by us (not stolen) + lockContent, err := os.ReadFile(paths.LockFile) + require.NoError(t, err) + + var lockData struct { + PID int `json:"pid"` + } + err = json.Unmarshal(lockContent, &lockData) + require.NoError(t, err) + assert.Equal(t, os.Getpid(), lockData.PID, "lock should still be held by test process") +} + +func TestSelfUpdate_UpdaterExitsWithErrorForMissingVersion(t *testing.T) { + baseDir := setupUpdateDirs(t) + + // Build the updater + binDir := t.TempDir() + updaterBin := buildUpdaterBinary(t, binDir) + + // Run without -version and without state file — should fail + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, updaterBin, + "-base-dir", baseDir, + ) + output, err := cmd.CombinedOutput() + require.Error(t, err, "updater should fail without version") + assert.Contains(t, string(output), "version", "error should mention version") +} + +func TestSelfUpdate_UpdaterReadsVersionFromState(t *testing.T) { + baseDir := setupUpdateDirs(t) + + // Write state file with target version + writeStateFile(t, baseDir, update.StateData{ + State: update.StateStaged, + TargetVersion: "3.0.0", + }) + + // Build the updater + binDir := t.TempDir() + updaterBin := buildUpdaterBinary(t, binDir) + + // Run without -version flag but with state file containing version + // It should read the version from state and proceed (then fail at systemctl stop) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, updaterBin, + "-base-dir", baseDir, + "-binary", "/nonexistent/hostlink", + ) + output, err := cmd.CombinedOutput() + // Should fail (no systemctl), but NOT because of missing version + require.Error(t, err) + assert.NotContains(t, string(output), "target version is required", + "should have read version from state file") +} + +func TestSelfUpdate_UpdaterPrintVersion(t *testing.T) { + // Build the updater + binDir := t.TempDir() + updaterBin := buildUpdaterBinary(t, binDir) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, updaterBin, "-v") + output, err := cmd.CombinedOutput() + require.NoError(t, err, "version flag should not fail") + assert.Contains(t, string(output), "hostlink-updater", "should print version info") +}