From a6b13fc3ac851733ce1e036701a0c328ef6dae11 Mon Sep 17 00:00:00 2001 From: Qingqing Zheng Date: Thu, 19 Feb 2026 13:41:21 -0800 Subject: [PATCH 1/5] add k8s version drift detection and remediation --- commands.go | 269 +++++++++++++----- docs/usage.md | 19 +- .../kube_binaries/kube_binaries_installer.go | 13 +- pkg/components/kubelet/kubelet_installer.go | 20 +- .../services/kubelet_only_installer.go | 58 ++++ .../services/kubelet_only_uninstaller.go | 49 ++++ pkg/config/agent_flags_test.go | 30 ++ pkg/config/config.go | 4 + pkg/config/config_test.go | 16 ++ pkg/config/copy.go | 52 ++++ pkg/config/copy_test.go | 123 ++++++++ pkg/config/structs.go | 18 ++ pkg/drift/defaults.go | 11 + pkg/drift/defaults_test.go | 15 + pkg/drift/detector.go | 80 ++++++ pkg/drift/detector_test.go | 49 ++++ pkg/drift/kubernetes_version.go | 83 ++++++ pkg/drift/kubernetes_version_test.go | 122 ++++++++ pkg/drift/remediation.go | 236 +++++++++++++++ pkg/drift/remediation_test.go | 177 ++++++++++++ pkg/drift/version.go | 55 ++++ pkg/drift/version_test.go | 74 +++++ pkg/spec/kubernetes_version_override.go | 57 ++++ pkg/spec/kubernetes_version_override_test.go | 87 ++++++ pkg/spec/loader.go | 31 ++ pkg/spec/paths.go | 7 +- pkg/spec/paths_test.go | 14 + pkg/spec/remove.go | 26 ++ pkg/spec/remove_test.go | 36 +++ pkg/status/health.go | 53 ++++ pkg/status/health_test.go | 43 +++ pkg/status/loader.go | 31 ++ pkg/status/loader_writer_test.go | 80 ++++++ pkg/status/remove.go | 37 +++ pkg/status/remove_test.go | 41 +++ pkg/status/types.go | 22 +- pkg/status/writer.go | 37 +++ 37 files changed, 2092 insertions(+), 83 deletions(-) create mode 100644 pkg/components/services/kubelet_only_installer.go create mode 100644 pkg/components/services/kubelet_only_uninstaller.go create mode 100644 pkg/config/agent_flags_test.go create mode 100644 pkg/config/copy.go create mode 100644 pkg/config/copy_test.go create mode 100644 pkg/drift/defaults.go create mode 100644 pkg/drift/defaults_test.go create mode 100644 pkg/drift/detector.go create mode 100644 pkg/drift/detector_test.go create mode 100644 pkg/drift/kubernetes_version.go create mode 100644 pkg/drift/kubernetes_version_test.go create mode 100644 pkg/drift/remediation.go create mode 100644 pkg/drift/remediation_test.go create mode 100644 pkg/drift/version.go create mode 100644 pkg/drift/version_test.go create mode 100644 pkg/spec/kubernetes_version_override.go create mode 100644 pkg/spec/kubernetes_version_override_test.go create mode 100644 pkg/spec/loader.go create mode 100644 pkg/spec/paths_test.go create mode 100644 pkg/spec/remove.go create mode 100644 pkg/spec/remove_test.go create mode 100644 pkg/status/health.go create mode 100644 pkg/status/health_test.go create mode 100644 pkg/status/loader.go create mode 100644 pkg/status/loader_writer_test.go create mode 100644 pkg/status/remove.go create mode 100644 pkg/status/remove_test.go create mode 100644 pkg/status/writer.go diff --git a/commands.go b/commands.go index bf8d798..4933ca3 100644 --- a/commands.go +++ b/commands.go @@ -2,10 +2,11 @@ package main import ( "context" - "encoding/json" "fmt" "os" "path/filepath" + "sync" + "sync/atomic" "time" "github.com/sirupsen/logrus" @@ -13,6 +14,7 @@ import ( "go.goms.io/aks/AKSFlexNode/pkg/bootstrapper" "go.goms.io/aks/AKSFlexNode/pkg/config" + "go.goms.io/aks/AKSFlexNode/pkg/drift" "go.goms.io/aks/AKSFlexNode/pkg/logger" "go.goms.io/aks/AKSFlexNode/pkg/spec" "go.goms.io/aks/AKSFlexNode/pkg/status" @@ -132,64 +134,205 @@ func runDaemonLoop(ctx context.Context, cfg *config.Config) error { // Clean up any stale status file on daemon startup if _, err := os.Stat(statusFilePath); err == nil { logger.Info("Removing stale status file from previous daemon session...") - if err := os.Remove(statusFilePath); err != nil { - logger.Warnf("Failed to remove stale status file: %v", err) - } else { - logger.Info("Stale status file removed successfully") - } + status.RemoveStatusFileBestEffortAtPath(logger, statusFilePath) + } + + // Always clean up any stale managed cluster spec snapshot on daemon startup. + // The snapshot is best-effort and should not be relied upon across sessions. + removed, err := spec.RemoveManagedClusterSpecSnapshot() + if err != nil { + logger.Warnf("Failed to remove stale managed cluster spec snapshot: %v", err) + } else if removed { + logger.Info("Removed stale managed cluster spec snapshot successfully") } - logger.Info("Starting periodic status collection daemon (status: 1 minutes, bootstrap check: 2 minute)") + logger.Info("Starting periodic status collection daemon (status: 1 minutes, bootstrap check: 2 minute, spec collection: 30 minutes)...") + + // Protect cfg reads/writes across concurrent loops. This avoids data races when we + // temporarily update cfg.Kubernetes.Version to trigger drift remediation bootstrap. + var cfgMu sync.RWMutex - // Create tickers for different intervals - statusTicker := time.NewTicker(1 * time.Minute) - bootstrapTicker := time.NewTicker(2 * time.Minute) - specTicker := time.NewTicker(30 * time.Minute) - defer statusTicker.Stop() - defer bootstrapTicker.Stop() - defer specTicker.Stop() + // Guard to prevent overlapping bootstrap runs across loops. + var bootstrapInProgress int32 // Collect status immediately on start if err := collectAndWriteStatus(ctx, cfg, statusFilePath); err != nil { logger.Errorf("Failed to collect initial status: %v", err) } - // Collect managed cluster spec once on daemon startup. - if err := collectAndWriteManagedClusterSpec(ctx, cfg); err != nil { - logger.Warnf("Failed to collect initial managed cluster spec: %v", err) + driftEnabled := cfg != nil && cfg.IsDriftDetectionAndRemediationEnabled() + if !driftEnabled { + logger.Info("Drift detection and remediation is disabled by config") + } + + var detectors []drift.Detector + if driftEnabled { + // Initialize drift detectors and collect initial managed cluster spec before starting loops to ensure drift loop has what it needs to run on schedule without waiting for the first spec collection interval. + detectors = drift.DefaultDetectors() + // Collect managed cluster spec once on daemon startup. + if err := collectAndWriteManagedClusterSpec(ctx, cfg); err != nil { + logger.Warnf("Failed to collect initial managed cluster spec: %v", err) + } else { + cfgSnap := snapshotConfig(cfg, &cfgMu) + if err := drift.DetectAndRemediateFromFiles(ctx, cfgSnap, logger, &bootstrapInProgress, detectors); err != nil { + logger.Warnf("Initial drift detection after spec collection failed: %v", err) + } + } + } + + var wg sync.WaitGroup + startDaemonLoops(ctx, cfg, statusFilePath, logger, &cfgMu, &bootstrapInProgress, detectors, driftEnabled, &wg) + + <-ctx.Done() + logger.Info("Daemon shutting down due to context cancellation") + wg.Wait() + return ctx.Err() +} + +func startDaemonLoops( + ctx context.Context, + cfg *config.Config, + statusFilePath string, + logger *logrus.Logger, + cfgMu *sync.RWMutex, + bootstrapInProgress *int32, + detectors []drift.Detector, + driftEnabled bool, + wg *sync.WaitGroup, +) { + if wg == nil { + return + } + if driftEnabled { + wg.Add(3) + } else { + wg.Add(2) + } + startStatusCollectionLoop(ctx, cfg, statusFilePath, logger, cfgMu, wg) + startBootstrapHealthCheckLoop(ctx, cfg, logger, cfgMu, bootstrapInProgress, wg) + if driftEnabled { + startNodeDriftDetectionAndRemediationLoop(ctx, cfg, logger, cfgMu, bootstrapInProgress, detectors, wg) + } +} + +func snapshotConfig(cfg *config.Config, cfgMu *sync.RWMutex) *config.Config { + if cfg == nil { + return nil + } + if cfgMu != nil { + cfgMu.RLock() + defer cfgMu.RUnlock() } + return cfg.DeepCopy() +} - // Run the periodic collection and monitoring loop - for { - select { - case <-ctx.Done(): - logger.Info("Daemon shutting down due to context cancellation") - return ctx.Err() - case <-statusTicker.C: - logger.Infof("Starting periodic status collection at %s...", time.Now().Format("2006-01-02 15:04:05")) - if err := collectAndWriteStatus(ctx, cfg, statusFilePath); err != nil { - logger.Errorf("Failed to collect status at %s: %v", time.Now().Format("2006-01-02 15:04:05"), err) - // Continue running even if status collection fails - } else { +func startStatusCollectionLoop( + ctx context.Context, + cfg *config.Config, + statusFilePath string, + logger *logrus.Logger, + cfgMu *sync.RWMutex, + wg *sync.WaitGroup, +) { + go func() { + defer wg.Done() + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + now := time.Now() + logger.Infof("Starting periodic status collection at %s...", now.Format("2006-01-02 15:04:05")) + cfgSnap := snapshotConfig(cfg, cfgMu) + err := collectAndWriteStatus(ctx, cfgSnap, statusFilePath) + if err != nil { + logger.Errorf("Failed to collect status at %s: %v", now.Format("2006-01-02 15:04:05"), err) + continue + } logger.Infof("Status collection completed successfully at %s", time.Now().Format("2006-01-02 15:04:05")) } - case <-bootstrapTicker.C: - logger.Infof("Starting bootstrap health check at %s...", time.Now().Format("2006-01-02 15:04:05")) - if err := checkAndBootstrap(ctx, cfg); err != nil { - logger.Errorf("Auto-bootstrap check failed at %s: %v", time.Now().Format("2006-01-02 15:04:05"), err) - // Continue running even if bootstrap check fails - } else { - logger.Infof("Bootstrap health check completed at %s", time.Now().Format("2006-01-02 15:04:05")) + } + }() +} + +func startBootstrapHealthCheckLoop( + ctx context.Context, + cfg *config.Config, + logger *logrus.Logger, + cfgMu *sync.RWMutex, + bootstrapInProgress *int32, + wg *sync.WaitGroup, +) { + go func() { + defer wg.Done() + ticker := time.NewTicker(2 * time.Minute) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + now := time.Now() + logger.Infof("Starting bootstrap health check at %s...", now.Format("2006-01-02 15:04:05")) + + if !atomic.CompareAndSwapInt32(bootstrapInProgress, 0, 1) { + logger.Warn("Bootstrap already in progress, skipping this interval") + continue + } + func() { + defer atomic.StoreInt32(bootstrapInProgress, 0) + cfgSnap := snapshotConfig(cfg, cfgMu) + err := checkAndBootstrap(ctx, cfgSnap) + if err != nil { + logger.Errorf("Auto-bootstrap check failed at %s: %v", now.Format("2006-01-02 15:04:05"), err) + return + } + logger.Infof("Bootstrap health check completed at %s", time.Now().Format("2006-01-02 15:04:05")) + }() } - case <-specTicker.C: - logger.Infof("Starting periodic managed cluster spec collection at %s...", time.Now().Format("2006-01-02 15:04:05")) - if err := collectAndWriteManagedClusterSpec(ctx, cfg); err != nil { - logger.Warnf("Failed to collect managed cluster spec at %s: %v", time.Now().Format("2006-01-02 15:04:05"), err) - } else { + } + }() +} + +func startNodeDriftDetectionAndRemediationLoop( + ctx context.Context, + cfg *config.Config, + logger *logrus.Logger, + cfgMu *sync.RWMutex, + bootstrapInProgress *int32, + detectors []drift.Detector, + wg *sync.WaitGroup, +) { + go func() { + defer wg.Done() + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + now := time.Now() + logger.Infof("Starting periodic managed cluster spec collection at %s...", now.Format("2006-01-02 15:04:05")) + cfgSnap := snapshotConfig(cfg, cfgMu) + err := collectAndWriteManagedClusterSpec(ctx, cfgSnap) + if err != nil { + logger.Warnf("Failed to collect managed cluster spec at %s: %v", now.Format("2006-01-02 15:04:05"), err) + continue + } logger.Infof("Managed cluster spec collection completed at %s", time.Now().Format("2006-01-02 15:04:05")) + + // Run drift detection immediately after spec is updated so we don't wait. + if err := drift.DetectAndRemediateFromFiles(ctx, cfgSnap, logger, bootstrapInProgress, detectors); err != nil { + logger.Warnf("Drift detection after spec collection failed at %s: %v", time.Now().Format("2006-01-02 15:04:05"), err) + } else { + logger.Infof("Drift detection after spec collection completed at %s", time.Now().Format("2006-01-02 15:04:05")) + } } } - } + }() } func collectAndWriteManagedClusterSpec(ctx context.Context, cfg *config.Config) error { @@ -213,19 +356,28 @@ func checkAndBootstrap(ctx context.Context, cfg *config.Config) error { logger.Info("Node requires re-bootstrapping, initiating auto-bootstrap...") + if cfg != nil && cfg.IsDriftDetectionAndRemediationEnabled() { + // Best-effort: prefer Kubernetes version from the persisted managed cluster spec snapshot. + // This keeps auto-bootstrap aligned with the cluster desired version even if the static + // config has an older value. + if changed, oldV, newV, err := spec.OverrideKubernetesVersionFromManagedClusterSpec(cfg); err == nil && changed { + logger.Infof("Overriding Kubernetes version from managed cluster spec: %q -> %q", oldV, newV) + } + } + // Perform bootstrap bootstrapExecutor := bootstrapper.New(cfg, logger) result, err := bootstrapExecutor.Bootstrap(ctx) if err != nil { // Bootstrap failed - remove status file so next check will detect the problem - removeStatusFile(ctx) + status.RemoveStatusFileBestEffort(logger) return fmt.Errorf("auto-bootstrap failed: %s", err) } // Handle and log the bootstrap result if err := handleExecutionResult(result, "auto-bootstrap", logger); err != nil { // Bootstrap execution failed - remove status file so next check will detect the problem - removeStatusFile(ctx) + status.RemoveStatusFileBestEffort(logger) return fmt.Errorf("auto-bootstrap execution failed: %s", err) } @@ -233,16 +385,6 @@ func checkAndBootstrap(ctx context.Context, cfg *config.Config) error { return nil } -func removeStatusFile(ctx context.Context) { - logger := logger.GetLoggerFromContext(ctx) - statusFilePath := status.GetStatusFilePath() - if removeErr := os.Remove(statusFilePath); removeErr != nil { - logger.Debugf("Failed to remove status file: %s", removeErr) - } else { - logger.Debug("Removed status file successfully") - } -} - // collectAndWriteStatus collects current node status and writes it to the status file func collectAndWriteStatus(ctx context.Context, cfg *config.Config, statusFilePath string) error { logger := logger.GetLoggerFromContext(ctx) @@ -255,23 +397,16 @@ func collectAndWriteStatus(ctx context.Context, cfg *config.Config, statusFilePa if err != nil { return fmt.Errorf("failed to collect node status: %w", err) } + if nodeStatus != nil { + nodeStatus.LastUpdatedBy = status.LastUpdatedByStatusCollectionLoop + nodeStatus.LastUpdatedReason = status.LastUpdatedReasonPeriodicStatusLoop + } // Write status to JSON file - statusData, err := json.MarshalIndent(nodeStatus, "", " ") + err = status.WriteStatusToFile(statusFilePath, nodeStatus) if err != nil { - return fmt.Errorf("failed to marshal status to JSON: %w", err) - } - - // Write to temporary file first, then rename (atomic operation) - tempFile := statusFilePath + ".tmp" - if err := os.WriteFile(tempFile, statusData, 0o600); err != nil { - return fmt.Errorf("failed to write status to temp file: %w", err) + return fmt.Errorf("failed to write status to file: %w", err) } - - if err := os.Rename(tempFile, statusFilePath); err != nil { - return fmt.Errorf("failed to rename temp status file: %w", err) - } - logger.Debugf("Status written to %s", statusFilePath) return nil } diff --git a/docs/usage.md b/docs/usage.md index 257813c..6b0bb54 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -131,7 +131,8 @@ sudo tee /etc/aks-flex-node/config.json > /dev/null << 'EOF' }, "agent": { "logLevel": "info", - "logDir": "/var/log/aks-flex-node" + "logDir": "/var/log/aks-flex-node", + "enableDriftDetectionAndRemediation": true } } EOF @@ -325,7 +326,8 @@ sudo tee /etc/aks-flex-node/config.json > /dev/null < /dev/null < enabled. + cfg := &Config{} + if !cfg.IsDriftDetectionAndRemediationEnabled() { + t.Fatalf("nil EnableDriftDetectionAndRemediation should default to true") + } + + cfg.Agent.EnableDriftDetectionAndRemediation = boolPtr(true) + if !cfg.IsDriftDetectionAndRemediationEnabled() { + t.Fatalf("flag true should return true") + } + + cfg.Agent.EnableDriftDetectionAndRemediation = boolPtr(false) + if cfg.IsDriftDetectionAndRemediationEnabled() { + t.Fatalf("flag false should return false") + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 498d91e..b2a207c 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -111,6 +111,10 @@ func (c *Config) setAgentDefaults() { if c.Agent.LogDir == "" { c.Agent.LogDir = defaultLogDir } + if c.Agent.EnableDriftDetectionAndRemediation == nil { + enabled := true + c.Agent.EnableDriftDetectionAndRemediation = &enabled + } } func (c *Config) setPathDefaults() { diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 321ef76..c995338 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -17,9 +17,11 @@ func TestSetDefaults(t *testing.T) { name: "empty config gets all defaults", config: &Config{}, want: func(c *Config) bool { + driftEnabled := c.IsDriftDetectionAndRemediationEnabled() return c.Azure.Cloud == "AzurePublicCloud" && c.Agent.LogLevel == "info" && c.Agent.LogDir == "/var/log/aks-flex-node" && + driftEnabled && c.Paths.Kubernetes.ConfigDir == "/etc/kubernetes" && c.Node.MaxPods == 110 && c.Runc.Version == "1.1.12" @@ -41,6 +43,20 @@ func TestSetDefaults(t *testing.T) { c.Agent.LogDir == "/custom/log/dir" }, }, + { + name: "drift detection default is enabled", + config: &Config{}, + want: func(c *Config) bool { + return c.IsDriftDetectionAndRemediationEnabled() + }, + }, + { + name: "drift detection can be disabled", + config: &Config{Agent: AgentConfig{EnableDriftDetectionAndRemediation: func() *bool { v := false; return &v }()}}, + want: func(c *Config) bool { + return !c.IsDriftDetectionAndRemediationEnabled() + }, + }, { name: "node kubelet defaults are set correctly", config: &Config{ diff --git a/pkg/config/copy.go b/pkg/config/copy.go new file mode 100644 index 0000000..67ba158 --- /dev/null +++ b/pkg/config/copy.go @@ -0,0 +1,52 @@ +package config + +func cloneStringMap(in map[string]string) map[string]string { + if in == nil { + return nil + } + out := make(map[string]string, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +// DeepCopy returns a copy of the config that does not share mutable sub-objects (maps/pointers) +// with the original. +func (cfg *Config) DeepCopy() *Config { + if cfg == nil { + return nil + } + + out := *cfg + + // Copy pointer sub-structs under Azure. + if cfg.Azure.ServicePrincipal != nil { + sp := *cfg.Azure.ServicePrincipal + out.Azure.ServicePrincipal = &sp + } + if cfg.Azure.ManagedIdentity != nil { + mi := *cfg.Azure.ManagedIdentity + out.Azure.ManagedIdentity = &mi + } + if cfg.Azure.BootstrapToken != nil { + bt := *cfg.Azure.BootstrapToken + out.Azure.BootstrapToken = &bt + } + if cfg.Azure.TargetCluster != nil { + tc := *cfg.Azure.TargetCluster + out.Azure.TargetCluster = &tc + } + if cfg.Azure.Arc != nil { + arc := *cfg.Azure.Arc + arc.Tags = cloneStringMap(cfg.Azure.Arc.Tags) + out.Azure.Arc = &arc + } + + // Copy node-level maps. + out.Node.Labels = cloneStringMap(cfg.Node.Labels) + out.Node.Kubelet.KubeReserved = cloneStringMap(cfg.Node.Kubelet.KubeReserved) + out.Node.Kubelet.EvictionHard = cloneStringMap(cfg.Node.Kubelet.EvictionHard) + + return &out +} diff --git a/pkg/config/copy_test.go b/pkg/config/copy_test.go new file mode 100644 index 0000000..ec4999b --- /dev/null +++ b/pkg/config/copy_test.go @@ -0,0 +1,123 @@ +package config + +import "testing" + +func TestCloneStringMap(t *testing.T) { + t.Parallel() + + if got := cloneStringMap(nil); got != nil { + t.Fatalf("cloneStringMap(nil)=%v, want nil", got) + } + + in := map[string]string{"a": "1"} + out := cloneStringMap(in) + if out["a"] != "1" { + t.Fatalf("cloneStringMap value=%q, want %q", out["a"], "1") + } + + // Mutate input; output should not change. + in["a"] = "2" + if out["a"] != "1" { + t.Fatalf("cloneStringMap shares backing map; out[a]=%q, want %q", out["a"], "1") + } + + // Mutate output; input should not change. + out["a"] = "3" + if in["a"] != "2" { + t.Fatalf("cloneStringMap shares backing map; in[a]=%q, want %q", in["a"], "2") + } +} + +func TestConfigDeepCopy_Nil(t *testing.T) { + t.Parallel() + + var cfg *Config + if got := cfg.DeepCopy(); got != nil { + t.Fatalf("DeepCopy()=%v, want nil", got) + } +} + +func TestConfigDeepCopy_DoesNotSharePointersOrMaps(t *testing.T) { + t.Parallel() + + falseVal := false + cfg := &Config{ + Azure: AzureConfig{ + ServicePrincipal: &ServicePrincipalConfig{TenantID: "t", ClientID: "c", ClientSecret: "s"}, + ManagedIdentity: &ManagedIdentityConfig{ClientID: "mi"}, + BootstrapToken: &BootstrapTokenConfig{Token: "abcdef.0123456789abcdef"}, + TargetCluster: &TargetClusterConfig{ResourceID: "/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/rg/providers/Microsoft.ContainerService/managedClusters/cluster", Location: "eastus"}, + Arc: &ArcConfig{Enabled: true, MachineName: "m", Location: "eastus", ResourceGroup: "rg", Tags: map[string]string{"k": "v"}}, + }, + Agent: AgentConfig{EnableDriftDetectionAndRemediation: &falseVal}, + Node: NodeConfig{ + Labels: map[string]string{"l": "1"}, + Kubelet: KubeletConfig{ + KubeReserved: map[string]string{"cpu": "100m"}, + EvictionHard: map[string]string{"memory.available": "200Mi"}, + }, + }, + } + + copy := cfg.DeepCopy() + if copy == nil { + t.Fatalf("DeepCopy()=nil") + } + if copy == cfg { + t.Fatalf("DeepCopy returned same pointer") + } + + // Pointer sub-objects should not be shared. + if cfg.Azure.ServicePrincipal == nil || copy.Azure.ServicePrincipal == nil || cfg.Azure.ServicePrincipal == copy.Azure.ServicePrincipal { + t.Fatalf("ServicePrincipal pointer shared or nil") + } + if cfg.Azure.ManagedIdentity == nil || copy.Azure.ManagedIdentity == nil || cfg.Azure.ManagedIdentity == copy.Azure.ManagedIdentity { + t.Fatalf("ManagedIdentity pointer shared or nil") + } + if cfg.Azure.BootstrapToken == nil || copy.Azure.BootstrapToken == nil || cfg.Azure.BootstrapToken == copy.Azure.BootstrapToken { + t.Fatalf("BootstrapToken pointer shared or nil") + } + if cfg.Azure.TargetCluster == nil || copy.Azure.TargetCluster == nil || cfg.Azure.TargetCluster == copy.Azure.TargetCluster { + t.Fatalf("TargetCluster pointer shared or nil") + } + if cfg.Azure.Arc == nil || copy.Azure.Arc == nil || cfg.Azure.Arc == copy.Azure.Arc { + t.Fatalf("Arc pointer shared or nil") + } + + // Maps should not be shared (validate via independent mutation behavior). + cfg.Azure.Arc.Tags["k"] = "orig" + if copy.Azure.Arc.Tags["k"] != "v" { + t.Fatalf("Arc.Tags shared; copy=%q, want %q", copy.Azure.Arc.Tags["k"], "v") + } + copy.Azure.Arc.Tags["k"] = "copy" + if cfg.Azure.Arc.Tags["k"] != "orig" { + t.Fatalf("Arc.Tags shared; orig=%q, want %q", cfg.Azure.Arc.Tags["k"], "orig") + } + + cfg.Node.Labels["l"] = "orig" + if copy.Node.Labels["l"] != "1" { + t.Fatalf("Node.Labels shared; copy=%q, want %q", copy.Node.Labels["l"], "1") + } + copy.Node.Labels["l"] = "copy" + if cfg.Node.Labels["l"] != "orig" { + t.Fatalf("Node.Labels shared; orig=%q, want %q", cfg.Node.Labels["l"], "orig") + } + + cfg.Node.Kubelet.KubeReserved["cpu"] = "200m" + if copy.Node.Kubelet.KubeReserved["cpu"] != "100m" { + t.Fatalf("KubeReserved shared; copy=%q, want %q", copy.Node.Kubelet.KubeReserved["cpu"], "100m") + } + copy.Node.Kubelet.KubeReserved["cpu"] = "300m" + if cfg.Node.Kubelet.KubeReserved["cpu"] != "200m" { + t.Fatalf("KubeReserved shared; orig=%q, want %q", cfg.Node.Kubelet.KubeReserved["cpu"], "200m") + } + + cfg.Node.Kubelet.EvictionHard["memory.available"] = "150Mi" + if copy.Node.Kubelet.EvictionHard["memory.available"] != "200Mi" { + t.Fatalf("EvictionHard shared; copy=%q, want %q", copy.Node.Kubelet.EvictionHard["memory.available"], "200Mi") + } + copy.Node.Kubelet.EvictionHard["memory.available"] = "250Mi" + if cfg.Node.Kubelet.EvictionHard["memory.available"] != "150Mi" { + t.Fatalf("EvictionHard shared; orig=%q, want %q", cfg.Node.Kubelet.EvictionHard["memory.available"], "150Mi") + } +} diff --git a/pkg/config/structs.go b/pkg/config/structs.go index a17e1f8..e2b2190 100644 --- a/pkg/config/structs.go +++ b/pkg/config/structs.go @@ -76,6 +76,24 @@ type ArcConfig struct { type AgentConfig struct { LogLevel string `json:"logLevel"` // Logging level: debug, info, warning, error LogDir string `json:"logDir"` // Directory for log files + + // EnableDriftDetectionAndRemediation controls whether the agent performs drift detection + // and automated remediation (e.g., Kubernetes version drift upgrades). + // + // When omitted from config, the default is true. + EnableDriftDetectionAndRemediation *bool `json:"enableDriftDetectionAndRemediation,omitempty"` +} + +// IsDriftDetectionAndRemediationEnabled returns whether automated drift detection/remediation +// is enabled. For backward compatibility, a nil setting is treated as enabled. +func (cfg *Config) IsDriftDetectionAndRemediationEnabled() bool { + if cfg == nil { + return false + } + if cfg.Agent.EnableDriftDetectionAndRemediation == nil { + return true + } + return *cfg.Agent.EnableDriftDetectionAndRemediation } // KubernetesConfig holds configuration settings for Kubernetes components. diff --git a/pkg/drift/defaults.go b/pkg/drift/defaults.go new file mode 100644 index 0000000..b4f82be --- /dev/null +++ b/pkg/drift/defaults.go @@ -0,0 +1,11 @@ +package drift + +// DefaultDetectors returns the default set of drift detectors enabled by the agent. +// +// The daemon can choose to use this helper to avoid wiring each detector manually, +// while still keeping detector selection injectable for tests and future customization. +func DefaultDetectors() []Detector { + return []Detector{ + NewKubernetesVersionDetector(), + } +} diff --git a/pkg/drift/defaults_test.go b/pkg/drift/defaults_test.go new file mode 100644 index 0000000..134da0c --- /dev/null +++ b/pkg/drift/defaults_test.go @@ -0,0 +1,15 @@ +package drift + +import "testing" + +func TestDefaultDetectors(t *testing.T) { + t.Parallel() + + d := DefaultDetectors() + if len(d) == 0 { + t.Fatalf("DefaultDetectors returned empty") + } + if d[0] == nil { + t.Fatalf("DefaultDetectors[0] is nil") + } +} diff --git a/pkg/drift/detector.go b/pkg/drift/detector.go new file mode 100644 index 0000000..fff9407 --- /dev/null +++ b/pkg/drift/detector.go @@ -0,0 +1,80 @@ +package drift + +import ( + "context" + "errors" + + "go.goms.io/aks/AKSFlexNode/pkg/config" + "go.goms.io/aks/AKSFlexNode/pkg/spec" + "go.goms.io/aks/AKSFlexNode/pkg/status" +) + +// Finding represents a detected drift between desired spec and current node state. +// Findings should be small and composable so multiple detectors can return multiple findings. +type Finding struct { + ID string + Title string + Details string + Remediation Remediation +} + +// RemediationAction indicates what kind of action should be taken to remediate a drift. +// Empty value means "unspecified" and will fall back to legacy behavior. +type RemediationAction string + +const ( + RemediationActionUnspecified RemediationAction = "" + RemediationActionKubernetesUpgrade RemediationAction = "kubernetes-upgrade" +) + +// Remediation describes what the agent should do to address a drift. +// This is intentionally a minimal set of knobs; it can be extended as new remediation +// types are needed (e.g., restart services, rewrite config files, run targeted installers). +type Remediation struct { + // Action indicates the remediation strategy. + // If unset, the finding is informational and won't trigger remediation. + Action RemediationAction + + // KubernetesVersion, when set, indicates the desired Kubernetes version that should be + // used during bootstrap (e.g., to trigger kubelet upgrade). + KubernetesVersion string +} + +// Detector compares desired spec and current status and returns any drift findings. +// Detectors should be pure (no side effects) and fast. +type Detector interface { + Name() string + Detect(ctx context.Context, cfg *config.Config, specSnap *spec.ManagedClusterSpec, statusSnap *status.NodeStatus) ([]Finding, error) +} + +// DetectAll runs all detectors, returning aggregated findings. +// If some detectors error, the error is returned (joined) along with any findings. +func DetectAll( + ctx context.Context, + detectors []Detector, + cfg *config.Config, + specSnap *spec.ManagedClusterSpec, + statusSnap *status.NodeStatus, +) ([]Finding, error) { + var findings []Finding + var errs []error + + for _, d := range detectors { + if d == nil { + continue + } + f, err := d.Detect(ctx, cfg, specSnap, statusSnap) + if err != nil { + errs = append(errs, err) + continue + } + if len(f) > 0 { + findings = append(findings, f...) + } + } + + if len(errs) > 0 { + return findings, errors.Join(errs...) + } + return findings, nil +} diff --git a/pkg/drift/detector_test.go b/pkg/drift/detector_test.go new file mode 100644 index 0000000..7fbac96 --- /dev/null +++ b/pkg/drift/detector_test.go @@ -0,0 +1,49 @@ +package drift + +import ( + "context" + "errors" + "testing" + + "go.goms.io/aks/AKSFlexNode/pkg/config" + "go.goms.io/aks/AKSFlexNode/pkg/spec" + "go.goms.io/aks/AKSFlexNode/pkg/status" +) + +type detectorFunc struct { + name string + fn func(ctx context.Context, cfg *config.Config, specSnap *spec.ManagedClusterSpec, statusSnap *status.NodeStatus) ([]Finding, error) +} + +func (d detectorFunc) Name() string { return d.name } + +func (d detectorFunc) Detect(ctx context.Context, cfg *config.Config, specSnap *spec.ManagedClusterSpec, statusSnap *status.NodeStatus) ([]Finding, error) { + if d.fn == nil { + return nil, nil + } + return d.fn(ctx, cfg, specSnap, statusSnap) +} + +func TestDetectAllAggregatesFindingsAndErrors(t *testing.T) { + t.Parallel() + + wantErr := errors.New("boom") + + d1 := detectorFunc{name: "d1", fn: func(context.Context, *config.Config, *spec.ManagedClusterSpec, *status.NodeStatus) ([]Finding, error) { + return []Finding{{ID: "f1"}}, nil + }} + d2 := detectorFunc{name: "d2", fn: func(context.Context, *config.Config, *spec.ManagedClusterSpec, *status.NodeStatus) ([]Finding, error) { + return nil, wantErr + }} + + findings, err := DetectAll(context.Background(), []Detector{nil, d1, d2}, nil, nil, nil) + if len(findings) != 1 { + t.Fatalf("findings len=%d, want 1", len(findings)) + } + if err == nil { + t.Fatalf("err=nil, want non-nil") + } + if !errors.Is(err, wantErr) { + t.Fatalf("err=%v, want to contain %v", err, wantErr) + } +} diff --git a/pkg/drift/kubernetes_version.go b/pkg/drift/kubernetes_version.go new file mode 100644 index 0000000..e23c560 --- /dev/null +++ b/pkg/drift/kubernetes_version.go @@ -0,0 +1,83 @@ +package drift + +import ( + "context" + "fmt" + "strings" + + "go.goms.io/aks/AKSFlexNode/pkg/config" + "go.goms.io/aks/AKSFlexNode/pkg/spec" + "go.goms.io/aks/AKSFlexNode/pkg/status" +) + +const KubernetesVersionFindingID = "kubernetes-version" + +type KubernetesVersionDetector struct{} + +func NewKubernetesVersionDetector() *KubernetesVersionDetector { + return &KubernetesVersionDetector{} +} + +func (d *KubernetesVersionDetector) Name() string { + return "KubernetesVersionDetector" +} + +func (d *KubernetesVersionDetector) Detect( + ctx context.Context, + _ *config.Config, + specSnap *spec.ManagedClusterSpec, + statusSnap *status.NodeStatus, +) ([]Finding, error) { + if ctx != nil { + if err := ctx.Err(); err != nil { + return nil, err + } + } + + if specSnap == nil || statusSnap == nil { + return nil, nil + } + + desired := strings.TrimSpace(specSnap.CurrentKubernetesVersion) + if desired == "" { + desired = strings.TrimSpace(specSnap.KubernetesVersion) + } + if desired == "" { + return nil, nil + } + + current := strings.TrimSpace(statusSnap.KubeletVersion) + if current == "" || current == "unknown" { + return nil, nil + } + + cmp, ok := compareMajorMinor(current, desired) + if ok { + // Never downgrade via drift remediation. If the node is already newer than the desired + // version, treat it as non-actionable drift. + if cmp >= 0 { + return nil, nil + } + } else { + // If we can't parse versions, fall back to string major.minor comparison. + // This keeps the detector safe (won't trigger if they look equal) while still working + // for common version formats. + if majorMinor(current) == majorMinor(desired) { + return nil, nil + } + // If we can't compare ordering, don't remediate automatically (avoid accidental downgrade). + return nil, nil + } + + return []Finding{ + { + ID: KubernetesVersionFindingID, + Title: "Kubernetes version drift", + Details: fmt.Sprintf("kubelet=%q desired=%q", current, desired), + Remediation: Remediation{ + Action: RemediationActionKubernetesUpgrade, + KubernetesVersion: desired, + }, + }, + }, nil +} diff --git a/pkg/drift/kubernetes_version_test.go b/pkg/drift/kubernetes_version_test.go new file mode 100644 index 0000000..b6e2cc7 --- /dev/null +++ b/pkg/drift/kubernetes_version_test.go @@ -0,0 +1,122 @@ +package drift + +import ( + "context" + "testing" + "time" + + "go.goms.io/aks/AKSFlexNode/pkg/spec" + "go.goms.io/aks/AKSFlexNode/pkg/status" +) + +func TestKubernetesVersionDetector_Detect_RespectsContextCancellation(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + d := NewKubernetesVersionDetector() + _, err := d.Detect(ctx, nil, &spec.ManagedClusterSpec{KubernetesVersion: "1.30.0", CollectedAt: time.Now()}, &status.NodeStatus{KubeletVersion: "1.29.0"}) + if err == nil { + t.Fatalf("err=nil, want context cancellation error") + } +} + +func TestKubernetesVersionDetector_Detect_NoDesiredVersion_NoFinding(t *testing.T) { + t.Parallel() + + d := NewKubernetesVersionDetector() + findings, err := d.Detect(context.Background(), nil, &spec.ManagedClusterSpec{CollectedAt: time.Now()}, &status.NodeStatus{KubeletVersion: "1.29.0"}) + if err != nil { + t.Fatalf("err=%v, want nil", err) + } + if len(findings) != 0 { + t.Fatalf("findings len=%d, want 0", len(findings)) + } +} + +func TestKubernetesVersionDetector_Detect_UnknownCurrent_NoFinding(t *testing.T) { + t.Parallel() + + d := NewKubernetesVersionDetector() + findings, err := d.Detect(context.Background(), nil, &spec.ManagedClusterSpec{KubernetesVersion: "1.30.0", CollectedAt: time.Now()}, &status.NodeStatus{KubeletVersion: "unknown"}) + if err != nil { + t.Fatalf("err=%v, want nil", err) + } + if len(findings) != 0 { + t.Fatalf("findings len=%d, want 0", len(findings)) + } +} + +func TestKubernetesVersionDetector_Detect_UpgradeOnly_Finding(t *testing.T) { + t.Parallel() + + d := NewKubernetesVersionDetector() + specSnap := &spec.ManagedClusterSpec{CurrentKubernetesVersion: "1.30.7", CollectedAt: time.Now()} + statusSnap := &status.NodeStatus{KubeletVersion: "1.29.8"} + + findings, err := d.Detect(context.Background(), nil, specSnap, statusSnap) + if err != nil { + t.Fatalf("err=%v, want nil", err) + } + if len(findings) != 1 { + t.Fatalf("findings len=%d, want 1", len(findings)) + } + if findings[0].ID != KubernetesVersionFindingID { + t.Fatalf("finding ID=%q, want %q", findings[0].ID, KubernetesVersionFindingID) + } + if findings[0].Remediation.Action != RemediationActionKubernetesUpgrade { + t.Fatalf("action=%q, want %q", findings[0].Remediation.Action, RemediationActionKubernetesUpgrade) + } + if findings[0].Remediation.KubernetesVersion != "1.30.7" { + t.Fatalf("kubernetesVersion=%q, want %q", findings[0].Remediation.KubernetesVersion, "1.30.7") + } +} + +func TestKubernetesVersionDetector_Detect_NoDowngrade_NoFinding(t *testing.T) { + t.Parallel() + + d := NewKubernetesVersionDetector() + specSnap := &spec.ManagedClusterSpec{KubernetesVersion: "1.29.0", CollectedAt: time.Now()} + statusSnap := &status.NodeStatus{KubeletVersion: "1.30.1"} + + findings, err := d.Detect(context.Background(), nil, specSnap, statusSnap) + if err != nil { + t.Fatalf("err=%v, want nil", err) + } + if len(findings) != 0 { + t.Fatalf("findings len=%d, want 0", len(findings)) + } +} + +func TestKubernetesVersionDetector_Detect_SameMajorMinor_NoFinding(t *testing.T) { + t.Parallel() + + d := NewKubernetesVersionDetector() + specSnap := &spec.ManagedClusterSpec{CurrentKubernetesVersion: "1.30.7", CollectedAt: time.Now()} + statusSnap := &status.NodeStatus{KubeletVersion: "1.30.1"} + + findings, err := d.Detect(context.Background(), nil, specSnap, statusSnap) + if err != nil { + t.Fatalf("err=%v, want nil", err) + } + if len(findings) != 0 { + t.Fatalf("findings len=%d, want 0", len(findings)) + } +} + +func TestKubernetesVersionDetector_Detect_UnparseableVersions_NoFinding(t *testing.T) { + t.Parallel() + + d := NewKubernetesVersionDetector() + specSnap := &spec.ManagedClusterSpec{CurrentKubernetesVersion: "1.30.7", CollectedAt: time.Now()} + statusSnap := &status.NodeStatus{KubeletVersion: "v1.x"} + + findings, err := d.Detect(context.Background(), nil, specSnap, statusSnap) + if err != nil { + t.Fatalf("err=%v, want nil", err) + } + if len(findings) != 0 { + t.Fatalf("findings len=%d, want 0", len(findings)) + } +} diff --git a/pkg/drift/remediation.go b/pkg/drift/remediation.go new file mode 100644 index 0000000..eecfeb1 --- /dev/null +++ b/pkg/drift/remediation.go @@ -0,0 +1,236 @@ +package drift + +import ( + "context" + "errors" + "fmt" + "sync/atomic" + "time" + + "github.com/sirupsen/logrus" + + "go.goms.io/aks/AKSFlexNode/pkg/bootstrapper" + "go.goms.io/aks/AKSFlexNode/pkg/components/kube_binaries" + "go.goms.io/aks/AKSFlexNode/pkg/components/kubelet" + "go.goms.io/aks/AKSFlexNode/pkg/components/services" + "go.goms.io/aks/AKSFlexNode/pkg/config" + "go.goms.io/aks/AKSFlexNode/pkg/spec" + "go.goms.io/aks/AKSFlexNode/pkg/status" +) + +const driftKubernetesUpgradeOperation = "drift-kubernetes-upgrade" + +// maxManagedClusterSpecAge is a safety guard to avoid acting on very stale spec snapshots. +// In normal operation we run drift immediately after a successful spec collection, so this +// should rarely block remediation. +const maxManagedClusterSpecAge = 2 * time.Hour + +// DetectAndRemediateFromFiles loads spec/status snapshots from disk, runs all detectors, +// and (if needed) performs remediation. +// +// Remediation attempts are guarded by bootstrapInProgress to avoid concurrent executions. +func DetectAndRemediateFromFiles( + ctx context.Context, + // cfg must be an immutable snapshot for the duration of this call. + // DetectAndRemediateFromFiles may mutate cfg (e.g., to apply desired KubernetesVersion) + // as part of remediation. + cfg *config.Config, + logger *logrus.Logger, + bootstrapInProgress *int32, + detectors []Detector, +) error { + if logger == nil { + logger = logrus.New() + } + + specSnap, err := spec.LoadManagedClusterSpec() + if err != nil { + // Spec may not exist yet. + return err + } + + nodeStatus, err := status.LoadStatus() + if err != nil { + return err + } + + return detectAndRemediate(ctx, cfg, logger, bootstrapInProgress, detectors, specSnap, nodeStatus) +} + +func detectAndRemediate( + ctx context.Context, + cfg *config.Config, + logger *logrus.Logger, + bootstrapInProgress *int32, + detectors []Detector, + specSnap *spec.ManagedClusterSpec, + statusSnap *status.NodeStatus, +) error { + if specSnap == nil || statusSnap == nil { + return nil + } + if isManagedClusterSpecStale(specSnap, time.Now()) { + logger.Warnf("Managed cluster spec snapshot is stale (collectedAt=%s); skipping drift remediation", specSnap.CollectedAt.Format(time.RFC3339)) + return nil + } + + var findings []Finding + var detectErr error + findings, detectErr = DetectAll(ctx, detectors, cfg, specSnap, statusSnap) + if detectErr != nil { + // Don't immediately fail; if some detectors produced findings we can still act. + logger.Warnf("One or more drift detectors failed: %v", detectErr) + } + if len(findings) == 0 { + return detectErr + } + + for _, f := range findings { + logger.Warnf("Drift detected: id=%s title=%s details=%s", f.ID, f.Title, f.Details) + } + + plan, requiresRemediation, err := resolveRemediationPlan(findings) + if err != nil { + return err + } + if !requiresRemediation { + return detectErr + } + + // Prevent overlapping remediation runs. + if bootstrapInProgress != nil { + if !atomic.CompareAndSwapInt32(bootstrapInProgress, 0, 1) { + logger.Warn("Bootstrap already in progress, skipping drift remediation") + return nil + } + defer atomic.StoreInt32(bootstrapInProgress, 0) + } + + if plan.DesiredKubernetesVersion != "" { + // Apply desired version to the snapshot so remediation uses the expected kube binaries. + if cfg != nil { + cfg.Kubernetes.Version = plan.DesiredKubernetesVersion + } + } + + // Run remediation. + switch plan.Action { + case RemediationActionKubernetesUpgrade: + result, upgradeErr := runKubernetesUpgradeRemediation(ctx, cfg, logger) + if upgradeErr != nil { + status.MarkKubeletUnhealthyBestEffort(logger) + return fmt.Errorf("kubernetes upgrade remediation failed: %w", upgradeErr) + } + if err := handleExecutionResult(result, driftKubernetesUpgradeOperation, logger); err != nil { + status.MarkKubeletUnhealthyBestEffort(logger) + return fmt.Errorf("kubernetes upgrade remediation execution failed: %w", err) + } + logger.Info("Kubernetes upgrade remediation completed successfully") + return detectErr + + default: + return fmt.Errorf("unsupported drift remediation action: %q", plan.Action) + } +} + +func isManagedClusterSpecStale(specSnap *spec.ManagedClusterSpec, now time.Time) bool { + if specSnap == nil { + return true + } + if specSnap.CollectedAt.IsZero() { + return true + } + if now.IsZero() { + now = time.Now() + } + return now.Sub(specSnap.CollectedAt) > maxManagedClusterSpecAge +} + +type remediationPlan struct { + Action RemediationAction + DesiredKubernetesVersion string +} + +// resolveRemediationPlan collapses potentially many drift findings into a single remediation plan. +// +// Today the remediation runner supports executing only one remediation action per pass. +// As more detectors are added, it's possible to receive multiple findings at once. This helper +// performs two tasks: +// 1. Dedup: pick a single action and a single set of parameters (e.g., Kubernetes version). +// 2. Consistency check: if findings disagree (different actions or different desired versions), +// fail fast rather than guessing. +func resolveRemediationPlan(findings []Finding) (remediationPlan, bool, error) { + plan := remediationPlan{Action: RemediationActionUnspecified} + requiresRemediation := false + + for _, f := range findings { + action := f.Remediation.Action + if action == RemediationActionUnspecified { + continue + } + + requiresRemediation = true + if plan.Action == RemediationActionUnspecified { + plan.Action = action + } else if plan.Action != action { + return remediationPlan{}, false, errors.New("conflicting drift remediation: multiple remediation actions") + } + + version := f.Remediation.KubernetesVersion + if version == "" { + continue + } + if plan.DesiredKubernetesVersion == "" { + plan.DesiredKubernetesVersion = version + continue + } + if plan.DesiredKubernetesVersion != version { + return remediationPlan{}, false, errors.New("conflicting drift remediation: multiple desired Kubernetes versions") + } + } + + return plan, requiresRemediation, nil +} + +func runKubernetesUpgradeRemediation( + ctx context.Context, + cfg *config.Config, + logger *logrus.Logger, +) (*bootstrapper.ExecutionResult, error) { + // runKubernetesUpgradeRemediation performs a targeted Kubernetes upgrade with minimal disruption. + // + // Key design points: + // - Stop/start kubelet around the upgrade so we don't run kubelet against partially-updated + // binaries or config (avoids flapping, crash loops, and nondeterministic behavior). + // - Do not stop/restart containerd to keep disruption lower and avoid impacting running pods + // more than necessary. + steps := []bootstrapper.Executor{ + // Stop/disable kubelet only so it cannot restart mid-upgrade. + services.NewKubeletOnlyUnInstaller(logger), + // Install the desired kube binaries version. + kube_binaries.NewInstallerWithConfig(cfg, logger), + // Reconfigure kubelet to match the upgraded bits. + kubelet.NewInstallerWithConfig(cfg, logger), + // Enable/start kubelet only and wait for it to be active. + services.NewKubeletOnlyInstaller(logger), + } + + be := bootstrapper.NewBaseExecutor(cfg, logger) + return be.ExecuteSteps(ctx, steps, driftKubernetesUpgradeOperation) +} + +// handleExecutionResult mirrors main's handleExecutionResult but lives in drift so remediation +// can share the same logging and error semantics. +func handleExecutionResult(result *bootstrapper.ExecutionResult, operation string, logger *logrus.Logger) error { + if result == nil { + return fmt.Errorf("%s result is nil", operation) + } + + if result.Success { + logger.Infof("%s completed successfully (duration: %v, steps: %d)", + operation, result.Duration, result.StepCount) + return nil + } + + return fmt.Errorf("%s failed: %s", operation, result.Error) +} diff --git a/pkg/drift/remediation_test.go b/pkg/drift/remediation_test.go new file mode 100644 index 0000000..9a56f7e --- /dev/null +++ b/pkg/drift/remediation_test.go @@ -0,0 +1,177 @@ +package drift + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/sirupsen/logrus" + + "go.goms.io/aks/AKSFlexNode/pkg/config" + "go.goms.io/aks/AKSFlexNode/pkg/spec" + "go.goms.io/aks/AKSFlexNode/pkg/status" +) + +type countingDetector struct { + called int32 + fn func() ([]Finding, error) +} + +func (d *countingDetector) Name() string { return "counting" } + +func (d *countingDetector) Detect(ctx context.Context, _ *config.Config, _ *spec.ManagedClusterSpec, _ *status.NodeStatus) ([]Finding, error) { + _ = ctx + atomic.AddInt32(&d.called, 1) + if d.fn == nil { + return nil, nil + } + return d.fn() +} + +func TestIsManagedClusterSpecStale(t *testing.T) { + t.Parallel() + + if !isManagedClusterSpecStale(nil, time.Now()) { + t.Fatalf("nil spec should be stale") + } + if !isManagedClusterSpecStale(&spec.ManagedClusterSpec{}, time.Now()) { + t.Fatalf("zero CollectedAt should be stale") + } + if isManagedClusterSpecStale(&spec.ManagedClusterSpec{CollectedAt: time.Now()}, time.Now()) { + t.Fatalf("fresh spec should not be stale") + } + old := time.Now().Add(-maxManagedClusterSpecAge - time.Minute) + if !isManagedClusterSpecStale(&spec.ManagedClusterSpec{CollectedAt: old}, time.Now()) { + t.Fatalf("old spec should be stale") + } +} + +func TestResolveRemediationPlan(t *testing.T) { + t.Parallel() + + plan, requires, err := resolveRemediationPlan(nil) + if err != nil { + t.Fatalf("err=%v, want nil", err) + } + if requires { + t.Fatalf("requiresRemediation=true, want false") + } + if plan.Action != RemediationActionUnspecified { + t.Fatalf("plan.Action=%q, want %q", plan.Action, RemediationActionUnspecified) + } + + plan, requires, err = resolveRemediationPlan([]Finding{{ + ID: "f1", + Remediation: Remediation{Action: RemediationActionKubernetesUpgrade, KubernetesVersion: "1.30.7"}, + }}) + if err != nil { + t.Fatalf("err=%v, want nil", err) + } + if !requires { + t.Fatalf("requiresRemediation=false, want true") + } + if plan.Action != RemediationActionKubernetesUpgrade { + t.Fatalf("plan.Action=%q, want %q", plan.Action, RemediationActionKubernetesUpgrade) + } + if plan.DesiredKubernetesVersion != "1.30.7" { + t.Fatalf("DesiredKubernetesVersion=%q, want %q", plan.DesiredKubernetesVersion, "1.30.7") + } + + _, _, err = resolveRemediationPlan([]Finding{ + {ID: "a", Remediation: Remediation{Action: RemediationActionKubernetesUpgrade, KubernetesVersion: "1.30.7"}}, + {ID: "b", Remediation: Remediation{Action: RemediationActionKubernetesUpgrade, KubernetesVersion: "1.31.0"}}, + }) + if err == nil { + t.Fatalf("err=nil, want conflict error") + } + + _, _, err = resolveRemediationPlan([]Finding{ + {ID: "a", Remediation: Remediation{Action: RemediationActionKubernetesUpgrade}}, + {ID: "b", Remediation: Remediation{Action: RemediationActionUnspecified}}, + }) + if err != nil { + t.Fatalf("err=%v, want nil", err) + } + + _, _, err = resolveRemediationPlan([]Finding{ + {ID: "a", Remediation: Remediation{Action: RemediationActionKubernetesUpgrade}}, + {ID: "b", Remediation: Remediation{Action: "something-else"}}, + }) + if err == nil { + t.Fatalf("err=nil, want action conflict error") + } +} + +func TestDetectAndRemediate_SkipsStaleSpec_DoesNotCallDetectors(t *testing.T) { + t.Parallel() + + logger := logrus.New() + d := &countingDetector{fn: func() ([]Finding, error) { + return []Finding{{ + ID: "f1", + Remediation: Remediation{Action: RemediationActionKubernetesUpgrade, KubernetesVersion: "1.30.0"}, + }}, nil + }} + + staleCollectedAt := time.Now().Add(-maxManagedClusterSpecAge - time.Minute) + specSnap := &spec.ManagedClusterSpec{CurrentKubernetesVersion: "1.30.0", CollectedAt: staleCollectedAt} + statusSnap := &status.NodeStatus{KubeletVersion: "1.29.0"} + + err := detectAndRemediate(context.Background(), nil, logger, nil, []Detector{d}, specSnap, statusSnap) + if err != nil { + t.Fatalf("err=%v, want nil", err) + } + if got := atomic.LoadInt32(&d.called); got != 0 { + t.Fatalf("detector called %d times, want 0", got) + } +} + +func TestDetectAndRemediate_BootstrapGuard_SkipsWhenInProgress(t *testing.T) { + t.Parallel() + + logger := logrus.New() + d := &countingDetector{fn: func() ([]Finding, error) { + return []Finding{{ + ID: "f1", + Remediation: Remediation{Action: RemediationActionKubernetesUpgrade, KubernetesVersion: "1.31.0"}, + }}, nil + }} + + specSnap := &spec.ManagedClusterSpec{CurrentKubernetesVersion: "1.31.0", CollectedAt: time.Now()} + statusSnap := &status.NodeStatus{KubeletVersion: "1.30.0"} + + var bootstrapInProgress int32 = 1 + err := detectAndRemediate(context.Background(), nil, logger, &bootstrapInProgress, []Detector{d}, specSnap, statusSnap) + if err != nil { + t.Fatalf("err=%v, want nil", err) + } + if got := atomic.LoadInt32(&d.called); got != 1 { + t.Fatalf("detector called %d times, want 1", got) + } + if got := atomic.LoadInt32(&bootstrapInProgress); got != 1 { + t.Fatalf("bootstrapInProgress=%d, want 1", got) + } +} + +func TestDetectAndRemediate_ReturnsDetectErrorIfNoFindings(t *testing.T) { + t.Parallel() + + logger := logrus.New() + wantErr := errors.New("detect failed") + d := &countingDetector{fn: func() ([]Finding, error) { + return nil, wantErr + }} + + specSnap := &spec.ManagedClusterSpec{CurrentKubernetesVersion: "1.31.0", CollectedAt: time.Now()} + statusSnap := &status.NodeStatus{KubeletVersion: "1.30.0"} + + err := detectAndRemediate(context.Background(), nil, logger, nil, []Detector{d}, specSnap, statusSnap) + if err == nil { + t.Fatalf("err=nil, want %v", wantErr) + } + if !errors.Is(err, wantErr) { + t.Fatalf("err=%v, want to contain %v", err, wantErr) + } +} diff --git a/pkg/drift/version.go b/pkg/drift/version.go new file mode 100644 index 0000000..e3a7103 --- /dev/null +++ b/pkg/drift/version.go @@ -0,0 +1,55 @@ +package drift + +import ( + "strconv" + "strings" +) + +func majorMinor(version string) string { + v := strings.TrimPrefix(strings.TrimSpace(version), "v") + parts := strings.Split(v, ".") + if len(parts) < 2 { + return v + } + return parts[0] + "." + parts[1] +} + +func parseMajorMinor(version string) (major int, minor int, ok bool) { + v := strings.TrimPrefix(strings.TrimSpace(version), "v") + parts := strings.Split(v, ".") + if len(parts) < 2 { + return 0, 0, false + } + maj, err := strconv.Atoi(parts[0]) + if err != nil { + return 0, 0, false + } + min, err := strconv.Atoi(parts[1]) + if err != nil { + return 0, 0, false + } + return maj, min, true +} + +// compareMajorMinor compares versions by major.minor only. +// Returns -1 if current < desired, 0 if equal, +1 if current > desired, and ok=false if parsing fails. +func compareMajorMinor(current, desired string) (cmp int, ok bool) { + cMaj, cMin, ok1 := parseMajorMinor(current) + dMaj, dMin, ok2 := parseMajorMinor(desired) + if !ok1 || !ok2 { + return 0, false + } + if cMaj != dMaj { + if cMaj < dMaj { + return -1, true + } + return 1, true + } + if cMin != dMin { + if cMin < dMin { + return -1, true + } + return 1, true + } + return 0, true +} diff --git a/pkg/drift/version_test.go b/pkg/drift/version_test.go new file mode 100644 index 0000000..3d9b862 --- /dev/null +++ b/pkg/drift/version_test.go @@ -0,0 +1,74 @@ +package drift + +import "testing" + +func TestMajorMinor(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + version string + want string + }{ + {name: "trim-v", version: "v1.30.2", want: "1.30"}, + {name: "already-major-minor", version: "1.30", want: "1.30"}, + {name: "only-major", version: "1", want: "1"}, + {name: "spaces", version: " v1.31.7 ", want: "1.31"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := majorMinor(tt.version); got != tt.want { + t.Fatalf("majorMinor(%q)=%q, want %q", tt.version, got, tt.want) + } + }) + } +} + +func TestParseMajorMinor(t *testing.T) { + t.Parallel() + + maj, min, ok := parseMajorMinor("v1.31.7") + if !ok || maj != 1 || min != 31 { + t.Fatalf("parseMajorMinor(v1.31.7)=(%d,%d,%v), want (1,31,true)", maj, min, ok) + } + + _, _, ok = parseMajorMinor("1") + if ok { + t.Fatalf("parseMajorMinor(1) ok=true, want false") + } + _, _, ok = parseMajorMinor("foo.1") + if ok { + t.Fatalf("parseMajorMinor(foo.1) ok=true, want false") + } + _, _, ok = parseMajorMinor("1.bar") + if ok { + t.Fatalf("parseMajorMinor(1.bar) ok=true, want false") + } +} + +func TestCompareMajorMinor(t *testing.T) { + t.Parallel() + + cmp, ok := compareMajorMinor("1.29.5", "1.30.0") + if !ok || cmp != -1 { + t.Fatalf("compareMajorMinor(1.29.5,1.30.0)=(%d,%v), want (-1,true)", cmp, ok) + } + + cmp, ok = compareMajorMinor("1.30.0", "1.30.7") + if !ok || cmp != 0 { + t.Fatalf("compareMajorMinor(1.30.0,1.30.7)=(%d,%v), want (0,true)", cmp, ok) + } + + cmp, ok = compareMajorMinor("1.31.0", "1.30.9") + if !ok || cmp != 1 { + t.Fatalf("compareMajorMinor(1.31.0,1.30.9)=(%d,%v), want (1,true)", cmp, ok) + } + + _, ok = compareMajorMinor("1.x", "1.30.0") + if ok { + t.Fatalf("compareMajorMinor(1.x,1.30.0) ok=true, want false") + } +} diff --git a/pkg/spec/kubernetes_version_override.go b/pkg/spec/kubernetes_version_override.go new file mode 100644 index 0000000..0dfd68f --- /dev/null +++ b/pkg/spec/kubernetes_version_override.go @@ -0,0 +1,57 @@ +package spec + +import ( + "strings" + + "go.goms.io/aks/AKSFlexNode/pkg/config" +) + +// DesiredKubernetesVersion returns the preferred Kubernetes version from a managed cluster spec. +// +// Priority: +// 1. CurrentKubernetesVersion (typically a full patch version) +// 2. KubernetesVersion (typically major.minor or a less specific version) +func DesiredKubernetesVersion(specSnap *ManagedClusterSpec) string { + if specSnap == nil { + return "" + } + desired := strings.TrimSpace(specSnap.CurrentKubernetesVersion) + if desired == "" { + desired = strings.TrimSpace(specSnap.KubernetesVersion) + } + return desired +} + +// OverrideKubernetesVersionFromManagedClusterSpec loads the managed cluster spec snapshot from disk +// and, if it contains a Kubernetes version, overwrites cfg.Kubernetes.Version. +// +// This is best-effort: callers may choose to ignore errors if the spec file doesn't exist yet. +func OverrideKubernetesVersionFromManagedClusterSpec(cfg *config.Config) (changed bool, oldVersion, newVersion string, err error) { + return OverrideKubernetesVersionFromManagedClusterSpecFile(cfg, GetManagedClusterSpecFilePath()) +} + +// OverrideKubernetesVersionFromManagedClusterSpecFile loads the spec snapshot from the given file path +// and overwrites cfg.Kubernetes.Version if a desired Kubernetes version is present. +func OverrideKubernetesVersionFromManagedClusterSpecFile(cfg *config.Config, path string) (changed bool, oldVersion, newVersion string, err error) { + if cfg == nil { + return false, "", "", nil + } + + specSnap, err := LoadManagedClusterSpecFromFile(path) + if err != nil || specSnap == nil { + return false, "", "", err + } + + desired := DesiredKubernetesVersion(specSnap) + if desired == "" { + return false, "", "", nil + } + + old := cfg.Kubernetes.Version + if old == desired { + return false, old, desired, nil + } + + cfg.Kubernetes.Version = desired + return true, old, desired, nil +} diff --git a/pkg/spec/kubernetes_version_override_test.go b/pkg/spec/kubernetes_version_override_test.go new file mode 100644 index 0000000..3b4e1d6 --- /dev/null +++ b/pkg/spec/kubernetes_version_override_test.go @@ -0,0 +1,87 @@ +package spec + +import ( + "encoding/json" + "os" + "testing" + "time" + + "go.goms.io/aks/AKSFlexNode/pkg/config" +) + +func TestDesiredKubernetesVersion(t *testing.T) { + t.Parallel() + + if got := DesiredKubernetesVersion(nil); got != "" { + t.Fatalf("DesiredKubernetesVersion(nil)=%q, want empty", got) + } + + if got := DesiredKubernetesVersion(&ManagedClusterSpec{CurrentKubernetesVersion: " 1.30.7 ", KubernetesVersion: "1.30"}); got != "1.30.7" { + t.Fatalf("DesiredKubernetesVersion(current+fallback)=%q, want %q", got, "1.30.7") + } + if got := DesiredKubernetesVersion(&ManagedClusterSpec{CurrentKubernetesVersion: " ", KubernetesVersion: " 1.31 "}); got != "1.31" { + t.Fatalf("DesiredKubernetesVersion(fallback)=%q, want %q", got, "1.31") + } + if got := DesiredKubernetesVersion(&ManagedClusterSpec{CurrentKubernetesVersion: " ", KubernetesVersion: " "}); got != "" { + t.Fatalf("DesiredKubernetesVersion(empty)=%q, want empty", got) + } +} + +func TestOverrideKubernetesVersionFromManagedClusterSpec(t *testing.T) { + dir := t.TempDir() + path := ManagedClusterSpecFilePath(dir) + + // Missing spec file: should return an error from LoadManagedClusterSpec. + cfg := &config.Config{Kubernetes: config.KubernetesConfig{Version: "1.29.0"}} + changed, oldV, newV, err := OverrideKubernetesVersionFromManagedClusterSpecFile(cfg, path) + if err == nil { + t.Fatalf("expected error when spec file missing") + } + if changed { + t.Fatalf("changed=true, want false") + } + if oldV != "" || newV != "" { + t.Fatalf("old/new versions should be empty on error; old=%q new=%q", oldV, newV) + } + + // Write a spec snapshot file with desired version. + snap := &ManagedClusterSpec{ + SchemaVersion: ManagedClusterSpecSchemaVersion, + KubernetesVersion: "1.30", + CurrentKubernetesVersion: "1.30.7", + CollectedAt: time.Now(), + } + b, err := json.Marshal(snap) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + if err := os.WriteFile(path, b, 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + changed, oldV, newV, err = OverrideKubernetesVersionFromManagedClusterSpecFile(cfg, path) + if err != nil { + t.Fatalf("OverrideKubernetesVersionFromManagedClusterSpec() err=%v, want nil", err) + } + if !changed { + t.Fatalf("changed=false, want true") + } + if oldV != "1.29.0" || newV != "1.30.7" { + t.Fatalf("old/new mismatch: old=%q new=%q", oldV, newV) + } + if cfg.Kubernetes.Version != "1.30.7" { + t.Fatalf("cfg.Kubernetes.Version=%q, want %q", cfg.Kubernetes.Version, "1.30.7") + } + + // Second call should be a no-op. + changed, oldV, newV, err = OverrideKubernetesVersionFromManagedClusterSpecFile(cfg, path) + if err != nil { + t.Fatalf("second override err=%v, want nil", err) + } + if changed { + t.Fatalf("second override changed=true, want false") + } + if oldV != "1.30.7" || newV != "1.30.7" { + t.Fatalf("old/new mismatch on no-op: old=%q new=%q", oldV, newV) + } +} diff --git a/pkg/spec/loader.go b/pkg/spec/loader.go new file mode 100644 index 0000000..901309b --- /dev/null +++ b/pkg/spec/loader.go @@ -0,0 +1,31 @@ +package spec + +import ( + "encoding/json" + "fmt" + "os" +) + +// LoadManagedClusterSpec loads the managed cluster spec snapshot from the default path. +func LoadManagedClusterSpec() (*ManagedClusterSpec, error) { + return LoadManagedClusterSpecFromFile(GetManagedClusterSpecFilePath()) +} + +// LoadManagedClusterSpecFromFile loads the managed cluster spec snapshot from a JSON file. +func LoadManagedClusterSpecFromFile(path string) (*ManagedClusterSpec, error) { + if path == "" { + return nil, fmt.Errorf("spec path is empty") + } + + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + var s ManagedClusterSpec + if err := json.Unmarshal(data, &s); err != nil { + return nil, fmt.Errorf("failed to unmarshal managed cluster spec: %w", err) + } + + return &s, nil +} diff --git a/pkg/spec/paths.go b/pkg/spec/paths.go index f8cbddc..e5c91d6 100644 --- a/pkg/spec/paths.go +++ b/pkg/spec/paths.go @@ -5,6 +5,11 @@ import ( "path/filepath" ) +// ManagedClusterSpecFilePath returns the managed cluster spec snapshot file path under a provided directory. +func ManagedClusterSpecFilePath(specDir string) string { + return filepath.Join(specDir, "managedcluster-spec.json") +} + // GetSpecDir returns the appropriate directory for spec artifacts. // Uses /run/aks-flex-node when running as aks-flex-node user (systemd service) // Uses /tmp/aks-flex-node for direct user execution (testing/development) @@ -19,5 +24,5 @@ func GetSpecDir() string { // GetManagedClusterSpecFilePath returns the path where the managed cluster spec snapshot is stored. func GetManagedClusterSpecFilePath() string { - return filepath.Join(GetSpecDir(), "managedcluster-spec.json") + return ManagedClusterSpecFilePath(GetSpecDir()) } diff --git a/pkg/spec/paths_test.go b/pkg/spec/paths_test.go new file mode 100644 index 0000000..e74b72c --- /dev/null +++ b/pkg/spec/paths_test.go @@ -0,0 +1,14 @@ +package spec + +import "testing" + +func TestManagedClusterSpecFilePath(t *testing.T) { + t.Parallel() + + dir := "/some/dir" + got := ManagedClusterSpecFilePath(dir) + want := dir + "/managedcluster-spec.json" + if got != want { + t.Fatalf("ManagedClusterSpecFilePath()=%q, want %q", got, want) + } +} diff --git a/pkg/spec/remove.go b/pkg/spec/remove.go new file mode 100644 index 0000000..86e7b10 --- /dev/null +++ b/pkg/spec/remove.go @@ -0,0 +1,26 @@ +package spec + +import "os" + +// RemoveManagedClusterSpecSnapshot removes the managed cluster spec snapshot file. +// +// It returns (removed=false, err=nil) when the file does not exist. +func RemoveManagedClusterSpecSnapshot() (removed bool, err error) { + return RemoveManagedClusterSpecSnapshotAtPath(GetManagedClusterSpecFilePath()) +} + +// RemoveManagedClusterSpecSnapshotAtPath removes the managed cluster spec snapshot file at the given path. +// +// It returns (removed=false, err=nil) when the file does not exist. +func RemoveManagedClusterSpecSnapshotAtPath(path string) (removed bool, err error) { + if path == "" { + return false, nil + } + if err := os.Remove(path); err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, err + } + return true, nil +} diff --git a/pkg/spec/remove_test.go b/pkg/spec/remove_test.go new file mode 100644 index 0000000..3271965 --- /dev/null +++ b/pkg/spec/remove_test.go @@ -0,0 +1,36 @@ +package spec + +import ( + "os" + "testing" +) + +func TestRemoveManagedClusterSpecSnapshot(t *testing.T) { + dir := t.TempDir() + path := ManagedClusterSpecFilePath(dir) + + // No file: should be (removed=false, err=nil). + removed, err := RemoveManagedClusterSpecSnapshotAtPath(path) + if err != nil { + t.Fatalf("RemoveManagedClusterSpecSnapshot() err=%v, want nil", err) + } + if removed { + t.Fatalf("RemoveManagedClusterSpecSnapshot() removed=true, want false") + } + + // Create the file and remove it. + if err := os.WriteFile(path, []byte("{}"), 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + removed, err = RemoveManagedClusterSpecSnapshotAtPath(path) + if err != nil { + t.Fatalf("RemoveManagedClusterSpecSnapshot() err=%v, want nil", err) + } + if !removed { + t.Fatalf("RemoveManagedClusterSpecSnapshot() removed=false, want true") + } + if _, statErr := os.Stat(path); !os.IsNotExist(statErr) { + t.Fatalf("spec file still exists; statErr=%v", statErr) + } +} diff --git a/pkg/status/health.go b/pkg/status/health.go new file mode 100644 index 0000000..e257756 --- /dev/null +++ b/pkg/status/health.go @@ -0,0 +1,53 @@ +package status + +import ( + "time" + + "github.com/sirupsen/logrus" +) + +// MarkKubeletUnhealthyBestEffort updates the existing status snapshot (or creates a minimal one) +// to clearly indicate the kubelet is unhealthy. +// +// This is intended to influence NeedsBootstrap() without deleting the entire status file. +func MarkKubeletUnhealthyBestEffort(logger *logrus.Logger) { + if logger == nil { + logger = logrus.New() + } + + statusFilePath := GetStatusFilePath() + MarkKubeletUnhealthyBestEffortAtPath(logger, statusFilePath, time.Time{}) +} + +// MarkKubeletUnhealthyBestEffortAtPath is the path-based variant used by tests and any callers +// that want to control where the status snapshot is written. +// +// If now is zero, time.Now() is used. +func MarkKubeletUnhealthyBestEffortAtPath(logger *logrus.Logger, statusFilePath string, now time.Time) { + if logger == nil { + logger = logrus.New() + } + if statusFilePath == "" { + return + } + if now.IsZero() { + now = time.Now() + } + + snap, err := LoadStatusFromFile(statusFilePath) + if err != nil || snap == nil { + snap = &NodeStatus{} + } + + // Make the status clearly unhealthy so NeedsBootstrap() will trigger. + snap.KubeletRunning = false + snap.KubeletReady = "Unknown" + snap.KubeletVersion = "unknown" + snap.LastUpdatedBy = LastUpdatedByDriftDetectionAndRemediation + snap.LastUpdatedReason = LastUpdatedReasonKubernetesVersionDrift + snap.LastUpdated = now + + if err := WriteStatusToFile(statusFilePath, snap); err != nil { + logger.Debugf("Failed to mark status unhealthy at %s: %v", statusFilePath, err) + } +} diff --git a/pkg/status/health_test.go b/pkg/status/health_test.go new file mode 100644 index 0000000..6ff39ff --- /dev/null +++ b/pkg/status/health_test.go @@ -0,0 +1,43 @@ +package status + +import ( + "path/filepath" + "testing" + "time" + + "github.com/sirupsen/logrus" +) + +func TestMarkKubeletUnhealthyBestEffortAtPath_CreatesOrUpdatesSnapshot(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "status.json") + logger := logrus.New() + + now := time.Date(2026, 2, 13, 12, 0, 0, 0, time.UTC) + MarkKubeletUnhealthyBestEffortAtPath(logger, path, now) + + snap, err := LoadStatusFromFile(path) + if err != nil { + t.Fatalf("LoadStatusFromFile() err=%v", err) + } + if snap.KubeletRunning != false { + t.Fatalf("KubeletRunning=%v, want false", snap.KubeletRunning) + } + if snap.KubeletReady != "Unknown" { + t.Fatalf("KubeletReady=%q, want %q", snap.KubeletReady, "Unknown") + } + if snap.KubeletVersion != "unknown" { + t.Fatalf("KubeletVersion=%q, want %q", snap.KubeletVersion, "unknown") + } + if snap.LastUpdatedBy != LastUpdatedByDriftDetectionAndRemediation { + t.Fatalf("LastUpdatedBy=%q, want %q", snap.LastUpdatedBy, LastUpdatedByDriftDetectionAndRemediation) + } + if snap.LastUpdatedReason != LastUpdatedReasonKubernetesVersionDrift { + t.Fatalf("LastUpdatedReason=%q, want %q", snap.LastUpdatedReason, LastUpdatedReasonKubernetesVersionDrift) + } + if !snap.LastUpdated.Equal(now) { + t.Fatalf("LastUpdated=%s, want %s", snap.LastUpdated.Format(time.RFC3339Nano), now.Format(time.RFC3339Nano)) + } +} diff --git a/pkg/status/loader.go b/pkg/status/loader.go new file mode 100644 index 0000000..4608d22 --- /dev/null +++ b/pkg/status/loader.go @@ -0,0 +1,31 @@ +package status + +import ( + "encoding/json" + "fmt" + "os" +) + +// LoadStatus loads the node status snapshot from the default path. +func LoadStatus() (*NodeStatus, error) { + return LoadStatusFromFile(GetStatusFilePath()) +} + +// LoadStatusFromFile loads the node status snapshot from a JSON file. +func LoadStatusFromFile(path string) (*NodeStatus, error) { + if path == "" { + return nil, fmt.Errorf("status path is empty") + } + + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + var s NodeStatus + if err := json.Unmarshal(data, &s); err != nil { + return nil, fmt.Errorf("failed to unmarshal node status: %w", err) + } + + return &s, nil +} diff --git a/pkg/status/loader_writer_test.go b/pkg/status/loader_writer_test.go new file mode 100644 index 0000000..4dc4dd2 --- /dev/null +++ b/pkg/status/loader_writer_test.go @@ -0,0 +1,80 @@ +package status + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestWriteStatusToFileAndLoadStatusFromFile_RoundTrip(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "status.json") + + in := &NodeStatus{ + KubeletVersion: "1.30.7", + RuncVersion: "1.1.12", + ContainerdVersion: "1.7.20", + KubeletRunning: true, + KubeletReady: "Ready", + ContainerdRunning: true, + LastUpdated: time.Now().UTC().Truncate(time.Second), + LastUpdatedBy: LastUpdatedByStatusCollectionLoop, + LastUpdatedReason: LastUpdatedReasonPeriodicStatusLoop, + AgentVersion: "dev", + ArcStatus: ArcStatus{Connected: true, Registered: true, MachineName: "m"}, + } + + if err := WriteStatusToFile(path, in); err != nil { + t.Fatalf("WriteStatusToFile() err=%v", err) + } + + // Ensure we didn't leave a temp file around. + if _, err := os.Stat(path + ".tmp"); err == nil { + t.Fatalf("temp file still exists") + } + + out, err := LoadStatusFromFile(path) + if err != nil { + t.Fatalf("LoadStatusFromFile() err=%v", err) + } + if out == nil { + t.Fatalf("LoadStatusFromFile() out=nil") + } + if out.KubeletVersion != in.KubeletVersion || out.RuncVersion != in.RuncVersion { + t.Fatalf("roundtrip mismatch: got kubelet=%q runc=%q", out.KubeletVersion, out.RuncVersion) + } + if out.LastUpdatedBy != in.LastUpdatedBy || out.LastUpdatedReason != in.LastUpdatedReason { + t.Fatalf("metadata mismatch: got by=%q reason=%q", out.LastUpdatedBy, out.LastUpdatedReason) + } +} + +func TestWriteStatusToFile_ValidationErrors(t *testing.T) { + t.Parallel() + + if err := WriteStatusToFile("", &NodeStatus{}); err == nil { + t.Fatalf("expected error for empty path") + } + if err := WriteStatusToFile("/tmp/does-not-matter.json", nil); err == nil { + t.Fatalf("expected error for nil status") + } +} + +func TestLoadStatusFromFile_Errors(t *testing.T) { + t.Parallel() + + if _, err := LoadStatusFromFile(""); err == nil { + t.Fatalf("expected error for empty path") + } + + dir := t.TempDir() + path := filepath.Join(dir, "status.json") + if err := os.WriteFile(path, []byte("not-json"), 0o600); err != nil { + t.Fatalf("WriteFile() err=%v", err) + } + if _, err := LoadStatusFromFile(path); err == nil { + t.Fatalf("expected unmarshal error") + } +} diff --git a/pkg/status/remove.go b/pkg/status/remove.go new file mode 100644 index 0000000..cefe128 --- /dev/null +++ b/pkg/status/remove.go @@ -0,0 +1,37 @@ +package status + +import ( + "errors" + "os" + + "github.com/sirupsen/logrus" +) + +// RemoveStatusFileBestEffort removes the current node status file. +// +// It is intentionally best-effort: failure to remove the file should not crash the agent, +// but it helps ensure subsequent health checks re-collect status from scratch. +func RemoveStatusFileBestEffort(logger *logrus.Logger) { + RemoveStatusFileBestEffortAtPath(logger, GetStatusFilePath()) +} + +func RemoveStatusFileBestEffortAtPath(logger *logrus.Logger, statusFilePath string) { + if logger == nil { + return + } + if statusFilePath == "" { + logger.Debug("Failed to remove status file: empty path") + return + } + + if err := os.Remove(statusFilePath); err != nil { + if errors.Is(err, os.ErrNotExist) { + logger.Debug("Status file already removed") + return + } + logger.Debugf("Failed to remove status file: %v", err) + return + } + + logger.Debug("Removed status file successfully") +} diff --git a/pkg/status/remove_test.go b/pkg/status/remove_test.go new file mode 100644 index 0000000..c930e71 --- /dev/null +++ b/pkg/status/remove_test.go @@ -0,0 +1,41 @@ +package status + +import ( + "os" + "path/filepath" + "testing" + + "github.com/sirupsen/logrus" +) + +func TestRemoveStatusFileBestEffortAtPath_RemovesFile(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + p := filepath.Join(tmpDir, "status.json") + if err := os.WriteFile(p, []byte(`{"foo":"bar"}`), 0o600); err != nil { + t.Fatalf("write temp status: %v", err) + } + + logger := logrus.New() + RemoveStatusFileBestEffortAtPath(logger, p) + + if _, err := os.Stat(p); !os.IsNotExist(err) { + t.Fatalf("expected file removed, stat err=%v", err) + } +} + +func TestRemoveStatusFileBestEffortAtPath_MissingFileNoError(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + p := filepath.Join(tmpDir, "status.json") + + logger := logrus.New() + RemoveStatusFileBestEffortAtPath(logger, p) + + // Should remain missing. + if _, err := os.Stat(p); !os.IsNotExist(err) { + t.Fatalf("expected file still missing, stat err=%v", err) + } +} diff --git a/pkg/status/types.go b/pkg/status/types.go index 9876f2e..ea32d81 100644 --- a/pkg/status/types.go +++ b/pkg/status/types.go @@ -4,6 +4,22 @@ import ( "time" ) +type LastUpdatedBy string + +const ( + LastUpdatedByUnspecified LastUpdatedBy = "" + LastUpdatedByStatusCollectionLoop LastUpdatedBy = "StatusCollectionLoop" + LastUpdatedByDriftDetectionAndRemediation LastUpdatedBy = "DriftDetectionAndRemediation" +) + +type LastUpdatedReason string + +const ( + LastUpdatedReasonUnspecified LastUpdatedReason = "" + LastUpdatedReasonPeriodicStatusLoop LastUpdatedReason = "perodicStatusLoop" + LastUpdatedReasonKubernetesVersionDrift LastUpdatedReason = "kubernetesVersionDrift" +) + // NodeStatus represents the current status and health information of the AKS edge node type NodeStatus struct { // Component versions @@ -21,8 +37,10 @@ type NodeStatus struct { ArcStatus ArcStatus `json:"arcStatus"` // Metadata - LastUpdated time.Time `json:"lastUpdated"` - AgentVersion string `json:"agentVersion"` + LastUpdated time.Time `json:"lastUpdated"` + LastUpdatedBy LastUpdatedBy `json:"lastUpdatedBy,omitempty"` + LastUpdatedReason LastUpdatedReason `json:"lastUpdatedReason,omitempty"` + AgentVersion string `json:"agentVersion"` } // ArcStatus contains Azure Arc machine registration and connection status diff --git a/pkg/status/writer.go b/pkg/status/writer.go new file mode 100644 index 0000000..243bb42 --- /dev/null +++ b/pkg/status/writer.go @@ -0,0 +1,37 @@ +package status + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" +) + +// WriteStatusToFile persists the node status snapshot to a JSON file. +// It writes to a temporary file and renames it for atomicity. +func WriteStatusToFile(path string, nodeStatus *NodeStatus) error { + if path == "" { + return fmt.Errorf("status path is empty") + } + if nodeStatus == nil { + return fmt.Errorf("node status is nil") + } + + if err := os.MkdirAll(filepath.Dir(path), 0o750); err != nil { + return fmt.Errorf("failed to create status directory: %w", err) + } + + statusData, err := json.MarshalIndent(nodeStatus, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal status to JSON: %w", err) + } + + tempFile := path + ".tmp" + if err := os.WriteFile(tempFile, statusData, 0o600); err != nil { + return fmt.Errorf("failed to write status to temp file: %w", err) + } + if err := os.Rename(tempFile, path); err != nil { + return fmt.Errorf("failed to rename temp status file: %w", err) + } + return nil +} From 39466ef3b5a0333dd13859b2a81972e743baa991 Mon Sep 17 00:00:00 2001 From: Qingqing Zheng Date: Thu, 19 Feb 2026 13:53:20 -0800 Subject: [PATCH 2/5] format --- pkg/drift/detector.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/drift/detector.go b/pkg/drift/detector.go index fff9407..6892ce0 100644 --- a/pkg/drift/detector.go +++ b/pkg/drift/detector.go @@ -23,7 +23,7 @@ type Finding struct { type RemediationAction string const ( - RemediationActionUnspecified RemediationAction = "" + RemediationActionUnspecified RemediationAction = "" RemediationActionKubernetesUpgrade RemediationAction = "kubernetes-upgrade" ) From 857979d098cc68be1c1b5481568daf049d3cca86 Mon Sep 17 00:00:00 2001 From: Qingqing Zheng Date: Thu, 19 Feb 2026 14:00:01 -0800 Subject: [PATCH 3/5] add nosec annotation --- pkg/spec/loader.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/spec/loader.go b/pkg/spec/loader.go index 901309b..045b5a0 100644 --- a/pkg/spec/loader.go +++ b/pkg/spec/loader.go @@ -17,6 +17,7 @@ func LoadManagedClusterSpecFromFile(path string) (*ManagedClusterSpec, error) { return nil, fmt.Errorf("spec path is empty") } + // #nosec G304 -- reading a local snapshot file path controlled by the agent (runtime/temp dir), not user input. data, err := os.ReadFile(path) if err != nil { return nil, err From ab0bf6a306b906c4796ace007a5ce4ec93e443cf Mon Sep 17 00:00:00 2001 From: Qingqing Zheng Date: Thu, 19 Feb 2026 14:01:10 -0800 Subject: [PATCH 4/5] add nosec annotation --- pkg/spec/collector_test.go | 1 + pkg/status/collector.go | 1 + pkg/status/loader.go | 1 + 3 files changed, 3 insertions(+) diff --git a/pkg/spec/collector_test.go b/pkg/spec/collector_test.go index a61bf88..028a724 100644 --- a/pkg/spec/collector_test.go +++ b/pkg/spec/collector_test.go @@ -55,6 +55,7 @@ func TestManagedClusterSpecCollector_Collect_WritesFile(t *testing.T) { t.Fatalf("Collect() error = %v", err) } + // #nosec G304 -- test reads a temp file path created by the test harness. b, err := os.ReadFile(outPath) if err != nil { t.Fatalf("ReadFile() error = %v", err) diff --git a/pkg/status/collector.go b/pkg/status/collector.go index c776b22..5ef5dcd 100644 --- a/pkg/status/collector.go +++ b/pkg/status/collector.go @@ -233,6 +233,7 @@ func (c *Collector) isKubeletReady(ctx context.Context) string { func (c *Collector) NeedsBootstrap(ctx context.Context) bool { statusFilePath := GetStatusFilePath() // Try to read the status file + // #nosec G304 -- reading a local status snapshot path controlled by the agent (runtime/temp dir), not user input. statusData, err := os.ReadFile(statusFilePath) if err != nil { c.logger.Info("Status file not found - bootstrap needed") diff --git a/pkg/status/loader.go b/pkg/status/loader.go index 4608d22..fb4a00a 100644 --- a/pkg/status/loader.go +++ b/pkg/status/loader.go @@ -17,6 +17,7 @@ func LoadStatusFromFile(path string) (*NodeStatus, error) { return nil, fmt.Errorf("status path is empty") } + // #nosec G304 -- reading a local status snapshot path controlled by the agent (runtime/temp dir), not user input. data, err := os.ReadFile(path) if err != nil { return nil, err From ee25ef774ecf611cb6078e62ca69f94f88791386 Mon Sep 17 00:00:00 2001 From: Qingqing Zheng Date: Fri, 20 Feb 2026 22:04:07 -0800 Subject: [PATCH 5/5] fix bug --- commands.go | 18 ++++++++++++++---- pkg/bootstrapper/bootstrapper.go | 20 ++++++++++---------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/commands.go b/commands.go index 4933ca3..2e22ddd 100644 --- a/commands.go +++ b/commands.go @@ -137,8 +137,8 @@ func runDaemonLoop(ctx context.Context, cfg *config.Config) error { status.RemoveStatusFileBestEffortAtPath(logger, statusFilePath) } - // Always clean up any stale managed cluster spec snapshot on daemon startup. - // The snapshot is best-effort and should not be relied upon across sessions. + // Always remove managed cluster spec snapshot on daemon startup. + // We'll re-collect it shortly after startup and on a schedule. removed, err := spec.RemoveManagedClusterSpecSnapshot() if err != nil { logger.Warnf("Failed to remove stale managed cluster spec snapshot: %v", err) @@ -146,7 +146,7 @@ func runDaemonLoop(ctx context.Context, cfg *config.Config) error { logger.Info("Removed stale managed cluster spec snapshot successfully") } - logger.Info("Starting periodic status collection daemon (status: 1 minutes, bootstrap check: 2 minute, spec collection: 30 minutes)...") + logger.Info("Starting periodic status collection daemon (status: 1 minutes, bootstrap check: 2 minute, spec collection: 10 minutes)...") // Protect cfg reads/writes across concurrent loops. This avoids data races when we // temporarily update cfg.Kubernetes.Version to trigger drift remediation bootstrap. @@ -176,7 +176,10 @@ func runDaemonLoop(ctx context.Context, cfg *config.Config) error { cfgSnap := snapshotConfig(cfg, &cfgMu) if err := drift.DetectAndRemediateFromFiles(ctx, cfgSnap, logger, &bootstrapInProgress, detectors); err != nil { logger.Warnf("Initial drift detection after spec collection failed: %v", err) + } else { + logger.Info("Initial drift detection after spec collection completed successfully") } + } } @@ -307,7 +310,7 @@ func startNodeDriftDetectionAndRemediationLoop( ) { go func() { defer wg.Done() - ticker := time.NewTicker(5 * time.Minute) + ticker := time.NewTicker(10 * time.Minute) defer ticker.Stop() for { select { @@ -357,6 +360,13 @@ func checkAndBootstrap(ctx context.Context, cfg *config.Config) error { logger.Info("Node requires re-bootstrapping, initiating auto-bootstrap...") if cfg != nil && cfg.IsDriftDetectionAndRemediationEnabled() { + // Best-effort: refresh the managed cluster spec snapshot before attempting to + // override Kubernetes version. This avoids falling back to an old static version + // right after reboot (we delete the snapshot at daemon startup). + if err := collectAndWriteManagedClusterSpec(ctx, cfg); err != nil { + logger.Warnf("Failed to refresh managed cluster spec before auto-bootstrap: %v", err) + } + // Best-effort: prefer Kubernetes version from the persisted managed cluster spec snapshot. // This keeps auto-bootstrap aligned with the cluster desired version even if the static // config has an older value. diff --git a/pkg/bootstrapper/bootstrapper.go b/pkg/bootstrapper/bootstrapper.go index d719a4c..8e8e4c5 100644 --- a/pkg/bootstrapper/bootstrapper.go +++ b/pkg/bootstrapper/bootstrapper.go @@ -33,16 +33,16 @@ func New(cfg *config.Config, logger *logrus.Logger) *Bootstrapper { func (b *Bootstrapper) Bootstrap(ctx context.Context) (*ExecutionResult, error) { // Define the bootstrap steps in order - using modules directly steps := []Executor{ - arc.NewInstaller(b.logger), // Setup Arc - services.NewUnInstaller(b.logger), // Stop kubelet before setup - system_configuration.NewInstaller(b.logger), // Configure system (early) - runc.NewInstaller(b.logger), // Install runc - containerd.NewInstaller(b.logger), // Install containerd - kube_binaries.NewInstaller(b.logger), // Install k8s binaries - cni.NewInstaller(b.logger), // Setup CNI (after container runtime) - kubelet.NewInstaller(b.logger), // Configure kubelet service with Arc MSI auth - npd.NewInstaller(b.logger), // Install Node Problem Detector - services.NewInstaller(b.logger), // Start services + arc.NewInstaller(b.logger), // Setup Arc + services.NewUnInstaller(b.logger), // Stop kubelet before setup + system_configuration.NewInstaller(b.logger), // Configure system (early) + runc.NewInstaller(b.logger), // Install runc + containerd.NewInstaller(b.logger), // Install containerd + kube_binaries.NewInstallerWithConfig(b.config, b.logger), // Install k8s binaries + cni.NewInstaller(b.logger), // Setup CNI (after container runtime) + kubelet.NewInstallerWithConfig(b.config, b.logger), // Configure kubelet service with Arc MSI auth + npd.NewInstaller(b.logger), // Install Node Problem Detector + services.NewInstaller(b.logger), // Start services } return b.ExecuteSteps(ctx, steps, "bootstrap")