From b5335289c667f0dfa4976c336bc04dbdcefe3838 Mon Sep 17 00:00:00 2001 From: Mohammad Aziz Date: Sun, 25 Jan 2026 16:20:31 +0530 Subject: [PATCH] feat: add platform headers (X-Agent-OS, X-Agent-Arch) to update check requests - Create internal/httpclient package with AgentTransport RoundTripper - AgentTransport injects X-Agent-Version, X-Agent-OS, X-Agent-Arch headers - Remove currentVersion parameter from Check() method (breaking change) - Add ErrUnsupportedPlatform error for 400 responses - Log unsupported platform at WARN level instead of ERROR - Update main.go to use httpclient.NewClient() This allows the control plane to serve the correct binary for the agent's OS/architecture combination during self-updates. --- AGENTS.md | 21 ---- app/jobs/selfupdatejob/selfupdatejob.go | 11 +- app/jobs/selfupdatejob/selfupdatejob_test.go | 26 +++- app/services/updatecheck/updatecheck.go | 12 +- app/services/updatecheck/updatecheck_test.go | 51 +++++--- internal/httpclient/transport.go | 39 ++++++ internal/httpclient/transport_test.go | 125 +++++++++++++++++++ main.go | 4 +- 8 files changed, 245 insertions(+), 44 deletions(-) create mode 100644 internal/httpclient/transport.go create mode 100644 internal/httpclient/transport_test.go diff --git a/AGENTS.md b/AGENTS.md index c078ce5..643ec7d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,22 +1 @@ -- ALWAYS USE PARALLEL TASKS SUBAGENTS FOR CODE EXPLORATION, INVESTIGATION, DEEP DIVES -- Use all tools available to keep current context window as small as possible -- When reading files, DELEGATE to subagents, if possible -- In plan mode, be bias to delegate to subagents -- Use question tool more frequently - Use jj instead of git -- ALWAYS FOLLOW TDD, red phase to green phase -- Use ripgrep instead of grep, use fd instead of find - -## Usage of question tool - -Before any kind of implementation, interview me in detail using the question tool. - -Ask about technical implementation, UI/UX, edge cases, concerns, and tradeoffs. -Don't ask obvious questions, dig into the hard parts I might not have considered. - -Keep interviewing until we've covered everything. - -## Tests - -- Test actual behavior, not the implementation -- Only test implementation when there is a technical limit to simulating the behavior diff --git a/app/jobs/selfupdatejob/selfupdatejob.go b/app/jobs/selfupdatejob/selfupdatejob.go index e565a3f..0f0e7fb 100644 --- a/app/jobs/selfupdatejob/selfupdatejob.go +++ b/app/jobs/selfupdatejob/selfupdatejob.go @@ -2,8 +2,10 @@ package selfupdatejob import ( "context" + "errors" "fmt" "path/filepath" + "runtime" "sync" "time" @@ -27,7 +29,7 @@ type TriggerFunc func(context.Context, func() error) // UpdateCheckerInterface abstracts the update check client. type UpdateCheckerInterface interface { - Check(currentVersion string) (*updatecheck.UpdateInfo, error) + Check() (*updatecheck.UpdateInfo, error) } // DownloaderInterface abstracts the download and verify functionality. @@ -125,8 +127,13 @@ func (j *SelfUpdateJob) Shutdown() { // runUpdate performs a single update check and apply cycle. func (j *SelfUpdateJob) runUpdate(ctx context.Context) error { // Step 1: Check for updates - info, err := j.config.UpdateChecker.Check(j.config.CurrentVersion) + info, err := j.config.UpdateChecker.Check() if err != nil { + // Log unsupported platform at WARN level and return nil (not an error condition) + if errors.Is(err, updatecheck.ErrUnsupportedPlatform) { + log.Warnf("update check failed: unsupported platform %s/%s", runtime.GOOS, runtime.GOARCH) + return nil + } return fmt.Errorf("update check failed: %w", err) } if !info.UpdateAvailable { diff --git a/app/jobs/selfupdatejob/selfupdatejob_test.go b/app/jobs/selfupdatejob/selfupdatejob_test.go index fc697fc..0e9caa9 100644 --- a/app/jobs/selfupdatejob/selfupdatejob_test.go +++ b/app/jobs/selfupdatejob/selfupdatejob_test.go @@ -206,6 +206,30 @@ func TestUpdateFlow_SkipsWhenNoUpdate(t *testing.T) { } } +func TestUpdateFlow_UnsupportedPlatform_ReturnsNilError(t *testing.T) { + // When the control plane returns 400 (unsupported platform), + // runUpdate should log WARN and return nil (not an error). + checker := &mockUpdateChecker{ + err: updatecheck.ErrUnsupportedPlatform, + } + downloader := &mockDownloader{} + + job := NewWithConfig(SelfUpdateJobConfig{ + UpdateChecker: checker, + Downloader: downloader, + CurrentVersion: "1.0.0", + }) + + err := job.runUpdate(context.Background()) + + if err != nil { + t.Errorf("expected nil error for unsupported platform, got %v", err) + } + if downloader.callCount.Load() > 0 { + t.Error("downloader should not be called when platform unsupported") + } +} + func TestUpdateFlow_SkipsWhenPreflightFails(t *testing.T) { checker := &mockUpdateChecker{ result: &updatecheck.UpdateInfo{ @@ -1112,7 +1136,7 @@ type mockUpdateChecker struct { callCount atomic.Int32 } -func (m *mockUpdateChecker) Check(currentVersion string) (*updatecheck.UpdateInfo, error) { +func (m *mockUpdateChecker) Check() (*updatecheck.UpdateInfo, error) { m.callCount.Add(1) return m.result, m.err } diff --git a/app/services/updatecheck/updatecheck.go b/app/services/updatecheck/updatecheck.go index 247dbf2..792e238 100644 --- a/app/services/updatecheck/updatecheck.go +++ b/app/services/updatecheck/updatecheck.go @@ -7,6 +7,10 @@ import ( "net/http" ) +// ErrUnsupportedPlatform is returned when the control plane returns 400, +// indicating no binary exists for the agent's OS/architecture combination. +var ErrUnsupportedPlatform = errors.New("unsupported platform") + // UpdateInfo represents the response from the update check endpoint. type UpdateInfo struct { UpdateAvailable bool `json:"update_available"` @@ -43,7 +47,8 @@ func New(client *http.Client, controlPlaneURL, agentID string, signer RequestSig } // Check queries the control plane for available updates. -func (c *UpdateChecker) Check(currentVersion string) (*UpdateInfo, error) { +// Agent headers (X-Agent-Version, X-Agent-OS, X-Agent-Arch) are set by the HTTP transport. +func (c *UpdateChecker) Check() (*UpdateInfo, error) { url := fmt.Sprintf("%s/api/v1/agents/%s/update", c.controlPlaneURL, c.agentID) req, err := http.NewRequest(http.MethodGet, url, nil) @@ -51,8 +56,6 @@ func (c *UpdateChecker) Check(currentVersion string) (*UpdateInfo, error) { return nil, fmt.Errorf("failed to create request: %w", err) } - req.Header.Set("X-Agent-Version", currentVersion) - if c.signer != nil { if err := c.signer.SignRequest(req); err != nil { return nil, fmt.Errorf("failed to sign request: %w", err) @@ -65,6 +68,9 @@ func (c *UpdateChecker) Check(currentVersion string) (*UpdateInfo, error) { } defer resp.Body.Close() + if resp.StatusCode == http.StatusBadRequest { + return nil, fmt.Errorf("update check returned status %d: %w", resp.StatusCode, ErrUnsupportedPlatform) + } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("update check returned status %d", resp.StatusCode) } diff --git a/app/services/updatecheck/updatecheck_test.go b/app/services/updatecheck/updatecheck_test.go index 6ca0bff..1de5f80 100644 --- a/app/services/updatecheck/updatecheck_test.go +++ b/app/services/updatecheck/updatecheck_test.go @@ -2,6 +2,7 @@ package updatecheck import ( "encoding/json" + "errors" "net/http" "net/http/httptest" "testing" @@ -23,7 +24,7 @@ func TestCheck_UpdateAvailable(t *testing.T) { defer server.Close() checker := newTestChecker(t, server.Client(), server.URL, nil) - info, err := checker.Check("1.0.0") + info, err := checker.Check() if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -54,7 +55,7 @@ func TestCheck_NoUpdateAvailable(t *testing.T) { defer server.Close() checker := newTestChecker(t, server.Client(), server.URL, nil) - info, err := checker.Check("2.0.0") + info, err := checker.Check() if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -65,7 +66,7 @@ func TestCheck_NoUpdateAvailable(t *testing.T) { func TestCheck_NetworkError(t *testing.T) { checker := newTestChecker(t, http.DefaultClient, "http://localhost:1", nil) - _, err := checker.Check("1.0.0") + _, err := checker.Check() if err == nil { t.Fatal("expected error for bad URL, got nil") } @@ -82,7 +83,7 @@ func TestCheck_SignsRequest(t *testing.T) { signer := &mockSigner{agentID: "agent-123"} checker := newTestChecker(t, server.Client(), server.URL, signer) - _, err := checker.Check("1.0.0") + _, err := checker.Check() if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -101,23 +102,25 @@ func TestCheck_SignsRequest(t *testing.T) { } } -func TestCheck_SendsCurrentVersionAsHeader(t *testing.T) { - var receivedVersion string +func TestCheck_MakesRequest(t *testing.T) { + // Note: X-Agent-Version, X-Agent-OS, X-Agent-Arch headers are now set by the transport. + // Header verification is done in internal/httpclient tests. + var requestMade bool server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedVersion = r.Header.Get("X-Agent-Version") + requestMade = true resp := UpdateInfo{UpdateAvailable: false} json.NewEncoder(w).Encode(resp) })) defer server.Close() checker := newTestChecker(t, server.Client(), server.URL, nil) - _, err := checker.Check("1.5.3") + _, err := checker.Check() if err != nil { t.Fatalf("unexpected error: %v", err) } - if receivedVersion != "1.5.3" { - t.Errorf("expected X-Agent-Version header 1.5.3, got %s", receivedVersion) + if !requestMade { + t.Error("expected request to be made") } } @@ -131,7 +134,7 @@ func TestCheck_NoQueryParams(t *testing.T) { defer server.Close() checker := newTestChecker(t, server.Client(), server.URL, nil) - checker.Check("1.5.3") + checker.Check() if receivedRawQuery != "" { t.Errorf("expected no query params, got %s", receivedRawQuery) @@ -145,12 +148,30 @@ func TestCheck_HTTPErrorStatus(t *testing.T) { defer server.Close() checker := newTestChecker(t, server.Client(), server.URL, nil) - _, err := checker.Check("1.0.0") + _, err := checker.Check() if err == nil { t.Fatal("expected error for 500 status, got nil") } } +func TestCheck_UnsupportedPlatform_Returns400(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + })) + defer server.Close() + + checker := newTestChecker(t, server.Client(), server.URL, nil) + _, err := checker.Check() + if err == nil { + t.Fatal("expected error for 400 status, got nil") + } + + // Verify we get the specific ErrUnsupportedPlatform error + if !errors.Is(err, ErrUnsupportedPlatform) { + t.Errorf("expected ErrUnsupportedPlatform, got %v", err) + } +} + func TestCheck_InvalidJSON(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("not json")) @@ -158,7 +179,7 @@ func TestCheck_InvalidJSON(t *testing.T) { defer server.Close() checker := newTestChecker(t, server.Client(), server.URL, nil) - _, err := checker.Check("1.0.0") + _, err := checker.Check() if err == nil { t.Fatal("expected error for invalid JSON, got nil") } @@ -174,7 +195,7 @@ func TestCheck_UsesGETMethod(t *testing.T) { defer server.Close() checker := newTestChecker(t, server.Client(), server.URL, nil) - checker.Check("1.0.0") + checker.Check() if receivedMethod != http.MethodGet { t.Errorf("expected GET method, got %s", receivedMethod) @@ -194,7 +215,7 @@ func TestCheck_UsesCorrectPath(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - checker.Check("1.0.0") + checker.Check() if receivedPath != "/api/v1/agents/agent-123/update" { t.Errorf("expected path /api/v1/agents/agent-123/update, got %s", receivedPath) diff --git a/internal/httpclient/transport.go b/internal/httpclient/transport.go new file mode 100644 index 0000000..e7eabfa --- /dev/null +++ b/internal/httpclient/transport.go @@ -0,0 +1,39 @@ +// Package httpclient provides HTTP client utilities with agent identification headers. +package httpclient + +import ( + "net/http" + "runtime" + "time" + + "hostlink/version" +) + +// AgentTransport wraps an http.RoundTripper and injects agent identification headers. +type AgentTransport struct { + Base http.RoundTripper +} + +// RoundTrip implements http.RoundTripper. +func (t *AgentTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Clone request to avoid mutating the original + clone := req.Clone(req.Context()) + + clone.Header.Set("X-Agent-Version", version.Version) + clone.Header.Set("X-Agent-OS", runtime.GOOS) + clone.Header.Set("X-Agent-Arch", runtime.GOARCH) + + base := t.Base + if base == nil { + base = http.DefaultTransport + } + return base.RoundTrip(clone) +} + +// NewClient returns an *http.Client configured with AgentTransport and the specified timeout. +func NewClient(timeout time.Duration) *http.Client { + return &http.Client{ + Transport: &AgentTransport{}, + Timeout: timeout, + } +} diff --git a/internal/httpclient/transport_test.go b/internal/httpclient/transport_test.go new file mode 100644 index 0000000..78b187a --- /dev/null +++ b/internal/httpclient/transport_test.go @@ -0,0 +1,125 @@ +package httpclient + +import ( + "net/http" + "net/http/httptest" + "runtime" + "testing" + "time" + + "hostlink/version" +) + +func TestAgentTransport_SetsAllHeaders(t *testing.T) { + var receivedHeaders http.Header + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := NewClient(5 * time.Second) + resp, err := client.Get(server.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + resp.Body.Close() + + // Verify X-Agent-Version header + if got := receivedHeaders.Get("X-Agent-Version"); got != version.Version { + t.Errorf("X-Agent-Version = %q, want %q", got, version.Version) + } + + // Verify X-Agent-OS header + if got := receivedHeaders.Get("X-Agent-OS"); got != runtime.GOOS { + t.Errorf("X-Agent-OS = %q, want %q", got, runtime.GOOS) + } + + // Verify X-Agent-Arch header + if got := receivedHeaders.Get("X-Agent-Arch"); got != runtime.GOARCH { + t.Errorf("X-Agent-Arch = %q, want %q", got, runtime.GOARCH) + } +} + +func TestAgentTransport_PreservesExistingHeaders(t *testing.T) { + var receivedHeaders http.Header + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := NewClient(5 * time.Second) + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + req.Header.Set("X-Custom-Header", "custom-value") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + resp.Body.Close() + + // Verify custom header is preserved + if got := receivedHeaders.Get("X-Custom-Header"); got != "custom-value" { + t.Errorf("X-Custom-Header = %q, want %q", got, "custom-value") + } + + // Verify agent headers are still set + if got := receivedHeaders.Get("X-Agent-Version"); got != version.Version { + t.Errorf("X-Agent-Version = %q, want %q", got, version.Version) + } +} + +func TestAgentTransport_DoesNotMutateOriginalRequest(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := NewClient(5 * time.Second) + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + // Store original header count + originalHeaderCount := len(req.Header) + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + resp.Body.Close() + + // Verify original request was not mutated + if len(req.Header) != originalHeaderCount { + t.Errorf("original request was mutated: header count changed from %d to %d", originalHeaderCount, len(req.Header)) + } +} + +func TestAgentTransport_UsesDefaultTransportWhenBaseIsNil(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Create transport with nil Base + transport := &AgentTransport{Base: nil} + client := &http.Client{Transport: transport, Timeout: 5 * time.Second} + + resp, err := client.Get(server.URL) + if err != nil { + t.Fatalf("unexpected error with nil Base: %v", err) + } + resp.Body.Close() +} + +func TestNewClient_SetsTimeout(t *testing.T) { + client := NewClient(42 * time.Second) + if client.Timeout != 42*time.Second { + t.Errorf("Timeout = %v, want %v", client.Timeout, 42*time.Second) + } +} diff --git a/main.go b/main.go index 5f1596b..a6b96d7 100644 --- a/main.go +++ b/main.go @@ -22,11 +22,11 @@ import ( "hostlink/config" "hostlink/config/appconf" "hostlink/internal/dbconn" + "hostlink/internal/httpclient" "hostlink/internal/update" "hostlink/internal/validator" "hostlink/version" "log" - "net/http" "os" "syscall" "time" @@ -289,7 +289,7 @@ func startSelfUpdateJob(ctx context.Context) { // Create update checker checker, err := updatecheck.New( - &http.Client{Timeout: 30 * time.Second}, + httpclient.NewClient(30*time.Second), appconf.ControlPlaneURL(), agentID, signer,