diff --git a/TESTING.md b/TESTING.md index b2d9cffa..c5e1548e 100644 --- a/TESTING.md +++ b/TESTING.md @@ -80,6 +80,15 @@ Use this pattern as a seed and add more mocks only as the code under test demand Acceptance tests (`*_acc_test.go`) run against a live Azure DevOps organization using the harness under `internal/test`. They are opt-in and should only be used when unit tests and mocks cannot provide enough confidence. +### Getting started (`AcceptanceTest`) + +`inttest.Test` executes a `TestCase` using a `TestContext` created by the harness. The `TestCase.AcceptanceTest` boolean controls which context is used: + +- `AcceptanceTest: true` (recommended for all `*_acc_test.go`): uses `inttest.NewAccTestContext(t)` which reads the required `AZDO_ACC_*` environment variables and creates real SDK clients for the configured organization. +- `AcceptanceTest: false`: uses `inttest.NewTestContext(t)` which generates random placeholder values for `Org`, `PAT`, and `Project`. This is only useful for hermetic unit tests that want a lightweight `util.CmdContext` implementation; it is not suitable for live Azure DevOps operations. + +For acceptance tests, always set `AcceptanceTest: true`. Otherwise the harness will still be gated by `AZDO_ACC_TEST=1`, but the generated placeholder org/token values will typically cause confusing failures once your steps attempt real API calls. + ### When to add an acceptance test - You need to verify a workflow/command (e.g., modifying security permissions) against real data. @@ -92,20 +101,39 @@ Acceptance tests (`*_acc_test.go`) run against a live Azure DevOps organization | ------------------ | ------------------------------------------------------------------------------------------------------ | | `AZDO_ACC_TEST=1` | Enables acceptance tests. Without it, `inttest.Test` skips all steps. | | `AZDO_ACC_ORG` | Organization name used for the session. | -| `AZDO_ACC_ORG_URL` | Optional explicit organization URL; defaults to `https://dev.azure.com/`. | | `AZDO_ACC_PAT` | Personal Access Token with the scopes required by the test steps. | | `AZDO_ACC_PROJECT` | Project name used by acceptance tests that operate on project-scoped resources. | -| `AZDO_ACC_TIMEOUT` | Optional override for the default 60 s timeout. Accepts Go durations (`45s`, `2m`) or integer seconds. | +| `AZDO_ACC_TIMEOUT` | Optional override for the default 240 s timeout. Accepts Go durations (`45s`, `2m`) or integer seconds. Use `-1` to disable timeouts. | ### Step-by-step skeleton 1. Place the test in the command package with the `_acc_test.go` suffix. -2. Wrap your steps in `inttest.Test(t, inttest.TestCase{ Steps: []inttest.Step{ ... } })`. +2. Wrap your steps in `inttest.Test(t, inttest.TestCase{ AcceptanceTest: true, Steps: []inttest.Step{ ... } })`. 3. **PreRun**: create or seed live resources (groups, repositories, permissions) using `ctx.ClientFactory()`. 4. **Run**: construct the command options and call the command’s `run...` helper directly (e.g., `return runCommand(ctx, opts)`). -5. **Verify**: use `inttest.Poll` to wait for eventual consistency and assert desired state. +5. **Verify**: use `internal/util/poll.go` (`pollutil.Poll(ctx.Context(), ...)`) to wait for eventual consistency and assert desired state. 6. **PostRun**: delete or revert all resources you created; aggregate cleanup errors with `errors.Join`. +Minimal scaffold: + +```go +inttest.Test(t, inttest.TestCase{ + AcceptanceTest: true, + PreCheck: func() error { + // Validate required inputs (e.g., ensure AZDO_ACC_PROJECT is set when needed). + return nil + }, + Steps: []inttest.Step{ + { + PreRun: func(ctx inttest.TestContext) error { return nil }, + Run: func(ctx inttest.TestContext) error { return nil }, + Verify: func(ctx inttest.TestContext) error { return nil }, + PostRun: func(ctx inttest.TestContext) error { return nil }, + }, + }, +}) +``` + Example shell to execute a single acceptance test: ```bash AZDO_ACC_TEST=1 \ @@ -119,7 +147,7 @@ Acceptance tests are not run in CI; execute them manually before publishing feat - `inttest.TestContext` now exposes `Project()` alongside `Org`, `OrgUrl`, and `PAT`. Set `AZDO_ACC_PROJECT` when a test needs to target a specific project and fail fast in `PreRun` if it is missing. - Use `TestContext.SetValue(key, value)`/`Value(key)` to propagate data across `PreRun`, `Run`, `Verify`, and `PostRun` without relying on package-level variables. Keys can be simple strings or typed aliases; mimic `context.Context` usage. -- The helper `inttest.WriteTestFile(path, contents)` creates or truncates files with `0600` permissions and ensures parent directories exist, which is useful for acceptance tests that need temporary credentials or certificates. +- The helpers `inttest.WriteTestFile(t, contents)` and `inttest.WriteTestFileWithName(t, filename, contents)` create files with `0600` permissions under `t.TempDir()`, which is useful for acceptance tests that need temporary credentials or certificates. ### Updating Mocks diff --git a/docs/azdo_help_reference.md b/docs/azdo_help_reference.md index 9ffcef35..9a6070ff 100644 --- a/docs/azdo_help_reference.md +++ b/docs/azdo_help_reference.md @@ -1089,7 +1089,10 @@ Create an Azure Resource Manager service connection --subscription-name string Azure subscription name -t, --template string Format JSON output using a Go template; see "azdo help formatting" --tenant-id string Azure tenant ID (e.g., GUID) --y, --yes Skip confirmation prompts + --timeout duration Maximum time to wait when --wait or --validate-connection is enabled (default 2m0s) + --validate-connection Run TestConnection after creation (opt-in) + --validate-schema Validate auth scheme/params against endpoint type metadata (opt-in) + --wait Wait until the endpoint reports ready/failed ``` Aliases @@ -1103,13 +1106,19 @@ cr, c, new, n, add, a Create a GitHub service endpoint ``` - --configuration-id string Configuration for connecting to the endpoint (use an OAuth/installation configuration). Mutually exclusive with --token. --q, --jq expression Filter JSON output using a jq expression - --json fields[=*] Output JSON with the specified fields. Prefix a field with '-' to exclude it. - --name string Name of the service endpoint --t, --template string Format JSON output using a Go template; see "azdo help formatting" - --token string Visit https://github.com/settings/tokens to create personal access tokens. Recommended scopes: repo, user, admin:repo_hook. If omitted, you will be prompted for a token when interactive. - --url string GitHub URL (defaults to https://github.com) + --configuration-id string Configuration for connecting to the endpoint (use an OAuth/installation configuration). Mutually exclusive with --token. + --description string Description for the service endpoint + --grant-permission-to-all-pipelines Grant access permission to all pipelines to use the service connection +-q, --jq expression Filter JSON output using a jq expression + --json fields[=*] Output JSON with the specified fields. Prefix a field with '-' to exclude it. + --name string Name of the service endpoint +-t, --template string Format JSON output using a Go template; see "azdo help formatting" + --timeout duration Maximum time to wait when --wait or --validate-connection is enabled (default 2m0s) + --token string Visit https://github.com/settings/tokens to create personal access tokens. Recommended scopes: repo, user, admin:repo_hook. If omitted, you will be prompted for a token when interactive. + --url string GitHub URL (defaults to https://github.com) + --validate-connection Run TestConnection after creation (opt-in) + --validate-schema Validate auth scheme/params against endpoint type metadata (opt-in) + --wait Wait until the endpoint reports ready/failed ``` ### `azdo service-endpoint delete [ORGANIZATION/]PROJECT/ID_OR_NAME [flags]` diff --git a/docs/azdo_service-endpoint_create_azurerm.md b/docs/azdo_service-endpoint_create_azurerm.md index ba547527..e5dc53f0 100644 --- a/docs/azdo_service-endpoint_create_azurerm.md +++ b/docs/azdo_service-endpoint_create_azurerm.md @@ -83,9 +83,21 @@ This command is modeled after the Azure DevOps Terraform Provider's implementati Azure tenant ID (e.g., GUID) -* `-y`, `--yes` +* `--timeout` `duration` (default `"2m0s"`) - Skip confirmation prompts + Maximum time to wait when --wait or --validate-connection is enabled + +* `--validate-connection` + + Run TestConnection after creation (opt-in) + +* `--validate-schema` + + Validate auth scheme/params against endpoint type metadata (opt-in) + +* `--wait` + + Wait until the endpoint reports ready/failed ### ALIASES diff --git a/docs/azdo_service-endpoint_create_github.md b/docs/azdo_service-endpoint_create_github.md index e5a67e53..260d9d1c 100644 --- a/docs/azdo_service-endpoint_create_github.md +++ b/docs/azdo_service-endpoint_create_github.md @@ -14,6 +14,14 @@ Create a GitHub service endpoint using a personal access token (PAT) or an insta Configuration for connecting to the endpoint (use an OAuth/installation configuration). Mutually exclusive with --token. +* `--description` `string` + + Description for the service endpoint + +* `--grant-permission-to-all-pipelines` + + Grant access permission to all pipelines to use the service connection + * `-q`, `--jq` `expression` Filter JSON output using a jq expression @@ -30,6 +38,10 @@ Create a GitHub service endpoint using a personal access token (PAT) or an insta Format JSON output using a Go template; see "azdo help formatting" +* `--timeout` `duration` (default `"2m0s"`) + + Maximum time to wait when --wait or --validate-connection is enabled + * `--token` `string` Visit https://github.com/settings/tokens to create personal access tokens. Recommended scopes: repo, user, admin:repo_hook. If omitted, you will be prompted for a token when interactive. @@ -38,6 +50,18 @@ Create a GitHub service endpoint using a personal access token (PAT) or an insta GitHub URL (defaults to https://github.com) +* `--validate-connection` + + Run TestConnection after creation (opt-in) + +* `--validate-schema` + + Validate auth scheme/params against endpoint type metadata (opt-in) + +* `--wait` + + Wait until the endpoint reports ready/failed + ### JSON Fields diff --git a/internal/cmd/security/permission/delete/delete_acc_test.go b/internal/cmd/security/permission/delete/delete_acc_test.go index 17aa98dd..24f78095 100644 --- a/internal/cmd/security/permission/delete/delete_acc_test.go +++ b/internal/cmd/security/permission/delete/delete_acc_test.go @@ -13,6 +13,7 @@ import ( "github.com/tmeckel/azdo-cli/internal/cmd/security/permission/shared" inttest "github.com/tmeckel/azdo-cli/internal/test" "github.com/tmeckel/azdo-cli/internal/types" + pollutil "github.com/tmeckel/azdo-cli/internal/util" ) type aceContainer struct { @@ -31,6 +32,7 @@ func TestAccDeletePermission(t *testing.T) { var groupIdentity string inttest.Test(t, inttest.TestCase{ + AcceptanceTest: true, Steps: []inttest.Step{ { PreRun: func(ctx inttest.TestContext) error { @@ -103,7 +105,7 @@ func TestAccDeletePermission(t *testing.T) { if err != nil { return err } - return inttest.Poll(func() error { + return pollutil.Poll(ctx.Context(), func() error { ace, err := shared.FindAccessControlEntry(ctx.Context(), secClient, nsUUID, token, groupIdentity) if err != nil { return err @@ -117,7 +119,7 @@ func TestAccDeletePermission(t *testing.T) { return nil } return fmt.Errorf("expected ACE to be removed; allow=%d deny=%d", allow, deny) - }, inttest.PollOptions{ + }, pollutil.PollOptions{ Tries: 10, Timeout: 30 * time.Second, }) diff --git a/internal/cmd/security/permission/reset/reset_acc_test.go b/internal/cmd/security/permission/reset/reset_acc_test.go index 3ff39cab..7a9c65e1 100644 --- a/internal/cmd/security/permission/reset/reset_acc_test.go +++ b/internal/cmd/security/permission/reset/reset_acc_test.go @@ -13,6 +13,7 @@ import ( "github.com/tmeckel/azdo-cli/internal/cmd/security/permission/shared" inttest "github.com/tmeckel/azdo-cli/internal/test" "github.com/tmeckel/azdo-cli/internal/types" + pollutil "github.com/tmeckel/azdo-cli/internal/util" ) type aceContainer struct { @@ -44,6 +45,7 @@ func TestAccResetPermission(t *testing.T) { ) inttest.Test(t, inttest.TestCase{ + AcceptanceTest: true, Steps: []inttest.Step{ { PreRun: func(ctx inttest.TestContext) error { @@ -119,7 +121,7 @@ func TestAccResetPermission(t *testing.T) { return err } - return inttest.Poll(func() error { + return pollutil.Poll(ctx.Context(), func() error { ace, err := shared.FindAccessControlEntry(ctx.Context(), secClient, nsUUID, token, groupIdentity) if err != nil { return err @@ -135,7 +137,7 @@ func TestAccResetPermission(t *testing.T) { return fmt.Errorf("unexpected permission state allow=0x%X deny=0x%X", allow, deny) } return nil - }, inttest.PollOptions{ + }, pollutil.PollOptions{ Tries: 10, Timeout: 30 * time.Second, }) diff --git a/internal/cmd/security/permission/update/update_acc_test.go b/internal/cmd/security/permission/update/update_acc_test.go index 4af6ef5a..c379ba4e 100644 --- a/internal/cmd/security/permission/update/update_acc_test.go +++ b/internal/cmd/security/permission/update/update_acc_test.go @@ -12,6 +12,7 @@ import ( "github.com/tmeckel/azdo-cli/internal/cmd/security/permission/shared" inttest "github.com/tmeckel/azdo-cli/internal/test" + pollutil "github.com/tmeckel/azdo-cli/internal/util" ) func TestAccUpdatePermission(t *testing.T) { @@ -26,6 +27,7 @@ func TestAccUpdatePermission(t *testing.T) { groupName := fmt.Sprintf("azdo-cli-test-group-%s", uuid.New().String()) inttest.Test(t, inttest.TestCase{ + AcceptanceTest: true, Steps: []inttest.Step{ { PreRun: func(ctx inttest.TestContext) error { @@ -72,7 +74,7 @@ func TestAccUpdatePermission(t *testing.T) { if err != nil { return err } - return inttest.Poll(func() error { + return pollutil.Poll(ctx.Context(), func() error { ace, err := shared.FindAccessControlEntry(ctx.Context(), sec, nsUUID, token, groupIdentity) if err != nil { return err @@ -87,7 +89,7 @@ func TestAccUpdatePermission(t *testing.T) { return fmt.Errorf("allow mask %d does not contain expected bit 0x2", *ace.Allow) } return nil - }, inttest.PollOptions{ + }, pollutil.PollOptions{ Tries: 10, Timeout: 30 * time.Second, }) diff --git a/internal/cmd/serviceendpoint/README.md b/internal/cmd/serviceendpoint/README.md new file mode 100644 index 00000000..adbdea01 --- /dev/null +++ b/internal/cmd/serviceendpoint/README.md @@ -0,0 +1,301 @@ +# Service Endpoint Framework (`internal/cmd/serviceendpoint`) + +This directory contains the CLI command tree for working with Azure DevOps service endpoints (service connections) and the shared “typed endpoint” framework used by per-type create commands (e.g. `github`, `azurerm`). + +The framework’s goal is to make it easy to add **new typed service endpoint commands** while keeping behavior consistent (scope parsing, progress indicator, optional validation, readiness wait, connection test, pipeline permissions, output formatting). + +## Command overview + +- Top-level group: `azdo service-endpoint …` (`internal/cmd/serviceendpoint/serviceendpoint.go`) +- Generic create/import (JSON payload): `azdo service-endpoint create [ORG/]PROJECT --from-file …` (`internal/cmd/serviceendpoint/create/create.go`) +- Typed create commands: `azdo service-endpoint create [ORG/]PROJECT …` (subcommands under `internal/cmd/serviceendpoint/create/`) +- Update: `azdo service-endpoint update [ORG/]PROJECT/ID_OR_NAME …` (`internal/cmd/serviceendpoint/update/update.go`) +- Show/List/Delete/Export: `internal/cmd/serviceendpoint/show`, `list`, `delete`, `export` + +## Shared framework components + +All shared primitives live under `internal/cmd/serviceendpoint/shared/`. + +### Typed create runner + +The runner is implemented in `internal/cmd/serviceendpoint/shared/runner_create.go`: + +- Entry point: `shared.RunTypedCreate(cmd, args, cfg)` +- Responsibilities: + - parse scope from the positional argument (`[ORG/]PROJECT`) + - resolve the target project reference + - build a common `serviceendpoint.ServiceEndpoint` skeleton (name/description/type/owner/project refs) + - call the type-specific configurer to populate: + - `endpoint.Url` + - `endpoint.Authorization` (scheme + parameters) + - `endpoint.Data` + - optional behaviors driven by common flags: + - `--validate-schema` + - `--wait` + - `--validate-connection` + - `--grant-permission-to-all-pipelines` + - redact authorization parameters before output + - output (JSON export when requested, otherwise template output) + +### Common typed-create flags + +Typed create commands register common flags via `shared.AddCreateCommonFlags(cmd)` in `internal/cmd/serviceendpoint/shared/create_common.go`. + +These flags are shared across all typed create commands: + +- `--name` (required) +- `--description` +- `--validate-schema` +- `--wait` +- `--timeout` +- `--validate-connection` +- `--grant-permission-to-all-pipelines` +- JSON output options via `util.AddJSONFlags` (`--json`, `--jq`, `--template`) + +Implementation detail: `AddCreateCommonFlags` stores a `createCommonOptions` value in the Cobra command context under the key `createCommonOptions`. The typed runner reads those options from `cmd.Context()`. + +### Metadata validation (`--validate-schema`) + +When `--validate-schema` is set, `shared.RunTypedCreate` calls: + +- `shared.ValidateEndpointAgainstMetadata` (`internal/cmd/serviceendpoint/shared/type_validate.go`) + +That validator fetches live endpoint type metadata via: + +- `shared.GetServiceEndpointTypes` (`internal/cmd/serviceendpoint/shared/type_registry.go`) + +Validation is currently focused on **authorization schema correctness**: + +- endpoint type exists +- authorization scheme exists for that type +- required authorization parameter keys (from `inputDescriptors[].validation.isRequired`) are present + +Gotchas: + +- This validation requires fetching live endpoint type metadata from the organization. If the metadata request fails (permissions/network/organization settings), `--validate-schema` will fail the command. +- Validation is limited to auth scheme/parameter presence; it does not validate endpoint URLs, reachability, or correctness of non-auth `Data` fields. + +### Readiness wait (`--wait`) + +Readiness wait is implemented in `internal/cmd/serviceendpoint/shared/wait_ready.go` and uses `internal/util/poll.go`. + +The runner polls `GetServiceEndpointDetails` until: + +- the endpoint reports `IsReady == true`, or +- the endpoint reports a terminal failure (`operationStatus.state == "failed"` when present) + +### TestConnection (`--validate-connection`) + +Connection validation is implemented in `internal/cmd/serviceendpoint/shared/test_connection.go`. + +Behavior: + +- fetch metadata and verify the endpoint type supports a `TestConnection` data source +- execute/poll `ExecuteServiceEndpointRequest` until the result’s `StatusCode` becomes `"ok"` (case-insensitive) or the timeout is reached + +Gotchas: + +- Not every service endpoint type supports `TestConnection`; in that case the command fails with an explicit “not supported” error when `--validate-connection` is enabled. +- This uses the organization’s endpoint type metadata to find the `TestConnection` data source, so it can fail for the same reasons as `--validate-schema` (metadata fetch issues). + +### Pipeline permissions (`--grant-permission-to-all-pipelines`) + +Pipeline permission granting is implemented in `internal/cmd/serviceendpoint/shared/pipeline_permissions.go` using the `pipelinepermissions` client. + +In the typed create runner, this step runs after creation (and after wait/test if enabled). If granting fails, the runner attempts rollback by deleting the created endpoint. + +For typed update commands, permission changes are only applied when the flag is explicitly provided. To revoke access for all pipelines, pass an explicit false value: `--grant-permission-to-all-pipelines=false`. + +### Output and redaction + +Output rendering is centralized in `internal/cmd/serviceendpoint/shared/output.go` and `internal/cmd/serviceendpoint/shared/show.tpl`. + +Typed create redaction: + +- `shared.RunTypedCreate` calls `shared.RedactSecrets(created)` (`internal/cmd/serviceendpoint/shared/endpoints.go`) before output. +- Current behavior is intentionally conservative: all authorization parameter values are replaced with `"REDACTED"`. + +Note: other commands (e.g. `show`, `update`) currently call `shared.Output` directly. If an API response ever contains sensitive authorization parameters, template output will display them unless the caller redacts first. Typed create already does this. + +Note: the typed update runner also redacts before output (`shared.RunTypedUpdate`), but the non-typed `service-endpoint update` command does not currently redact before calling `shared.Output`. + +Implementation detail: the default template output (`shared/show.tpl`) prints `Authorization.Parameters` when present, so redaction needs to happen before calling `shared.Output`. + +## How to add a new typed create command + +Use existing commands as references: + +- GitHub: `internal/cmd/serviceendpoint/create/github/create.go` +- AzureRM: `internal/cmd/serviceendpoint/create/azurerm/create.go` + +### 1) Create the new package + +Create a new directory: + +`internal/cmd/serviceendpoint/create//` + +Then implement `create.go` with a factory: + +- `func NewCmd(ctx util.CmdContext) *cobra.Command` + +### 2) Implement the configurer + +Typed create commands use a small interface (see `internal/cmd/serviceendpoint/shared/runner_create.go`): + +```go +type EndpointTypeConfigurer interface { + CommandContext() util.CmdContext + TypeName() string + Configure(endpoint *serviceendpoint.ServiceEndpoint) error +} +``` + +Recommended structure: + +- define a `*Configurer` struct + - embed/contain `cmdCtx util.CmdContext` + - add fields for type-specific flags +- implement: + - `CommandContext()` returning the injected context + - `TypeName()` returning the Azure DevOps endpoint type identifier (e.g. `github`, `azurerm`) + - `Configure(endpoint)` populating `Url`, `Authorization`, and `Data` + +### 3) Wire Cobra flags and the shared runner + +In `NewCmd`: + +1) instantiate `cfg := &Configurer{cmdCtx: ctx}` +2) create a Cobra command with `Args: cobra.ExactArgs(1)` where the arg is `[ORG/]PROJECT` +3) bind type-specific flags onto `cfg` fields +4) call `shared.AddCreateCommonFlags(cmd)` to add common framework flags (this is required; the shared runner reads options from `cmd.Context()` and will panic if the value is missing) +5) set `RunE` to `return shared.RunTypedCreate(cmd, args, cfg)` + +### 4) Register the subcommand + +In `internal/cmd/serviceendpoint/create/create.go`, add: + +```go +cmd.AddCommand(.NewCmd(ctx)) +``` + +### 5) Tests + +Typed create commands are tested via their **public Cobra surface**: + +- Build the command with `NewCmd(ctx)` +- Provide args/flags with `cmd.SetArgs(...)` +- Execute with `cmd.Execute()` + +This is required because type-specific logic now lives in the configurer + shared runner (`shared.RunTypedCreate`), not in a package-local `runCreate` helper. + +#### Unit tests (hermetic, preferred) + +Unit tests should be table-driven and mock the command context and clients that the shared runner uses. + +Minimum mocks for a typed create test: + +- `CmdContext.Context()`, `CmdContext.IOStreams()` +- `CmdContext.ClientFactory()` +- `core.Client.GetProject(...)` (used by `shared.ResolveProjectReference`) +- `serviceendpoint.Client.CreateServiceEndpoint(...)` + +Mocks when common flags are enabled: + +- `--validate-schema`: `serviceendpoint.Client.GetServiceEndpointTypes(...)` +- `--wait`: `serviceendpoint.Client.GetServiceEndpointDetails(...)` +- `--validate-connection`: `serviceendpoint.Client.ExecuteServiceEndpointRequest(...)` (and related polling calls) +- `--grant-permission-to-all-pipelines`: `pipelinepermissions.Client.UpdatePipelinePermisionsForResource(...)` (and `serviceendpoint.Client.DeleteServiceEndpoint(...)` for rollback paths) + +Assertions should focus on: + +- The endpoint payload passed to `CreateServiceEndpoint` (type, URL, auth scheme, expected auth/data keys) +- Optional follow-up calls are only made when the corresponding flags are set + +Reference implementation: + +- `internal/cmd/serviceendpoint/create/github/create_test.go` + +#### Acceptance tests (live Azure DevOps) + +Acceptance tests should keep the existing harness (`internal/test`) and only change **how the command is invoked**: + +- In the test step `Run`, construct and execute the command: + - `cmd := .NewCmd(ctx)` + - `cmd.SetArgs([]string{projectArg, "--name", ..., ...})` + - `return cmd.Execute()` +- In `Verify`, poll for eventual consistency using `internal/util/poll.go` and assert stable fields. +- In `PostRun`, clean up created endpoints (see helpers under `internal/cmd/serviceendpoint/test`). + +Environment variables / gating are handled by the acceptance harness in `internal/test/helpers.go` (e.g. `AZDO_ACC_TEST`, `AZDO_ACC_ORG`, `AZDO_ACC_PAT`, `AZDO_ACC_PROJECT`). + +Examples: + +- `internal/cmd/serviceendpoint/create/azurerm/create_acc_test.go` +- `internal/cmd/serviceendpoint/create/github/create_acc_test.go` + +### Typed update runner + +The runner is implemented in `internal/cmd/serviceendpoint/shared/runner_update.go`: + +- Entry point: `shared.RunTypedUpdate(cmd, args, cfg)` +- Responsibilities: + - parse scope from the positional argument (`[ORG/]PROJECT/ID_OR_NAME`) + - resolve the existing endpoint + - apply common field updates (name, description) + - call the type-specific configurer to update: + - `endpoint.Url` + - `endpoint.Authorization` + - `endpoint.Data` + - optional behaviors (validate schema, update pipeline permissions, etc.) + - redact secrets and output + +### Common typed-update flags + +Typed update commands register common flags via `shared.AddUpdateCommonFlags(cmd)` in `internal/cmd/serviceendpoint/shared/update_common.go`. + +These flags are shared across all typed update commands: + +- `--name` (optional) +- `--description` (optional) +- `--wait` +- `--timeout` +- `--validate-schema` +- `--validate-connection` +- `--grant-permission-to-all-pipelines` +- JSON output options + +## How to add a new typed update command + +This follows the same pattern as typed create. + +### 1) Create the new package + +Create a new directory: + +`internal/cmd/serviceendpoint/update//` + +Then implement `update.go` with a factory: + +- `func NewCmd(ctx util.CmdContext) *cobra.Command` + +### 2) Implement the configurer + +Reuse or create an `EndpointTypeConfigurer` (same interface as create). The `Configure` method will be called with the *existing* endpoint, allowing you to modify fields based on flags. + +### 3) Wire Cobra flags and the shared runner + +In `NewCmd`: + +1) instantiate `cfg := &Configurer{cmdCtx: ctx}` +2) create a Cobra command with `Args: cobra.ExactArgs(1)` +3) bind type-specific flags onto `cfg` fields +4) call `shared.AddUpdateCommonFlags(cmd)` (this is required; the shared runner reads options from `cmd.Context()` and will panic if the value is missing) +5) set `RunE` to `return shared.RunTypedUpdate(cmd, args, cfg)` + +### 4) Register the subcommand + +In `internal/cmd/serviceendpoint/update/update.go` (or wherever the parent command is), add: + +```go +cmd.AddCommand(.NewCmd(ctx)) +``` diff --git a/internal/cmd/serviceendpoint/create/azurerm/create.go b/internal/cmd/serviceendpoint/create/azurerm/create.go index fd3743c0..4f68f0c2 100644 --- a/internal/cmd/serviceendpoint/create/azurerm/create.go +++ b/internal/cmd/serviceendpoint/create/azurerm/create.go @@ -6,23 +6,30 @@ import ( "os" "github.com/MakeNowJust/heredoc" - "github.com/google/uuid" "github.com/microsoft/azure-devops-go-api/azuredevops/v7/serviceendpoint" "github.com/spf13/cobra" - "go.uber.org/zap" "github.com/tmeckel/azdo-cli/internal/cmd/serviceendpoint/shared" "github.com/tmeckel/azdo-cli/internal/cmd/util" - "github.com/tmeckel/azdo-cli/internal/iostreams" - "github.com/tmeckel/azdo-cli/internal/prompter" - "github.com/tmeckel/azdo-cli/internal/types" ) -type createOptions struct { - project string +const ( + // Authentication Schemes + AuthSchemeServicePrincipal = "ServicePrincipal" + AuthSchemeManagedServiceIdentity = "ManagedServiceIdentity" + AuthSchemeWorkloadIdentityFederation = "WorkloadIdentityFederation" + + // Creation Modes + CreationModeManual = "Manual" + CreationModeAutomatic = "Automatic" - name string - description string + // Scope Levels + ScopeLevelSubscription = "Subscription" + ScopeLevelResourceGroup = "ResourceGroup" + ScopeLevelManagementGroup = "ManagementGroup" +) + +type createOptions struct { authenticationScheme string servicePrincipalID string servicePrincipalKey string @@ -38,29 +45,196 @@ type createOptions struct { serverURL string serviceEndpointCreationMode string grantPermissionToAllPipelines bool +} - yes bool - exporter util.Exporter +type azurermConfigurer struct { + cmdCtx util.CmdContext + createOptions } -const ( - // Authentication Schemes - AuthSchemeServicePrincipal = "ServicePrincipal" - AuthSchemeManagedServiceIdentity = "ManagedServiceIdentity" - AuthSchemeWorkloadIdentityFederation = "WorkloadIdentityFederation" +func (g *azurermConfigurer) CommandContext() util.CmdContext { + return g.cmdCtx +} - // Creation Modes - CreationModeManual = "Manual" - CreationModeAutomatic = "Automatic" +func (c *azurermConfigurer) TypeName() string { + return "azurerm" +} - // Scope Levels - ScopeLevelSubscription = "Subscription" - ScopeLevelResourceGroup = "ResourceGroup" - ScopeLevelManagementGroup = "ManagementGroup" -) +func (c *azurermConfigurer) Configure(endpoint *serviceendpoint.ServiceEndpoint) error { + err := c.validateOpts() + if err != nil { + return err + } + + url, err := c.getEndpointURL() + if err != nil { + return err + } + + endpoint.Url = &url + + authParams := map[string]string{ + "tenantid": c.tenantID, + } + + data := map[string]string{ + "environment": c.environment, + } + + if c.serviceEndpointCreationMode != "" && c.authenticationScheme != AuthSchemeManagedServiceIdentity { + data["creationMode"] = c.serviceEndpointCreationMode + } + + // Scope handling + if c.subscriptionID != "" { + if c.resourceGroup != "" && c.authenticationScheme != AuthSchemeManagedServiceIdentity { + authParams["scope"] = fmt.Sprintf("/subscriptions/%s/resourcegroups/%s", c.subscriptionID, c.resourceGroup) + } + data["scopeLevel"] = ScopeLevelSubscription + data["subscriptionId"] = c.subscriptionID + data["subscriptionName"] = c.subscriptionName + + } else if c.managementGroupID != "" { + data["scopeLevel"] = ScopeLevelManagementGroup + data["managementGroupId"] = c.managementGroupID + data["managementGroupName"] = c.managementGroupName + } + + // Auth scheme specific logic + switch c.authenticationScheme { + case AuthSchemeServicePrincipal: + authParams["serviceprincipalid"] = c.servicePrincipalID + if c.servicePrincipalKey != "" { + authParams["authenticationType"] = "spnKey" + authParams["serviceprincipalkey"] = c.servicePrincipalKey + } else if c.servicePrincipalCertificate != "" { + authParams["authenticationType"] = "spnCertificate" + authParams["servicePrincipalCertificate"] = c.servicePrincipalCertificate + } + case AuthSchemeWorkloadIdentityFederation: + if c.serviceEndpointCreationMode == CreationModeManual { + if c.servicePrincipalID == "" { + return errors.New("serviceprincipalid is required for WorkloadIdentityFederation in Manual mode") + } + authParams["serviceprincipalid"] = c.servicePrincipalID + } else { + authParams["serviceprincipalid"] = "" + } + case AuthSchemeManagedServiceIdentity: + // No extra auth params needed + } + + endpoint.Authorization = &serviceendpoint.EndpointAuthorization{ + Scheme: &c.authenticationScheme, + Parameters: &authParams, + } + endpoint.Data = &data + return nil +} + +func (c *azurermConfigurer) validateOpts() error { + if c.tenantID == "" { + return errors.New("--tenant-id is required") + } + + // Validate scope + if c.subscriptionID == "" && c.managementGroupID == "" { + return errors.New("one of --subscription-id or --management-group-id must be provided") + } + if c.subscriptionID != "" && c.managementGroupID != "" { + return errors.New("--subscription-id and --management-group-id are mutually exclusive") + } + if c.managementGroupID != "" && c.managementGroupName == "" { + return errors.New("--management-group-name is required when --management-group-id is specified") + } + if c.subscriptionID != "" && c.subscriptionName == "" { + return errors.New("--subscription-name is required when --subscription-id is specified") + } + + // Set creation mode + hasCredentials := c.servicePrincipalID != "" + if c.authenticationScheme == AuthSchemeServicePrincipal || c.authenticationScheme == AuthSchemeWorkloadIdentityFederation { + if hasCredentials { + c.serviceEndpointCreationMode = CreationModeManual + } else { + c.serviceEndpointCreationMode = CreationModeAutomatic + } + } + + // Validate auth scheme specific requirements + switch c.authenticationScheme { + case AuthSchemeServicePrincipal: + if c.serviceEndpointCreationMode == CreationModeAutomatic { + return errors.New("automatic creation mode is not supported for ServicePrincipal from the CLI. Please provide --service-principal-id") + } + if c.servicePrincipalKey == "" && c.certificatePath == "" { + cmdCtx := c.cmdCtx + ios, err := cmdCtx.IOStreams() + if err != nil { + return err + } + + if !ios.CanPrompt() { + return errors.New("--service-principal-key not provided and prompting disabled") + } + + p, err := cmdCtx.Prompter() + if err != nil { + return err + } + + secret, err := p.Password("Service principal key:") + if err != nil { + return fmt.Errorf("prompt for secret failed: %w", err) + } + c.servicePrincipalKey = secret + } + if c.servicePrincipalKey != "" && c.certificatePath != "" { + return errors.New("--service-principal-key and --certificate-path are mutually exclusive") + } + if c.certificatePath != "" { + certBytes, err := os.ReadFile(c.certificatePath) + if err != nil { + return fmt.Errorf("failed to read certificate file: %w", err) + } + c.servicePrincipalCertificate = string(certBytes) + } + case AuthSchemeWorkloadIdentityFederation: + // This is a valid scenario, where ADO will configure the SPN. + case AuthSchemeManagedServiceIdentity: + // No specific validation needed + default: + return fmt.Errorf("invalid --authentication-scheme: %s", c.authenticationScheme) + } + + if c.environment == "AzureStack" && c.serverURL == "" { + return errors.New("--server-url is required when environment is AzureStack") + } + + return nil +} + +func (c *azurermConfigurer) getEndpointURL() (string, error) { + switch c.environment { + case "AzureCloud": + return "https://management.azure.com/", nil + case "AzureChinaCloud": + return "https://management.chinacloudapi.cn/", nil + case "AzureUSGovernment": + return "https://management.usgovcloudapi.net/", nil + case "AzureGermanCloud": + return "https://management.microsoftazure.de/", nil + case "AzureStack": + return c.serverURL, nil + default: + return "", fmt.Errorf("unknown environment: %s", c.environment) + } +} func NewCmd(ctx util.CmdContext) *cobra.Command { - opts := &createOptions{} + cfg := &azurermConfigurer{ + cmdCtx: ctx, + } cmd := &cobra.Command{ Use: "azurerm [ORGANIZATION/]PROJECT --name --authentication-scheme [flags]", @@ -155,298 +329,29 @@ func NewCmd(ctx util.CmdContext) *cobra.Command { "a", }, RunE: func(cmd *cobra.Command, args []string) error { - opts.project = args[0] - return runCreate(ctx, opts) + return shared.RunTypedCreate(cmd, args, cfg) }, } - cmd.Flags().StringVar(&opts.name, "name", "", "Name of the service endpoint") - cmd.Flags().StringVar(&opts.description, "description", "", "Description for the service endpoint") - util.StringEnumFlag(cmd, &opts.authenticationScheme, "authentication-scheme", "", AuthSchemeServicePrincipal, + util.StringEnumFlag(cmd, &cfg.authenticationScheme, "authentication-scheme", "", AuthSchemeServicePrincipal, []string{AuthSchemeServicePrincipal, AuthSchemeManagedServiceIdentity, AuthSchemeWorkloadIdentityFederation}, "Authentication scheme") - cmd.Flags().StringVar(&opts.tenantID, "tenant-id", "", "Azure tenant ID (e.g., GUID)") - cmd.Flags().StringVar(&opts.subscriptionID, "subscription-id", "", "Azure subscription ID (e.g., GUID)") - cmd.Flags().StringVar(&opts.subscriptionName, "subscription-name", "", "Azure subscription name") - cmd.Flags().StringVar(&opts.managementGroupID, "management-group-id", "", "Azure management group ID") - cmd.Flags().StringVar(&opts.managementGroupName, "management-group-name", "", "Azure management group name") - cmd.Flags().StringVar(&opts.resourceGroup, "resource-group", "", "Name of the resource group (for subscription-level scope)") - cmd.Flags().StringVar(&opts.servicePrincipalID, "service-principal-id", "", "Service principal/application ID (e.g., GUID)") - cmd.Flags().StringVar(&opts.servicePrincipalKey, "service-principal-key", "", "Service principal key (secret value)") - cmd.Flags().StringVar(&opts.certificatePath, "certificate-path", "", "Path to service principal certificate file (PEM format)") - util.StringEnumFlag(cmd, &opts.environment, "environment", "", "AzureCloud", + cmd.Flags().StringVar(&cfg.tenantID, "tenant-id", "", "Azure tenant ID (e.g., GUID)") + cmd.Flags().StringVar(&cfg.subscriptionID, "subscription-id", "", "Azure subscription ID (e.g., GUID)") + cmd.Flags().StringVar(&cfg.subscriptionName, "subscription-name", "", "Azure subscription name") + cmd.Flags().StringVar(&cfg.managementGroupID, "management-group-id", "", "Azure management group ID") + cmd.Flags().StringVar(&cfg.managementGroupName, "management-group-name", "", "Azure management group name") + cmd.Flags().StringVar(&cfg.resourceGroup, "resource-group", "", "Name of the resource group (for subscription-level scope)") + cmd.Flags().StringVar(&cfg.servicePrincipalID, "service-principal-id", "", "Service principal/application ID (e.g., GUID)") + cmd.Flags().StringVar(&cfg.servicePrincipalKey, "service-principal-key", "", "Service principal key (secret value)") + cmd.Flags().StringVar(&cfg.certificatePath, "certificate-path", "", "Path to service principal certificate file (PEM format)") + util.StringEnumFlag(cmd, &cfg.environment, "environment", "", "AzureCloud", []string{"AzureCloud", "AzureChinaCloud", "AzureUSGovernment", "AzureGermanCloud", "AzureStack"}, "Azure environment") - cmd.Flags().StringVar(&opts.serverURL, "server-url", "", "Azure Stack Resource Manager base URL. Required if --environment is AzureStack.") - cmd.Flags().BoolVarP(&opts.yes, "yes", "y", false, "Skip confirmation prompts") - cmd.Flags().BoolVar(&opts.grantPermissionToAllPipelines, "grant-permission-to-all-pipelines", false, "Grant access permission to all pipelines to use the service connection") - - util.AddJSONFlags(cmd, &opts.exporter, shared.ServiceEndpointJSONFields) + cmd.Flags().StringVar(&cfg.serverURL, "server-url", "", "Azure Stack Resource Manager base URL. Required if --environment is AzureStack.") _ = cmd.MarkFlagRequired("name") _ = cmd.MarkFlagRequired("authentication-scheme") - return cmd -} - -func runCreate(ctx util.CmdContext, opts *createOptions) error { - ios, err := ctx.IOStreams() - if err != nil { - return err - } - - p, err := ctx.Prompter() - if err != nil { - return err - } - - scope, err := util.ParseProjectScope(ctx, opts.project) - if err != nil { - return util.FlagErrorWrap(err) - } - - if err := validateOpts(opts, ios, p); err != nil { - return util.FlagErrorWrap(err) - } - - if !opts.yes { - ok, err := p.Confirm("This will create credentials in Azure DevOps. Continue?", false) - if err != nil { - return err - } - if !ok { - return util.ErrCancel - } - } - - projectRef, err := shared.ResolveProjectReference(ctx, scope) - if err != nil { - return util.FlagErrorWrap(err) - } - - endpoint, err := buildServiceEndpoint(opts, projectRef) - if err != nil { - return util.FlagErrorf("failed to build service endpoint payload: %w", err) - } - - ios.StartProgressIndicator() - defer ios.StopProgressIndicator() - - client, err := ctx.ClientFactory().ServiceEndpoint(ctx.Context(), scope.Organization) - if err != nil { - return err - } - - createdEndpoint, err := client.CreateServiceEndpoint(ctx.Context(), serviceendpoint.CreateServiceEndpointArgs{ - Endpoint: endpoint, - }) - if err != nil { - return fmt.Errorf("failed to create service endpoint: %w", err) - } - - zap.L().Debug("azurerm service endpoint created", - zap.String("id", types.GetValue(createdEndpoint.Id, uuid.Nil).String()), - zap.String("name", types.GetValue(createdEndpoint.Name, "")), - ) - - if opts.grantPermissionToAllPipelines { - projectID := types.GetValue(projectRef.Id, uuid.Nil) - if projectID == uuid.Nil { - return errors.New("project reference missing ID") - } - - endpointID := types.GetValue(createdEndpoint.Id, uuid.Nil) - if endpointID == uuid.Nil { - return errors.New("service endpoint create response missing ID") - } - - if err := shared.SetAllPipelinesAccessToEndpoint(ctx, - scope.Organization, - projectID, - endpointID, - true, - func() error { - return client.DeleteServiceEndpoint(ctx.Context(), serviceendpoint.DeleteServiceEndpointArgs{ - EndpointId: types.ToPtr(endpointID), - ProjectIds: &[]string{projectID.String()}, - }) - }); err != nil { - return err - } - - zap.L().Debug("Granted all pipelines permission to service endpoint", - zap.String("id", endpointID.String()), - ) - } - - ios.StopProgressIndicator() - - return shared.Output(ctx, createdEndpoint, opts.exporter) -} - -func validateOpts(opts *createOptions, ios *iostreams.IOStreams, p prompter.Prompter) error { - if opts.tenantID == "" { - return errors.New("--tenant-id is required") - } - - // Validate scope - if opts.subscriptionID == "" && opts.managementGroupID == "" { - return errors.New("one of --subscription-id or --management-group-id must be provided") - } - if opts.subscriptionID != "" && opts.managementGroupID != "" { - return errors.New("--subscription-id and --management-group-id are mutually exclusive") - } - if opts.managementGroupID != "" && opts.managementGroupName == "" { - return errors.New("--management-group-name is required when --management-group-id is specified") - } - if opts.subscriptionID != "" && opts.subscriptionName == "" { - return errors.New("--subscription-name is required when --subscription-id is specified") - } - - // Set creation mode - hasCredentials := opts.servicePrincipalID != "" - if opts.authenticationScheme == AuthSchemeServicePrincipal || opts.authenticationScheme == AuthSchemeWorkloadIdentityFederation { - if hasCredentials { - opts.serviceEndpointCreationMode = CreationModeManual - } else { - opts.serviceEndpointCreationMode = CreationModeAutomatic - } - } - - // Validate auth scheme specific requirements - switch opts.authenticationScheme { - case AuthSchemeServicePrincipal: - if opts.serviceEndpointCreationMode == CreationModeAutomatic { - return errors.New("automatic creation mode is not supported for ServicePrincipal from the CLI. Please provide --service-principal-id") - } - if opts.servicePrincipalKey == "" && opts.certificatePath == "" { - if !ios.CanPrompt() { - return errors.New("--service-principal-key not provided and prompting disabled") - } - secret, err := p.Password("Service principal key:") - if err != nil { - return fmt.Errorf("prompt for secret failed: %w", err) - } - opts.servicePrincipalKey = secret - } - if opts.servicePrincipalKey != "" && opts.certificatePath != "" { - return errors.New("--service-principal-key and --certificate-path are mutually exclusive") - } - if opts.certificatePath != "" { - certBytes, err := os.ReadFile(opts.certificatePath) - if err != nil { - return fmt.Errorf("failed to read certificate file: %w", err) - } - opts.servicePrincipalCertificate = string(certBytes) - } - case AuthSchemeWorkloadIdentityFederation: - // This is a valid scenario, where ADO will configure the SPN. - case AuthSchemeManagedServiceIdentity: - // No specific validation needed - default: - return fmt.Errorf("invalid --authentication-scheme: %s", opts.authenticationScheme) - } - - if opts.environment == "AzureStack" && opts.serverURL == "" { - return errors.New("--server-url is required when environment is AzureStack") - } - - return nil -} - -func buildServiceEndpoint(opts *createOptions, projectRef *serviceendpoint.ProjectReference) (*serviceendpoint.ServiceEndpoint, error) { - endpointType := "azurerm" - endpointURL, err := getEndpointURL(opts) - if err != nil { - return nil, err - } - owner := "library" - - authParams := map[string]string{ - "tenantid": opts.tenantID, - } - - data := map[string]string{ - "environment": opts.environment, - } - - if opts.serviceEndpointCreationMode != "" && opts.authenticationScheme != AuthSchemeManagedServiceIdentity { - data["creationMode"] = opts.serviceEndpointCreationMode - } - - // Scope handling - if opts.subscriptionID != "" { - if opts.resourceGroup != "" && opts.authenticationScheme != AuthSchemeManagedServiceIdentity { - authParams["scope"] = fmt.Sprintf("/subscriptions/%s/resourcegroups/%s", opts.subscriptionID, opts.resourceGroup) - } - data["scopeLevel"] = ScopeLevelSubscription - data["subscriptionId"] = opts.subscriptionID - data["subscriptionName"] = opts.subscriptionName - - } else if opts.managementGroupID != "" { - data["scopeLevel"] = ScopeLevelManagementGroup - data["managementGroupId"] = opts.managementGroupID - data["managementGroupName"] = opts.managementGroupName - } - - // Auth scheme specific logic - switch opts.authenticationScheme { - case AuthSchemeServicePrincipal: - authParams["serviceprincipalid"] = opts.servicePrincipalID - if opts.servicePrincipalKey != "" { - authParams["authenticationType"] = "spnKey" - authParams["serviceprincipalkey"] = opts.servicePrincipalKey - } else if opts.servicePrincipalCertificate != "" { - authParams["authenticationType"] = "spnCertificate" - authParams["servicePrincipalCertificate"] = opts.servicePrincipalCertificate - } - case AuthSchemeWorkloadIdentityFederation: - if opts.serviceEndpointCreationMode == CreationModeManual { - if opts.servicePrincipalID == "" { - return nil, errors.New("serviceprincipalid is required for WorkloadIdentityFederation in Manual mode") - } - authParams["serviceprincipalid"] = opts.servicePrincipalID - } else { - authParams["serviceprincipalid"] = "" - } - case AuthSchemeManagedServiceIdentity: - // No extra auth params needed - } - - return &serviceendpoint.ServiceEndpoint{ - Name: &opts.name, - Type: &endpointType, - Url: &endpointURL, - Description: &opts.description, - Owner: &owner, - Authorization: &serviceendpoint.EndpointAuthorization{ - Scheme: &opts.authenticationScheme, - Parameters: &authParams, - }, - Data: &data, - ServiceEndpointProjectReferences: &[]serviceendpoint.ServiceEndpointProjectReference{ - { - ProjectReference: projectRef, - Name: &opts.name, - Description: &opts.description, - }, - }, - }, nil -} - -func getEndpointURL(opts *createOptions) (string, error) { - switch opts.environment { - case "AzureCloud": - return "https://management.azure.com/", nil - case "AzureChinaCloud": - return "https://management.chinacloudapi.cn/", nil - case "AzureUSGovernment": - return "https://management.usgovcloudapi.net/", nil - case "AzureGermanCloud": - return "https://management.microsoftazure.de/", nil - case "AzureStack": - return opts.serverURL, nil - default: - return "", fmt.Errorf("unknown environment: %s", opts.environment) - } + return shared.AddCreateCommonFlags(cmd) } diff --git a/internal/cmd/serviceendpoint/create/azurerm/create_acc_test.go b/internal/cmd/serviceendpoint/create/azurerm/create_acc_test.go index 33c9b70e..0ba07abe 100644 --- a/internal/cmd/serviceendpoint/create/azurerm/create_acc_test.go +++ b/internal/cmd/serviceendpoint/create/azurerm/create_acc_test.go @@ -14,13 +14,13 @@ import ( "github.com/tmeckel/azdo-cli/internal/cmd/serviceendpoint/test" inttest "github.com/tmeckel/azdo-cli/internal/test" "github.com/tmeckel/azdo-cli/internal/types" + pollutil "github.com/tmeckel/azdo-cli/internal/util" ) type contextKey string const ( - ctxKeyCreateOpts contextKey = "azurerm/create-opts" - ctxKeyCertPath contextKey = "azurerm/cert-path" + ctxKeyCertPath contextKey = "azurerm/cert-path" ) const ( @@ -34,6 +34,11 @@ gKCAQEAwuTanj/Uo5Yhq7ckmL5jycB3Z/zPBuZjviQ4fAar/7xeOUe7/y2Kpls= -----END CERTIFICATE-----` ) +type createTestOptions struct { + servicePrincipalKey string + certificateFileName string +} + func TestAccCreateAzureRMServiceEndpoint(t *testing.T) { t.Parallel() @@ -48,7 +53,7 @@ func TestAccCreateAzureRMServiceEndpoint(t *testing.T) { // Test Service Principal with Secret t.Run("ServicePrincipalWithSecret", func(t *testing.T) { t.Parallel() - testAccCreateAzureRMServiceEndpoint(t, sharedProj, AuthSchemeServicePrincipal, CreationModeManual, func(opts *createOptions) { + testAccCreateAzureRMServiceEndpoint(t, sharedProj, AuthSchemeServicePrincipal, CreationModeManual, func(opts *createTestOptions) { opts.servicePrincipalKey = "test-secret-123" }) }) @@ -56,8 +61,8 @@ func TestAccCreateAzureRMServiceEndpoint(t *testing.T) { // Test Service Principal with Certificate t.Run("ServicePrincipalWithCertificate", func(t *testing.T) { t.Parallel() - testAccCreateAzureRMServiceEndpoint(t, sharedProj, AuthSchemeServicePrincipal, CreationModeManual, func(opts *createOptions) { - opts.certificatePath = "test-cert.pem" + testAccCreateAzureRMServiceEndpoint(t, sharedProj, AuthSchemeServicePrincipal, CreationModeManual, func(opts *createTestOptions) { + opts.certificateFileName = "test-cert.pem" }) }) @@ -80,14 +85,20 @@ func TestAccCreateAzureRMServiceEndpoint(t *testing.T) { }) } -func testAccCreateAzureRMServiceEndpoint(t *testing.T, sharedProj *test.SharedProject, authScheme string, creationMode string, setupFunc func(*createOptions)) { +func testAccCreateAzureRMServiceEndpoint(t *testing.T, sharedProj *test.SharedProject, authScheme string, creationMode string, setupFunc func(*createTestOptions)) { // Generate unique names for each test run endpointName := fmt.Sprintf("azdo-cli-test-ep-%s-%s", authScheme, uuid.New().String()) subscriptionID := uuid.New().String() subscriptionName := fmt.Sprintf("Test Subscription %s", authScheme) resourceGroup := fmt.Sprintf("test-rg-%s", uuid.New().String()) + testOpts := &createTestOptions{} + if setupFunc != nil { + setupFunc(testOpts) + } + inttest.Test(t, inttest.TestCase{ + AcceptanceTest: true, Steps: []inttest.Step{ { PreRun: func(ctx inttest.TestContext) error { @@ -100,57 +111,61 @@ func testAccCreateAzureRMServiceEndpoint(t *testing.T, sharedProj *test.SharedPr } projectArg := fmt.Sprintf("%s/%s", ctx.Org(), projectName) - opts := &createOptions{ - project: projectArg, - name: endpointName, - description: fmt.Sprintf("Test AzureRM endpoint with %s auth", authScheme), - authenticationScheme: authScheme, - servicePrincipalID: uuid.New().String(), // Random SPN ID - servicePrincipalKey: "", - certificatePath: "", - tenantID: uuid.New().String(), // Random tenant ID - subscriptionID: subscriptionID, - subscriptionName: subscriptionName, - resourceGroup: resourceGroup, - environment: "AzureCloud", - serviceEndpointCreationMode: creationMode, - grantPermissionToAllPipelines: true, - yes: true, + var certPath string + if strings.TrimSpace(testOpts.certificateFileName) != "" { + path, err := inttest.WriteTestFileWithName(t, testOpts.certificateFileName, strings.NewReader(testCertificatePEM)) + if err != nil { + return fmt.Errorf("failed to write certificate file: %w", err) + } + certPath = path + ctx.SetValue(ctxKeyCertPath, path) } - if setupFunc != nil { - setupFunc(opts) + var servicePrincipalID string + switch authScheme { + case AuthSchemeServicePrincipal: + servicePrincipalID = uuid.New().String() + case AuthSchemeWorkloadIdentityFederation: + if creationMode == CreationModeManual { + servicePrincipalID = uuid.New().String() + } } - if opts.certificatePath != "" { - // create a temporary certificate file using the test helper - certPath, err := inttest.WriteTestFileWithName(t, opts.certificatePath, strings.NewReader(testCertificatePEM)) - if err != nil { - return fmt.Errorf("failed to write certificate file: %w", err) - } - // override the path in opts so the command uses the generated file - opts.certificatePath = certPath - ctx.SetValue(ctxKeyCertPath, certPath) + args := []string{ + projectArg, + "--name", endpointName, + "--description", fmt.Sprintf("Test AzureRM endpoint with %s auth", authScheme), + "--authentication-scheme", authScheme, + "--tenant-id", uuid.New().String(), + "--subscription-id", subscriptionID, + "--subscription-name", subscriptionName, + "--resource-group", resourceGroup, + "--environment", "AzureCloud", + "--grant-permission-to-all-pipelines", } - ctx.SetValue(ctxKeyCreateOpts, opts) + if servicePrincipalID != "" { + args = append(args, "--service-principal-id", servicePrincipalID) + } + if authScheme == AuthSchemeServicePrincipal { + if strings.TrimSpace(testOpts.servicePrincipalKey) != "" { + args = append(args, "--service-principal-key", testOpts.servicePrincipalKey) + } else if certPath != "" { + args = append(args, "--certificate-path", certPath) + } + } - // Execute the command - return runCreate(ctx, opts) + cmd := NewCmd(ctx) + cmd.SetArgs(args) + return cmd.Execute() }, Verify: func(ctx inttest.TestContext) error { - storedOpts, ok := ctx.Value(ctxKeyCreateOpts) - if !ok { - return fmt.Errorf("test context missing create options") - } - opts := storedOpts.(*createOptions) - client, err := ctx.ClientFactory().ServiceEndpoint(ctx.Context(), ctx.Org()) if err != nil { return fmt.Errorf("failed to create service endpoint client: %w", err) } - return inttest.Poll(func() error { + return pollutil.Poll(ctx.Context(), func() error { projectName, err := test.GetTestProjectName(ctx) if err != nil { return err @@ -160,6 +175,7 @@ func testAccCreateAzureRMServiceEndpoint(t *testing.T, sharedProj *test.SharedPr Project: &projectName, Type: types.ToPtr("azurerm"), IncludeDetails: types.ToPtr(true), + IncludeFailed: types.ToPtr(true), }) if err != nil { return fmt.Errorf("failed to list service endpoints: %w", err) @@ -230,11 +246,11 @@ func testAccCreateAzureRMServiceEndpoint(t *testing.T, sharedProj *test.SharedPr if _, ok := params["serviceprincipalid"]; !ok { return fmt.Errorf("serviceprincipalid not found in auth parameters") } - if opts.servicePrincipalKey != "" { + if strings.TrimSpace(testOpts.servicePrincipalKey) != "" { if _, ok := params["authenticationType"]; !ok || params["authenticationType"] != "spnKey" { return fmt.Errorf("expected authenticationType 'spnKey' for service principal with secret") } - } else if opts.certificatePath != "" { + } else if strings.TrimSpace(testOpts.certificateFileName) != "" { if _, ok := params["authenticationType"]; !ok || params["authenticationType"] != "spnCertificate" { return fmt.Errorf("expected authenticationType 'spnCertificate' for service principal with certificate") } @@ -248,7 +264,7 @@ func testAccCreateAzureRMServiceEndpoint(t *testing.T, sharedProj *test.SharedPr } return nil - }, inttest.PollOptions{ + }, pollutil.PollOptions{ Tries: 10, Timeout: 30 * time.Second, }) diff --git a/internal/cmd/serviceendpoint/create/github/create.go b/internal/cmd/serviceendpoint/create/github/create.go index 8789bceb..76db4d85 100644 --- a/internal/cmd/serviceendpoint/create/github/create.go +++ b/internal/cmd/serviceendpoint/create/github/create.go @@ -4,169 +4,110 @@ import ( "fmt" "github.com/MakeNowJust/heredoc" - "github.com/google/uuid" "github.com/microsoft/azure-devops-go-api/azuredevops/v7/serviceendpoint" "github.com/spf13/cobra" - "go.uber.org/zap" "github.com/tmeckel/azdo-cli/internal/cmd/serviceendpoint/shared" "github.com/tmeckel/azdo-cli/internal/cmd/util" - "github.com/tmeckel/azdo-cli/internal/types" ) -type createOptions struct { - project string - - name string +type githubConfigurer struct { + cmdCtx util.CmdContext url string token string configurationID string - - exporter util.Exporter } -func NewCmd(ctx util.CmdContext) *cobra.Command { - opts := &createOptions{} - - cmd := &cobra.Command{ - Use: "github [ORGANIZATION/]PROJECT --name NAME [--url URL] [--token TOKEN]", - Short: "Create a GitHub service endpoint", - Long: heredoc.Doc(` - Create a GitHub service endpoint using a personal access token (PAT) or an installation/oauth configuration. - `), - Example: heredoc.Doc(` - # Create a GitHub service endpoint with a personal access token (PAT) - azdo service-endpoint create github my-org/my-project --name "gh-ep" --token - - # Create a GitHub service endpoint with an installation / OAuth configuration id - azdo service-endpoint create github my-org/my-project --name "gh-ep" --configuration-id - `), - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - opts.project = args[0] - return runCreate(ctx, opts) - }, - } - - cmd.Flags().StringVar(&opts.name, "name", "", "Name of the service endpoint") - cmd.Flags().StringVar(&opts.url, "url", "", "GitHub URL (defaults to https://github.com)") - // Help text taken from service-endpoint-types.json (inputDescriptors.AccessToken.description) - cmd.Flags().StringVar(&opts.token, "token", "", "Visit https://github.com/settings/tokens to create personal access tokens. Recommended scopes: repo, user, admin:repo_hook. If omitted, you will be prompted for a token when interactive.") - // Support installation/oauth configuration via ConfigurationId (InstallationToken scheme) - // Help text taken from service-endpoint-types.json (inputDescriptors.ConfigurationId.description) - cmd.Flags().StringVar(&opts.configurationID, "configuration-id", "", "Configuration for connecting to the endpoint (use an OAuth/installation configuration). Mutually exclusive with --token.") - - _ = cmd.MarkFlagRequired("name") - - util.AddJSONFlags(cmd, &opts.exporter, shared.ServiceEndpointJSONFields) - - return cmd +func (g *githubConfigurer) CommandContext() util.CmdContext { + return g.cmdCtx } -func runCreate(ctx util.CmdContext, opts *createOptions) error { - ios, err := ctx.IOStreams() - if err != nil { - return err - } +func (g *githubConfigurer) TypeName() string { + return "github" +} - p, err := ctx.Prompter() +func (g *githubConfigurer) Configure(endpoint *serviceendpoint.ServiceEndpoint) error { + cmdCtx := g.cmdCtx + ios, err := cmdCtx.IOStreams() if err != nil { return err } - scope, err := util.ParseProjectScope(ctx, opts.project) - if err != nil { - return util.FlagErrorWrap(err) + if g.url == "" { + g.url = "https://github.com" } - // default URL - if opts.url == "" { - opts.url = "https://github.com" - } - - // authentication selection: token (PAT) or configuration-id (InstallationToken) - if opts.token != "" && opts.configurationID != "" { + // reuse existing logic from runCreate's auth selection + if g.token != "" && g.configurationID != "" { return fmt.Errorf("--token and --configuration-id are mutually exclusive") } - if opts.token == "" && opts.configurationID == "" { + + if g.token == "" && g.configurationID == "" { // default to prompting for token when interactive if !ios.CanPrompt() { return fmt.Errorf("no authentication provided: pass --token or --configuration-id (and enable prompting to provide token interactively)") } + + p, err := cmdCtx.Prompter() + if err != nil { + return err + } + secret, err := p.Password("GitHub token:") if err != nil { return fmt.Errorf("prompt for token failed: %w", err) } - opts.token = secret - } - - projectRef, err := shared.ResolveProjectReference(ctx, scope) - if err != nil { - return util.FlagErrorWrap(err) + g.token = secret } - endpointType := "github" - owner := "library" - var scheme string - var authParams map[string]string - if opts.configurationID != "" { - // InstallationToken scheme expects ConfigurationId parameter + params := map[string]string{} + if g.configurationID != "" { scheme = "InstallationToken" - authParams = map[string]string{ - "ConfigurationId": opts.configurationID, - } + params["ConfigurationId"] = g.configurationID } else { - // default to PAT token scheme = "Token" - authParams = map[string]string{ - "AccessToken": opts.token, - } + params["AccessToken"] = g.token } - - endpoint := &serviceendpoint.ServiceEndpoint{ - Name: &opts.name, - Type: &endpointType, - Url: &opts.url, - Owner: &owner, - Authorization: &serviceendpoint.EndpointAuthorization{ - Scheme: &scheme, - Parameters: &authParams, - }, - ServiceEndpointProjectReferences: &[]serviceendpoint.ServiceEndpointProjectReference{ - { - ProjectReference: projectRef, - Name: &opts.name, - Description: types.ToPtr(""), - }, - }, - } - - ios.StartProgressIndicator() - defer ios.StopProgressIndicator() - - client, err := ctx.ClientFactory().ServiceEndpoint(ctx.Context(), scope.Organization) - if err != nil { - return fmt.Errorf("failed to create service endpoint client: %w", err) + endpoint.Url = &g.url + endpoint.Authorization = &serviceendpoint.EndpointAuthorization{ + Scheme: &scheme, + Parameters: ¶ms, } + endpoint.Data = &map[string]string{} + return nil +} - createdEndpoint, err := client.CreateServiceEndpoint(ctx.Context(), serviceendpoint.CreateServiceEndpointArgs{ - Endpoint: endpoint, - }) - if err != nil { - return fmt.Errorf("failed to create service endpoint: %w", err) +func NewCmd(ctx util.CmdContext) *cobra.Command { + cfg := &githubConfigurer{ + cmdCtx: ctx, } - zap.L().Debug("github service endpoint created", - zap.String("id", types.GetValue(createdEndpoint.Id, uuid.Nil).String()), - zap.String("name", types.GetValue(createdEndpoint.Name, "")), - ) - - ios.StopProgressIndicator() + cmd := &cobra.Command{ + Use: "github [ORGANIZATION/]PROJECT --name NAME [--url URL] [--token TOKEN]", + Short: "Create a GitHub service endpoint", + Long: heredoc.Doc(` + Create a GitHub service endpoint using a personal access token (PAT) or an installation/oauth configuration. + `), + Example: heredoc.Doc(` + # Create a GitHub service endpoint with a personal access token (PAT) + azdo service-endpoint create github my-org/my-project --name "gh-ep" --token - if opts.exporter != nil { - shared.RedactSecrets(createdEndpoint) + # Create a GitHub service endpoint with an installation / OAuth configuration id + azdo service-endpoint create github my-org/my-project --name "gh-ep" --configuration-id + `), + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return shared.RunTypedCreate(cmd, args, cfg) + }, } - return shared.Output(ctx, createdEndpoint, opts.exporter) + cmd.Flags().StringVar(&cfg.url, "url", "", "GitHub URL (defaults to https://github.com)") + // Help text taken from service-endpoint-types.json (inputDescriptors.AccessToken.description) + cmd.Flags().StringVar(&cfg.token, "token", "", "Visit https://github.com/settings/tokens to create personal access tokens. Recommended scopes: repo, user, admin:repo_hook. If omitted, you will be prompted for a token when interactive.") + // Support installation/oauth configuration via ConfigurationId (InstallationToken scheme) + // Help text taken from service-endpoint-types.json (inputDescriptors.ConfigurationId.description) + cmd.Flags().StringVar(&cfg.configurationID, "configuration-id", "", "Configuration for connecting to the endpoint (use an OAuth/installation configuration). Mutually exclusive with --token.") + + return shared.AddCreateCommonFlags(cmd) } diff --git a/internal/cmd/serviceendpoint/create/github/create_acc_test.go b/internal/cmd/serviceendpoint/create/github/create_acc_test.go index 461b615d..6f01e0db 100644 --- a/internal/cmd/serviceendpoint/create/github/create_acc_test.go +++ b/internal/cmd/serviceendpoint/create/github/create_acc_test.go @@ -11,12 +11,7 @@ import ( "github.com/tmeckel/azdo-cli/internal/cmd/serviceendpoint/test" inttest "github.com/tmeckel/azdo-cli/internal/test" "github.com/tmeckel/azdo-cli/internal/types" -) - -type contextKey string - -const ( - ctxKeyCreateOpts contextKey = "github/create-opts" + pollutil "github.com/tmeckel/azdo-cli/internal/util" ) func TestAccCreateGitHubServiceEndpoint(t *testing.T) { @@ -34,6 +29,7 @@ func TestAccCreateGitHubServiceEndpoint(t *testing.T) { endpointName := fmt.Sprintf("azdo-cli-acc-gh-%s", uuid.New().String()) inttest.Test(t, inttest.TestCase{ + AcceptanceTest: true, Steps: []inttest.Step{ { PreRun: func(ctx inttest.TestContext) error { @@ -46,15 +42,14 @@ func TestAccCreateGitHubServiceEndpoint(t *testing.T) { } projectArg := fmt.Sprintf("%s/%s", ctx.Org(), projectName) - opts := &createOptions{ - project: projectArg, - name: endpointName, - url: "https://github.com", - token: uuid.New().String(), - } - - ctx.SetValue(ctxKeyCreateOpts, opts) - return runCreate(ctx, opts) + cmd := NewCmd(ctx) + cmd.SetArgs([]string{ + projectArg, + "--name", endpointName, + "--url", "https://github.com", + "--token", uuid.New().String(), + }) + return cmd.Execute() }, Verify: func(ctx inttest.TestContext) error { client, err := ctx.ClientFactory().ServiceEndpoint(ctx.Context(), ctx.Org()) @@ -62,7 +57,7 @@ func TestAccCreateGitHubServiceEndpoint(t *testing.T) { return err } - return inttest.Poll(func() error { + return pollutil.Poll(ctx.Context(), func() error { projectName, err := test.GetTestProjectName(ctx) if err != nil { return err @@ -110,7 +105,7 @@ func TestAccCreateGitHubServiceEndpoint(t *testing.T) { } return nil - }, inttest.PollOptions{ + }, pollutil.PollOptions{ Tries: 10, Timeout: 30 * time.Second, }) @@ -129,6 +124,7 @@ func TestAccCreateGitHubServiceEndpoint(t *testing.T) { endpointName := fmt.Sprintf("azdo-cli-acc-gh-cfg-%s", uuid.New().String()) inttest.Test(t, inttest.TestCase{ + AcceptanceTest: true, Steps: []inttest.Step{ { PreRun: func(ctx inttest.TestContext) error { @@ -141,15 +137,14 @@ func TestAccCreateGitHubServiceEndpoint(t *testing.T) { } projectArg := fmt.Sprintf("%s/%s", ctx.Org(), projectName) - opts := &createOptions{ - project: projectArg, - name: endpointName, - url: "https://github.com", - configurationID: uuid.New().String(), - } - - ctx.SetValue(ctxKeyCreateOpts, opts) - return runCreate(ctx, opts) + cmd := NewCmd(ctx) + cmd.SetArgs([]string{ + projectArg, + "--name", endpointName, + "--url", "https://github.com", + "--configuration-id", uuid.New().String(), + }) + return cmd.Execute() }, Verify: func(ctx inttest.TestContext) error { client, err := ctx.ClientFactory().ServiceEndpoint(ctx.Context(), ctx.Org()) @@ -157,7 +152,7 @@ func TestAccCreateGitHubServiceEndpoint(t *testing.T) { return err } - return inttest.Poll(func() error { + return pollutil.Poll(ctx.Context(), func() error { projectName, err := test.GetTestProjectName(ctx) if err != nil { return err @@ -205,7 +200,7 @@ func TestAccCreateGitHubServiceEndpoint(t *testing.T) { } return nil - }, inttest.PollOptions{ + }, pollutil.PollOptions{ Tries: 10, Timeout: 30 * time.Second, }) diff --git a/internal/cmd/serviceendpoint/create/github/create_test.go b/internal/cmd/serviceendpoint/create/github/create_test.go index 6f7222fa..b49efbb9 100644 --- a/internal/cmd/serviceendpoint/create/github/create_test.go +++ b/internal/cmd/serviceendpoint/create/github/create_test.go @@ -15,7 +15,7 @@ import ( "github.com/tmeckel/azdo-cli/internal/types" ) -func TestRunCreate_WithTokenFlag(t *testing.T) { +func TestNewCmd_WithTokenFlag(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) @@ -32,20 +32,15 @@ func TestRunCreate_WithTokenFlag(t *testing.T) { // Core client is used by ResolveProjectReference to fetch project metadata mockCore := mocks.NewMockCoreClient(ctrl) mClientFactory.EXPECT().Core(gomock.Any(), "org1").Return(mockCore, nil).AnyTimes() - mockCore.EXPECT().GetProject(gomock.Any(), gomock.Any()).Return(&core.TeamProject{Id: types.ToPtr(uuid.New())}, nil).AnyTimes() + mockCore.EXPECT().GetProject(gomock.Any(), gomock.Any()).Return(&core.TeamProject{ + Id: types.ToPtr(uuid.New()), + Name: types.ToPtr("proj1"), + }, nil).AnyTimes() mockSEClient := mocks.NewMockServiceEndpointClient(ctrl) // The connection factory mock exposes ServiceEndpoint via ClientFactory mock mClientFactory.EXPECT().ServiceEndpoint(gomock.Any(), "org1").Return(mockSEClient, nil).AnyTimes() - // Printer mock for table output - mPrinter := mocks.NewMockPrinter(ctrl) - mCmdCtx.EXPECT().Printer(gomock.Any()).Return(mPrinter, nil).AnyTimes() - mPrinter.EXPECT().AddColumns(gomock.Any()).AnyTimes() - mPrinter.EXPECT().AddField(gomock.Any()).AnyTimes() - mPrinter.EXPECT().EndRow().AnyTimes() - mPrinter.EXPECT().Render().AnyTimes() - // Expect CreateServiceEndpoint to be called and return a created endpoint created := &serviceendpoint.ServiceEndpoint{ Id: types.ToPtr(uuid.New()), @@ -59,25 +54,31 @@ func TestRunCreate_WithTokenFlag(t *testing.T) { if args.Endpoint == nil { t.Fatalf("expected endpoint payload") } + if args.Endpoint.Authorization == nil || args.Endpoint.Authorization.Scheme == nil { + t.Fatalf("expected endpoint authorization") + } + if types.GetValue(args.Endpoint.Authorization.Scheme, "") != "Token" { + t.Fatalf("expected Token scheme, got %q", types.GetValue(args.Endpoint.Authorization.Scheme, "")) + } + if args.Endpoint.Authorization.Parameters == nil { + t.Fatalf("expected authorization parameters") + } + if got := (*args.Endpoint.Authorization.Parameters)["AccessToken"]; got != "tok-flag" { + t.Fatalf("expected AccessToken=tok-flag, got %q", got) + } return created, nil }, ).Times(1) mCmdCtx.EXPECT().Prompter().Return(nil, nil).AnyTimes() - opts := &createOptions{ - project: "org1/proj1", - name: "ep-name", - url: "", - token: "tok-flag", - } - - // run - err := runCreate(mCmdCtx, opts) + cmd := NewCmd(mCmdCtx) + cmd.SetArgs([]string{"org1/proj1", "--name", "ep-name", "--token", "tok-flag"}) + err := cmd.Execute() assert.NoError(t, err) } -func TestRunCreate_PromptForToken(t *testing.T) { +func TestNewCmd_PromptForToken(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) @@ -96,19 +97,14 @@ func TestRunCreate_PromptForToken(t *testing.T) { mockCore := mocks.NewMockCoreClient(ctrl) mClientFactory.EXPECT().Core(gomock.Any(), "org1").Return(mockCore, nil).AnyTimes() - mockCore.EXPECT().GetProject(gomock.Any(), gomock.Any()).Return(&core.TeamProject{Id: types.ToPtr(uuid.New())}, nil).AnyTimes() + mockCore.EXPECT().GetProject(gomock.Any(), gomock.Any()).Return(&core.TeamProject{ + Id: types.ToPtr(uuid.New()), + Name: types.ToPtr("proj1"), + }, nil).AnyTimes() mockSEClient := mocks.NewMockServiceEndpointClient(ctrl) mClientFactory.EXPECT().ServiceEndpoint(gomock.Any(), "org1").Return(mockSEClient, nil).AnyTimes() - // Printer mock for table output - mPrinter := mocks.NewMockPrinter(ctrl) - mCmdCtx.EXPECT().Printer(gomock.Any()).Return(mPrinter, nil).AnyTimes() - mPrinter.EXPECT().AddColumns(gomock.Any()).AnyTimes() - mPrinter.EXPECT().AddField(gomock.Any()).AnyTimes() - mPrinter.EXPECT().EndRow().AnyTimes() - mPrinter.EXPECT().Render().AnyTimes() - created := &serviceendpoint.ServiceEndpoint{ Id: types.ToPtr(uuid.New()), Name: types.ToPtr("ep-name"), @@ -122,18 +118,13 @@ func TestRunCreate_PromptForToken(t *testing.T) { prom.EXPECT().Password(gomock.Any()).Return("sometoken", nil).Times(1) mCmdCtx.EXPECT().Prompter().Return(prom, nil).AnyTimes() - opts := &createOptions{ - project: "org1/proj1", - name: "ep-name", - url: "", - token: "", - } - - err := runCreate(mCmdCtx, opts) + cmd := NewCmd(mCmdCtx) + cmd.SetArgs([]string{"org1/proj1", "--name", "ep-name"}) + err := cmd.Execute() assert.NoError(t, err) } -func TestRunCreate_WithConfigurationID(t *testing.T) { +func TestNewCmd_WithConfigurationID(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) @@ -149,19 +140,14 @@ func TestRunCreate_WithConfigurationID(t *testing.T) { mockCore := mocks.NewMockCoreClient(ctrl) mClientFactory.EXPECT().Core(gomock.Any(), "org1").Return(mockCore, nil).AnyTimes() - mockCore.EXPECT().GetProject(gomock.Any(), gomock.Any()).Return(&core.TeamProject{Id: types.ToPtr(uuid.New())}, nil).AnyTimes() + mockCore.EXPECT().GetProject(gomock.Any(), gomock.Any()).Return(&core.TeamProject{ + Id: types.ToPtr(uuid.New()), + Name: types.ToPtr("proj1"), + }, nil).AnyTimes() mockSEClient := mocks.NewMockServiceEndpointClient(ctrl) mClientFactory.EXPECT().ServiceEndpoint(gomock.Any(), "org1").Return(mockSEClient, nil).AnyTimes() - // Printer mock for table output - mPrinter := mocks.NewMockPrinter(ctrl) - mCmdCtx.EXPECT().Printer(gomock.Any()).Return(mPrinter, nil).AnyTimes() - mPrinter.EXPECT().AddColumns(gomock.Any()).AnyTimes() - mPrinter.EXPECT().AddField(gomock.Any()).AnyTimes() - mPrinter.EXPECT().EndRow().AnyTimes() - mPrinter.EXPECT().Render().AnyTimes() - // Expect CreateServiceEndpoint and validate Authorization scheme/params mockSEClient.EXPECT().CreateServiceEndpoint(gomock.Any(), gomock.Any()).DoAndReturn( func(_ context.Context, args serviceendpoint.CreateServiceEndpointArgs) (*serviceendpoint.ServiceEndpoint, error) { @@ -189,13 +175,8 @@ func TestRunCreate_WithConfigurationID(t *testing.T) { mCmdCtx.EXPECT().Prompter().Return(nil, nil).AnyTimes() - opts := &createOptions{ - project: "org1/proj1", - name: "ep-name", - url: "", - configurationID: "cfg-123", - } - - err := runCreate(mCmdCtx, opts) + cmd := NewCmd(mCmdCtx) + cmd.SetArgs([]string{"org1/proj1", "--name", "ep-name", "--configuration-id", "cfg-123"}) + err := cmd.Execute() assert.NoError(t, err) } diff --git a/internal/cmd/serviceendpoint/delete/delete_acc_test.go b/internal/cmd/serviceendpoint/delete/delete_acc_test.go index 9168fb6f..01606076 100644 --- a/internal/cmd/serviceendpoint/delete/delete_acc_test.go +++ b/internal/cmd/serviceendpoint/delete/delete_acc_test.go @@ -12,6 +12,7 @@ import ( "github.com/tmeckel/azdo-cli/internal/cmd/util" inttest "github.com/tmeckel/azdo-cli/internal/test" "github.com/tmeckel/azdo-cli/internal/types" + pollutil "github.com/tmeckel/azdo-cli/internal/util" ) type ctxKey string @@ -31,6 +32,7 @@ func TestAccDeleteServiceEndpoint(t *testing.T) { }) inttest.Test(t, inttest.TestCase{ + AcceptanceTest: true, Steps: []inttest.Step{ { PreRun: func(ctx inttest.TestContext) error { @@ -86,7 +88,7 @@ func TestAccDeleteServiceEndpoint(t *testing.T) { if err != nil { return fmt.Errorf("invalid endpoint id: %w", err) } - return inttest.Poll(func() error { + return pollutil.Poll(ctx.Context(), func() error { sp, err := client.GetServiceEndpointDetails(ctx.Context(), serviceendpoint.GetServiceEndpointDetailsArgs{ Project: types.ToPtr(projectName), EndpointId: &endpointID, @@ -99,7 +101,7 @@ func TestAccDeleteServiceEndpoint(t *testing.T) { return nil } return err - }, inttest.PollOptions{ + }, pollutil.PollOptions{ Tries: 10, Timeout: 240 * time.Second, }) diff --git a/internal/cmd/serviceendpoint/shared/create_common.go b/internal/cmd/serviceendpoint/shared/create_common.go new file mode 100644 index 00000000..7c618cd1 --- /dev/null +++ b/internal/cmd/serviceendpoint/shared/create_common.go @@ -0,0 +1,47 @@ +package shared + +import ( + "context" + "time" + + "github.com/spf13/cobra" + + "github.com/tmeckel/azdo-cli/internal/cmd/util" +) + +// CreateCommonOptions contains flags/args that apply to all typed create commands. +type createCommonOptions struct { + Name string + Description string + + WaitReady bool + ValidateSchema bool + ValidateConnection bool + GrantAllPipelines bool + + Timeout time.Duration + Exporter util.Exporter +} + +// AddCreateCommonFlags registers the common flags on a create command. +func AddCreateCommonFlags(cmd *cobra.Command) *cobra.Command { + common := createCommonOptions{} + cmd.Flags().StringVar(&common.Name, "name", "", "Name of the service endpoint") + cmd.Flags().StringVar(&common.Description, "description", "", "Description for the service endpoint") + cmd.Flags().BoolVar(&common.WaitReady, "wait", false, "Wait until the endpoint reports ready/failed") + cmd.Flags().DurationVar(&common.Timeout, "timeout", 2*time.Minute, "Maximum time to wait when --wait or --validate-connection is enabled") + cmd.Flags().BoolVar(&common.ValidateSchema, "validate-schema", false, "Validate auth scheme/params against endpoint type metadata (opt-in)") + cmd.Flags().BoolVar(&common.ValidateConnection, "validate-connection", false, "Run TestConnection after creation (opt-in)") + cmd.Flags().BoolVar(&common.GrantAllPipelines, "grant-permission-to-all-pipelines", false, "Grant access permission to all pipelines to use the service connection") + util.AddJSONFlags(cmd, &common.Exporter, ServiceEndpointJSONFields) + + _ = cmd.MarkFlagRequired("name") + + ctx := cmd.Context() + if ctx == nil { + ctx = context.Background() + } + cmd.SetContext(context.WithValue(ctx, "createCommonOptions", &common)) + + return cmd +} diff --git a/internal/cmd/serviceendpoint/shared/create_common_test.go b/internal/cmd/serviceendpoint/shared/create_common_test.go new file mode 100644 index 00000000..2ca3d265 --- /dev/null +++ b/internal/cmd/serviceendpoint/shared/create_common_test.go @@ -0,0 +1,154 @@ +package shared + +import ( + "context" + "io" + "testing" + "time" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/require" +) + +func TestAddCreateCommonFlags(t *testing.T) { + t.Parallel() + + t.Run("Flags registration", func(t *testing.T) { + t.Parallel() + cmd := &cobra.Command{ + Use: "test", + } + + updatedCmd := AddCreateCommonFlags(cmd) + + // Verify flags exist and have correct metadata + flags := []struct { + name string + usage string + defValue string + shortHand string + }{ + {name: "name", usage: "Name of the service endpoint", defValue: ""}, + {name: "description", usage: "Description for the service endpoint", defValue: ""}, + {name: "wait", usage: "Wait until the endpoint reports ready/failed", defValue: "false"}, + {name: "timeout", usage: "Maximum time to wait when --wait or --validate-connection is enabled", defValue: "2m0s"}, + {name: "validate-schema", usage: "Validate auth scheme/params against endpoint type metadata (opt-in)", defValue: "false"}, + {name: "validate-connection", usage: "Run TestConnection after creation (opt-in)", defValue: "false"}, + {name: "grant-permission-to-all-pipelines", usage: "Grant access permission to all pipelines to use the service connection", defValue: "false"}, + } + + for _, f := range flags { + flag := updatedCmd.Flag(f.name) + require.NotNil(t, flag, "flag %s should exist", f.name) + require.Equal(t, f.usage, flag.Usage, "usage for flag %s", f.name) + require.Equal(t, f.defValue, flag.DefValue, "default value for flag %s", f.name) + } + + // Verify JSON flags are added + require.NotNil(t, updatedCmd.Flag("json"), "json flag should exist") + + // Verify context + ctx := updatedCmd.Context() + require.NotNil(t, ctx, "context should not be nil") + opts := ctx.Value("createCommonOptions") + require.NotNil(t, opts, "createCommonOptions should be in context") + _, ok := opts.(*createCommonOptions) + require.True(t, ok, "context value should be of type *createCommonOptions") + }) + + t.Run("Flag Parsing", func(t *testing.T) { + t.Parallel() + cmd := &cobra.Command{ + Use: "test", + Run: func(cmd *cobra.Command, args []string) {}, + } + AddCreateCommonFlags(cmd) + + // Parse flags + err := cmd.ParseFlags([]string{ + "--name", "my-endpoint", + "--description", "desc", + "--wait", + "--timeout", "5m", + "--validate-schema", + "--validate-connection", + "--grant-permission-to-all-pipelines", + }) + require.NoError(t, err) + + // Retrieve options from context + ctx := cmd.Context() + opts := ctx.Value("createCommonOptions").(*createCommonOptions) + + require.Equal(t, "my-endpoint", opts.Name) + require.Equal(t, "desc", opts.Description) + require.True(t, opts.WaitReady) + require.Equal(t, 5*time.Minute, opts.Timeout) + require.True(t, opts.ValidateSchema) + require.True(t, opts.ValidateConnection) + require.True(t, opts.GrantAllPipelines) + }) + + t.Run("Required flag", func(t *testing.T) { + t.Parallel() + cmd := &cobra.Command{ + Use: "test", + } + AddCreateCommonFlags(cmd) + + flag := cmd.Flag("name") + require.NotNil(t, flag) + + // Check if it's marked as required in the command + // Cobra stores required flags in an internal map, but we can check the annotation + require.Equal(t, "true", flag.Annotations[cobra.BashCompOneRequiredFlag][0]) + }) + + t.Run("Invalid flag values", func(t *testing.T) { + t.Parallel() + cmd := &cobra.Command{ + Use: "test", + Run: func(cmd *cobra.Command, args []string) {}, + } + AddCreateCommonFlags(cmd) + + // Parse invalid timeout + err := cmd.ParseFlags([]string{"--name", "test", "--timeout", "invalid"}) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid argument \"invalid\" for \"--timeout\"") + }) + + t.Run("Existing context is preserved", func(t *testing.T) { + t.Parallel() + type key string + ctx := context.WithValue(context.Background(), key("existing"), "value") + cmd := &cobra.Command{ + Use: "test", + } + cmd.SetContext(ctx) + + AddCreateCommonFlags(cmd) + + updatedCtx := cmd.Context() + require.Equal(t, "value", updatedCtx.Value(key("existing"))) + require.NotNil(t, updatedCtx.Value("createCommonOptions")) + }) + + t.Run("Required flag enforcement", func(t *testing.T) { + t.Parallel() + cmd := &cobra.Command{ + Use: "test", + Run: func(cmd *cobra.Command, args []string) {}, + } + AddCreateCommonFlags(cmd) + + // Disable printing to avoid noise + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + + cmd.SetArgs([]string{"--description", "desc"}) + err := cmd.Execute() + require.Error(t, err) + require.Contains(t, err.Error(), "required flag(s) \"name\" not set") + }) +} diff --git a/internal/cmd/serviceendpoint/shared/runner_create.go b/internal/cmd/serviceendpoint/shared/runner_create.go new file mode 100644 index 00000000..cf0f7b41 --- /dev/null +++ b/internal/cmd/serviceendpoint/shared/runner_create.go @@ -0,0 +1,119 @@ +package shared + +import ( + "fmt" + + "github.com/google/uuid" + "github.com/microsoft/azure-devops-go-api/azuredevops/v7/serviceendpoint" + "github.com/spf13/cobra" + + "github.com/tmeckel/azdo-cli/internal/cmd/util" + "github.com/tmeckel/azdo-cli/internal/types" +) + +// WithCreateCommonOptions already defined in create_common.go + +// EndpointTypeConfigurer populates type-specific parts of a service endpoint. +type EndpointTypeConfigurer interface { + CommandContext() util.CmdContext + TypeName() string + Configure(endpoint *serviceendpoint.ServiceEndpoint) error +} + +// RunTypedCreate centralizes creation flow for typed service endpoint commands. +func RunTypedCreate(cmd *cobra.Command, args []string, cfg EndpointTypeConfigurer) error { + cmdCtx := cfg.CommandContext() + ios, err := cmdCtx.IOStreams() + if err != nil { + return err + } + + common := cmd.Context().Value("createCommonOptions").(*createCommonOptions) + ios.StartProgressIndicator() + defer ios.StopProgressIndicator() + + scope, err := util.ParseProjectScope(cmdCtx, args[0]) + if err != nil { + return util.FlagErrorWrap(err) + } + + projectRef, err := ResolveProjectReference(cmdCtx, scope) + if err != nil { + return err + } + + endpointType := cfg.TypeName() + owner := "library" + + endpoint := &serviceendpoint.ServiceEndpoint{ + Name: &common.Name, + Description: &common.Description, + Type: &endpointType, + Owner: &owner, + ServiceEndpointProjectReferences: &[]serviceendpoint.ServiceEndpointProjectReference{{ + ProjectReference: projectRef, + Name: &common.Name, + Description: &common.Description, + }}, + } + + err = cfg.Configure(endpoint) + if err != nil { + return util.FlagErrorWrap(err) + } + + if common.ValidateSchema { + if err := ValidateEndpointAgainstMetadata(cmdCtx, scope.Organization, endpoint); err != nil { + return util.FlagErrorWrap(err) + } + } + + client, err := cmdCtx.ClientFactory().ServiceEndpoint(cmdCtx.Context(), scope.Organization) + if err != nil { + return fmt.Errorf("failed to create service endpoint client: %w", err) + } + + created, err := client.CreateServiceEndpoint(cmdCtx.Context(), serviceendpoint.CreateServiceEndpointArgs{Endpoint: endpoint}) + if err != nil { + return fmt.Errorf("failed to create service endpoint: %w", err) + } + + if common.WaitReady { + created, err = WaitForReady(cmdCtx.Context(), client, scope.Project, created, common.Timeout) + if err != nil { + return err + } + } + + if common.ValidateConnection { + if err := TestConnection(cmdCtx, client, scope.Organization, scope.Project, created, common.Timeout); err != nil { + return err + } + } + + if common.GrantAllPipelines { + projectID := types.GetValue(projectRef.Id, uuid.Nil) + if projectID == uuid.Nil { + return fmt.Errorf("project reference missing ID") + } + endpointID := types.GetValue(created.Id, uuid.Nil) + if endpointID == uuid.Nil { + return fmt.Errorf("service endpoint create response missing ID") + } + + if err := SetAllPipelinesAccessToEndpoint(cmdCtx, scope.Organization, projectID, endpointID, true, func() error { + return client.DeleteServiceEndpoint(cmdCtx.Context(), serviceendpoint.DeleteServiceEndpointArgs{ + EndpointId: &endpointID, + ProjectIds: &[]string{projectID.String()}, + }) + }); err != nil { + return err + } + } + + ios.StopProgressIndicator() + + // redact secrets before output + RedactSecrets(created) + return Output(cmdCtx, created, common.Exporter) +} diff --git a/internal/cmd/serviceendpoint/shared/runner_update.go b/internal/cmd/serviceendpoint/shared/runner_update.go new file mode 100644 index 00000000..2c70171d --- /dev/null +++ b/internal/cmd/serviceendpoint/shared/runner_update.go @@ -0,0 +1,114 @@ +package shared + +import ( + "errors" + "fmt" + + "github.com/google/uuid" + "github.com/microsoft/azure-devops-go-api/azuredevops/v7/serviceendpoint" + "github.com/spf13/cobra" + + "github.com/tmeckel/azdo-cli/internal/cmd/util" + "github.com/tmeckel/azdo-cli/internal/types" +) + +// RunTypedUpdate centralizes update flow for typed service endpoint commands. +func RunTypedUpdate(cmd *cobra.Command, args []string, cfg EndpointTypeConfigurer) error { + cmdCtx := cfg.CommandContext() + ios, err := cmdCtx.IOStreams() + if err != nil { + return err + } + + common := cmd.Context().Value("updateCommonOptions").(*updateCommonOptions) + ios.StartProgressIndicator() + defer ios.StopProgressIndicator() + + // 1. Parse scope + scope, err := util.ParseProjectTargetWithDefaultOrganization(cmdCtx, args[0]) + if err != nil { + return util.FlagErrorWrap(err) + } + + client, err := cmdCtx.ClientFactory().ServiceEndpoint(cmdCtx.Context(), scope.Organization) + if err != nil { + return fmt.Errorf("failed to create service endpoint client: %w", err) + } + + // 2. Find existing endpoint + endpoint, err := FindServiceEndpoint(cmdCtx, client, scope.Project, scope.Target) + if err != nil { + if errors.Is(err, ErrEndpointNotFound) { + ios.StopProgressIndicator() + cs := ios.ColorScheme() + fmt.Fprintf(ios.Out, "%s Service endpoint %q was not found in %s/%s.\n", cs.WarningIcon(), scope.Target, scope.Organization, scope.Project) + return nil + } + return err + } + + // 3. Update common fields + if cmd.Flags().Changed("name") { + endpoint.Name = &common.Name + } + if cmd.Flags().Changed("description") { + endpoint.Description = &common.Description + } + + // 4. Update type specific fields + err = cfg.Configure(endpoint) + if err != nil { + return util.FlagErrorWrap(err) + } + + // 5. Validate schema + if common.ValidateSchema { + if err := ValidateEndpointAgainstMetadata(cmdCtx, scope.Organization, endpoint); err != nil { + return util.FlagErrorWrap(err) + } + } + + // 6. Execute Update + updated, err := client.UpdateServiceEndpoint(cmdCtx.Context(), serviceendpoint.UpdateServiceEndpointArgs{ + Endpoint: endpoint, + EndpointId: endpoint.Id, + }) + if err != nil { + return fmt.Errorf("failed to update service endpoint: %w", err) + } + + // 7. Wait if requested + if common.WaitReady { + updated, err = WaitForReady(cmdCtx.Context(), client, scope.Project, updated, common.Timeout) + if err != nil { + return err + } + } + + // 8. Validate connection + if common.ValidateConnection { + if err := TestConnection(cmdCtx, client, scope.Organization, scope.Project, updated, common.Timeout); err != nil { + return err + } + } + + // 9. Pipeline permissions + if cmd.Flags().Changed("grant-permission-to-all-pipelines") { + projectRef, err := ResolveProjectReference(cmdCtx, &scope.Scope) + if err != nil { + return err + } + projectID := types.GetValue(projectRef.Id, uuid.Nil) + endpointID := types.GetValue(updated.Id, uuid.Nil) + + if err := SetAllPipelinesAccessToEndpoint(cmdCtx, scope.Organization, projectID, endpointID, common.GrantAllPipelines, nil); err != nil { + return err + } + } + + ios.StopProgressIndicator() + + // redact secrets before output + RedactSecrets(updated) + return Output(cmdCtx, updated, common.Exporter) +} diff --git a/internal/cmd/serviceendpoint/shared/test_connection.go b/internal/cmd/serviceendpoint/shared/test_connection.go new file mode 100644 index 00000000..2ab32692 --- /dev/null +++ b/internal/cmd/serviceendpoint/shared/test_connection.go @@ -0,0 +1,97 @@ +package shared + +import ( + "fmt" + "strings" + "time" + + "github.com/microsoft/azure-devops-go-api/azuredevops/v7/serviceendpoint" + "github.com/tmeckel/azdo-cli/internal/cmd/util" + pollutil "github.com/tmeckel/azdo-cli/internal/util" +) + +// TestConnection executes the endpoint type's TestConnection data source when available. +// It polls until the data source reports StatusCode == "ok" or timeout/context cancel. +func TestConnection(cmdCtx util.CmdContext, client serviceendpoint.Client, organization, project string, ep *serviceendpoint.ServiceEndpoint, timeout time.Duration) error { + if ep == nil { + return fmt.Errorf("endpoint required") + } + if ep.Type == nil { + return fmt.Errorf("endpoint.Type required") + } + + types, err := GetServiceEndpointTypes(cmdCtx, organization) + if err != nil { + return err + } + + var matched *serviceendpoint.ServiceEndpointType + for _, t := range types { + if t.Name != nil && *t.Name == *ep.Type { + matched = &t + break + } + } + if matched == nil { + return fmt.Errorf("unknown service endpoint type: %s", *ep.Type) + } + + // Find TestConnection data source + var ds *serviceendpoint.DataSource + if matched.DataSources != nil { + for _, d := range *matched.DataSources { + if d.Name != nil && strings.EqualFold(*d.Name, "TestConnection") { + ds = &d + break + } + } + } + if ds == nil { + return fmt.Errorf("TestConnection not supported for endpoint type %s", *ep.Type) + } + + ctx := cmdCtx.Context() + opts := pollutil.PollOptions{} + if timeout > 0 { + opts.Timeout = timeout + } + + // Build request template + name := "TestConnection" + req := &serviceendpoint.ServiceEndpointRequest{ + DataSourceDetails: &serviceendpoint.DataSourceDetails{DataSourceName: &name}, + ServiceEndpointDetails: &serviceendpoint.ServiceEndpointDetails{ + Data: ep.Data, + Authorization: ep.Authorization, + Url: ep.Url, + Type: ep.Type, + }, + ResultTransformationDetails: &serviceendpoint.ResultTransformationDetails{}, + } + + var lastErr error + err = pollutil.Poll(ctx, func() error { + res, err := client.ExecuteServiceEndpointRequest(ctx, serviceendpoint.ExecuteServiceEndpointRequestArgs{ServiceEndpointRequest: req}) + if err != nil { + lastErr = err + return err + } + if res == nil || res.StatusCode == nil { + lastErr = fmt.Errorf("test connection returned empty result") + return lastErr + } + if strings.EqualFold(*res.StatusCode, "ok") { + return nil + } + // not ok => retry until timeout + lastErr = fmt.Errorf("test connection status: %s", *res.StatusCode) + return lastErr + }, opts) + if err != nil { + if lastErr != nil { + return fmt.Errorf("test connection failed: %w", lastErr) + } + return err + } + return nil +} diff --git a/internal/cmd/serviceendpoint/shared/test_connection_test.go b/internal/cmd/serviceendpoint/shared/test_connection_test.go new file mode 100644 index 00000000..7df22357 --- /dev/null +++ b/internal/cmd/serviceendpoint/shared/test_connection_test.go @@ -0,0 +1,71 @@ +package shared + +import ( + "testing" + "time" + + "github.com/microsoft/azure-devops-go-api/azuredevops/v7/serviceendpoint" + "github.com/tmeckel/azdo-cli/internal/mocks" + "github.com/tmeckel/azdo-cli/internal/test" + typespkg "github.com/tmeckel/azdo-cli/internal/types" + "go.uber.org/mock/gomock" +) + +func TestTestConnectionSuccess(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mock := mocks.NewMockServiceEndpointClient(ctrl) + name := typespkg.ToPtr("ok") + mock.EXPECT().ExecuteServiceEndpointRequest(gomock.Any(), gomock.Any()).Return(&serviceendpoint.ServiceEndpointRequestResult{StatusCode: name}, nil) + + // seed metadata cache so TestConnection finds TestConnection data source + dt := serviceendpoint.DataSource{Name: typespkg.ToPtr("TestConnection")} + st := serviceendpoint.ServiceEndpointType{Name: typespkg.ToPtr("github"), DataSources: &[]serviceendpoint.DataSource{dt}} + setTypesCacheForTest("org", []serviceendpoint.ServiceEndpointType{st}) + + ep := &serviceendpoint.ServiceEndpoint{Type: typespkg.ToPtr("github"), Url: typespkg.ToPtr("https://github.com")} + // minimal cmdCtx stub implementing util.CmdContext methods + cmdCtx := test.NewTestContext(t) + err := TestConnection(cmdCtx, mock, "org", "proj", ep, 1*time.Second) + if err != nil { + t.Fatalf("expected success, got: %v", err) + } +} + +func TestTestConnectionNotSupported(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mock := mocks.NewMockServiceEndpointClient(ctrl) + ep := &serviceendpoint.ServiceEndpoint{Type: typespkg.ToPtr("unknown-type")} + // ensure types cache does not contain this type + clearTypesCacheForTest() + cmdCtx := test.NewTestContext(t) + err := TestConnection(cmdCtx, mock, "org", "proj", ep, 100*time.Millisecond) + if err == nil { + t.Fatalf("expected error for unsupported type") + } +} + +func TestTestConnectionRetriesThenFails(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mock := mocks.NewMockServiceEndpointClient(ctrl) + notOk := typespkg.ToPtr("error") + // always return error status + mock.EXPECT().ExecuteServiceEndpointRequest(gomock.Any(), gomock.Any()).Return(&serviceendpoint.ServiceEndpointRequestResult{StatusCode: notOk}, nil).AnyTimes() + + // seed metadata cache for github + dt := serviceendpoint.DataSource{Name: typespkg.ToPtr("TestConnection")} + st := serviceendpoint.ServiceEndpointType{Name: typespkg.ToPtr("github"), DataSources: &[]serviceendpoint.DataSource{dt}} + setTypesCacheForTest("org", []serviceendpoint.ServiceEndpointType{st}) + + ep := &serviceendpoint.ServiceEndpoint{Type: typespkg.ToPtr("github"), Url: typespkg.ToPtr("https://github.com")} + cmdCtx := test.NewTestContext(t) + err := TestConnection(cmdCtx, mock, "org", "proj", ep, 200*time.Millisecond) + if err == nil { + t.Fatalf("expected failure") + } +} diff --git a/internal/cmd/serviceendpoint/shared/type_registry.go b/internal/cmd/serviceendpoint/shared/type_registry.go new file mode 100644 index 00000000..b1823ff2 --- /dev/null +++ b/internal/cmd/serviceendpoint/shared/type_registry.go @@ -0,0 +1,48 @@ +package shared + +import ( + "fmt" + "sync" + + "github.com/tmeckel/azdo-cli/internal/cmd/util" + "github.com/microsoft/azure-devops-go-api/azuredevops/v7/serviceendpoint" +) + +var typesCache sync.Map // map[string][]serviceendpoint.ServiceEndpointType + +// GetServiceEndpointTypes fetches service endpoint types for an organization and caches them for +// the duration of the command process. It uses the vendored Azure DevOps SDK. +func GetServiceEndpointTypes(cmdCtx util.CmdContext, organization string) ([]serviceendpoint.ServiceEndpointType, error) { + if organization == "" { + return nil, fmt.Errorf("organization required") + } + if v, ok := typesCache.Load(organization); ok { + return v.([]serviceendpoint.ServiceEndpointType), nil + } + + client, err := cmdCtx.ClientFactory().ServiceEndpoint(cmdCtx.Context(), organization) + if err != nil { + return nil, fmt.Errorf("create serviceendpoint client: %w", err) + } + + res, err := client.GetServiceEndpointTypes(cmdCtx.Context(), serviceendpoint.GetServiceEndpointTypesArgs{}) + if err != nil { + return nil, fmt.Errorf("get service endpoint types: %w", err) + } + if res == nil { + return nil, fmt.Errorf("no service endpoint types returned") + } + + typesCache.Store(organization, *res) + return *res, nil +} + +// Helpers for tests +func setTypesCacheForTest(org string, types []serviceendpoint.ServiceEndpointType) { + typesCache.Store(org, types) +} + +func clearTypesCacheForTest() { + typesCache = sync.Map{} +} + diff --git a/internal/cmd/serviceendpoint/shared/type_validate.go b/internal/cmd/serviceendpoint/shared/type_validate.go new file mode 100644 index 00000000..d368575b --- /dev/null +++ b/internal/cmd/serviceendpoint/shared/type_validate.go @@ -0,0 +1,74 @@ +package shared + +import ( + "fmt" + + "github.com/microsoft/azure-devops-go-api/azuredevops/v7/serviceendpoint" + "github.com/tmeckel/azdo-cli/internal/cmd/util" +) + +// ValidateEndpointAgainstMetadata validates that the given endpoint's Type and +// Authorization.Scheme and Parameters match the live metadata for the organization. +func ValidateEndpointAgainstMetadata(cmdCtx util.CmdContext, organization string, endpoint *serviceendpoint.ServiceEndpoint) error { + if endpoint == nil { + return fmt.Errorf("endpoint required") + } + if endpoint.Type == nil { + return fmt.Errorf("endpoint.Type required") + } + + types, err := GetServiceEndpointTypes(cmdCtx, organization) + if err != nil { + return err + } + + var matched *serviceendpoint.ServiceEndpointType + for _, t := range types { + if t.Name != nil && *t.Name == *endpoint.Type { + matched = &t + break + } + } + if matched == nil { + return fmt.Errorf("unknown service endpoint type: %s", *endpoint.Type) + } + + if endpoint.Authorization == nil || endpoint.Authorization.Scheme == nil { + return fmt.Errorf("endpoint.Authorization.Scheme required for type validation") + } + + scheme := *endpoint.Authorization.Scheme + var matchedScheme *serviceendpoint.ServiceEndpointAuthenticationScheme + for _, s := range *matched.AuthenticationSchemes { + if s.Scheme != nil && *s.Scheme == scheme { + matchedScheme = &s + break + } + } + if matchedScheme == nil { + return fmt.Errorf("scheme %s not supported for type %s", scheme, *endpoint.Type) + } + + // Validate required input descriptors + params := map[string]string{} + if endpoint.Authorization.Parameters != nil { + params = *endpoint.Authorization.Parameters + } + + if matchedScheme.InputDescriptors != nil { + for _, desc := range *matchedScheme.InputDescriptors { + if desc.Id == nil { + continue + } + id := *desc.Id + // If the input validation marks this as required, enforce presence. + if desc.Validation != nil && desc.Validation.IsRequired != nil && *desc.Validation.IsRequired { + if _, ok := params[id]; !ok { + return fmt.Errorf("missing required auth parameter: %s", id) + } + } + } + } + + return nil +} diff --git a/internal/cmd/serviceendpoint/shared/update_common.go b/internal/cmd/serviceendpoint/shared/update_common.go new file mode 100644 index 00000000..5d12b534 --- /dev/null +++ b/internal/cmd/serviceendpoint/shared/update_common.go @@ -0,0 +1,45 @@ +package shared + +import ( + "context" + "time" + + "github.com/spf13/cobra" + + "github.com/tmeckel/azdo-cli/internal/cmd/util" +) + +// updateCommonOptions contains flags/args that apply to all typed update commands. +type updateCommonOptions struct { + Name string + Description string + + WaitReady bool + ValidateSchema bool + ValidateConnection bool + GrantAllPipelines bool + + Timeout time.Duration + Exporter util.Exporter +} + +// AddUpdateCommonFlags registers the common flags on an update command. +func AddUpdateCommonFlags(cmd *cobra.Command) *cobra.Command { + common := updateCommonOptions{} + cmd.Flags().StringVar(&common.Name, "name", "", "New friendly name for the service endpoint") + cmd.Flags().StringVar(&common.Description, "description", "", "New description for the service endpoint") + cmd.Flags().BoolVar(&common.WaitReady, "wait", false, "Wait until the endpoint reports ready/failed") + cmd.Flags().DurationVar(&common.Timeout, "timeout", 2*time.Minute, "Maximum time to wait when --wait or --validate-connection is enabled") + cmd.Flags().BoolVar(&common.ValidateSchema, "validate-schema", false, "Validate auth scheme/params against endpoint type metadata (opt-in)") + cmd.Flags().BoolVar(&common.ValidateConnection, "validate-connection", false, "Run TestConnection after update (opt-in)") + cmd.Flags().BoolVar(&common.GrantAllPipelines, "grant-permission-to-all-pipelines", false, "Grant (true) or revoke (false) access permission to all pipelines") + util.AddJSONFlags(cmd, &common.Exporter, ServiceEndpointJSONFields) + + ctx := cmd.Context() + if ctx == nil { + ctx = context.Background() + } + cmd.SetContext(context.WithValue(ctx, "updateCommonOptions", &common)) + + return cmd +} diff --git a/internal/cmd/serviceendpoint/shared/update_common_test.go b/internal/cmd/serviceendpoint/shared/update_common_test.go new file mode 100644 index 00000000..39c90d14 --- /dev/null +++ b/internal/cmd/serviceendpoint/shared/update_common_test.go @@ -0,0 +1,143 @@ +package shared + +import ( + "context" + "io" + "testing" + "time" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/require" +) + +func TestAddUpdateCommonFlags(t *testing.T) { + t.Parallel() + + t.Run("Flags registration", func(t *testing.T) { + t.Parallel() + cmd := &cobra.Command{ + Use: "test", + } + + updatedCmd := AddUpdateCommonFlags(cmd) + + // Verify flags exist and have correct metadata + flags := []struct { + name string + usage string + defValue string + shortHand string + }{ + {name: "name", usage: "New friendly name for the service endpoint", defValue: ""}, + {name: "description", usage: "New description for the service endpoint", defValue: ""}, + {name: "wait", usage: "Wait until the endpoint reports ready/failed", defValue: "false"}, + {name: "timeout", usage: "Maximum time to wait when --wait or --validate-connection is enabled", defValue: "2m0s"}, + {name: "validate-schema", usage: "Validate auth scheme/params against endpoint type metadata (opt-in)", defValue: "false"}, + {name: "validate-connection", usage: "Run TestConnection after update (opt-in)", defValue: "false"}, + {name: "grant-permission-to-all-pipelines", usage: "Grant (true) or revoke (false) access permission to all pipelines", defValue: "false"}, + } + + for _, f := range flags { + flag := updatedCmd.Flag(f.name) + require.NotNil(t, flag, "flag %s should exist", f.name) + require.Equal(t, f.usage, flag.Usage, "usage for flag %s", f.name) + require.Equal(t, f.defValue, flag.DefValue, "default value for flag %s", f.name) + } + + // Verify JSON flags are added + require.NotNil(t, updatedCmd.Flag("json"), "json flag should exist") + + // Verify context + ctx := updatedCmd.Context() + require.NotNil(t, ctx, "context should not be nil") + opts := ctx.Value("updateCommonOptions") + require.NotNil(t, opts, "updateCommonOptions should be in context") + _, ok := opts.(*updateCommonOptions) + require.True(t, ok, "context value should be of type *updateCommonOptions") + }) + + t.Run("Flag Parsing", func(t *testing.T) { + t.Parallel() + cmd := &cobra.Command{ + Use: "test", + Run: func(cmd *cobra.Command, args []string) {}, + } + AddUpdateCommonFlags(cmd) + + // Parse flags + err := cmd.ParseFlags([]string{ + "--name", "my-endpoint", + "--description", "desc", + "--wait", + "--timeout", "5m", + "--validate-schema", + "--validate-connection", + "--grant-permission-to-all-pipelines", + }) + require.NoError(t, err) + + // Retrieve options from context + ctx := cmd.Context() + opts := ctx.Value("updateCommonOptions").(*updateCommonOptions) + + require.Equal(t, "my-endpoint", opts.Name) + require.Equal(t, "desc", opts.Description) + require.True(t, opts.WaitReady) + require.Equal(t, 5*time.Minute, opts.Timeout) + require.True(t, opts.ValidateSchema) + require.True(t, opts.ValidateConnection) + require.True(t, opts.GrantAllPipelines) + }) + + t.Run("Invalid flag values", func(t *testing.T) { + t.Parallel() + cmd := &cobra.Command{ + Use: "test", + Run: func(cmd *cobra.Command, args []string) {}, + } + AddUpdateCommonFlags(cmd) + + // Parse invalid timeout + err := cmd.ParseFlags([]string{"--name", "test", "--timeout", "invalid"}) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid argument \"invalid\" for \"--timeout\"") + }) + + t.Run("Existing context is preserved", func(t *testing.T) { + t.Parallel() + type key string + ctx := context.WithValue(context.Background(), key("existing"), "value") + cmd := &cobra.Command{ + Use: "test", + } + cmd.SetContext(ctx) + + AddUpdateCommonFlags(cmd) + + updatedCtx := cmd.Context() + require.Equal(t, "value", updatedCtx.Value(key("existing"))) + require.NotNil(t, updatedCtx.Value("updateCommonOptions")) + }) + + t.Run("Optional name flag", func(t *testing.T) { + t.Parallel() + cmd := &cobra.Command{ + Use: "test", + Run: func(cmd *cobra.Command, args []string) {}, + } + AddUpdateCommonFlags(cmd) + + // Disable printing to avoid noise + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + + cmd.SetArgs([]string{"--description", "desc"}) + err := cmd.Execute() + require.NoError(t, err) + + ctx := cmd.Context() + opts := ctx.Value("updateCommonOptions").(*updateCommonOptions) + require.Equal(t, "desc", opts.Description) + require.Equal(t, "", opts.Name) + }) +} diff --git a/internal/cmd/serviceendpoint/shared/wait_ready.go b/internal/cmd/serviceendpoint/shared/wait_ready.go new file mode 100644 index 00000000..10f93b4d --- /dev/null +++ b/internal/cmd/serviceendpoint/shared/wait_ready.go @@ -0,0 +1,59 @@ +package shared + +import ( + "context" + "errors" + "fmt" + "time" + + pollutil "github.com/tmeckel/azdo-cli/internal/util" + "github.com/microsoft/azure-devops-go-api/azuredevops/v7/serviceendpoint" +) + +// WaitForReady polls GetServiceEndpointDetails until IsReady==true or a terminal +// failed state is detected in OperationStatus, or context/timeout occurs. +func WaitForReady(ctx context.Context, client serviceendpoint.Client, project string, ep *serviceendpoint.ServiceEndpoint, timeout time.Duration) (*serviceendpoint.ServiceEndpoint, error) { + if ep == nil || ep.Id == nil { + return nil, fmt.Errorf("endpoint or id missing") + } + opts := pollutil.PollOptions{Tries: 0} + if timeout > 0 { + opts.Timeout = timeout + } + + var last *serviceendpoint.ServiceEndpoint + err := pollutil.Poll(ctx, func() error { + id := *ep.Id + res, err := client.GetServiceEndpointDetails(ctx, serviceendpoint.GetServiceEndpointDetailsArgs{ + Project: &project, + EndpointId: &id, + }) + if err != nil { + // transient error, retry + last = nil + return err + } + last = res + if res != nil && res.IsReady != nil && *res.IsReady { + return nil + } + // Inspect OperationStatus defensively for a failure signal + if res != nil && res.OperationStatus != nil { + if opMap, ok := res.OperationStatus.(map[string]any); ok { + if stateRaw, ok := opMap["state"]; ok { + if stateStr, ok := stateRaw.(string); ok { + if stateStr == "failed" { + return fmt.Errorf("service endpoint creation failed: %v", res.OperationStatus) + } + } + } + } + } + return errors.New("not ready") + }, opts) + + if err != nil { + return last, err + } + return last, nil +} diff --git a/internal/cmd/serviceendpoint/shared/wait_ready_test.go b/internal/cmd/serviceendpoint/shared/wait_ready_test.go new file mode 100644 index 00000000..7b8e8a72 --- /dev/null +++ b/internal/cmd/serviceendpoint/shared/wait_ready_test.go @@ -0,0 +1,50 @@ +package shared + +import ( + "context" + "testing" + "time" + + "github.com/google/uuid" + "github.com/microsoft/azure-devops-go-api/azuredevops/v7/serviceendpoint" + "github.com/tmeckel/azdo-cli/internal/mocks" + "go.uber.org/mock/gomock" +) + +func TestWaitForReadySuccess(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mock := mocks.NewMockServiceEndpointClient(ctrl) + // first call returns not ready, second returns ready + id := uuid.New() + ep := serviceendpoint.ServiceEndpoint{Id: &id} + + mock.EXPECT().GetServiceEndpointDetails(gomock.Any(), gomock.Any()).Return(&serviceendpoint.ServiceEndpoint{Id: &id, IsReady: newTrue()}, nil).AnyTimes() + + _, err := WaitForReady(context.Background(), mock, "proj", &ep, 1*time.Second) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } +} + +func TestWaitForReadyFailedState(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mock := mocks.NewMockServiceEndpointClient(ctrl) + id := uuid.New() + op := map[string]any{"state": "failed"} + mock.EXPECT().GetServiceEndpointDetails(gomock.Any(), gomock.Any()).Return(&serviceendpoint.ServiceEndpoint{Id: &id, OperationStatus: op}, nil).AnyTimes() + + ep := serviceendpoint.ServiceEndpoint{Id: &id} + _, err := WaitForReady(context.Background(), mock, "proj", &ep, 500*time.Millisecond) + if err == nil { + t.Fatalf("expected error for failed state") + } +} + +func newTrue() *bool { b := true; return &b } + +// utilStub is the minimal CmdContext used by tests. +// no utilStub required since WaitForReady accepts context.Context diff --git a/internal/test/acc_helpers.go b/internal/test/helpers.go similarity index 92% rename from internal/test/acc_helpers.go rename to internal/test/helpers.go index 4552fbd5..2ebcaa28 100644 --- a/internal/test/acc_helpers.go +++ b/internal/test/helpers.go @@ -18,13 +18,14 @@ import ( "github.com/tmeckel/azdo-cli/internal/iostreams" "github.com/tmeckel/azdo-cli/internal/printer" "github.com/tmeckel/azdo-cli/internal/prompter" + u "github.com/tmeckel/azdo-cli/internal/util" ) const ( accToggleEnv = "AZDO_ACC_TEST" accOrgEnv = "AZDO_ACC_ORG" accPATEnv = "AZDO_ACC_PAT" - accTimeoutSeconds = 60 + accTimeoutSeconds = 240 accTimeoutEnv = "AZDO_ACC_TIMEOUT" accProjectEnv = "AZDO_ACC_PROJECT" ) @@ -50,8 +51,9 @@ func (n *nullPrinter) Render() error { } type TestCase struct { - PreCheck func() error - Steps []Step + PreCheck func() error + Steps []Step + AcceptanceTest bool } type TestContext interface { @@ -135,15 +137,41 @@ func (tc *testContext) Value(key any) (any, bool) { } // Precheck and context builder -func newTestContext(t *testing.T) TestContext { +func NewAccTestContext(t *testing.T) TestContext { org := os.Getenv(accOrgEnv) pat := os.Getenv(accPATEnv) project := os.Getenv(accProjectEnv) + timeoutVal := os.Getenv(accTimeoutEnv) if org == "" || pat == "" { t.Fatalf("missing acceptance env variables: %q, %q", accOrgEnv, accPATEnv) } + return initTestContext( + t, + org, + pat, + project, + timeoutVal, + ) +} + +func NewTestContext(t *testing.T) TestContext { + sb := u.NewStringBuilder(). + WithUpperLetters(). + WithUpperLetters() + org, _ := sb.Generate(8) + pat, _ := sb.Generate(8) + project, _ := sb.Generate(8) + return initTestContext( + t, + org, + pat, + project, + "", + ) +} +func initTestContext(t *testing.T, org, pat, project, timeoutVal string) TestContext { orgurl := fmt.Sprintf("https://dev.azure.com/%s", org) // Build a safe YAML configuration using marshaling instead of fmt.Sprintf interpolation. @@ -190,7 +218,6 @@ func newTestContext(t *testing.T) TestContext { var baseCtx context.Context var cancel context.CancelFunc - timeoutVal := os.Getenv(accTimeoutEnv) debugVal := os.Getenv("AZDO_DEBUG") if timeoutVal == "-1" || debugVal == "1" { @@ -327,7 +354,13 @@ func Test(t *testing.T, tc TestCase) { t.Fatalf("test PreCheck failed: %v", err) } } - ctx := newTestContext(t) + var ctx TestContext + if tc.AcceptanceTest { + ctx = NewAccTestContext(t) + } else { + ctx = NewTestContext(t) + } + for _, s := range tc.Steps { if err := runStep(ctx, s); err != nil { t.Fatalf("%v", err) diff --git a/internal/test/acc_helpers_test.go b/internal/test/helpers_test.go similarity index 97% rename from internal/test/acc_helpers_test.go rename to internal/test/helpers_test.go index 98841d11..48c168e1 100644 --- a/internal/test/acc_helpers_test.go +++ b/internal/test/helpers_test.go @@ -67,7 +67,7 @@ func TestTestContext_OrgFields(t *testing.T) { t.Setenv(accProjectEnv, "proj-value") // newTestContext validates env vars and returns a TestContext built from them. - tc := newTestContext(t) + tc := NewAccTestContext(t) // Verify that the TestContext accessors return the expected values from env. if got := tc.Org(); got != org { @@ -98,7 +98,7 @@ func TestNewTestContext_Config(t *testing.T) { t.Setenv(accProjectEnv, "proj-alpha") // Call newTestContext which will build a config from the env vars. - tc := newTestContext(t) + tc := NewAccTestContext(t) // Retrieve the underlying config via TestContext.Config() cfg, err := tc.Config() @@ -119,11 +119,7 @@ func TestNewTestContext_Config(t *testing.T) { // TestTestContextValueStore ensures SetValue/Value share data between steps. func TestTestContextValueStore(t *testing.T) { - t.Setenv(accOrgEnv, "org") - t.Setenv(accPATEnv, "pat") - t.Setenv(accProjectEnv, "project") - - tc := newTestContext(t) + tc := NewTestContext(t) tc.SetValue("key", 42) if v, ok := tc.Value("missing"); ok || v != nil { diff --git a/internal/test/poll.go b/internal/test/poll.go deleted file mode 100644 index 57c63856..00000000 --- a/internal/test/poll.go +++ /dev/null @@ -1,73 +0,0 @@ -package test - -import ( - "fmt" - "math" - "time" -) - -// PollFunc is the function to be executed by the Poll function. -// It should return an error to indicate a failure and that it should be retried. -// A nil error indicates success. -type PollFunc func() error - -// PollOptions configures the behavior of the Poll function. -type PollOptions struct { - // Tries is the maximum number of times to try the function. - // If Tries is 0, it will retry until Timeout is reached. - Tries int - // Delay is the time to wait between retries. - // If Delay is 0, binary exponential backoff is used, starting at 2 seconds. - Delay time.Duration - // Timeout is the maximum total time to spend retrying. - // If Timeout is 0, there is no time limit. - Timeout time.Duration -} - -// Poll executes the given function `fn` until it returns no error, or until the -// configured number of tries or timeout is reached. -func Poll(fn PollFunc, opts PollOptions) error { - var lastErr error - - if opts.Tries == 0 && opts.Timeout == 0 { - opts.Tries = 1 - } - - startTime := time.Now() - for i := 0; opts.Tries == 0 || i < opts.Tries; i++ { - if opts.Timeout > 0 && time.Since(startTime) > opts.Timeout { - if lastErr != nil { - return fmt.Errorf("timed out after %v, last error: %w", opts.Timeout, lastErr) - } - return fmt.Errorf("timed out after %v", opts.Timeout) - } - - err := fn() - if err == nil { - return nil - } - lastErr = err - - if opts.Tries > 0 && i == opts.Tries-1 { - break - } - - var wait time.Duration - if opts.Delay > 0 { - wait = opts.Delay - } else { - // Binary exponential backoff with initial 2 seconds - wait = time.Duration(math.Pow(2, float64(i))) * 2 * time.Second - } - time.Sleep(wait) - } - - if lastErr != nil { - if opts.Tries > 0 { - return fmt.Errorf("after %d attempts, last error: %w", opts.Tries, lastErr) - } - return fmt.Errorf("last error: %w", lastErr) - } - - return fmt.Errorf("polling failed without returning an error") -} diff --git a/internal/test/poll_test.go b/internal/test/poll_test.go deleted file mode 100644 index ebebd6d7..00000000 --- a/internal/test/poll_test.go +++ /dev/null @@ -1,109 +0,0 @@ -package test - -import ( - "errors" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -func TestPollDefaultsToSingleTry(t *testing.T) { - var calls int - - err := Poll(func() error { - calls++ - return nil - }, PollOptions{}) - - require.NoError(t, err) - require.Equal(t, 1, calls) -} - -func TestPollRetriesUntilSuccess(t *testing.T) { - var calls int - errBoom := errors.New("temporary failure") - - err := Poll(func() error { - calls++ - if calls < 3 { - return errBoom - } - - return nil - }, PollOptions{ - Tries: 5, - Delay: time.Millisecond, - }) - - require.NoError(t, err) - require.Equal(t, 3, calls) -} - -func TestPollReturnsAfterMaxAttempts(t *testing.T) { - var calls int - errBoom := errors.New("permanent failure") - - err := Poll(func() error { - calls++ - return errBoom - }, PollOptions{ - Tries: 3, - Delay: time.Millisecond, - }) - - require.Error(t, err) - require.ErrorContains(t, err, "after 3 attempts") - require.ErrorIs(t, err, errBoom) - require.Equal(t, 3, calls) -} - -func TestPollTimeout(t *testing.T) { - var calls int - errBoom := errors.New("timeout failure") - - err := Poll(func() error { - calls++ - return errBoom - }, PollOptions{ - Delay: 20 * time.Millisecond, - Timeout: 30 * time.Millisecond, - }) - - require.Error(t, err) - require.ErrorContains(t, err, "timed out after") - require.ErrorContains(t, err, errBoom.Error()) - require.GreaterOrEqual(t, calls, 2) -} - -func TestPollBinaryExponentialBackoff(t *testing.T) { - // Expected waits: 2s, 4s, 8s (formula: 2^i * 2s) - expectedDurations := []time.Duration{2 * time.Second, 4 * time.Second, 8 * time.Second} - var lastCall time.Time - var callIndex int - - err := Poll(func() error { - now := time.Now() - if callIndex > 0 { - elapsed := now.Sub(lastCall) - expected := expectedDurations[callIndex-1] - - // Allow ±10% margin for scheduler jitter - min := expected - expected/10 - max := expected + expected/10 - if elapsed < min || elapsed > max { - t.Fatalf("Call %d: expected ~%v wait, got %v", callIndex, expected, elapsed) - } - } - if callIndex >= len(expectedDurations) { - return nil // succeed after verifying waits - } - lastCall = now - callIndex++ - return errors.New("simulated failure") - }, PollOptions{ - Tries: len(expectedDurations) + 1, // enough tries to verify all intervals - }) // Delay left at zero - - require.NoError(t, err) -} diff --git a/internal/util/poll.go b/internal/util/poll.go new file mode 100644 index 00000000..ef2c9658 --- /dev/null +++ b/internal/util/poll.go @@ -0,0 +1,96 @@ +package util + +import ( + "context" + "fmt" + "math" + "time" +) + +// PollFunc is the function to be executed by the Poll function. +// It should return an error to indicate a failure and that it should be retried. +// A nil error indicates success. +type PollFunc func() error + +// PollOptions configures the behavior of the Poll function. +type PollOptions struct { + // Tries is the maximum number of times to try the function. + // If Tries is 0, it will retry until Timeout is reached. + Tries int + // Delay is the time to wait between retries. + // If Delay is 0, binary exponential backoff is used, starting at 2 seconds. + Delay time.Duration + // Timeout is the maximum total time to spend retrying. + // If Timeout is 0, there is no time limit. + Timeout time.Duration +} + +// Poll executes the given function `fn` until it returns no error, or until the +// configured number of tries, timeout, or context cancellation is reached. +// +// The function respects the provided context for cancellation. If `opts.Timeout` +// is non-zero, the timeout is applied in addition to the provided context. +func Poll(ctx context.Context, fn PollFunc, opts PollOptions) error { + var lastErr error + + // If neither tries nor timeout are set, default to a single try (legacy behavior) + if opts.Tries == 0 && opts.Timeout == 0 { + opts.Tries = 1 + } + + // If a timeout is specified, derive a child context so we can cancel after timeout. + if opts.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, opts.Timeout) + defer cancel() + } + + const maxDelay = 30 * time.Second + + for i := 0; opts.Tries == 0 || i < opts.Tries; i++ { + select { + case <-ctx.Done(): + // prefer returning the context error + return ctx.Err() + default: + } + + err := fn() + if err == nil { + return nil + } + lastErr = err + + if opts.Tries > 0 && i == opts.Tries-1 { + break + } + + var wait time.Duration + if opts.Delay > 0 { + wait = opts.Delay + } else { + // Binary exponential backoff with initial 2 seconds: 2^i * 2s + wait = time.Duration(math.Pow(2, float64(i))) * 2 * time.Second + if wait > maxDelay { + wait = maxDelay + } + } + + // Sleep but wake early if context is done + select { + case <-time.After(wait): + // continue to next iteration + case <-ctx.Done(): + return ctx.Err() + } + } + + if lastErr != nil { + if opts.Tries > 0 { + return fmt.Errorf("after %d attempts, last error: %w", opts.Tries, lastErr) + } + return fmt.Errorf("last error: %w", lastErr) + } + + return fmt.Errorf("polling failed without returning an error") +} diff --git a/internal/util/poll_test.go b/internal/util/poll_test.go new file mode 100644 index 00000000..d90de2bf --- /dev/null +++ b/internal/util/poll_test.go @@ -0,0 +1,100 @@ +package util + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestPollDefaultsToSingleTry(t *testing.T) { + var calls int + + err := Poll(context.Background(), func() error { + calls++ + return nil + }, PollOptions{}) + + require.NoError(t, err) + require.Equal(t, 1, calls) +} + +func TestPollRetriesUntilSuccess(t *testing.T) { + var calls int + errBoom := errors.New("temporary failure") + + err := Poll(context.Background(), func() error { + calls++ + if calls < 3 { + return errBoom + } + return nil + }, PollOptions{ + Tries: 5, + Delay: time.Millisecond, + }) + + require.NoError(t, err) + require.Equal(t, 3, calls) +} + +func TestPollReturnsAfterMaxAttempts(t *testing.T) { + var calls int + errBoom := errors.New("permanent failure") + + err := Poll(context.Background(), func() error { + calls++ + return errBoom + }, PollOptions{ + Tries: 3, + Delay: time.Millisecond, + }) + + require.Error(t, err) + require.ErrorContains(t, err, "after 3 attempts") + require.ErrorIs(t, err, errBoom) + require.Equal(t, 3, calls) +} + +func TestPollTimeout(t *testing.T) { + var calls int + errBoom := errors.New("timeout failure") + + err := Poll(context.Background(), func() error { + calls++ + return errBoom + }, PollOptions{ + Delay: 20 * time.Millisecond, + Timeout: 30 * time.Millisecond, + }) + + require.Error(t, err) + require.ErrorContains(t, err, "context deadline exceeded") + require.GreaterOrEqual(t, calls, 1) +} + +func TestPollRespectsContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := Poll(ctx, func() error { return errors.New("no-op") }, PollOptions{Delay: time.Millisecond}) + require.ErrorIs(t, err, context.Canceled) +} + +func TestPollBinaryExponentialBackoff(t *testing.T) { + var calls int + err := Poll(context.Background(), func() error { + calls++ + if calls >= 3 { + return nil + } + return errors.New("simulated failure") + }, PollOptions{ + Tries: 4, + }) + + require.NoError(t, err) + require.GreaterOrEqual(t, calls, 3) +} diff --git a/internal/util/stringBuilder.go b/internal/util/stringBuilder.go new file mode 100644 index 00000000..fa12ddd1 --- /dev/null +++ b/internal/util/stringBuilder.go @@ -0,0 +1,123 @@ +package util + +import ( + "encoding/hex" + "errors" + "fmt" + "math/rand" + "strings" + "time" +) + +const ( + upperLetters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + lowerLetters = "abcdefghijklmnopqrstuvwxyz" + numbers = "0123456789" + specialChars = "!@#$%^&*()_+-=[]{}|;':\",./<>?" +) + +// StringBuilder is a builder for configuring and generating custom random strings. +// By default, no character sets are included. You must add at least one using With* methods. +// Mutations like AsHexString or AsBinaryString can be applied to transform the output after generation. +type StringBuilder struct { + includeUpper bool + includeLower bool + includeNumbers bool + includeSpecial bool + mutation string // "" (none), "hex", or "binary" + rand *rand.Rand +} + +// NewStringBuilder creates a new StringBuilder with no default character sets. +func NewStringBuilder() *StringBuilder { + s := rand.NewSource(time.Now().UnixNano()) + return &StringBuilder{ + rand: rand.New(s), + } +} + +// WithUpperLetters adds uppercase letters (A-Z) to the character set. +func (sb *StringBuilder) WithUpperLetters() *StringBuilder { + sb.includeUpper = true + return sb +} + +// WithLowerLetters adds lowercase letters (a-z) to the character set. +func (sb *StringBuilder) WithLowerLetters() *StringBuilder { + sb.includeLower = true + return sb +} + +// WithNumbers adds digits (0-9) to the character set. +func (sb *StringBuilder) WithNumbers() *StringBuilder { + sb.includeNumbers = true + return sb +} + +// WithSpecialCharacters adds special characters (!@#$%^&*()_+-=[]{}|;':",./<>?) to the character set. +func (sb *StringBuilder) WithSpecialCharacters() *StringBuilder { + sb.includeSpecial = true + return sb +} + +// AsHexString sets the mutation to convert the generated string to its hexadecimal representation. +// This transforms each byte of the string into two hex characters, doubling the length. +func (sb *StringBuilder) AsHexString() *StringBuilder { + sb.mutation = "hex" + return sb +} + +// AsBinaryString sets the mutation to convert the generated string to its binary representation. +// This transforms each byte into an 8-bit binary string, multiplying the length by 8. +func (sb *StringBuilder) AsBinaryString() *StringBuilder { + sb.mutation = "binary" + return sb +} + +// Generate creates a random string of the specified length using the configured character set. +// If no character sets are selected, it returns an error. +// After generation, any set mutation (hex or binary) is applied to the output. +// Length specifies the pre-mutation length; the final length will differ if a mutation is applied. +// Returns an empty string and error if length <= 0 or no sets selected. +func (sb *StringBuilder) Generate(length int) (string, error) { + if length <= 0 { + return "", errors.New("length must be positive") + } + + var charset string + if sb.includeUpper { + charset += upperLetters + } + if sb.includeLower { + charset += lowerLetters + } + if sb.includeNumbers { + charset += numbers + } + if sb.includeSpecial { + charset += specialChars + } + + if charset == "" { + return "", errors.New("no character sets selected") + } + + b := make([]byte, length) + for i := range b { + b[i] = charset[sb.rand.Int31n(int32(len(charset)))] + } + s := string(b) + + switch sb.mutation { + case "hex": + return hex.EncodeToString([]byte(s)), nil + case "binary": + var strBuilder strings.Builder + for _, by := range []byte(s) { + strBuilder.WriteString(fmt.Sprintf("%08b", by)) + } + return strBuilder.String(), nil + default: + return s, nil + } +}