diff --git a/.gitignore b/.gitignore index bb66236..e31bba0 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ *.so *.dylib aks-flex-node +AKSFlexNode # Test binary, built with `go test -c` *.test @@ -42,6 +43,7 @@ Thumbs.db # Config files with sensitive data (keep sample config) config.json +Standard_D8pds_v6_sku.json # Environment files with secrets .env diff --git a/commands.go b/commands.go index bf8d798..c382dc2 100644 --- a/commands.go +++ b/commands.go @@ -25,6 +25,9 @@ var ( BuildTime = "unknown" ) +// Unbootstrap command flags +var cleanupMode string + // NewAgentCommand creates a new agent command func NewAgentCommand() *cobra.Command { cmd := &cobra.Command{ @@ -44,12 +47,19 @@ func NewUnbootstrapCommand() *cobra.Command { cmd := &cobra.Command{ Use: "unbootstrap", Short: "Remove AKS node configuration and Arc connection", - Long: "Clean up and remove all AKS node components and Arc registration from this machine", + Long: `Clean up and remove all AKS node components and Arc registration from this machine. + +For private clusters (config has private: true), this also handles VPN cleanup: + --cleanup-mode=local Remove node and local VPN config, keep Gateway (default) + --cleanup-mode=full Remove everything including Gateway VM and Azure resources`, RunE: func(cmd *cobra.Command, args []string) error { return runUnbootstrap(cmd.Context()) }, } + cmd.Flags().StringVar(&cleanupMode, "cleanup-mode", "local", + "[private cluster only] Cleanup mode: 'local' (keep Gateway) or 'full' (remove all Azure resources)") + return cmd } @@ -87,6 +97,13 @@ func runAgent(ctx context.Context) error { return err } + // Print visible success message + fmt.Println() + fmt.Println("========================================") + fmt.Println(" Join process finished successfully!") + fmt.Println("========================================") + fmt.Println() + // After successful bootstrap, transition to daemon mode logger.Info("Bootstrap completed successfully, transitioning to daemon mode...") return runDaemonLoop(ctx, cfg) @@ -101,6 +118,11 @@ func runUnbootstrap(ctx context.Context) error { return fmt.Errorf("failed to load config from %s: %w", configPath, err) } + // Pass cleanup mode to config so the PrivateClusterUninstall step can read it + if cfg.Azure.TargetCluster != nil { + cfg.Azure.TargetCluster.CleanupMode = cleanupMode + } + bootstrapExecutor := bootstrapper.New(cfg, logger) result, err := bootstrapExecutor.Unbootstrap(ctx) if err != nil { @@ -108,7 +130,15 @@ func runUnbootstrap(ctx context.Context) error { } // Handle and log the result (unbootstrap is more lenient with failures) - return handleExecutionResult(result, "unbootstrap", logger) + if err := handleExecutionResult(result, "unbootstrap", logger); err != nil { + return err + } + + // Print final success message + fmt.Println() + fmt.Println("\033[0;32mSUCCESS:\033[0m Unbootstrap completed successfully!") + + return nil } // runVersion displays version information diff --git a/go.mod b/go.mod index 35db3e7..96792e5 100644 --- a/go.mod +++ b/go.mod @@ -6,8 +6,11 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v3 v3.0.0-beta.2 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6 v6.4.0 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5 v5.0.0 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/hybridcompute/armhybridcompute v1.2.0 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6 v6.2.0 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armsubscriptions v1.3.0 github.com/Azure/go-autorest/autorest/to v0.4.1 github.com/google/renameio/v2 v2.0.2 github.com/google/uuid v1.6.0 diff --git a/go.sum b/go.sum index 42dd70e..2fde20a 100644 --- a/go.sum +++ b/go.sum @@ -8,14 +8,22 @@ github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDo github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v3 v3.0.0-beta.2 h1:qiir/pptnHqp6hV8QwV+IExYIf6cPsXBfUDUXQ27t2Y= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v3 v3.0.0-beta.2/go.mod h1:jVRrRDLCOuif95HDYC23ADTMlvahB7tMdl519m9Iyjc= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6 v6.4.0 h1:z7Mqz6l0EFH549GvHEqfjKvi+cRScxLWbaoeLm9wxVQ= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6 v6.4.0/go.mod h1:v6gbfH+7DG7xH2kUNs+ZJ9tF6O3iNnR85wMtmr+F54o= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5 v5.0.0 h1:5n7dPVqsWfVKw+ZiEKSd3Kzu7gwBkbEBkeXb8rgaE9Q= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5 v5.0.0/go.mod h1:HcZY0PHPo/7d75p99lB6lK0qYOP4vLRJUBpiehYXtLQ= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/hybridcompute/armhybridcompute v1.2.0 h1:7UuAn4ljE+H3GQ7qts3c7oAaMRvge68EgyckoNP/1Ro= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/hybridcompute/armhybridcompute v1.2.0/go.mod h1:F2eDq/BGK2LOEoDtoHbBOphaPqcjT0K/Y5Am8vf7+0w= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0 h1:PTFGRSlMKCQelWwxUyYVEUqseBJVemLyqWJjvMyt0do= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0/go.mod h1:LRr2FzBTQlONPPa5HREE5+RjSCTXl7BwOvYOaWTqCaI= -github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.1.1 h1:7CBQ+Ei8SP2c6ydQTGCCrS35bDxgTMfoP2miAwK++OU= -github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.1.1/go.mod h1:c/wcGeGx5FUPbM/JltUYHZcKmigwyVLJlDq+4HdtXaw= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v3 v3.1.0 h1:2qsIIvxVT+uE6yrNldntJKlLRgxGbZ85kgtz5SNBhMw= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v3 v3.1.0/go.mod h1:AW8VEadnhw9xox+VaVd9sP7NjzOAnaZBLRH6Tq3cJ38= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6 v6.2.0 h1:HYGD75g0bQ3VO/Omedm54v4LrD3B1cGImuRF3AJ5wLo= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6 v6.2.0/go.mod h1:ulHyBFJOI0ONiRL4vcJTmS7rx18jQQlEPmAgo80cRdM= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.2.0 h1:Dd+RhdJn0OTtVGaeDLZpcumkIVCtA/3/Fo42+eoYvVM= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.2.0/go.mod h1:5kakwfW5CjC9KK+Q4wjXAg+ShuIm2mBMua0ZFj2C8PE= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armsubscriptions v1.3.0 h1:wxQx2Bt4xzPIKvW59WQf1tJNx/ZZKPfN+EhPX3Z6CYY= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armsubscriptions v1.3.0/go.mod h1:TpiwjwnW/khS0LKs4vW5UmmT9OWcxaveS8U7+tlknzo= github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs= github.com/Azure/go-autorest v14.2.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSWATqVooLgysK6ZNox3g/xq24= github.com/Azure/go-autorest/autorest/to v0.4.1 h1:CxNHBqdzTr7rLtdrtb5CMjJcDut+WNGCVv7OmS5+lTc= diff --git a/pkg/bootstrapper/bootstrapper.go b/pkg/bootstrapper/bootstrapper.go index d719a4c..95c1a7b 100644 --- a/pkg/bootstrapper/bootstrapper.go +++ b/pkg/bootstrapper/bootstrapper.go @@ -15,6 +15,7 @@ import ( "go.goms.io/aks/AKSFlexNode/pkg/components/services" "go.goms.io/aks/AKSFlexNode/pkg/components/system_configuration" "go.goms.io/aks/AKSFlexNode/pkg/config" + "go.goms.io/aks/AKSFlexNode/pkg/privatecluster" ) // Bootstrapper executes bootstrap steps sequentially @@ -33,6 +34,7 @@ func New(cfg *config.Config, logger *logrus.Logger) *Bootstrapper { func (b *Bootstrapper) Bootstrap(ctx context.Context) (*ExecutionResult, error) { // Define the bootstrap steps in order - using modules directly steps := []Executor{ + privatecluster.NewInstaller(b.logger), // VPN/Gateway setup (if private cluster) arc.NewInstaller(b.logger), // Setup Arc services.NewUnInstaller(b.logger), // Stop kubelet before setup system_configuration.NewInstaller(b.logger), // Configure system (early) @@ -51,6 +53,7 @@ func (b *Bootstrapper) Bootstrap(ctx context.Context) (*ExecutionResult, error) // Unbootstrap executes all cleanup steps sequentially (in reverse order of bootstrap) func (b *Bootstrapper) Unbootstrap(ctx context.Context) (*ExecutionResult, error) { steps := []Executor{ + privatecluster.NewUninstaller(b.logger), // Node removal + VPN teardown (if private cluster) services.NewUnInstaller(b.logger), // Stop services first npd.NewUnInstaller(b.logger), // Uninstall Node Problem Detector kubelet.NewUnInstaller(b.logger), // Clean kubelet configuration diff --git a/pkg/config/structs.go b/pkg/config/structs.go index a17e1f8..0d46c88 100644 --- a/pkg/config/structs.go +++ b/pkg/config/structs.go @@ -55,8 +55,12 @@ type BootstrapTokenConfig struct { // TargetClusterConfig holds configuration for the target AKS cluster the ARC machine will connect to. type TargetClusterConfig struct { - ResourceID string `json:"resourceId"` // Full resource ID of the target AKS cluster - Location string `json:"location"` // Azure region of the cluster (e.g., "eastus", "westus2") + ResourceID string `json:"resourceId"` // Full resource ID of the target AKS cluster + Location string `json:"location"` // Azure region of the cluster (e.g., "eastus", "westus2") + IsPrivateCluster bool `json:"private" mapstructure:"private"` // Whether this is a private AKS cluster (requires Gateway/VPN setup) + GatewayVMSize string `json:"gatewayVMSize,omitempty" mapstructure:"gatewayVMSize"` // VPN Gateway VM size (defaults to "Standard_D2s_v3") + GatewayPort int `json:"gatewayPort,omitempty" mapstructure:"gatewayPort"` // VPN Gateway port (defaults to 51820) + CleanupMode string `json:"-"` // Runtime-only, set by CLI flag for unbootstrap Name string // will be populated from ResourceID ResourceGroup string // will be populated from ResourceID SubscriptionID string // will be populated from ResourceID diff --git a/pkg/privatecluster/README.md b/pkg/privatecluster/README.md new file mode 100644 index 0000000..0fb2a90 --- /dev/null +++ b/pkg/privatecluster/README.md @@ -0,0 +1,99 @@ +# Private AKS Cluster - Edge Node Join/Leave + +## Prerequisites + +### 1. Login to Azure CLI + +```bash +az login +``` + +> **Note:** When running the agent with `sudo`, use `sudo -E` to preserve your Azure CLI token. + +### 2. Create a Private AKS Cluster + +Create a Private AKS cluster with AAD and Azure RBAC enabled, and assign the required roles to your user. + +See: [create_private_cluster.md](create_private_cluster.md) + +### 3. Prepare Configuration File + +Create a `config.json` with `"private": true` in the `targetCluster` section: + +```json +{ + "azure": { + "subscriptionId": "", + "tenantId": "", + "targetCluster": { + "resourceId": "/subscriptions//resourceGroups//providers/Microsoft.ContainerService/managedClusters/", + "location": "eastus2", + "private": true + }, + "arc": { + "enabled": true, + "resourceGroup": "", + "location": "eastus2" + } + }, + "kubernetes": { + "version": "1.33.0" + }, + "containerd": { + "version": "1.7.11", + "pauseImage": "mcr.microsoft.com/oss/kubernetes/pause:3.6" + }, + "agent": { + "logLevel": "info", + "logDir": "/var/log/aks-flex-node" + } +} +``` + +## Join Private AKS Cluster + +### 1. Build the project + +```bash +go build -o aks-flex-node . +``` + +### 2. Join the cluster + +When the config has `"private": true`, the `agent` command automatically sets up the Gateway/VPN before bootstrapping: + +```bash +sudo -E ./aks-flex-node agent --config config.json +``` + +This will: +1. Detect private cluster from config +2. Set up Gateway VM and VPN tunnel (WireGuard) +3. Run normal bootstrap (Arc, containerd, kubelet, etc.) +4. Enter daemon mode for status monitoring + +### 3. Verify + +```bash +kubectl get nodes +``` + +## Leave Private AKS Cluster + +When the config has `"private": true`, the `unbootstrap` command automatically handles VPN/Gateway cleanup: + +```bash +sudo -E ./aks-flex-node unbootstrap --config config.json [--cleanup-mode ] +``` + +### Mode Comparison + +| Mode | Command | Description | +|------|---------|-------------| +| `local` (default) | `sudo -E ./aks-flex-node unbootstrap --config config.json` | Remove node and local VPN config, **keep Gateway** for other nodes | +| `full` | `sudo -E ./aks-flex-node unbootstrap --config config.json --cleanup-mode full` | Remove all components **including Gateway VM and Azure resources** | + +### When to use each mode + +- **`--cleanup-mode=local`** (default): Other nodes are still using the Gateway, or you plan to rejoin later +- **`--cleanup-mode=full`**: Last node leaving, clean up all Azure resources (Gateway VM, subnet, NSG, public IP) diff --git a/pkg/privatecluster/azure_client.go b/pkg/privatecluster/azure_client.go new file mode 100644 index 0000000..73655a2 --- /dev/null +++ b/pkg/privatecluster/azure_client.go @@ -0,0 +1,547 @@ +package privatecluster + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "path/filepath" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armsubscriptions" + "github.com/sirupsen/logrus" +) + +// AzureClient provides Azure operations using the Azure SDK for Go. +type AzureClient struct { + logger *logrus.Logger + subscriptionID string + + vmClient *armcompute.VirtualMachinesClient + subnetClient *armnetwork.SubnetsClient + nsgClient *armnetwork.SecurityGroupsClient + pipClient *armnetwork.PublicIPAddressesClient + nicClient *armnetwork.InterfacesClient + aksClient *armcontainerservice.ManagedClustersClient + subscriptionClient *armsubscriptions.Client +} + +// NewAzureClient creates a new AzureClient with all sub-clients initialized. +func NewAzureClient(cred azcore.TokenCredential, subscriptionID string, logger *logrus.Logger) (*AzureClient, error) { + c := &AzureClient{ + logger: logger, + subscriptionID: subscriptionID, + } + + var err error + + if c.vmClient, err = armcompute.NewVirtualMachinesClient(subscriptionID, cred, nil); err != nil { + return nil, fmt.Errorf("failed to create VM client: %w", err) + } + if c.subnetClient, err = armnetwork.NewSubnetsClient(subscriptionID, cred, nil); err != nil { + return nil, fmt.Errorf("failed to create subnet client: %w", err) + } + if c.nsgClient, err = armnetwork.NewSecurityGroupsClient(subscriptionID, cred, nil); err != nil { + return nil, fmt.Errorf("failed to create NSG client: %w", err) + } + if c.pipClient, err = armnetwork.NewPublicIPAddressesClient(subscriptionID, cred, nil); err != nil { + return nil, fmt.Errorf("failed to create public IP client: %w", err) + } + if c.nicClient, err = armnetwork.NewInterfacesClient(subscriptionID, cred, nil); err != nil { + return nil, fmt.Errorf("failed to create NIC client: %w", err) + } + if c.aksClient, err = armcontainerservice.NewManagedClustersClient(subscriptionID, cred, nil); err != nil { + return nil, fmt.Errorf("failed to create AKS client: %w", err) + } + if c.subscriptionClient, err = armsubscriptions.NewClient(cred, nil); err != nil { + return nil, fmt.Errorf("failed to create subscription client: %w", err) + } + + return c, nil +} + +// GetTenantID returns the tenant ID for the configured subscription. +func (c *AzureClient) GetTenantID(ctx context.Context) (string, error) { + resp, err := c.subscriptionClient.Get(ctx, c.subscriptionID, nil) + if err != nil { + return "", fmt.Errorf("failed to get subscription info: %w", err) + } + if resp.TenantID == nil { + return "", fmt.Errorf("tenant ID not found for subscription %s", c.subscriptionID) + } + return *resp.TenantID, nil +} + +// AKSClusterExists checks if an AKS cluster exists. +func (c *AzureClient) AKSClusterExists(ctx context.Context, resourceGroup, clusterName string) bool { + _, err := c.aksClient.Get(ctx, resourceGroup, clusterName, nil) + return err == nil +} + +// GetPrivateClusterInfo retrieves cluster info needed for private cluster setup (location, node resource group, private FQDN) +// and validates that AAD and Azure RBAC are enabled. +func (c *AzureClient) GetPrivateClusterInfo(ctx context.Context, resourceGroup, clusterName string) (*AKSClusterInfo, error) { + resp, err := c.aksClient.Get(ctx, resourceGroup, clusterName, nil) + if err != nil { + return nil, fmt.Errorf("failed to get AKS cluster: %w", err) + } + + cluster := resp.ManagedCluster + props := cluster.Properties + if props == nil { + return nil, fmt.Errorf("AKS cluster properties are nil") + } + + if props.AADProfile == nil || props.AADProfile.Managed == nil || !*props.AADProfile.Managed { + return nil, fmt.Errorf("AKS cluster AAD not enabled, please enable: az aks update --enable-aad") + } + + if props.AADProfile.EnableAzureRBAC == nil || !*props.AADProfile.EnableAzureRBAC { + return nil, fmt.Errorf("AKS cluster Azure RBAC not enabled, please enable: az aks update --enable-azure-rbac") + } + + info := &AKSClusterInfo{ + ResourceGroup: resourceGroup, + ClusterName: clusterName, + } + + if cluster.Location != nil { + info.Location = *cluster.Location + } + if props.NodeResourceGroup != nil { + info.NodeResourceGroup = *props.NodeResourceGroup + } + if props.PrivateFQDN != nil { + info.PrivateFQDN = *props.PrivateFQDN + } + + // Extract VNet info from agent pool subnet ID + for _, pool := range props.AgentPoolProfiles { + if pool.VnetSubnetID != nil && *pool.VnetSubnetID != "" { + // Format: /subscriptions/{sub}/resourceGroups/{rg}/providers/Microsoft.Network/virtualNetworks/{vnet}/subnets/{subnet} + parts := strings.Split(*pool.VnetSubnetID, "/") + if len(parts) >= 9 { + info.VNetResourceGroup = parts[4] + info.VNetName = parts[8] + } + break + } + } + + return info, nil +} + +// VMExists checks if a VM exists. +func (c *AzureClient) VMExists(ctx context.Context, resourceGroup, vmName string) bool { + _, err := c.vmClient.Get(ctx, resourceGroup, vmName, nil) + return err == nil +} + +// GetVMPublicIP retrieves a VM's public IP address by tracing VM → NIC → PIP. +func (c *AzureClient) GetVMPublicIP(ctx context.Context, resourceGroup, vmName string) (string, error) { + vmResp, err := c.vmClient.Get(ctx, resourceGroup, vmName, nil) + if err != nil { + return "", fmt.Errorf("failed to get VM: %w", err) + } + if vmResp.Properties == nil || vmResp.Properties.NetworkProfile == nil || + len(vmResp.Properties.NetworkProfile.NetworkInterfaces) == 0 { + return "", fmt.Errorf("VM has no network interfaces") + } + + nicID := vmResp.Properties.NetworkProfile.NetworkInterfaces[0].ID + if nicID == nil { + return "", fmt.Errorf("NIC ID is nil") + } + nicRG, nicName := parseResourceGroupAndName(*nicID) + + nicResp, err := c.nicClient.Get(ctx, nicRG, nicName, nil) + if err != nil { + return "", fmt.Errorf("failed to get NIC: %w", err) + } + if nicResp.Properties == nil || len(nicResp.Properties.IPConfigurations) == 0 { + return "", fmt.Errorf("NIC has no IP configurations") + } + + ipConfig := nicResp.Properties.IPConfigurations[0] + if ipConfig.Properties == nil || ipConfig.Properties.PublicIPAddress == nil || ipConfig.Properties.PublicIPAddress.ID == nil { + return "", fmt.Errorf("NIC has no public IP") + } + pipRG, pipName := parseResourceGroupAndName(*ipConfig.Properties.PublicIPAddress.ID) + + pipResp, err := c.pipClient.Get(ctx, pipRG, pipName, nil) + if err != nil { + return "", fmt.Errorf("failed to get public IP: %w", err) + } + if pipResp.Properties == nil || pipResp.Properties.IPAddress == nil { + return "", fmt.Errorf("public IP address is not allocated") + } + return *pipResp.Properties.IPAddress, nil +} + +// CreateSubnet creates a subnet in a VNet. +func (c *AzureClient) CreateSubnet(ctx context.Context, vnetRG, vnetName, subnetName, addressPrefix string) error { + _, err := c.subnetClient.Get(ctx, vnetRG, vnetName, subnetName, nil) + if err == nil { + c.logger.Infof("Subnet %s already exists", subnetName) + return nil + } + + poller, err := c.subnetClient.BeginCreateOrUpdate(ctx, vnetRG, vnetName, subnetName, armnetwork.Subnet{ + Properties: &armnetwork.SubnetPropertiesFormat{ + AddressPrefix: ptr(addressPrefix), + }, + }, nil) + if err != nil { + return fmt.Errorf("failed to create subnet: %w", err) + } + if _, err = poller.PollUntilDone(ctx, nil); err != nil { + return fmt.Errorf("failed to create subnet: %w", err) + } + return nil +} + +// CreateNSG creates a network security group with SSH and VPN rules. +func (c *AzureClient) CreateNSG(ctx context.Context, resourceGroup, nsgName, location string, vpnPort int) error { + _, err := c.nsgClient.Get(ctx, resourceGroup, nsgName, nil) + if err == nil { + c.logger.Infof("NSG %s already exists", nsgName) + return nil + } + + nsg := armnetwork.SecurityGroup{ + Location: ptr(location), + Properties: &armnetwork.SecurityGroupPropertiesFormat{ + SecurityRules: []*armnetwork.SecurityRule{ + { + Name: ptr("allow-ssh"), + Properties: &armnetwork.SecurityRulePropertiesFormat{ + Priority: ptr[int32](100), + Protocol: ptr(armnetwork.SecurityRuleProtocolTCP), + Access: ptr(armnetwork.SecurityRuleAccessAllow), + Direction: ptr(armnetwork.SecurityRuleDirectionInbound), + SourceAddressPrefix: ptr("*"), + SourcePortRange: ptr("*"), + DestinationAddressPrefix: ptr("*"), + DestinationPortRanges: []*string{ptr("22")}, + }, + }, + { + Name: ptr("allow-vpn"), + Properties: &armnetwork.SecurityRulePropertiesFormat{ + Priority: ptr[int32](200), + Protocol: ptr(armnetwork.SecurityRuleProtocolUDP), + Access: ptr(armnetwork.SecurityRuleAccessAllow), + Direction: ptr(armnetwork.SecurityRuleDirectionInbound), + SourceAddressPrefix: ptr("*"), + SourcePortRange: ptr("*"), + DestinationAddressPrefix: ptr("*"), + DestinationPortRanges: []*string{ptr(fmt.Sprintf("%d", vpnPort))}, + }, + }, + }, + }, + } + + poller, err := c.nsgClient.BeginCreateOrUpdate(ctx, resourceGroup, nsgName, nsg, nil) + if err != nil { + return fmt.Errorf("failed to create NSG: %w", err) + } + if _, err = poller.PollUntilDone(ctx, nil); err != nil { + return fmt.Errorf("failed to create NSG: %w", err) + } + return nil +} + +// CreatePublicIP creates a static public IP address. +func (c *AzureClient) CreatePublicIP(ctx context.Context, resourceGroup, pipName, location string) error { + _, err := c.pipClient.Get(ctx, resourceGroup, pipName, nil) + if err == nil { + c.logger.Infof("Public IP %s already exists", pipName) + return nil + } + + poller, err := c.pipClient.BeginCreateOrUpdate(ctx, resourceGroup, pipName, armnetwork.PublicIPAddress{ + Location: ptr(location), + SKU: &armnetwork.PublicIPAddressSKU{ + Name: ptr(armnetwork.PublicIPAddressSKUNameStandard), + }, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAllocationMethod: ptr(armnetwork.IPAllocationMethodStatic), + }, + }, nil) + if err != nil { + return fmt.Errorf("failed to create public IP: %w", err) + } + if _, err = poller.PollUntilDone(ctx, nil); err != nil { + return fmt.Errorf("failed to create public IP: %w", err) + } + return nil +} + +// GetPublicIPAddress retrieves a public IP address value. +func (c *AzureClient) GetPublicIPAddress(ctx context.Context, resourceGroup, pipName string) (string, error) { + resp, err := c.pipClient.Get(ctx, resourceGroup, pipName, nil) + if err != nil { + return "", fmt.Errorf("failed to get public IP: %w", err) + } + if resp.Properties == nil || resp.Properties.IPAddress == nil { + return "", fmt.Errorf("public IP address is not allocated") + } + return *resp.Properties.IPAddress, nil +} + +// CreateVM creates a NIC and VM with the specified configuration. +func (c *AzureClient) CreateVM(ctx context.Context, resourceGroup, vmName, location, vnetRG, vnetName, subnetName, nsgName, pipName, sshKeyPath, vmSize string) error { + pubKeyData, err := ReadFileContent(sshKeyPath + ".pub") + if err != nil { + return fmt.Errorf("failed to read SSH public key: %w", err) + } + pubKey := strings.TrimSpace(pubKeyData) + + subnetID := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/virtualNetworks/%s/subnets/%s", + c.subscriptionID, vnetRG, vnetName, subnetName) + nsgID := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/securityGroups/%s", + c.subscriptionID, resourceGroup, nsgName) + pipID := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/publicIPAddresses/%s", + c.subscriptionID, resourceGroup, pipName) + + nicName := vmName + "VMNic" + nicPoller, err := c.nicClient.BeginCreateOrUpdate(ctx, resourceGroup, nicName, armnetwork.Interface{ + Location: ptr(location), + Properties: &armnetwork.InterfacePropertiesFormat{ + NetworkSecurityGroup: &armnetwork.SecurityGroup{ + ID: ptr(nsgID), + }, + IPConfigurations: []*armnetwork.InterfaceIPConfiguration{ + { + Name: ptr("ipconfig1"), + Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{ + Subnet: &armnetwork.Subnet{ + ID: ptr(subnetID), + }, + PublicIPAddress: &armnetwork.PublicIPAddress{ + ID: ptr(pipID), + }, + PrivateIPAllocationMethod: ptr(armnetwork.IPAllocationMethodDynamic), + }, + }, + }, + }, + }, nil) + if err != nil { + return fmt.Errorf("failed to create NIC: %w", err) + } + nicResp, err := nicPoller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to create NIC: %w", err) + } + + vm := armcompute.VirtualMachine{ + Location: ptr(location), + Zones: []*string{ptr("1")}, + Properties: &armcompute.VirtualMachineProperties{ + HardwareProfile: &armcompute.HardwareProfile{ + VMSize: ptr(armcompute.VirtualMachineSizeTypes(vmSize)), + }, + StorageProfile: &armcompute.StorageProfile{ + ImageReference: &armcompute.ImageReference{ + Publisher: ptr("Canonical"), + Offer: ptr("0001-com-ubuntu-server-jammy"), + SKU: ptr("22_04-lts-gen2"), + Version: ptr("latest"), + }, + OSDisk: &armcompute.OSDisk{ + CreateOption: ptr(armcompute.DiskCreateOptionTypesFromImage), + DeleteOption: ptr(armcompute.DiskDeleteOptionTypesDelete), + ManagedDisk: &armcompute.ManagedDiskParameters{ + StorageAccountType: ptr(armcompute.StorageAccountTypesPremiumLRS), + }, + }, + }, + OSProfile: &armcompute.OSProfile{ + ComputerName: ptr(vmName), + AdminUsername: ptr("azureuser"), + LinuxConfiguration: &armcompute.LinuxConfiguration{ + DisablePasswordAuthentication: ptr(true), + SSH: &armcompute.SSHConfiguration{ + PublicKeys: []*armcompute.SSHPublicKey{ + { + Path: ptr("/home/azureuser/.ssh/authorized_keys"), + KeyData: ptr(pubKey), + }, + }, + }, + }, + }, + NetworkProfile: &armcompute.NetworkProfile{ + NetworkInterfaces: []*armcompute.NetworkInterfaceReference{ + { + ID: nicResp.ID, + Properties: &armcompute.NetworkInterfaceReferenceProperties{ + Primary: ptr(true), + DeleteOption: ptr(armcompute.DeleteOptionsDelete), + }, + }, + }, + }, + }, + } + + vmPoller, err := c.vmClient.BeginCreateOrUpdate(ctx, resourceGroup, vmName, vm, nil) + if err != nil { + return fmt.Errorf("failed to create VM: %w", err) + } + if _, err = vmPoller.PollUntilDone(ctx, nil); err != nil { + return fmt.Errorf("failed to create VM: %w", err) + } + return nil +} + +// AddSSHKeyToVM adds an SSH key to a VM using RunCommand. +func (c *AzureClient) AddSSHKeyToVM(ctx context.Context, resourceGroup, vmName, sshKeyPath string) error { + pubKey, err := ReadFileContent(sshKeyPath + ".pub") + if err != nil { + return fmt.Errorf("failed to read SSH public key: %w", err) + } + + script := fmt.Sprintf( + "mkdir -p /home/azureuser/.ssh && echo '%s' >> /home/azureuser/.ssh/authorized_keys && "+ + "sort -u -o /home/azureuser/.ssh/authorized_keys /home/azureuser/.ssh/authorized_keys && "+ + "chown -R azureuser:azureuser /home/azureuser/.ssh && "+ + "chmod 700 /home/azureuser/.ssh && chmod 600 /home/azureuser/.ssh/authorized_keys", + strings.TrimSpace(pubKey)) + + poller, err := c.vmClient.BeginRunCommand(ctx, resourceGroup, vmName, armcompute.RunCommandInput{ + CommandID: ptr("RunShellScript"), + Script: []*string{ptr(script)}, + }, nil) + if err != nil { + return fmt.Errorf("failed to run SSH key command: %w", err) + } + if _, err = poller.PollUntilDone(ctx, nil); err != nil { + return fmt.Errorf("failed to add SSH key to VM: %w", err) + } + return nil +} + +// RestartVM restarts a VM. +func (c *AzureClient) RestartVM(ctx context.Context, resourceGroup, vmName string) error { + poller, err := c.vmClient.BeginRestart(ctx, resourceGroup, vmName, nil) + if err != nil { + return fmt.Errorf("failed to restart VM: %w", err) + } + if _, err = poller.PollUntilDone(ctx, nil); err != nil { + return fmt.Errorf("failed to restart VM: %w", err) + } + return nil +} + +// DeleteVM deletes a VM if it exists. +func (c *AzureClient) DeleteVM(ctx context.Context, resourceGroup, vmName string) error { + if !c.VMExists(ctx, resourceGroup, vmName) { + return nil + } + forceDeletion := true + poller, err := c.vmClient.BeginDelete(ctx, resourceGroup, vmName, &armcompute.VirtualMachinesClientBeginDeleteOptions{ + ForceDeletion: &forceDeletion, + }) + if err != nil { + if isNotFoundError(err) { + return nil + } + return fmt.Errorf("failed to delete VM: %w", err) + } + if _, err = poller.PollUntilDone(ctx, nil); err != nil { + return fmt.Errorf("failed to delete VM: %w", err) + } + return nil +} + +// DeletePublicIP deletes a public IP address if it exists. +func (c *AzureClient) DeletePublicIP(ctx context.Context, resourceGroup, pipName string) error { + poller, err := c.pipClient.BeginDelete(ctx, resourceGroup, pipName, nil) + if err != nil { + if isNotFoundError(err) { + return nil + } + return fmt.Errorf("failed to delete public IP: %w", err) + } + if _, err = poller.PollUntilDone(ctx, nil); err != nil { + return fmt.Errorf("failed to delete public IP: %w", err) + } + return nil +} + +// DeleteNSG deletes a network security group if it exists. +func (c *AzureClient) DeleteNSG(ctx context.Context, resourceGroup, nsgName string) error { + poller, err := c.nsgClient.BeginDelete(ctx, resourceGroup, nsgName, nil) + if err != nil { + if isNotFoundError(err) { + return nil + } + return fmt.Errorf("failed to delete NSG: %w", err) + } + if _, err = poller.PollUntilDone(ctx, nil); err != nil { + return fmt.Errorf("failed to delete NSG: %w", err) + } + return nil +} + +// DeleteSubnet deletes a subnet (errors are ignored). +func (c *AzureClient) DeleteSubnet(ctx context.Context, vnetRG, vnetName, subnetName string) error { + poller, err := c.subnetClient.BeginDelete(ctx, vnetRG, vnetName, subnetName, nil) + if err != nil { + return nil // Ignore errors + } + _, _ = poller.PollUntilDone(ctx, nil) + return nil +} + +// GetAKSCredentials gets AKS cluster credentials and writes the kubeconfig to the specified path. +func (c *AzureClient) GetAKSCredentials(ctx context.Context, resourceGroup, clusterName, kubeconfigPath string) error { + resp, err := c.aksClient.ListClusterUserCredentials(ctx, resourceGroup, clusterName, nil) + if err != nil { + return fmt.Errorf("failed to get AKS credentials: %w", err) + } + if len(resp.Kubeconfigs) == 0 || resp.Kubeconfigs[0].Value == nil { + return fmt.Errorf("no kubeconfig returned for cluster %s", clusterName) + } + + if err := EnsureDirectory(filepath.Dir(kubeconfigPath)); err != nil { + return fmt.Errorf("failed to create kubeconfig directory: %w", err) + } + if err := os.WriteFile(kubeconfigPath, resp.Kubeconfigs[0].Value, 0600); err != nil { + return fmt.Errorf("failed to write kubeconfig: %w", err) + } + return nil +} + +// ptr returns a pointer to the given value. +func ptr[T any](v T) *T { + return &v +} + +// isNotFoundError checks if an error is a 404 Not Found response. +func isNotFoundError(err error) bool { + var respErr *azcore.ResponseError + return errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound +} + +// parseResourceGroupAndName extracts resource group and resource name from an Azure resource ID. +func parseResourceGroupAndName(resourceID string) (resourceGroup, name string) { + parts := strings.Split(resourceID, "/") + for i, part := range parts { + if strings.EqualFold(part, "resourceGroups") && i+1 < len(parts) { + resourceGroup = parts[i+1] + } + } + if len(parts) > 0 { + name = parts[len(parts)-1] + } + return +} diff --git a/pkg/privatecluster/create_private_cluster.md b/pkg/privatecluster/create_private_cluster.md new file mode 100644 index 0000000..2358c03 --- /dev/null +++ b/pkg/privatecluster/create_private_cluster.md @@ -0,0 +1,165 @@ +# Create Private AKS Cluster + +This guide shows how to create a Private AKS Cluster with AAD and Azure RBAC enabled for edge node testing. + +## Prerequisites + +### 1. Login to Azure CLI + +```bash +az login +``` + +### 2. Set variables + +```bash +# Required +CLUSTER_NAME="my-private-aks" +RESOURCE_GROUP="my-rg" +LOCATION="eastus2" + +# Optional (defaults) +VNET_NAME="${CLUSTER_NAME}-vnet" +VNET_CIDR="10.224.0.0/12" +SUBNET_NAME="aks-subnet" +SUBNET_CIDR="10.224.0.0/16" +NODE_COUNT=1 +NODE_VM_SIZE="Standard_D2s_v3" +``` + +## Step 1: Create Resource Group + +```bash +az group create \ + --name "$RESOURCE_GROUP" \ + --location "$LOCATION" +``` + +## Step 2: Create VNet and Subnet + +```bash +# Create VNet +az network vnet create \ + --resource-group "$RESOURCE_GROUP" \ + --name "$VNET_NAME" \ + --address-prefix "$VNET_CIDR" + +# Create Subnet +az network vnet subnet create \ + --resource-group "$RESOURCE_GROUP" \ + --vnet-name "$VNET_NAME" \ + --name "$SUBNET_NAME" \ + --address-prefix "$SUBNET_CIDR" +``` + +## Step 3: Create Private AKS Cluster + +```bash +# Get Subnet ID +SUBNET_ID=$(az network vnet subnet show \ + --resource-group "$RESOURCE_GROUP" \ + --vnet-name "$VNET_NAME" \ + --name "$SUBNET_NAME" \ + --query id -o tsv) + +# Create Private AKS Cluster +az aks create \ + --resource-group "$RESOURCE_GROUP" \ + --name "$CLUSTER_NAME" \ + --location "$LOCATION" \ + --node-count "$NODE_COUNT" \ + --node-vm-size "$NODE_VM_SIZE" \ + --network-plugin azure \ + --vnet-subnet-id "$SUBNET_ID" \ + --enable-private-cluster \ + --enable-aad \ + --enable-azure-rbac \ + --generate-ssh-keys +``` + +> **Note:** This may take 5-10 minutes. + +## Step 4: Assign RBAC Roles to Current User + +The current user needs two roles to manage the cluster: + +| Role | Purpose | +|------|---------| +| Azure Kubernetes Service Cluster Admin Role | Get kubectl credentials | +| Azure Kubernetes Service RBAC Cluster Admin | Perform cluster operations | + +```bash +# Get current user's Object ID +USER_OBJECT_ID=$(az ad signed-in-user show --query id -o tsv) + +# Get AKS Resource ID +AKS_RESOURCE_ID=$(az aks show \ + --resource-group "$RESOURCE_GROUP" \ + --name "$CLUSTER_NAME" \ + --query id -o tsv) + +# Assign Role 1: Azure Kubernetes Service Cluster Admin Role +az role assignment create \ + --assignee "$USER_OBJECT_ID" \ + --role "Azure Kubernetes Service Cluster Admin Role" \ + --scope "$AKS_RESOURCE_ID" + +# Assign Role 2: Azure Kubernetes Service RBAC Cluster Admin +az role assignment create \ + --assignee "$USER_OBJECT_ID" \ + --role "Azure Kubernetes Service RBAC Cluster Admin" \ + --scope "$AKS_RESOURCE_ID" +``` + +## Step 5: Get Kubectl Credentials + +```bash +# Create kubeconfig directory +sudo mkdir -p /root/.kube + +# Get credentials (use sudo -E to preserve Azure CLI token) +sudo -E az aks get-credentials \ + --resource-group "$RESOURCE_GROUP" \ + --name "$CLUSTER_NAME" \ + --overwrite-existing \ + --file /root/.kube/config + +# Convert kubeconfig for Azure CLI auth +sudo -E kubelogin convert-kubeconfig -l azurecli --kubeconfig /root/.kube/config +``` + +## Step 6: Get Cluster Resource ID + +Save this for use in the `config.json` file's `targetCluster.resourceId` field: + +```bash +az aks show \ + --resource-group "$RESOURCE_GROUP" \ + --name "$CLUSTER_NAME" \ + --query id -o tsv +``` + +Example output: +``` +/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourcegroups/my-rg/providers/Microsoft.ContainerService/managedClusters/my-private-aks +``` + +## Next Steps + +### Join an edge node to the private cluster + +Set `"private": true` in your `config.json`, then run: + +```bash +sudo -E ./aks-flex-node agent --config config.json +``` + +### Leave the private cluster + +```bash +# Local cleanup (keep Gateway for other nodes) +sudo -E ./aks-flex-node unbootstrap --config config.json + +# Full cleanup (remove Gateway and all Azure resources) +sudo -E ./aks-flex-node unbootstrap --config config.json --cleanup-mode full +``` diff --git a/pkg/privatecluster/installer.go b/pkg/privatecluster/installer.go new file mode 100644 index 0000000..c9d6789 --- /dev/null +++ b/pkg/privatecluster/installer.go @@ -0,0 +1,381 @@ +package privatecluster + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/sirupsen/logrus" + + "go.goms.io/aks/AKSFlexNode/pkg/auth" + "go.goms.io/aks/AKSFlexNode/pkg/config" + "go.goms.io/aks/AKSFlexNode/pkg/utils" +) + +// Installer handles private cluster VPN/Gateway setup, implementing bootstrapper.StepExecutor. +type Installer struct { + config *config.Config + logger *logrus.Logger + authProvider *auth.AuthProvider + azureClient *AzureClient + toolInstaller *ToolInstaller + + clusterInfo *AKSClusterInfo + vpnConfig VPNConfig + sshKeyPath string + gatewayIP string +} + +// NewInstaller creates a new private cluster Installer. +func NewInstaller(logger *logrus.Logger) *Installer { + return &Installer{ + config: config.GetConfig(), + logger: logger, + authProvider: auth.NewAuthProvider(), + toolInstaller: NewToolInstaller(logger), + vpnConfig: DefaultVPNConfig(), + sshKeyPath: GetSSHKeyPath(), + } +} + +// GetName returns the step name. +func (i *Installer) GetName() string { + return "PrivateClusterInstall" +} + +// Validate checks prerequisites for private cluster installation. +func (i *Installer) Validate(ctx context.Context) error { + if !i.isPrivateCluster() { + return nil + } + if os.Getuid() != 0 { + return fmt.Errorf("private cluster setup requires root privileges, please run with 'sudo'") + } + return nil +} + +// IsCompleted returns true for non-private clusters or when VPN is already connected. +func (i *Installer) IsCompleted(ctx context.Context) bool { + if !i.isPrivateCluster() { + return true + } + vpnClient := NewVPNClient(i.vpnConfig, i.logger) + return vpnClient.TestConnection(ctx) +} + +// Execute runs the private cluster installation (Gateway/VPN setup). +func (i *Installer) Execute(ctx context.Context) error { + if !i.isPrivateCluster() { + return nil + } + + i.logger.Infof("========================================") + i.logger.Infof(" Add Edge Node to Private AKS Cluster") + i.logger.Infof("========================================") + + cred, err := i.authProvider.UserCredential(i.config) + if err != nil { + return fmt.Errorf("failed to get Azure credential: %w", err) + } + + subscriptionID := i.config.GetTargetClusterSubscriptionID() + azureClient, err := NewAzureClient(cred, subscriptionID, i.logger) + if err != nil { + return fmt.Errorf("failed to create Azure client: %w", err) + } + i.azureClient = azureClient + + i.clusterInfo = &AKSClusterInfo{ + ResourceID: i.config.GetTargetClusterID(), + SubscriptionID: subscriptionID, + ResourceGroup: i.config.GetTargetClusterResourceGroup(), + ClusterName: i.config.GetTargetClusterName(), + } + + if err := i.checkEnvironment(ctx); err != nil { + return fmt.Errorf("environment check failed: %w", err) + } + if err := i.setupGateway(ctx); err != nil { + return fmt.Errorf("gateway setup failed: %w", err) + } + if err := i.setupVPNClient(ctx); err != nil { + return fmt.Errorf("client setup failed: %w", err) + } + if err := i.joinNode(ctx); err != nil { + return fmt.Errorf("node join failed: %w", err) + } + + i.logger.Infof("Private cluster setup completed. Bootstrap will continue...") + return nil +} + +// isPrivateCluster checks if the config indicates a private cluster. +func (i *Installer) isPrivateCluster() bool { + return i.config != nil && + i.config.Azure.TargetCluster != nil && + i.config.Azure.TargetCluster.IsPrivateCluster +} + +// gatewayConfig returns the Gateway configuration, applying any overrides from config. +func (i *Installer) gatewayConfig() GatewayConfig { + gw := DefaultGatewayConfig() + if i.config.Azure.TargetCluster.GatewayVMSize != "" { + gw.VMSize = i.config.Azure.TargetCluster.GatewayVMSize + } + if i.config.Azure.TargetCluster.GatewayPort > 0 { + gw.Port = i.config.Azure.TargetCluster.GatewayPort + } + return gw +} + +// checkEnvironment checks prerequisites for private cluster setup. +func (i *Installer) checkEnvironment(ctx context.Context) error { + _ = CleanKubeCache() + i.logger.Infof("Azure SDK client ready") + i.logger.Infof("Subscription: %s", i.clusterInfo.SubscriptionID) + + tenantID, err := i.azureClient.GetTenantID(ctx) + if err != nil { + return err + } + i.clusterInfo.TenantID = tenantID + i.logger.Debugf("Tenant ID: %s", tenantID) + + if !i.azureClient.AKSClusterExists(ctx, i.clusterInfo.ResourceGroup, i.clusterInfo.ClusterName) { + return fmt.Errorf("AKS cluster '%s' not found", i.clusterInfo.ClusterName) + } + clusterInfo, err := i.azureClient.GetPrivateClusterInfo(ctx, i.clusterInfo.ResourceGroup, i.clusterInfo.ClusterName) + if err != nil { + return err + } + i.clusterInfo.Location = clusterInfo.Location + i.clusterInfo.NodeResourceGroup = clusterInfo.NodeResourceGroup + i.clusterInfo.PrivateFQDN = clusterInfo.PrivateFQDN + i.clusterInfo.VNetName = clusterInfo.VNetName + i.clusterInfo.VNetResourceGroup = clusterInfo.VNetResourceGroup + i.logger.Infof("AKS cluster: %s (AAD/RBAC enabled)", i.clusterInfo.ClusterName) + i.logger.Infof("VNet: %s/%s", i.clusterInfo.VNetResourceGroup, i.clusterInfo.VNetName) + + if err := InstallVPNTools(ctx, i.logger); err != nil { + return fmt.Errorf("failed to install VPN tools: %w", err) + } + if err := i.toolInstaller.InstallKubectl(ctx, i.config.GetKubernetesVersion()); err != nil { + return fmt.Errorf("failed to install kubectl: %w", err) + } + if err := i.toolInstaller.InstallKubelogin(ctx); err != nil { + return fmt.Errorf("failed to install kubelogin: %w", err) + } + i.logger.Infof("Dependencies ready") + + return nil +} + +// setupGateway sets up the VPN Gateway. +func (i *Installer) setupGateway(ctx context.Context) error { + gateway := i.gatewayConfig() + gatewayExists := false + if i.azureClient.VMExists(ctx, i.clusterInfo.ResourceGroup, gateway.Name) { + gatewayExists = true + ip, err := i.azureClient.GetVMPublicIP(ctx, i.clusterInfo.ResourceGroup, gateway.Name) + if err != nil { + return fmt.Errorf("failed to get Gateway public IP: %w", err) + } + i.gatewayIP = ip + i.logger.Infof("Gateway exists: %s", i.gatewayIP) + } else { + i.logger.Infof("Creating Gateway...") + if err := i.createGatewayInfrastructure(ctx); err != nil { + return err + } + } + + if err := GenerateSSHKey(i.sshKeyPath); err != nil { + return fmt.Errorf("failed to generate SSH key: %w", err) + } + if err := i.azureClient.AddSSHKeyToVM(ctx, i.clusterInfo.ResourceGroup, gateway.Name, i.sshKeyPath); err != nil { + return fmt.Errorf("failed to add SSH key to Gateway: %w", err) + } + + if err := i.waitForVMReady(ctx, gatewayExists); err != nil { + return err + } + + if err := i.configureVPNServer(ctx); err != nil { + return err + } + + return nil +} + +// createGatewayInfrastructure creates Gateway VM and related resources. +func (i *Installer) createGatewayInfrastructure(ctx context.Context) error { + gateway := i.gatewayConfig() + nsgName := gateway.Name + "-nsg" + pipName := gateway.Name + "-pip" + location := i.clusterInfo.Location + + if err := i.azureClient.CreateSubnet(ctx, i.clusterInfo.VNetResourceGroup, i.clusterInfo.VNetName, + gateway.SubnetName, gateway.SubnetPrefix); err != nil { + return fmt.Errorf("failed to create subnet: %w", err) + } + if err := i.azureClient.CreateNSG(ctx, i.clusterInfo.ResourceGroup, nsgName, location, gateway.Port); err != nil { + return fmt.Errorf("failed to create NSG: %w", err) + } + if err := i.azureClient.CreatePublicIP(ctx, i.clusterInfo.ResourceGroup, pipName, location); err != nil { + return fmt.Errorf("failed to create public IP: %w", err) + } + if err := GenerateSSHKey(i.sshKeyPath); err != nil { + return fmt.Errorf("failed to generate SSH key: %w", err) + } + if err := i.azureClient.CreateVM(ctx, i.clusterInfo.ResourceGroup, gateway.Name, + location, i.clusterInfo.VNetResourceGroup, i.clusterInfo.VNetName, + gateway.SubnetName, nsgName, pipName, + i.sshKeyPath, gateway.VMSize); err != nil { + return fmt.Errorf("failed to create Gateway VM: %w", err) + } + + ip, err := i.azureClient.GetPublicIPAddress(ctx, i.clusterInfo.ResourceGroup, pipName) + if err != nil { + return fmt.Errorf("failed to get public IP address: %w", err) + } + i.gatewayIP = ip + i.logger.Infof("Gateway created: %s", i.gatewayIP) + + i.logger.Infof("Waiting for VM to boot (120s)...") + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(120 * time.Second): + } + + return nil +} + +// waitForVMReady waits for SSH connectivity to Gateway. +func (i *Installer) waitForVMReady(ctx context.Context, gatewayExists bool) error { + sshConfig := DefaultSSHConfig(i.sshKeyPath, i.gatewayIP) + ssh := NewSSHClient(sshConfig, i.logger) + + if ssh.TestConnection(ctx) { + i.logger.Infof("SSH ready") + return nil + } + + if gatewayExists { + gateway := i.gatewayConfig() + i.logger.Infof("Restarting VM...") + _ = i.azureClient.RestartVM(ctx, i.clusterInfo.ResourceGroup, gateway.Name) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(120 * time.Second): + } + } + + if err := ssh.WaitForConnection(ctx, 18, 10*time.Second); err != nil { + return fmt.Errorf("VM SSH connection timeout") + } + i.logger.Infof("SSH ready") + return nil +} + +// configureVPNServer configures VPN on the Gateway. +func (i *Installer) configureVPNServer(ctx context.Context) error { + sshConfig := DefaultSSHConfig(i.sshKeyPath, i.gatewayIP) + ssh := NewSSHClient(sshConfig, i.logger) + vpnServer := NewVPNServerManager(ssh, i.logger) + + if !vpnServer.IsInstalled(ctx) { + i.logger.Infof("Installing VPN on Gateway...") + if err := vpnServer.Install(ctx); err != nil { + return fmt.Errorf("failed to install VPN on Gateway: %w", err) + } + } + + serverPubKey, err := vpnServer.GetPublicKey(ctx) + if err != nil { + if err := vpnServer.Install(ctx); err != nil { + return err + } + serverPubKey, err = vpnServer.GetPublicKey(ctx) + if err != nil { + return err + } + } + i.vpnConfig.ServerPublicKey = serverPubKey + i.vpnConfig.ServerEndpoint = i.gatewayIP + + peerCount, _ := vpnServer.GetPeerCount(ctx) + i.vpnConfig.ClientVPNIP = fmt.Sprintf("172.16.0.%d", peerCount+2) + i.logger.Infof("VPN server ready, client IP: %s", i.vpnConfig.ClientVPNIP) + + return nil +} + +// setupVPNClient configures the local VPN client. +func (i *Installer) setupVPNClient(ctx context.Context) error { + gateway := i.gatewayConfig() + vpnClient := NewVPNClient(i.vpnConfig, i.logger) + privateKey, publicKey, err := vpnClient.GenerateKeyPair(ctx) + if err != nil { + return err + } + if err := vpnClient.CreateClientConfig(privateKey, gateway.Port); err != nil { + return err + } + + sshConfig := DefaultSSHConfig(i.sshKeyPath, i.gatewayIP) + ssh := NewSSHClient(sshConfig, i.logger) + vpnServer := NewVPNServerManager(ssh, i.logger) + if err := vpnServer.AddPeer(ctx, publicKey, i.vpnConfig.ClientVPNIP); err != nil { + return err + } + + if err := vpnClient.Start(ctx); err != nil { + return err + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(3 * time.Second): + } + + if !vpnClient.TestConnection(ctx) { + return fmt.Errorf("VPN connection failed") + } + i.logger.Infof("VPN connected: %s", i.vpnConfig.GatewayVPNIP) + + return nil +} + +// joinNode joins the node to the AKS cluster. +func (i *Installer) joinNode(ctx context.Context) error { + sshConfig := DefaultSSHConfig(i.sshKeyPath, i.gatewayIP) + ssh := NewSSHClient(sshConfig, i.logger) + vpnServer := NewVPNServerManager(ssh, i.logger) + + apiServerIP, err := vpnServer.ResolveDNS(ctx, i.clusterInfo.PrivateFQDN) + if err != nil { + return err + } + i.clusterInfo.APIServerIP = apiServerIP + if err := AddHostsEntry(apiServerIP, i.clusterInfo.PrivateFQDN); err != nil { + return fmt.Errorf("failed to add hosts entry: %w", err) + } + i.logger.Infof("API Server: %s (%s)", i.clusterInfo.PrivateFQDN, apiServerIP) + + _, _ = utils.RunCommandWithOutputContext(ctx, "swapoff", "-a") + + kubeconfigPath := "/root/.kube/config" + if err := i.azureClient.GetAKSCredentials(ctx, i.clusterInfo.ResourceGroup, i.clusterInfo.ClusterName, kubeconfigPath); err != nil { + return fmt.Errorf("failed to get AKS credentials: %w", err) + } + if _, err := utils.RunCommandWithOutputContext(ctx, "kubelogin", "convert-kubeconfig", "-l", "azurecli", "--kubeconfig", kubeconfigPath); err != nil { + return fmt.Errorf("failed to convert kubeconfig: %w", err) + } + i.logger.Infof("Kubeconfig ready: %s", kubeconfigPath) + + return nil +} diff --git a/pkg/privatecluster/privatecluster_test.go b/pkg/privatecluster/privatecluster_test.go new file mode 100644 index 0000000..d598557 --- /dev/null +++ b/pkg/privatecluster/privatecluster_test.go @@ -0,0 +1,94 @@ +package privatecluster + +import ( + "context" + "testing" + + "github.com/sirupsen/logrus" +) + +func TestDefaultConfigs(t *testing.T) { + // Test DefaultGatewayConfig + gw := DefaultGatewayConfig() + if gw.Name != "wg-gateway" { + t.Errorf("DefaultGatewayConfig().Name = %v, want wg-gateway", gw.Name) + } + if gw.Port != 51820 { + t.Errorf("DefaultGatewayConfig().Port = %v, want 51820", gw.Port) + } + + // Test DefaultVPNConfig + vpn := DefaultVPNConfig() + if vpn.NetworkInterface != "wg-aks" { + t.Errorf("DefaultVPNConfig().NetworkInterface = %v, want wg-aks", vpn.NetworkInterface) + } + if vpn.GatewayVPNIP != "172.16.0.1" { + t.Errorf("DefaultVPNConfig().GatewayVPNIP = %v, want 172.16.0.1", vpn.GatewayVPNIP) + } +} + +func TestCommandExists(t *testing.T) { + // Test with common command + if !CommandExists("ls") { + t.Error("CommandExists() should return true for 'ls'") + } + + // Test with non-existing command + if CommandExists("nonexistent_command_12345") { + t.Error("CommandExists() should return false for non-existent command") + } +} + +func TestInstallerCreation(t *testing.T) { + logger := logrus.New() + installer := NewInstaller(logger) + if installer == nil { + t.Fatal("NewInstaller() should not return nil") + } + if installer.logger != logger { + t.Error("Installer.logger should match the provided logger") + } +} + +func TestInstallerGetName(t *testing.T) { + installer := NewInstaller(logrus.New()) + if name := installer.GetName(); name != "PrivateClusterInstall" { + t.Errorf("GetName() = %v, want PrivateClusterInstall", name) + } +} + +func TestInstallerIsCompletedNonPrivate(t *testing.T) { + // When config is nil (non-private cluster), IsCompleted should return true + installer := NewInstaller(logrus.New()) + installer.config = nil + if !installer.IsCompleted(context.Background()) { + t.Error("IsCompleted() should return true for non-private cluster") + } +} + +func TestUninstallerCreation(t *testing.T) { + logger := logrus.New() + uninstaller := NewUninstaller(logger) + if uninstaller == nil { + t.Fatal("NewUninstaller() should not return nil") + } + if uninstaller.logger != logger { + t.Error("Uninstaller.logger should match the provided logger") + } +} + +func TestUninstallerGetName(t *testing.T) { + uninstaller := NewUninstaller(logrus.New()) + if name := uninstaller.GetName(); name != "PrivateClusterUninstall" { + t.Errorf("GetName() = %v, want PrivateClusterUninstall", name) + } +} + +func TestUninstallerIsCompletedNonPrivate(t *testing.T) { + // When config is nil (non-private cluster), IsCompleted should return true + uninstaller := NewUninstaller(logrus.New()) + uninstaller.config = nil + if !uninstaller.IsCompleted(context.Background()) { + t.Error("IsCompleted() should return true for non-private cluster") + } +} diff --git a/pkg/privatecluster/ssh.go b/pkg/privatecluster/ssh.go new file mode 100644 index 0000000..1e84646 --- /dev/null +++ b/pkg/privatecluster/ssh.go @@ -0,0 +1,149 @@ +package privatecluster + +import ( + "context" + "fmt" + "os/exec" + "strings" + "time" + + "github.com/sirupsen/logrus" + + "go.goms.io/aks/AKSFlexNode/pkg/utils" +) + +// SSHClient provides SSH operations to a remote host. +type SSHClient struct { + config SSHConfig + logger *logrus.Logger +} + +// NewSSHClient creates a new SSHClient instance. +func NewSSHClient(config SSHConfig, logger *logrus.Logger) *SSHClient { + return &SSHClient{ + config: config, + logger: logger, + } +} + +// buildSSHArgs builds common SSH arguments. +func (s *SSHClient) buildSSHArgs() []string { + return []string{ + "-o", "IdentitiesOnly=yes", + "-o", "StrictHostKeyChecking=no", + "-o", fmt.Sprintf("ConnectTimeout=%d", s.config.Timeout), + "-i", s.config.KeyPath, + } +} + +// Execute runs a command on the remote host and returns the output. +func (s *SSHClient) Execute(ctx context.Context, command string) (string, error) { + args := s.buildSSHArgs() + args = append(args, fmt.Sprintf("%s@%s", s.config.User, s.config.Host), command) + + cmd := exec.CommandContext(ctx, "ssh", args...) // #nosec G204 -- ssh with trusted internal args + output, err := cmd.CombinedOutput() + if err != nil { + return string(output), fmt.Errorf("SSH command failed: %w\nOutput: %s", err, string(output)) + } + return strings.TrimSpace(string(output)), nil +} + +// ExecuteSilent runs a command on the remote host, returning only success/failure. +func (s *SSHClient) ExecuteSilent(ctx context.Context, command string) bool { + args := s.buildSSHArgs() + args = append(args, fmt.Sprintf("%s@%s", s.config.User, s.config.Host), command) + + cmd := exec.CommandContext(ctx, "ssh", args...) // #nosec G204 -- ssh with trusted internal args + return cmd.Run() == nil +} + +// ExecuteScript runs a multi-line script on the remote host. +func (s *SSHClient) ExecuteScript(ctx context.Context, script string) error { + args := s.buildSSHArgs() + args = append(args, fmt.Sprintf("%s@%s", s.config.User, s.config.Host)) + + cmd := exec.CommandContext(ctx, "ssh", args...) // #nosec G204 -- ssh with trusted internal args + cmd.Stdin = strings.NewReader(script) + + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("SSH script execution failed: %w\nOutput: %s", err, string(output)) + } + return nil +} + +// TestConnection tests if SSH connection is ready. +func (s *SSHClient) TestConnection(ctx context.Context) bool { + return s.ExecuteSilent(ctx, "echo ready") +} + +// WaitForConnection waits for SSH connection to be ready with retries. +func (s *SSHClient) WaitForConnection(ctx context.Context, maxAttempts int, interval time.Duration) error { + if s.TestConnection(ctx) { + return nil + } + + s.logger.Infof("Waiting for SSH connection to be ready...") + + for attempt := 1; attempt <= maxAttempts; attempt++ { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(interval): + } + + if s.TestConnection(ctx) { + return nil + } + + s.logger.Debugf("Waiting for SSH... (%d/%d)", attempt, maxAttempts) + } + + return fmt.Errorf("SSH connection timeout after %d attempts", maxAttempts) +} + +// ReadRemoteFile reads a file from the remote host. +func (s *SSHClient) ReadRemoteFile(ctx context.Context, path string) (string, error) { + return s.Execute(ctx, fmt.Sprintf("sudo cat %s 2>/dev/null || echo ''", path)) +} + +// CommandExists checks if a command exists on the remote host. +func (s *SSHClient) CommandExists(ctx context.Context, command string) bool { + return s.ExecuteSilent(ctx, fmt.Sprintf("command -v %s", command)) +} + +// GenerateSSHKey generates an SSH key pair. +func GenerateSSHKey(keyPath string) error { + if utils.FileExists(keyPath) { + return nil + } + + if err := EnsureDirectory(GetRealHome() + "/.ssh"); err != nil { + return err + } + + cmd := exec.Command("ssh-keygen", "-t", "rsa", "-b", "4096", "-f", keyPath, "-N", "") // #nosec G204 -- fixed args + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to generate SSH key: %w\nOutput: %s", err, string(output)) + } + + return FixSSHKeyOwnership(keyPath) +} + +// RemoveSSHKeys removes SSH key pair. +func RemoveSSHKeys(keyPath string) error { + for _, path := range []string{keyPath, keyPath + ".pub"} { + if utils.FileExists(path) { + if err := removeFile(path); err != nil { + return err + } + } + } + return nil +} + +func removeFile(path string) error { + cmd := exec.Command("rm", "-f", path) // #nosec G204 -- fixed command with path arg + return cmd.Run() +} diff --git a/pkg/privatecluster/tool_installer.go b/pkg/privatecluster/tool_installer.go new file mode 100644 index 0000000..e1bf5d3 --- /dev/null +++ b/pkg/privatecluster/tool_installer.go @@ -0,0 +1,86 @@ +package privatecluster + +import ( + "context" + "fmt" + "os" + + "github.com/sirupsen/logrus" + + "go.goms.io/aks/AKSFlexNode/pkg/utils" + "go.goms.io/aks/AKSFlexNode/pkg/utils/utilhost" +) + +const ( + kubeloginVersion = "0.1.6" + kubeloginURLPattern = "https://github.com/Azure/kubelogin/releases/download/v%s/kubelogin-linux-%s.zip" + kubectlURLPattern = "https://acs-mirror.azureedge.net/kubernetes/v%s/bin/linux/%s/kubectl" +) + +// ToolInstaller handles installation of CLI tools via direct downloads. +type ToolInstaller struct { + logger *logrus.Logger +} + +// NewToolInstaller creates a new ToolInstaller instance. +func NewToolInstaller(logger *logrus.Logger) *ToolInstaller { + return &ToolInstaller{logger: logger} +} + +// InstallKubelogin downloads and installs kubelogin binary. +func (t *ToolInstaller) InstallKubelogin(ctx context.Context) error { + if CommandExists("kubelogin") { + return nil + } + + arch := utilhost.GetArch() + url := fmt.Sprintf(kubeloginURLPattern, kubeloginVersion, arch) + zipPath := "/tmp/kubelogin.zip" + + t.logger.Infof("Downloading kubelogin v%s...", kubeloginVersion) + + if _, err := utils.RunCommandWithOutputContext(ctx, "curl", "-L", "-o", zipPath, url); err != nil { + return fmt.Errorf("failed to download kubelogin: %w", err) + } + defer func() { _ = os.Remove(zipPath) }() + + extractDir := "/tmp/kubelogin-extract" + _ = os.RemoveAll(extractDir) + if _, err := utils.RunCommandWithOutputContext(ctx, "unzip", "-o", zipPath, "-d", extractDir); err != nil { + return fmt.Errorf("failed to extract kubelogin: %w", err) + } + defer func() { _ = os.RemoveAll(extractDir) }() + + binaryPath := fmt.Sprintf("%s/bin/linux_%s/kubelogin", extractDir, arch) + if !utils.FileExists(binaryPath) { + return fmt.Errorf("kubelogin binary not found at %s", binaryPath) + } + + if _, err := utils.RunCommandWithOutputContext(ctx, "cp", binaryPath, "/usr/local/bin/kubelogin"); err != nil { + return fmt.Errorf("failed to install kubelogin: %w", err) + } + _ = os.Chmod("/usr/local/bin/kubelogin", 0755) // #nosec G302 G306 -- binary must be executable + + t.logger.Infof("kubelogin v%s installed", kubeloginVersion) + return nil +} + +// InstallKubectl downloads and installs kubectl binary. +func (t *ToolInstaller) InstallKubectl(ctx context.Context, kubernetesVersion string) error { + if CommandExists("kubectl") { + return nil + } + + arch := utilhost.GetArch() + url := fmt.Sprintf(kubectlURLPattern, kubernetesVersion, arch) + + t.logger.Infof("Downloading kubectl v%s...", kubernetesVersion) + + if _, err := utils.RunCommandWithOutputContext(ctx, "curl", "-L", "-o", "/usr/local/bin/kubectl", url); err != nil { + return fmt.Errorf("failed to download kubectl: %w", err) + } + _ = os.Chmod("/usr/local/bin/kubectl", 0755) // #nosec G302 G306 -- binary must be executable + + t.logger.Infof("kubectl v%s installed", kubernetesVersion) + return nil +} diff --git a/pkg/privatecluster/types.go b/pkg/privatecluster/types.go new file mode 100644 index 0000000..04e7fcd --- /dev/null +++ b/pkg/privatecluster/types.go @@ -0,0 +1,83 @@ +package privatecluster + +// CleanupMode defines the cleanup mode for uninstallation +type CleanupMode string + +const ( + CleanupModeLocal CleanupMode = "local" + CleanupModeFull CleanupMode = "full" +) + +// GatewayConfig holds configuration for the VPN Gateway VM +type GatewayConfig struct { + Name string + SubnetName string + SubnetPrefix string + VMSize string + Port int +} + +// VPNConfig holds VPN connection configuration +type VPNConfig struct { + NetworkInterface string + VPNNetwork string + GatewayVPNIP string + ClientVPNIP string + ServerPublicKey string + ServerEndpoint string +} + +// AKSClusterInfo holds parsed AKS cluster information +type AKSClusterInfo struct { + ResourceID string + SubscriptionID string + ResourceGroup string + ClusterName string + Location string + TenantID string + NodeResourceGroup string + VNetName string + VNetResourceGroup string + PrivateFQDN string + APIServerIP string +} + +// SSHConfig holds SSH connection configuration +type SSHConfig struct { + KeyPath string + Host string + User string + Port int + Timeout int +} + +// DefaultGatewayConfig returns the default Gateway configuration +func DefaultGatewayConfig() GatewayConfig { + return GatewayConfig{ + Name: "wg-gateway", + SubnetName: "wg-subnet", + SubnetPrefix: "10.0.100.0/24", + VMSize: "Standard_D2s_v3", + Port: 51820, + } +} + +// DefaultVPNConfig returns the default VPN configuration +func DefaultVPNConfig() VPNConfig { + return VPNConfig{ + NetworkInterface: "wg-aks", + VPNNetwork: "172.16.0.0/24", + GatewayVPNIP: "172.16.0.1", + } +} + +// DefaultSSHConfig returns the default SSH configuration +func DefaultSSHConfig(keyPath, host string) SSHConfig { + return SSHConfig{ + KeyPath: keyPath, + Host: host, + User: "azureuser", + Port: 22, + Timeout: 10, + } +} diff --git a/pkg/privatecluster/uninstaller.go b/pkg/privatecluster/uninstaller.go new file mode 100644 index 0000000..69e9804 --- /dev/null +++ b/pkg/privatecluster/uninstaller.go @@ -0,0 +1,281 @@ +package privatecluster + +import ( + "context" + "fmt" + + "github.com/sirupsen/logrus" + + "go.goms.io/aks/AKSFlexNode/pkg/auth" + "go.goms.io/aks/AKSFlexNode/pkg/config" + "go.goms.io/aks/AKSFlexNode/pkg/utils" +) + +// Uninstaller handles private cluster VPN/Gateway teardown, implementing bootstrapper.Executor. +type Uninstaller struct { + config *config.Config + logger *logrus.Logger + authProvider *auth.AuthProvider + azureClient *AzureClient + + clusterInfo *AKSClusterInfo + vpnConfig VPNConfig + sshKeyPath string + gatewayIP string + clientKey string +} + +// NewUninstaller creates a new private cluster Uninstaller. +func NewUninstaller(logger *logrus.Logger) *Uninstaller { + return &Uninstaller{ + config: config.GetConfig(), + logger: logger, + authProvider: auth.NewAuthProvider(), + vpnConfig: DefaultVPNConfig(), + sshKeyPath: GetSSHKeyPath(), + } +} + +// GetName returns the step name. +func (u *Uninstaller) GetName() string { + return "PrivateClusterUninstall" +} + +// IsCompleted returns true for non-private clusters; always false for private clusters. +func (u *Uninstaller) IsCompleted(ctx context.Context) bool { + if !u.isPrivateCluster() { + return true + } + return false // Always attempt cleanup for private clusters +} + +// Execute runs the private cluster uninstallation. +func (u *Uninstaller) Execute(ctx context.Context) error { + if !u.isPrivateCluster() { + return nil + } + + u.logger.Infof("Remove Edge Node from Private AKS Cluster") + u.logger.Infof("=====================================") + + cleanupMode := u.config.Azure.TargetCluster.CleanupMode + var mode CleanupMode + switch cleanupMode { + case "local", "": + mode = CleanupModeLocal + case "full": + mode = CleanupModeFull + default: + return fmt.Errorf("invalid cleanup mode: %s (use 'local' or 'full')", cleanupMode) + } + + resourceID := u.config.GetTargetClusterID() + if resourceID != "" { + subscriptionID := u.config.GetTargetClusterSubscriptionID() + resourceGroup := u.config.GetTargetClusterResourceGroup() + clusterName := u.config.GetTargetClusterName() + u.clusterInfo = &AKSClusterInfo{ + ResourceID: resourceID, + SubscriptionID: subscriptionID, + ResourceGroup: resourceGroup, + ClusterName: clusterName, + } + u.logger.Infof("Cluster: %s/%s (Subscription: %s)", resourceGroup, clusterName, subscriptionID) + + if mode == CleanupModeFull { + cred, err := u.authProvider.UserCredential(u.config) + if err != nil { + u.logger.Warnf("Failed to get Azure credential: %v", err) + } else { + azureClient, err := NewAzureClient(cred, subscriptionID, u.logger) + if err != nil { + u.logger.Warnf("Failed to create Azure client: %v", err) + } else { + u.azureClient = azureClient + } + } + } + } + + switch mode { + case CleanupModeLocal: + return u.cleanupLocal(ctx) + case CleanupModeFull: + return u.cleanupFull(ctx) + default: + return fmt.Errorf("invalid cleanup mode: %s", mode) + } +} + +// isPrivateCluster checks if the config indicates a private cluster. +func (u *Uninstaller) isPrivateCluster() bool { + return u.config != nil && + u.config.Azure.TargetCluster != nil && + u.config.Azure.TargetCluster.IsPrivateCluster +} + +// cleanupLocal performs local cleanup (keeps Gateway). +func (u *Uninstaller) cleanupLocal(ctx context.Context) error { + u.logger.Infof("Performing local cleanup (keeping Gateway)...") + + hostname, err := GetHostname() + if err != nil { + return err + } + + u.readVPNConfig() + u.removeNodeFromCluster(ctx, hostname) // Must happen while VPN is still up + u.removeClientPeerFromGateway(ctx) + u.stopVPN(ctx) + u.deleteVPNConfig() + u.cleanupHostsEntries() + + u.logger.Infof("Local cleanup completed!") + u.logger.Infof("To rejoin cluster, run:") + u.logger.Infof(" sudo -E ./aks-flex-node agent --config config.json # with private: true") + + return nil +} + +// cleanupFull performs full cleanup (removes all Azure resources). +func (u *Uninstaller) cleanupFull(ctx context.Context) error { + u.logger.Infof("Performing full cleanup...") + + hostname, err := GetHostname() + if err != nil { + return err + } + + u.readVPNConfig() + u.removeNodeFromCluster(ctx, hostname) // Must happen while VPN is still up + u.removeClientPeerFromGateway(ctx) + u.stopVPN(ctx) + u.deleteVPNConfig() + u.cleanupHostsEntries() + + if err := u.deleteAzureResources(ctx); err != nil { + u.logger.Warnf("Failed to delete some Azure resources: %v", err) + } + + u.deleteSSHKeys() + + u.logger.Infof("Full cleanup completed!") + u.logger.Infof("All components and Azure resources have been removed.") + u.logger.Infof("The local machine is now clean.") + + return nil +} + +// readVPNConfig reads Gateway IP and client key from VPN config. +func (u *Uninstaller) readVPNConfig() { + vpnClient := NewVPNClient(u.vpnConfig, u.logger) + gatewayIP, clientKey, err := vpnClient.GetClientConfigInfo() + if err == nil { + u.gatewayIP = gatewayIP + u.clientKey = clientKey + } +} + +// removeNodeFromCluster removes the node from the Kubernetes cluster. +func (u *Uninstaller) removeNodeFromCluster(ctx context.Context, nodeName string) { + if !CommandExists("kubectl") { + return + } + + u.logger.Infof("Removing node %s from cluster...", nodeName) + + if _, err := utils.RunCommandWithOutputContext(ctx, "kubectl", "--kubeconfig", "/root/.kube/config", + "delete", "node", nodeName, "--ignore-not-found"); err == nil { + u.logger.Infof("Node removed from cluster") + return + } + + if _, err := utils.RunCommandWithOutputContext(ctx, "kubectl", "delete", "node", nodeName, "--ignore-not-found"); err == nil { + u.logger.Infof("Node removed from cluster") + return + } + + u.logger.Warnf("Failed to remove node from cluster (may need manual cleanup: kubectl delete node %s)", nodeName) +} + +// removeClientPeerFromGateway removes this client's peer from the Gateway. +func (u *Uninstaller) removeClientPeerFromGateway(ctx context.Context) { + if u.gatewayIP == "" || u.clientKey == "" || !utils.FileExists(u.sshKeyPath) { + return + } + + u.logger.Infof("Removing client peer from Gateway...") + + vpnClient := NewVPNClient(u.vpnConfig, u.logger) + clientPubKey, err := vpnClient.GetPublicKeyFromPrivate(ctx, u.clientKey) + if err != nil || clientPubKey == "" { + return + } + + sshConfig := DefaultSSHConfig(u.sshKeyPath, u.gatewayIP) + sshConfig.Timeout = 10 + ssh := NewSSHClient(sshConfig, u.logger) + vpnServer := NewVPNServerManager(ssh, u.logger) + + _ = vpnServer.RemovePeer(ctx, clientPubKey) + u.logger.Infof("Client peer removed from Gateway") +} + +// stopVPN stops the VPN connection. +func (u *Uninstaller) stopVPN(ctx context.Context) { + vpnClient := NewVPNClient(u.vpnConfig, u.logger) + _ = vpnClient.Stop(ctx) + u.logger.Infof("VPN connection stopped") +} + +// deleteVPNConfig deletes the VPN client configuration. +func (u *Uninstaller) deleteVPNConfig() { + vpnClient := NewVPNClient(u.vpnConfig, u.logger) + _ = vpnClient.RemoveClientConfig() + u.logger.Infof("VPN config deleted") +} + +// cleanupHostsEntries removes AKS-related entries from /etc/hosts. +func (u *Uninstaller) cleanupHostsEntries() { + _ = RemoveHostsEntries("privatelink") + _ = RemoveHostsEntries("azmk8s.io") + u.logger.Infof("Hosts entries cleaned") +} + +// deleteSSHKeys deletes the Gateway SSH keys. +func (u *Uninstaller) deleteSSHKeys() { + _ = RemoveSSHKeys(u.sshKeyPath) + u.logger.Infof("SSH keys deleted") +} + +// deleteAzureResources deletes all Azure resources created for the Gateway. +func (u *Uninstaller) deleteAzureResources(ctx context.Context) error { + if u.clusterInfo == nil || u.azureClient == nil { + return fmt.Errorf("cluster info or Azure client not available") + } + + u.logger.Infof("Deleting Azure resources...") + + gatewayName := "wg-gateway" + pipName := gatewayName + "-pip" + nsgName := gatewayName + "-nsg" + + // VM deletion cascades to NIC and OS disk via DeleteOption set at creation time. + if err := u.azureClient.DeleteVM(ctx, u.clusterInfo.ResourceGroup, gatewayName); err != nil { + u.logger.Warnf("Delete VM: %v", err) + } + if err := u.azureClient.DeletePublicIP(ctx, u.clusterInfo.ResourceGroup, pipName); err != nil { + u.logger.Warnf("Delete Public IP: %v", err) + } + if err := u.azureClient.DeleteNSG(ctx, u.clusterInfo.ResourceGroup, nsgName); err != nil { + u.logger.Warnf("Delete NSG: %v", err) + } + + clusterInfo, err := u.azureClient.GetPrivateClusterInfo(ctx, u.clusterInfo.ResourceGroup, u.clusterInfo.ClusterName) + if err == nil && clusterInfo.VNetName != "" { + _ = u.azureClient.DeleteSubnet(ctx, clusterInfo.VNetResourceGroup, clusterInfo.VNetName, "wg-subnet") + } + u.logger.Infof("Azure resources deleted") + + return nil +} diff --git a/pkg/privatecluster/utils.go b/pkg/privatecluster/utils.go new file mode 100644 index 0000000..cbd4c4c --- /dev/null +++ b/pkg/privatecluster/utils.go @@ -0,0 +1,164 @@ +package privatecluster + +import ( + "bufio" + "fmt" + "os" + "os/exec" + "os/user" + "path/filepath" + "strings" + + "go.goms.io/aks/AKSFlexNode/pkg/utils" +) + +// CommandExists checks if a command is available in PATH. +func CommandExists(name string) bool { + _, err := exec.LookPath(name) + return err == nil +} + +// GetRealHome returns the real user's home directory (handles sudo). +func GetRealHome() string { + if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" { + if u, err := user.Lookup(sudoUser); err == nil { + return u.HomeDir + } + } + if home := os.Getenv("HOME"); home != "" { + return home + } + if u, err := user.Current(); err == nil { + return u.HomeDir + } + return "/root" +} + +// GetSSHKeyPath returns the default SSH key path for the Gateway. +func GetSSHKeyPath() string { + return filepath.Join(GetRealHome(), ".ssh", "id_rsa_wg_gateway") +} + +// EnsureDirectory creates a directory if it doesn't exist. +func EnsureDirectory(path string) error { + return os.MkdirAll(path, 0750) +} + +// ReadFileContent reads a file and returns its content. +func ReadFileContent(path string) (string, error) { + data, err := os.ReadFile(path) // #nosec G304 -- path is from trusted internal code + if err != nil { + return "", err + } + return string(data), nil +} + +// WriteFileContent writes content to a file with specified permissions. +func WriteFileContent(path, content string, perm os.FileMode) error { + return os.WriteFile(path, []byte(content), perm) // #nosec G304 G306 G703 -- path, perm are from trusted internal code +} + +// AddHostsEntry adds an entry to /etc/hosts if it doesn't exist. +func AddHostsEntry(ip, hostname string) error { + hostsPath := "/etc/hosts" + + content, err := ReadFileContent(hostsPath) + if err != nil { + return fmt.Errorf("failed to read hosts file: %w", err) + } + + if strings.Contains(content, hostname) { + return nil // Entry already exists + } + + f, err := os.OpenFile(hostsPath, os.O_APPEND|os.O_WRONLY, 0600) // #nosec G304 -- hostsPath is validated + if err != nil { + return fmt.Errorf("failed to open hosts file: %w", err) + } + defer func() { _ = f.Close() }() + + entry := fmt.Sprintf("%s %s\n", ip, hostname) + if _, err := f.WriteString(entry); err != nil { + return fmt.Errorf("failed to write hosts entry: %w", err) + } + + return nil +} + +// RemoveHostsEntries removes entries matching a pattern from /etc/hosts. +func RemoveHostsEntries(pattern string) error { + hostsPath := "/etc/hosts" + + content, err := ReadFileContent(hostsPath) + if err != nil { + return fmt.Errorf("failed to read hosts file: %w", err) + } + + var newLines []string + scanner := bufio.NewScanner(strings.NewReader(content)) + for scanner.Scan() { + line := scanner.Text() + if !strings.Contains(line, pattern) { + newLines = append(newLines, line) + } + } + + newContent := strings.Join(newLines, "\n") + if !strings.HasSuffix(newContent, "\n") { + newContent += "\n" + } + + return WriteFileContent(hostsPath, newContent, 0644) // #nosec G306 -- /etc/hosts requires 0644 for system readability +} + +// FixSSHKeyOwnership fixes SSH key ownership when running with sudo. +func FixSSHKeyOwnership(keyPath string) error { + sudoUser := os.Getenv("SUDO_USER") + if sudoUser == "" { + return nil // Not running with sudo + } + + u, err := user.Lookup(sudoUser) + if err != nil { + return fmt.Errorf("failed to lookup user %s: %w", sudoUser, err) + } + + for _, path := range []string{keyPath, keyPath + ".pub"} { + if utils.FileExists(path) { + cmd := exec.Command("chown", fmt.Sprintf("%s:%s", u.Uid, u.Gid), path) // #nosec G204 G702 -- chown with uid/gid from user.Lookup + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to change ownership of %s: %w", path, err) + } + } + } + + return nil +} + +// GetHostname returns the lowercase hostname. +func GetHostname() (string, error) { + hostname, err := os.Hostname() + if err != nil { + return "", err + } + return strings.ToLower(hostname), nil +} + +// CleanKubeCache removes kube cache directories. +func CleanKubeCache() error { + paths := []string{ + "/root/.kube/cache", + filepath.Join(GetRealHome(), ".kube", "cache"), + } + + for _, path := range paths { + if utils.FileExists(path) { + if err := os.RemoveAll(path); err != nil { + // Log but don't fail + continue + } + } + } + + return nil +} diff --git a/pkg/privatecluster/vpn.go b/pkg/privatecluster/vpn.go new file mode 100644 index 0000000..d6dbde1 --- /dev/null +++ b/pkg/privatecluster/vpn.go @@ -0,0 +1,258 @@ +package privatecluster + +import ( + "context" + "fmt" + "os/exec" + "strconv" + "strings" + + "github.com/sirupsen/logrus" + + "go.goms.io/aks/AKSFlexNode/pkg/utils" +) + +// VPNClient provides VPN (WireGuard) operations. +type VPNClient struct { + config VPNConfig + logger *logrus.Logger +} + +// NewVPNClient creates a new VPNClient instance. +func NewVPNClient(config VPNConfig, logger *logrus.Logger) *VPNClient { + return &VPNClient{ + config: config, + logger: logger, + } +} + +// GenerateKeyPair generates a WireGuard key pair and returns (privateKey, publicKey). +func (v *VPNClient) GenerateKeyPair(ctx context.Context) (string, string, error) { + privateKeyRaw, err := utils.RunCommandWithOutputContext(ctx, "wg", "genkey") + if err != nil { + return "", "", fmt.Errorf("failed to generate VPN private key: %w", err) + } + privateKey := strings.TrimSpace(privateKeyRaw) + + cmd := exec.CommandContext(ctx, "wg", "pubkey") // #nosec G204 -- fixed wg command + cmd.Stdin = strings.NewReader(privateKey) + publicKeyBytes, err := cmd.Output() + if err != nil { + return "", "", fmt.Errorf("failed to generate VPN public key: %w", err) + } + + return privateKey, strings.TrimSpace(string(publicKeyBytes)), nil +} + +// CreateClientConfig creates the client VPN configuration file. +func (v *VPNClient) CreateClientConfig(privateKey string, gatewayPort int) error { + configPath := fmt.Sprintf("/etc/wireguard/%s.conf", v.config.NetworkInterface) + + config := fmt.Sprintf(`[Interface] +PrivateKey = %s +Address = %s/24 + +[Peer] +PublicKey = %s +Endpoint = %s:%d +AllowedIPs = 10.0.0.0/8, 172.16.0.0/24 +PersistentKeepalive = 25 +`, privateKey, v.config.ClientVPNIP, v.config.ServerPublicKey, v.config.ServerEndpoint, gatewayPort) + + if err := WriteFileContent(configPath, config, 0600); err != nil { + return fmt.Errorf("failed to create VPN client config: %w", err) + } + + return nil +} + +// Start starts the VPN connection. +func (v *VPNClient) Start(ctx context.Context) error { + _ = v.Stop(ctx) + + _, err := utils.RunCommandWithOutputContext(ctx, "wg-quick", "up", v.config.NetworkInterface) + if err != nil { + return fmt.Errorf("failed to start VPN: %w", err) + } + + return nil +} + +// Stop stops the VPN connection. +func (v *VPNClient) Stop(ctx context.Context) error { + _, _ = utils.RunCommandWithOutputContext(ctx, "wg-quick", "down", v.config.NetworkInterface) + return nil +} + +// TestConnection tests VPN connectivity by pinging the gateway. +func (v *VPNClient) TestConnection(ctx context.Context) bool { + return utils.RunCommandSilentContext(ctx, "ping", "-c", "1", "-W", "3", v.config.GatewayVPNIP) +} + +// RemoveClientConfig removes the client VPN configuration file. +func (v *VPNClient) RemoveClientConfig() error { + configPath := fmt.Sprintf("/etc/wireguard/%s.conf", v.config.NetworkInterface) + if utils.FileExists(configPath) { + cmd := exec.Command("rm", "-f", configPath) // #nosec G204 -- fixed rm command + return cmd.Run() + } + return nil +} + +// GetClientConfigInfo reads the current client config and returns Gateway IP and private key. +func (v *VPNClient) GetClientConfigInfo() (gatewayIP, privateKey string, err error) { + configPath := fmt.Sprintf("/etc/wireguard/%s.conf", v.config.NetworkInterface) + + content, err := ReadFileContent(configPath) + if err != nil { + return "", "", fmt.Errorf("failed to read VPN config: %w", err) + } + + for _, line := range strings.Split(content, "\n") { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "Endpoint") { + parts := strings.SplitN(line, "=", 2) + if len(parts) == 2 { + endpoint := strings.TrimSpace(parts[1]) + gatewayIP = strings.Split(endpoint, ":")[0] + } + } + if strings.HasPrefix(line, "PrivateKey") { + parts := strings.SplitN(line, "=", 2) + if len(parts) == 2 { + privateKey = strings.TrimSpace(parts[1]) + } + } + } + + return gatewayIP, privateKey, nil +} + +// GetPublicKeyFromPrivate derives public key from private key. +func (v *VPNClient) GetPublicKeyFromPrivate(ctx context.Context, privateKey string) (string, error) { + cmd := exec.CommandContext(ctx, "wg", "pubkey") // #nosec G204 -- fixed wg command + cmd.Stdin = strings.NewReader(privateKey) + output, err := cmd.Output() + if err != nil { + return "", fmt.Errorf("failed to derive public key: %w", err) + } + return strings.TrimSpace(string(output)), nil +} + +// VPNServerManager manages VPN server on the Gateway. +type VPNServerManager struct { + ssh *SSHClient + logger *logrus.Logger +} + +// NewVPNServerManager creates a new VPNServerManager instance. +func NewVPNServerManager(ssh *SSHClient, logger *logrus.Logger) *VPNServerManager { + return &VPNServerManager{ + ssh: ssh, + logger: logger, + } +} + +// IsInstalled checks if VPN software is installed on the server. +func (m *VPNServerManager) IsInstalled(ctx context.Context) bool { + return m.ssh.CommandExists(ctx, "wg") +} + +// Install installs and configures VPN server. +func (m *VPNServerManager) Install(ctx context.Context) error { + script := `set -e + +# Install WireGuard +sudo apt-get update +sudo apt-get install -y wireguard + +# Generate key pair +sudo wg genkey | sudo tee /etc/wireguard/server_private.key | sudo wg pubkey | sudo tee /etc/wireguard/server_public.key +sudo chmod 600 /etc/wireguard/server_private.key + +SERVER_PRIVATE_KEY=$(sudo cat /etc/wireguard/server_private.key) + +# Create configuration +sudo tee /etc/wireguard/wg0.conf << EOF +[Interface] +PrivateKey = ${SERVER_PRIVATE_KEY} +Address = 172.16.0.1/24 +ListenPort = 51820 +PostUp = iptables -A FORWARD -i wg0 -j ACCEPT; iptables -t nat -A POSTROUTING -o eth0 -j MASQUERADE +PostDown = iptables -D FORWARD -i wg0 -j ACCEPT; iptables -t nat -D POSTROUTING -o eth0 -j MASQUERADE +EOF + +# Enable IP forwarding +echo 'net.ipv4.ip_forward=1' | sudo tee -a /etc/sysctl.conf +sudo sysctl -p + +# Start VPN service +sudo systemctl enable wg-quick@wg0 +sudo systemctl start wg-quick@wg0 || sudo systemctl restart wg-quick@wg0 + +echo "VPN server configuration complete" +` + return m.ssh.ExecuteScript(ctx, script) +} + +// GetPublicKey retrieves the server's public key. +func (m *VPNServerManager) GetPublicKey(ctx context.Context) (string, error) { + key, err := m.ssh.ReadRemoteFile(ctx, "/etc/wireguard/server_public.key") + if err != nil || key == "" { + return "", fmt.Errorf("failed to get server public key") + } + return strings.TrimSpace(key), nil +} + +// GetPeerCount returns the number of existing peers. +func (m *VPNServerManager) GetPeerCount(ctx context.Context) (int, error) { + output, err := m.ssh.Execute(ctx, "sudo wg show wg0 peers 2>/dev/null | wc -l || echo 0") + if err != nil { + return 0, nil // Default to 0 if error + } + count, _ := strconv.Atoi(strings.TrimSpace(output)) + return count, nil +} + +// AddPeer adds a client peer to the server. +func (m *VPNServerManager) AddPeer(ctx context.Context, clientPublicKey, clientIP string) error { + cmd := fmt.Sprintf("sudo wg set wg0 peer '%s' allowed-ips %s/32", clientPublicKey, clientIP) + if _, err := m.ssh.Execute(ctx, cmd); err != nil { + return fmt.Errorf("failed to add peer: %w", err) + } + + if _, err := m.ssh.Execute(ctx, "sudo wg-quick save wg0"); err != nil { + return fmt.Errorf("failed to save VPN config: %w", err) + } + + return nil +} + +// RemovePeer removes a client peer from the server. +func (m *VPNServerManager) RemovePeer(ctx context.Context, clientPublicKey string) error { + cmd := fmt.Sprintf("sudo wg set wg0 peer '%s' remove && sudo wg-quick save wg0", clientPublicKey) + _, _ = m.ssh.Execute(ctx, cmd) + return nil +} + +// ResolveDNS resolves a hostname through the Gateway. +func (m *VPNServerManager) ResolveDNS(ctx context.Context, hostname string) (string, error) { + cmd := fmt.Sprintf("nslookup %s | grep -A1 'Name:' | grep 'Address:' | awk '{print $2}'", hostname) + output, err := m.ssh.Execute(ctx, cmd) + if err != nil || output == "" { + return "", fmt.Errorf("failed to resolve %s through Gateway", hostname) + } + return strings.TrimSpace(output), nil +} + +// InstallVPNTools installs VPN tools locally. +func InstallVPNTools(ctx context.Context, logger *logrus.Logger) error { + if CommandExists("wg") { + return nil + } + if _, err := utils.RunCommandWithOutputContext(ctx, "apt-get", "update"); err != nil { + return err + } + _, err := utils.RunCommandWithOutputContext(ctx, "apt-get", "install", "-y", "wireguard-tools") + return err +} diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 3f35941..093d3c7 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -32,6 +32,19 @@ func RunCommandWithOutput(name string, args ...string) (string, error) { return string(output), err } +// RunCommandWithOutputContext executes a command with context and returns its combined output. +func RunCommandWithOutputContext(ctx context.Context, name string, args ...string) (string, error) { + cmd := exec.CommandContext(ctx, name, args...) // #nosec G204 -- same pattern as RunCommandWithOutput + output, err := cmd.CombinedOutput() + return string(output), err +} + +// RunCommandSilentContext executes a command with context and returns only whether it succeeded. +func RunCommandSilentContext(ctx context.Context, name string, args ...string) bool { + cmd := exec.CommandContext(ctx, name, args...) // #nosec G204 -- same pattern as RunSystemCommand + return cmd.Run() == nil +} + // FileExists checks if a file exists func FileExists(path string) bool { _, err := os.Stat(path)