diff --git a/tools/Azure.Mcp.Tools.Sql/src/Commands/Server/ServerConnPolicyShowCommand.cs b/tools/Azure.Mcp.Tools.Sql/src/Commands/Server/ServerConnPolicyShowCommand.cs new file mode 100644 index 0000000000..90e086516f --- /dev/null +++ b/tools/Azure.Mcp.Tools.Sql/src/Commands/Server/ServerConnPolicyShowCommand.cs @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Net; +using Azure.Mcp.Core.Commands; +using Azure.Mcp.Tools.Sql.Models; +using Azure.Mcp.Tools.Sql.Options.Server; +using Azure.Mcp.Tools.Sql.Services; +using Microsoft.Extensions.Logging; + +namespace Azure.Mcp.Tools.Sql.Commands.Server; + +public sealed class ServerConnPolicyShowCommand(ILogger logger) + : BaseSqlCommand(logger) +{ + private const string CommandTitle = "Show SQL Server Connection Policy"; + + public override string Name => "show"; + + public override string Description => + """ + Retrieves the connection policy for an Azure SQL Server. The connection policy determines + how clients connect to the SQL server and can be one of: Default (uses Azure defaults), + Proxy (all connections are proxied through Azure gateway), or Redirect (connections are + redirected directly to the database node). + """; + + public override string Title => CommandTitle; + + public override ToolMetadata Metadata => new() + { + Destructive = false, + Idempotent = true, + OpenWorld = false, + ReadOnly = true, + LocalRequired = false, + Secret = false + }; + + public override async Task ExecuteAsync(CommandContext context, ParseResult parseResult) + { + if (!Validate(parseResult.CommandResult, context.Response).IsValid) + { + return context.Response; + } + + var options = BindOptions(parseResult); + + try + { + var sqlService = context.GetService(); + + var connectionPolicy = await sqlService.GetServerConnectionPolicyAsync( + options.Server!, + options.ResourceGroup!, + options.Subscription!, + options.RetryPolicy); + + context.Response.Results = ResponseResult.Create(new(connectionPolicy), SqlJsonContext.Default.ServerConnPolicyShowResult); + } + catch (Exception ex) + { + _logger.LogError(ex, + "Error retrieving SQL server connection policy. Server: {Server}, ResourceGroup: {ResourceGroup}, Options: {@Options}", + options.Server, options.ResourceGroup, options); + HandleException(context, ex); + } + + return context.Response; + } + + protected override string GetErrorMessage(Exception ex) => ex switch + { + KeyNotFoundException => + "SQL server connection policy not found. Verify the server name and resource group.", + RequestFailedException reqEx when reqEx.Status == (int)HttpStatusCode.NotFound => + "SQL server or connection policy not found. Verify the server name and resource group.", + RequestFailedException reqEx when reqEx.Status == (int)HttpStatusCode.Forbidden => + $"Authorization failed retrieving the SQL server connection policy. Verify you have appropriate permissions. Details: {reqEx.Message}", + RequestFailedException reqEx => reqEx.Message, + ArgumentException argEx => $"Invalid parameter: {argEx.Message}", + _ => base.GetErrorMessage(ex) + }; + + protected override HttpStatusCode GetStatusCode(Exception ex) => ex switch + { + KeyNotFoundException => HttpStatusCode.NotFound, + RequestFailedException reqEx => (HttpStatusCode)reqEx.Status, + ArgumentException => HttpStatusCode.BadRequest, + _ => base.GetStatusCode(ex) + }; + + internal record ServerConnPolicyShowResult(SqlServerConnectionPolicy ConnectionPolicy); +} diff --git a/tools/Azure.Mcp.Tools.Sql/src/Commands/SqlJsonContext.cs b/tools/Azure.Mcp.Tools.Sql/src/Commands/SqlJsonContext.cs index db75719e05..a515ca1f59 100644 --- a/tools/Azure.Mcp.Tools.Sql/src/Commands/SqlJsonContext.cs +++ b/tools/Azure.Mcp.Tools.Sql/src/Commands/SqlJsonContext.cs @@ -26,9 +26,11 @@ namespace Azure.Mcp.Tools.Sql.Commands; [JsonSerializable(typeof(ServerDeleteCommand.ServerDeleteResult))] [JsonSerializable(typeof(ServerListCommand.ServerListResult))] [JsonSerializable(typeof(ServerShowCommand.ServerShowResult))] +[JsonSerializable(typeof(ServerConnPolicyShowCommand.ServerConnPolicyShowResult))] [JsonSerializable(typeof(ElasticPoolListCommand.ElasticPoolListResult))] [JsonSerializable(typeof(SqlDatabase))] [JsonSerializable(typeof(SqlServer))] +[JsonSerializable(typeof(SqlServerConnectionPolicy))] [JsonSerializable(typeof(SqlServerEntraAdministrator))] [JsonSerializable(typeof(SqlServerFirewallRule))] [JsonSerializable(typeof(SqlElasticPool))] diff --git a/tools/Azure.Mcp.Tools.Sql/src/Models/SqlServerConnectionPolicy.cs b/tools/Azure.Mcp.Tools.Sql/src/Models/SqlServerConnectionPolicy.cs new file mode 100644 index 0000000000..c481ca9954 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Sql/src/Models/SqlServerConnectionPolicy.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.Mcp.Tools.Sql.Models; + +public record SqlServerConnectionPolicy( + string Name, + string Id, + string Type, + string ConnectionType); diff --git a/tools/Azure.Mcp.Tools.Sql/src/Options/Server/ServerConnPolicyShowOptions.cs b/tools/Azure.Mcp.Tools.Sql/src/Options/Server/ServerConnPolicyShowOptions.cs new file mode 100644 index 0000000000..75d6836355 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Sql/src/Options/Server/ServerConnPolicyShowOptions.cs @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.Mcp.Tools.Sql.Options.Server; + +public class ServerConnPolicyShowOptions : BaseSqlOptions +{ +} diff --git a/tools/Azure.Mcp.Tools.Sql/src/Services/ISqlService.cs b/tools/Azure.Mcp.Tools.Sql/src/Services/ISqlService.cs index ac8cd3cece..e055ec1e9c 100644 --- a/tools/Azure.Mcp.Tools.Sql/src/Services/ISqlService.cs +++ b/tools/Azure.Mcp.Tools.Sql/src/Services/ISqlService.cs @@ -309,4 +309,20 @@ Task DeleteServerAsync( string subscription, RetryPolicyOptions? retryPolicy, CancellationToken cancellationToken = default); + + /// + /// Gets the connection policy for a SQL server. + /// + /// The name of the SQL server + /// The name of the resource group + /// The subscription ID or name + /// Optional retry policy options + /// Cancellation token + /// The SQL server connection policy information + Task GetServerConnectionPolicyAsync( + string serverName, + string resourceGroup, + string subscription, + RetryPolicyOptions? retryPolicy, + CancellationToken cancellationToken = default); } diff --git a/tools/Azure.Mcp.Tools.Sql/src/Services/SqlService.cs b/tools/Azure.Mcp.Tools.Sql/src/Services/SqlService.cs index f5360caaf1..45d7d5a4aa 100644 --- a/tools/Azure.Mcp.Tools.Sql/src/Services/SqlService.cs +++ b/tools/Azure.Mcp.Tools.Sql/src/Services/SqlService.cs @@ -1109,4 +1109,68 @@ private static SqlServerFirewallRule ConvertToSqlFirewallRuleModel(JsonElement i EndIpAddress: firewallRule.Properties?.EndIPAddress ); } + + private static SqlServerConnectionPolicy ConvertToSqlServerConnectionPolicyModel(JsonElement item) + { + var nameValue = item.GetProperty("name").GetString() ?? "Unknown"; + var idValue = item.GetProperty("id").GetString() ?? "Unknown"; + var typeValue = item.GetProperty("type").GetString() ?? "Unknown"; + var connectionTypeValue = "Default"; + + if (item.TryGetProperty("properties", out var properties)) + { + if (properties.TryGetProperty("connectionType", out var connectionType)) + { + connectionTypeValue = connectionType.GetString() ?? "Default"; + } + } + + return new SqlServerConnectionPolicy( + Name: nameValue, + Id: idValue, + Type: typeValue, + ConnectionType: connectionTypeValue); + } + + /// + /// Retrieves the connection policy for an Azure SQL Server. + /// The connection policy determines how clients connect to the SQL server (Default, Proxy, or Redirect). + /// + /// The name of the SQL server to get the connection policy for + /// The name of the resource group containing the server + /// The subscription ID or name + /// Optional retry policy configuration for resilient operations + /// Token to observe for cancellation requests + /// The SQL server connection policy information + /// Thrown when required parameters are null or empty + public async Task GetServerConnectionPolicyAsync( + string serverName, + string resourceGroup, + string subscription, + RetryPolicyOptions? retryPolicy, + CancellationToken cancellationToken = default) + { + ValidateRequiredParameters( + (nameof(serverName), serverName), + (nameof(resourceGroup), resourceGroup), + (nameof(subscription), subscription) + ); + + // Use Resource Graph to query connection policy + var result = await ExecuteSingleResourceQueryAsync( + "Microsoft.Sql/servers/connectionPolicies", + resourceGroup: resourceGroup, + subscription: subscription, + retryPolicy: retryPolicy, + converter: ConvertToSqlServerConnectionPolicyModel, + additionalFilter: $"name =~ '{EscapeKqlString(serverName)}/default'", + cancellationToken: cancellationToken); + + if (result == null) + { + throw new KeyNotFoundException($"Connection policy not found for SQL server '{serverName}' in resource group '{resourceGroup}'."); + } + + return result; + } } diff --git a/tools/Azure.Mcp.Tools.Sql/src/SqlSetup.cs b/tools/Azure.Mcp.Tools.Sql/src/SqlSetup.cs index c3e0656baf..0f569f44f2 100644 --- a/tools/Azure.Mcp.Tools.Sql/src/SqlSetup.cs +++ b/tools/Azure.Mcp.Tools.Sql/src/SqlSetup.cs @@ -32,6 +32,7 @@ public void ConfigureServices(IServiceCollection services) services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); @@ -74,6 +75,12 @@ public CommandGroup RegisterCommands(IServiceProvider serviceProvider) var serverShow = serviceProvider.GetRequiredService(); server.AddCommand(serverShow.Name, serverShow); + var connPolicy = new CommandGroup("conn-policy", "SQL server connection policy operations"); + server.AddSubGroup(connPolicy); + + var connPolicyShow = serviceProvider.GetRequiredService(); + connPolicy.AddCommand(connPolicyShow.Name, connPolicyShow); + var elasticPool = new CommandGroup("elastic-pool", "SQL elastic pool operations"); sql.AddSubGroup(elasticPool); var elasticPoolList = serviceProvider.GetRequiredService(); diff --git a/tools/Azure.Mcp.Tools.Sql/tests/Azure.Mcp.Tools.Sql.UnitTests/Server/ServerConnPolicyShowCommandTests.cs b/tools/Azure.Mcp.Tools.Sql/tests/Azure.Mcp.Tools.Sql.UnitTests/Server/ServerConnPolicyShowCommandTests.cs new file mode 100644 index 0000000000..2492db9b3c --- /dev/null +++ b/tools/Azure.Mcp.Tools.Sql/tests/Azure.Mcp.Tools.Sql.UnitTests/Server/ServerConnPolicyShowCommandTests.cs @@ -0,0 +1,255 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using System.Net; +using Azure; +using Azure.Mcp.Core.Models.Command; +using Azure.Mcp.Core.Options; +using Azure.Mcp.Tools.Sql.Commands.Server; +using Azure.Mcp.Tools.Sql.Models; +using Azure.Mcp.Tools.Sql.Services; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using NSubstitute; +using NSubstitute.ExceptionExtensions; +using Xunit; + +namespace Azure.Mcp.Tools.Sql.UnitTests.Server; + +public class ServerConnPolicyShowCommandTests +{ + private readonly IServiceProvider _serviceProvider; + private readonly ISqlService _sqlService; + private readonly ILogger _logger; + private readonly ServerConnPolicyShowCommand _command; + private readonly CommandContext _context; + private readonly Command _commandDefinition; + + public ServerConnPolicyShowCommandTests() + { + _sqlService = Substitute.For(); + _logger = Substitute.For>(); + + var collection = new ServiceCollection(); + collection.AddSingleton(_sqlService); + _serviceProvider = collection.BuildServiceProvider(); + + _command = new(_logger); + _context = new(_serviceProvider); + _commandDefinition = _command.GetCommand(); + } + + [Fact] + public void Constructor_InitializesCommandCorrectly() + { + var command = _command.GetCommand(); + Assert.Equal("show", command.Name); + Assert.NotNull(command.Description); + Assert.Contains("connection policy", command.Description, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task ExecuteAsync_WithValidParameters_ReturnsConnectionPolicy() + { + // Arrange + var mockConnectionPolicy = new SqlServerConnectionPolicy( + Name: "default", + Id: "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Sql/servers/server1/connectionPolicies/default", + Type: "Microsoft.Sql/servers/connectionPolicies", + ConnectionType: "Default"); + + _sqlService.GetServerConnectionPolicyAsync( + Arg.Is("server1"), + Arg.Is("rg"), + Arg.Is("sub"), + Arg.Any(), + Arg.Any()) + .Returns(mockConnectionPolicy); + + var args = _commandDefinition.Parse([ + "--subscription", "sub", + "--resource-group", "rg", + "--server", "server1" + ]); + + // Act + var response = await _command.ExecuteAsync(_context, args); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.OK, response.Status); + Assert.NotNull(response.Results); + Assert.Equal("Success", response.Message); + } + + [Fact] + public async Task ExecuteAsync_HandlesNotFound() + { + // Arrange + _sqlService.GetServerConnectionPolicyAsync( + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()) + .ThrowsAsync(new KeyNotFoundException("Connection policy not found")); + + var args = _commandDefinition.Parse([ + "--subscription", "sub", + "--resource-group", "rg", + "--server", "missing" + ]); + + // Act + var response = await _command.ExecuteAsync(_context, args); + + // Assert + Assert.Equal(HttpStatusCode.NotFound, response.Status); + Assert.Contains("not found", response.Message, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task ExecuteAsync_HandlesForbidden() + { + // Arrange + var forbiddenException = new RequestFailedException((int)HttpStatusCode.Forbidden, "Forbidden"); + _sqlService.GetServerConnectionPolicyAsync( + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()) + .ThrowsAsync(forbiddenException); + + var args = _commandDefinition.Parse([ + "--subscription", "sub", + "--resource-group", "rg", + "--server", "server1" + ]); + + // Act + var response = await _command.ExecuteAsync(_context, args); + + // Assert + Assert.Equal(HttpStatusCode.Forbidden, response.Status); + Assert.Contains("Authorization failed", response.Message); + } + + [Fact] + public async Task ExecuteAsync_HandlesGeneralException() + { + // Arrange + _sqlService.GetServerConnectionPolicyAsync( + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()) + .ThrowsAsync(new Exception("Unexpected error")); + + var args = _commandDefinition.Parse([ + "--subscription", "sub", + "--resource-group", "rg", + "--server", "server1" + ]); + + // Act + var response = await _command.ExecuteAsync(_context, args); + + // Assert + Assert.Equal(HttpStatusCode.InternalServerError, response.Status); + Assert.Contains("Unexpected error", response.Message); + Assert.Contains("troubleshooting", response.Message); + } + + [Theory] + [InlineData("", false, "Missing required options")] + [InlineData("--subscription sub", false, "Missing required options")] + [InlineData("--subscription sub --resource-group rg --server server1", true, null)] + [InlineData("--resource-group rg --server server1", false, "Missing required options")] // Missing subscription + [InlineData("--subscription sub --server server1", false, "Missing required options")] // Missing resource-group + [InlineData("--subscription sub --resource-group rg", false, "Missing required options")] // Missing server + public async Task ExecuteAsync_ValidatesRequiredParameters(string commandArgs, bool shouldSucceed, string? expectedError) + { + // Arrange + if (shouldSucceed) + { + var mockConnectionPolicy = new SqlServerConnectionPolicy( + Name: "default", + Id: "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Sql/servers/server1/connectionPolicies/default", + Type: "Microsoft.Sql/servers/connectionPolicies", + ConnectionType: "Default"); + + _sqlService.GetServerConnectionPolicyAsync( + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()) + .Returns(mockConnectionPolicy); + } + + var args = _commandDefinition.Parse(commandArgs.Split(' ', StringSplitOptions.RemoveEmptyEntries)); + + // Act + var response = await _command.ExecuteAsync(_context, args); + + // Assert + if (shouldSucceed) + { + Assert.Equal(HttpStatusCode.OK, response.Status); + } + else + { + Assert.NotEqual(HttpStatusCode.OK, response.Status); + if (expectedError != null) + { + Assert.Contains(expectedError, response.Message, StringComparison.OrdinalIgnoreCase); + } + } + } + + [Fact] + public async Task ExecuteAsync_WithSubscriptionFromEnvironment_Succeeds() + { + // Arrange - Test when subscription comes from environment variable + Environment.SetEnvironmentVariable("AZURE_SUBSCRIPTION_ID", "env-sub-id"); + + var mockConnectionPolicy = new SqlServerConnectionPolicy( + Name: "default", + Id: "/subscriptions/env-sub-id/resourceGroups/rg/providers/Microsoft.Sql/servers/server1/connectionPolicies/default", + Type: "Microsoft.Sql/servers/connectionPolicies", + ConnectionType: "Default"); + + _sqlService.GetServerConnectionPolicyAsync( + Arg.Is("server1"), + Arg.Is("rg"), + Arg.Is("env-sub-id"), + Arg.Any(), + Arg.Any()) + .Returns(mockConnectionPolicy); + + try + { + var args = _commandDefinition.Parse([ + "--resource-group", "rg", + "--server", "server1" + ]); + + // Act + var response = await _command.ExecuteAsync(_context, args); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.OK, response.Status); + Assert.NotNull(response.Results); + Assert.Equal("Success", response.Message); + } + finally + { + // Clean up environment variable + Environment.SetEnvironmentVariable("AZURE_SUBSCRIPTION_ID", null); + } + } +}