diff --git a/core/Microsoft.Mcp.Core/src/Helpers/EndpointValidator.cs b/core/Microsoft.Mcp.Core/src/Helpers/EndpointValidator.cs new file mode 100644 index 0000000000..95a1d4983d --- /dev/null +++ b/core/Microsoft.Mcp.Core/src/Helpers/EndpointValidator.cs @@ -0,0 +1,332 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Net; +using System.Net.Sockets; +using System.Security; +using Azure.ResourceManager; + +namespace Microsoft.Mcp.Core.Helpers; + +/// +/// Validates Azure service endpoints. +/// +public static class EndpointValidator +{ + /// + /// Gets the allowed domain suffixes for Azure services based on the cloud environment. + /// + /// The ARM environment (cloud) to get suffixes for. + /// Dictionary mapping service types to their allowed domain suffixes. + private static Dictionary GetAllowedDomainSuffixes(ArmEnvironment armEnvironment) + { + // Determine which cloud we're in + var isPublicCloud = armEnvironment.Equals(ArmEnvironment.AzurePublicCloud); + var isChinaCloud = armEnvironment.Equals(ArmEnvironment.AzureChina); + var isGovCloud = armEnvironment.Equals(ArmEnvironment.AzureGovernment); + + // Build cloud-specific suffixes for services that require validation + var acrSuffix = isPublicCloud ? "azurecr.io" + : isChinaCloud ? "azurecr.cn" + : isGovCloud ? "azurecr.us" + : "azurecr.io"; + + var appConfigSuffix = isPublicCloud ? "azconfig.io" + : isChinaCloud ? "azconfig.azure.cn" + : isGovCloud ? "azconfig.azure.us" + : "azconfig.io"; + + var commSuffix = isPublicCloud ? "communication.azure.com" + : isChinaCloud ? "communication.azure.cn" + : isGovCloud ? "communication.azure.us" + : "communication.azure.com"; + + return new Dictionary + { + // Azure Communication Services + { "communication", [$".{commSuffix}"] }, + + // Azure App Configuration + { "appconfig", [$".{appConfigSuffix}"] }, + + // Azure Container Registry + { "acr", [$".{acrSuffix}"] }, + }; + } + + /// + /// Validates that an endpoint belongs to an allowed Azure service domain. + /// Uses Azure Public Cloud domains by default. + /// + /// The endpoint URL to validate. + /// The type of Azure service (e.g., "storage-blob", "keyvault"). + public static void ValidateAzureServiceEndpoint(string endpoint, string serviceType) + { + ValidateAzureServiceEndpoint(endpoint, serviceType, ArmEnvironment.AzurePublicCloud); + } + + /// + /// Validates that an endpoint belongs to an allowed Azure service domain for the specified cloud environment. + /// + /// The endpoint URL to validate. + /// The type of Azure service (e.g., "storage-blob", "keyvault"). + /// The Azure cloud environment (Public, China, Government, etc.). + public static void ValidateAzureServiceEndpoint(string endpoint, string serviceType, ArmEnvironment armEnvironment) + { + if (string.IsNullOrWhiteSpace(endpoint)) + { + throw new ArgumentException("Endpoint cannot be null or empty", nameof(endpoint)); + } + + if (!Uri.TryCreate(endpoint, UriKind.Absolute, out var uri)) + { + throw new SecurityException($"Invalid endpoint format: {endpoint}"); + } + + // Ensure HTTPS + if (!uri.Scheme.Equals("https", StringComparison.OrdinalIgnoreCase)) + { + throw new SecurityException( + $"Endpoint must use HTTPS protocol. Got: {uri.Scheme}"); + } + + var allowedDomainSuffixes = GetAllowedDomainSuffixes(armEnvironment); + + if (!allowedDomainSuffixes.TryGetValue(serviceType, out var allowedSuffixes)) + { + throw new ArgumentException($"Unknown service type: {serviceType}", nameof(serviceType)); + } + + // Validate domain: must exactly match suffix or be a proper subdomain + var isValid = allowedSuffixes.Any(suffix => + { + // Exact match (e.g., "azconfig.io") + if (uri.Host.Equals(suffix.TrimStart('.'), StringComparison.OrdinalIgnoreCase)) + return true; + + // Proper subdomain match (e.g., "myconfig.azconfig.io" matches ".azconfig.io") + // Ensure the suffix starts with a dot, then check if host ends with it + if (suffix.StartsWith('.') && uri.Host.EndsWith(suffix, StringComparison.OrdinalIgnoreCase)) + { + // Ensure there's a subdomain portion and it doesn't contain path separators + // This prevents path components from being interpreted as subdomains (e.g., "azconfig.io/evil") + // Note: Multi-level subdomains like "sub.myconfig.azconfig.io" are valid and allowed + var domainBeforeSuffix = uri.Host.Substring(0, uri.Host.Length - suffix.Length); + return !string.IsNullOrEmpty(domainBeforeSuffix) && !domainBeforeSuffix.Contains('/'); + } + + return false; + }); + + if (!isValid) + { + var cloudName = armEnvironment.Equals(ArmEnvironment.AzurePublicCloud) ? "Azure Public Cloud" + : armEnvironment.Equals(ArmEnvironment.AzureChina) ? "Azure China Cloud" + : armEnvironment.Equals(ArmEnvironment.AzureGovernment) ? "Azure US Government Cloud" + : "configured Azure cloud"; + + throw new SecurityException( + $"Endpoint host '{uri.Host}' is not a valid {serviceType} domain for {cloudName}. " + + $"Expected domains: {string.Join(", ", allowedSuffixes)}"); + } + } + + /// + /// Validates that a URL is from an allowed external domain (GitHub, etc.) + /// + public static void ValidateExternalUrl(string url, string[] allowedHosts) + { + if (string.IsNullOrWhiteSpace(url)) + { + throw new ArgumentException("URL cannot be null or empty", nameof(url)); + } + + if (!Uri.TryCreate(url, UriKind.Absolute, out var uri)) + { + throw new SecurityException($"Invalid URL format: {url}"); + } + + // Ensure HTTPS for external URLs + if (!uri.Scheme.Equals("https", StringComparison.OrdinalIgnoreCase)) + { + throw new SecurityException( + $"External URL must use HTTPS protocol. Got: {uri.Scheme}"); + } + + var isAllowed = allowedHosts.Any(host => + uri.Host.Equals(host, StringComparison.OrdinalIgnoreCase)); + + if (!isAllowed) + { + throw new SecurityException( + $"URL host '{uri.Host}' is not in the allowed list. " + + $"Allowed hosts: {string.Join(", ", allowedHosts)}"); + } + } + + /// + /// Validates that a target URL (for load testing, etc.) isn't pointing to internal resources. + /// Performs DNS resolution to detect hostnames that resolve to private/reserved IPs. + /// + public static void ValidatePublicTargetUrl(string url) + { + if (string.IsNullOrWhiteSpace(url)) + { + throw new ArgumentException("URL cannot be null or empty", nameof(url)); + } + + if (!Uri.TryCreate(url, UriKind.Absolute, out var uri)) + { + throw new SecurityException($"Invalid URL format: {url}"); + } + + // Check if host is a literal IP address + if (IPAddress.TryParse(uri.Host, out var ipAddress)) + { + if (IsPrivateOrReservedIP(ipAddress)) + { + throw new SecurityException( + $"Target URL '{url}' uses a private or reserved IP address ({ipAddress}). " + + "Targeting internal endpoints is not permitted."); + } + } + else + { + // Check for reserved hostnames (catches localhost variations) + var reservedHosts = new[] + { + "localhost", + "local", + "localtest.me", // Common localhost alias + "lvh.me", // localhost variations + "169.254.169.254.nip.io" // IMDS bypass attempt + }; + + if (reservedHosts.Any(reserved => + uri.Host.Equals(reserved, StringComparison.OrdinalIgnoreCase) || + uri.Host.EndsWith($".{reserved}", StringComparison.OrdinalIgnoreCase))) + { + throw new SecurityException( + $"Target URL hostname '{uri.Host}' is reserved and cannot be targeted."); + } + + // Resolve DNS and validate all resolved IPs + try + { + var hostEntry = Dns.GetHostEntry(uri.Host); + foreach (var resolvedIp in hostEntry.AddressList) + { + if (IsPrivateOrReservedIP(resolvedIp)) + { + throw new SecurityException( + $"Target URL hostname '{uri.Host}' resolves to a private or reserved IP address ({resolvedIp}). " + + "Targeting internal endpoints is not permitted."); + } + } + } + catch (SecurityException) + { + throw; // Re-throw SecurityException from private IP check + } + catch (Exception ex) + { + // DNS resolution failure - treat as invalid for security + throw new SecurityException( + $"Unable to resolve hostname '{uri.Host}' for security validation. " + + $"Ensure the hostname is publicly resolvable. Details: {ex.Message}"); + } + } + } + + /// + /// Checks if an IP address is private, reserved, or otherwise non-routable + /// + public static bool IsPrivateOrReservedIP(IPAddress ipAddress) + { + var bytes = ipAddress.GetAddressBytes(); + + if (ipAddress.AddressFamily == AddressFamily.InterNetwork) + { + // Loopback: 127.0.0.0/8 + if (bytes[0] == 127) + { + return true; + } + + // Private: 10.0.0.0/8 + if (bytes[0] == 10) + { + return true; + } + + // Private: 172.16.0.0/12 + if (bytes[0] == 172 && bytes[1] >= 16 && bytes[1] <= 31) + { + return true; + } + + // Private: 192.168.0.0/16 + if (bytes[0] == 192 && bytes[1] == 168) + { + return true; + } + + // Link-local: 169.254.0.0/16 (includes IMDS at 169.254.169.254) + if (bytes[0] == 169 && bytes[1] == 254) + { + return true; + } + + // WireServer: 168.63.129.16 + if (bytes[0] == 168 && bytes[1] == 63 && bytes[2] == 129 && bytes[3] == 16) + { + return true; + } + + // Shared address space: 100.64.0.0/10 + if (bytes[0] == 100 && bytes[1] >= 64 && bytes[1] <= 127) + { + return true; + } + + // Broadcast: 255.255.255.255 + if (bytes[0] == 255 && bytes[1] == 255 && bytes[2] == 255 && bytes[3] == 255) + { + return true; + } + + // Reserved ranges + if (bytes[0] == 0) + { + return true; // 0.0.0.0/8 + } + + if (bytes[0] >= 224) + { + return true; // Multicast (224.0.0.0/4) and Reserved (240.0.0.0/4) + } + } + else if (ipAddress.AddressFamily == AddressFamily.InterNetworkV6) + { + // Loopback: ::1 + if (ipAddress.Equals(IPAddress.IPv6Loopback)) + { + return true; + } + + // Private: fc00::/7 + if ((bytes[0] & 0xfe) == 0xfc) + { + return true; + } + + // Link-local: fe80::/10 + if (bytes[0] == 0xfe && (bytes[1] & 0xc0) == 0x80) + { + return true; + } + } + + return false; + } +} diff --git a/core/Microsoft.Mcp.Core/tests/Microsoft.Mcp.Core.UnitTests/Helpers/EndpointValidatorTests.cs b/core/Microsoft.Mcp.Core/tests/Microsoft.Mcp.Core.UnitTests/Helpers/EndpointValidatorTests.cs new file mode 100644 index 0000000000..027c7cd683 --- /dev/null +++ b/core/Microsoft.Mcp.Core/tests/Microsoft.Mcp.Core.UnitTests/Helpers/EndpointValidatorTests.cs @@ -0,0 +1,429 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Security; +using Azure.ResourceManager; +using Microsoft.Mcp.Core.Helpers; +using Xunit; + +namespace Microsoft.Mcp.Core.UnitTests.Helpers; + +public class EndpointValidatorTests +{ + #region ValidateAzureServiceEndpoint Tests + + [Theory] + [InlineData("https://mycomm.communication.azure.com", "communication")] + [InlineData("https://myconfig.azconfig.io", "appconfig")] + [InlineData("https://myregistry.azurecr.io", "acr")] + public void ValidateAzureServiceEndpoint_ValidEndpoints_DoesNotThrow(string endpoint, string serviceType) + { + // Act & Assert + var exception = Record.Exception(() => EndpointValidator.ValidateAzureServiceEndpoint(endpoint, serviceType)); + Assert.Null(exception); + } + + [Theory] + [InlineData("https://evil.com", "communication", "not a valid communication domain")] + [InlineData("https://evil.com/.communication.azure.com", "communication", "not a valid communication domain")] + [InlineData("http://mycomm.communication.azure.com", "communication", "must use HTTPS")] + [InlineData("ftp://myconfig.azconfig.io", "appconfig", "must use HTTPS")] + public void ValidateAzureServiceEndpoint_InvalidEndpoints_ThrowsSecurityException( + string endpoint, + string serviceType, + string expectedMessagePart) + { + // Act & Assert + var exception = Assert.Throws( + () => EndpointValidator.ValidateAzureServiceEndpoint(endpoint, serviceType)); + Assert.Contains(expectedMessagePart, exception.Message, StringComparison.OrdinalIgnoreCase); + } + + [Theory] + [InlineData("", "communication")] + [InlineData(" ", "communication")] + public void ValidateAzureServiceEndpoint_NullOrEmptyEndpoint_ThrowsArgumentException( + string endpoint, + string serviceType) + { + // Act & Assert + Assert.Throws( + () => EndpointValidator.ValidateAzureServiceEndpoint(endpoint, serviceType)); + } + + [Fact] + public void ValidateAzureServiceEndpoint_NullEndpoint_ThrowsArgumentException() + { + // Act & Assert + Assert.Throws( + () => EndpointValidator.ValidateAzureServiceEndpoint(null!, "communication")); + } + + [Fact] + public void ValidateAzureServiceEndpoint_InvalidUriFormat_ThrowsSecurityException() + { + // Arrange + var invalidEndpoint = "not-a-valid-uri"; + + // Act & Assert + var exception = Assert.Throws( + () => EndpointValidator.ValidateAzureServiceEndpoint(invalidEndpoint, "communication")); + Assert.Contains("Invalid endpoint format", exception.Message); + } + + [Fact] + public void ValidateAzureServiceEndpoint_UnknownServiceType_ThrowsArgumentException() + { + // Arrange + var endpoint = "https://example.com"; + var unknownServiceType = "unknown-service"; + + // Act & Assert + var exception = Assert.Throws( + () => EndpointValidator.ValidateAzureServiceEndpoint(endpoint, unknownServiceType)); + Assert.Contains("Unknown service type", exception.Message); + } + + #endregion + + #region Sovereign Cloud Tests + + [Theory] + // Azure China Cloud + [InlineData("https://myregistry.azurecr.cn", "acr")] + [InlineData("https://myconfig.azconfig.azure.cn", "appconfig")] + [InlineData("https://mycomm.communication.azure.cn", "communication")] + public void ValidateAzureServiceEndpoint_AzureChinaCloud_ValidEndpoints_DoesNotThrow(string endpoint, string serviceType) + { + // Act & Assert + var exception = Record.Exception(() => + EndpointValidator.ValidateAzureServiceEndpoint(endpoint, serviceType, ArmEnvironment.AzureChina)); + Assert.Null(exception); + } + + [Theory] + // Azure US Government + [InlineData("https://myregistry.azurecr.us", "acr")] + [InlineData("https://myconfig.azconfig.azure.us", "appconfig")] + [InlineData("https://mycomm.communication.azure.us", "communication")] + public void ValidateAzureServiceEndpoint_AzureGovernment_ValidEndpoints_DoesNotThrow(string endpoint, string serviceType) + { + // Act & Assert + var exception = Record.Exception(() => + EndpointValidator.ValidateAzureServiceEndpoint(endpoint, serviceType, ArmEnvironment.AzureGovernment)); + Assert.Null(exception); + } + + [Theory] + // Public cloud endpoint should fail in China cloud + [InlineData("https://myregistry.azurecr.io", "acr")] + [InlineData("https://myconfig.azconfig.io", "appconfig")] + public void ValidateAzureServiceEndpoint_PublicCloudEndpoint_InChinaCloud_Throws(string endpoint, string serviceType) + { + // Act & Assert + var exception = Assert.Throws(() => + EndpointValidator.ValidateAzureServiceEndpoint(endpoint, serviceType, ArmEnvironment.AzureChina)); + Assert.Contains("Azure China Cloud", exception.Message); + Assert.Contains("not a valid", exception.Message); + } + + [Theory] + // Public cloud endpoint should fail in Gov cloud + [InlineData("https://myregistry.azurecr.io", "acr")] + [InlineData("https://myconfig.azconfig.io", "appconfig")] + public void ValidateAzureServiceEndpoint_PublicCloudEndpoint_InGovCloud_Throws(string endpoint, string serviceType) + { + // Act & Assert + var exception = Assert.Throws(() => + EndpointValidator.ValidateAzureServiceEndpoint(endpoint, serviceType, ArmEnvironment.AzureGovernment)); + Assert.Contains("Azure US Government Cloud", exception.Message); + Assert.Contains("not a valid", exception.Message); + } + + [Theory] + // China cloud endpoint should fail in public cloud + [InlineData("https://myregistry.azurecr.cn", "acr")] + [InlineData("https://myconfig.azconfig.azure.cn", "appconfig")] + public void ValidateAzureServiceEndpoint_ChinaCloudEndpoint_InPublicCloud_Throws(string endpoint, string serviceType) + { + // Act & Assert + var exception = Assert.Throws(() => + EndpointValidator.ValidateAzureServiceEndpoint(endpoint, serviceType, ArmEnvironment.AzurePublicCloud)); + Assert.Contains("Azure Public Cloud", exception.Message); + Assert.Contains("not a valid", exception.Message); + } + + #endregion + + #region ValidateExternalUrl Tests + + [Theory] + [InlineData("https://raw.githubusercontent.com/user/repo/main/file.txt", new[] { "raw.githubusercontent.com", "github.com" })] + [InlineData("https://github.com/user/repo", new[] { "raw.githubusercontent.com", "github.com" })] + [InlineData("https://example.com/path", new[] { "example.com" })] + public void ValidateExternalUrl_AllowedHost_DoesNotThrow(string url, string[] allowedHosts) + { + // Act & Assert + var exception = Record.Exception(() => EndpointValidator.ValidateExternalUrl(url, allowedHosts)); + Assert.Null(exception); + } + + [Theory] + [InlineData("https://evil.com/malicious", new[] { "github.com" }, "not in the allowed list")] + [InlineData("http://github.com/repo", new[] { "github.com" }, "must use HTTPS")] + public void ValidateExternalUrl_InvalidHost_ThrowsSecurityException( + string url, + string[] allowedHosts, + string expectedMessagePart) + { + // Act & Assert + var exception = Assert.Throws( + () => EndpointValidator.ValidateExternalUrl(url, allowedHosts)); + Assert.Contains(expectedMessagePart, exception.Message, StringComparison.OrdinalIgnoreCase); + } + + [Theory] + [InlineData("", new[] { "github.com" })] + [InlineData(" ", new[] { "github.com" })] + public void ValidateExternalUrl_NullOrEmptyUrl_ThrowsArgumentException(string url, string[] allowedHosts) + { + // Act & Assert + Assert.Throws(() => EndpointValidator.ValidateExternalUrl(url, allowedHosts)); + } + + [Fact] + public void ValidateExternalUrl_NullUrl_ThrowsArgumentException() + { + // Act & Assert + Assert.Throws( + () => EndpointValidator.ValidateExternalUrl(null!, new[] { "github.com" })); + } + + #endregion + + #region ValidatePublicTargetUrl Tests - SDL Exit Criteria + + [Theory] + [InlineData("https://www.microsoft.com")] + [InlineData("https://www.google.com")] + [InlineData("https://github.com")] + [InlineData("https://8.8.8.8")] // Public IP (Google DNS) + [InlineData("https://1.1.1.1")] // Public IP (Cloudflare DNS) + public void ValidatePublicTargetUrl_PublicEndpoints_DoesNotThrow(string url) + { + // Act & Assert + var exception = Record.Exception(() => EndpointValidator.ValidatePublicTargetUrl(url)); + Assert.Null(exception); + } + + [Theory] + // IMDS and WireServer (Critical) + [InlineData("http://169.254.169.254")] + [InlineData("http://169.254.169.254/latest/meta-data/")] + [InlineData("http://168.63.129.16")] + [InlineData("http://168.63.129.16/machine?comp=goalstate")] + + // Loopback addresses + [InlineData("http://127.0.0.1")] + [InlineData("http://127.0.200.8")] + [InlineData("http://127.255.255.255")] + [InlineData("http://[::1]")] + + // Private networks (RFC1918) + [InlineData("http://10.0.0.1")] + [InlineData("http://10.255.255.255")] + [InlineData("http://172.16.0.1")] + [InlineData("http://172.16.0.99")] + [InlineData("http://172.31.255.255")] + [InlineData("http://192.168.0.1")] + [InlineData("http://192.168.0.101")] + [InlineData("http://192.168.255.255")] + + // Shared address space (CGNAT) + [InlineData("http://100.64.0.1")] + [InlineData("http://100.64.0.123")] + [InlineData("http://100.127.255.255")] + + // Link-local (APIPA) + [InlineData("http://169.254.0.1")] + [InlineData("http://169.254.255.255")] + + // Reserved/Special addresses + [InlineData("http://0.0.0.0")] + [InlineData("http://255.255.255.255")] + + // IPv6 private + [InlineData("http://[fc00::1]")] + [InlineData("http://[fd00::1]")] + + // Reserved hostnames + [InlineData("http://localhost")] + [InlineData("http://local")] + public void ValidatePublicTargetUrl_PrivateOrReservedAddresses_ThrowsSecurityException(string url) + { + // Act & Assert + var exception = Assert.Throws(() => + EndpointValidator.ValidatePublicTargetUrl(url)); + // The error message varies: "private or reserved" for IPs, "reserved" for hostnames + Assert.True( + exception.Message.Contains("private or reserved", StringComparison.OrdinalIgnoreCase) || + exception.Message.Contains("reserved", StringComparison.OrdinalIgnoreCase), + $"Expected error message to contain 'private or reserved' or 'reserved', but got: {exception.Message}"); + } + + [Theory] + [InlineData("")] + [InlineData(" ")] + public void ValidatePublicTargetUrl_NullOrEmptyUrl_ThrowsArgumentException(string url) + { + // Act & Assert + Assert.Throws(() => EndpointValidator.ValidatePublicTargetUrl(url)); + } + + [Fact] + public void ValidatePublicTargetUrl_NullUrl_ThrowsArgumentException() + { + // Act & Assert + Assert.Throws(() => EndpointValidator.ValidatePublicTargetUrl(null!)); + } + + [Fact] + public void ValidatePublicTargetUrl_InvalidUriFormat_ThrowsSecurityException() + { + // Arrange + var invalidUrl = "not-a-valid-uri"; + + // Act & Assert + var exception = Assert.Throws(() => + EndpointValidator.ValidatePublicTargetUrl(invalidUrl)); + Assert.Contains("Invalid URL format", exception.Message); + } + + [Theory] + [InlineData("http://localhost")] + [InlineData("http://LOCALHOST")] + [InlineData("http://localhost:8080")] + [InlineData("http://local")] + [InlineData("http://localtest.me")] // Common localhost alias + [InlineData("http://lvh.me")] // Another localhost variation + public void ValidatePublicTargetUrl_ReservedHostnames_ThrowsSecurityException(string url) + { + // Act & Assert + var exception = Assert.Throws(() => + EndpointValidator.ValidatePublicTargetUrl(url)); + Assert.Contains("reserved", exception.Message, StringComparison.OrdinalIgnoreCase); + } + + [Theory] + [InlineData("http://127.0.0.1.nip.io")] // nip.io resolves to 127.0.0.1 + [InlineData("http://127.0.0.1.xip.io")] // xip.io resolves to 127.0.0.1 + [InlineData("http://127.0.0.1.sslip.io")] // sslip.io resolves to 127.0.0.1 + [InlineData("http://10.0.0.1.nip.io")] // Private IP via DNS + [InlineData("http://192.168.1.1.nip.io")] // Private IP via DNS + public void ValidatePublicTargetUrl_DnsResolvesToPrivateIP_ThrowsSecurityException(string url) + { + // This test validates that DNS resolution is performed and private IPs are caught + // Note: These services (nip.io, xip.io, sslip.io) actually resolve to the IPs in the subdomain + // If DNS resolution fails (e.g., offline), the test will throw SecurityException for different reason + + // Act & Assert + var exception = Assert.Throws(() => + EndpointValidator.ValidatePublicTargetUrl(url)); + + // The error should mention either: + // 1. "resolves to a private or reserved IP" (if DNS succeeded) + // 2. "Unable to resolve hostname" (if DNS failed - still secure) + Assert.True( + exception.Message.Contains("private or reserved", StringComparison.OrdinalIgnoreCase) || + exception.Message.Contains("Unable to resolve hostname", StringComparison.OrdinalIgnoreCase), + $"Expected error about private IP or DNS resolution, but got: {exception.Message}"); + } + + [Fact] + public void ValidatePublicTargetUrl_UnresolvableHostname_ThrowsSecurityException() + { + // Arrange - use a guaranteed non-existent hostname + var url = "http://this-hostname-definitely-does-not-exist-12345.invalid"; + + // Act & Assert + var exception = Assert.Throws(() => + EndpointValidator.ValidatePublicTargetUrl(url)); + Assert.Contains("Unable to resolve hostname", exception.Message); + } + + #endregion + + #region DNS Bypass Prevention Tests - SDL Security + + [Theory] + [InlineData("http://169.254.169.254.nip.io")] // IMDS bypass attempt + public void ValidatePublicTargetUrl_KnownSSRFBypassDomains_ThrowsSecurityException(string url) + { + // Act & Assert + var exception = Assert.Throws(() => + EndpointValidator.ValidatePublicTargetUrl(url)); + Assert.True( + exception.Message.Contains("reserved", StringComparison.OrdinalIgnoreCase) || + exception.Message.Contains("private or reserved", StringComparison.OrdinalIgnoreCase) || + exception.Message.Contains("Unable to resolve hostname", StringComparison.OrdinalIgnoreCase), + $"Expected security error, but got: {exception.Message}"); + } + + #endregion + + #region Edge Cases and Security Scenarios + + [Theory] + [InlineData("https://myconfig.azconfig.io/", "appconfig")] // Trailing slash + [InlineData("https://myconfig.azconfig.io:443", "appconfig")] // Explicit port + [InlineData("https://MYCONFIG.AZCONFIG.IO", "appconfig")] // Mixed case + public void ValidateAzureServiceEndpoint_EdgeCases_DoesNotThrow(string endpoint, string serviceType) + { + // Act & Assert + var exception = Record.Exception( + () => EndpointValidator.ValidateAzureServiceEndpoint(endpoint, serviceType)); + Assert.Null(exception); + } + + [Theory] + [InlineData("https://evil.com/.azconfig.io", "appconfig")] // Domain suffix attack + [InlineData("https://azconfig.io.evil.com", "appconfig")] // Domain prefix attack + [InlineData("https://myconfig-azconfig.io", "appconfig")] // Typosquatting + public void ValidateAzureServiceEndpoint_DomainSpoofingAttempts_ThrowsSecurityException( + string endpoint, + string serviceType) + { + // Act & Assert + Assert.Throws( + () => EndpointValidator.ValidateAzureServiceEndpoint(endpoint, serviceType)); + } + + [Fact] + public void ValidateExternalUrl_CaseInsensitiveHostMatching_Works() + { + // Arrange + var url = "https://GITHUB.COM/repo"; + var allowedHosts = new[] { "github.com" }; + + // Act & Assert + var exception = Record.Exception(() => EndpointValidator.ValidateExternalUrl(url, allowedHosts)); + Assert.Null(exception); + } + + [Theory] + [InlineData("http://192.168.1.1/admin")] // Private network admin panel + [InlineData("http://10.0.0.1/api")] // Private API endpoint + [InlineData("http://localhost:8080/health")] // Local service health check + public void ValidatePublicTargetUrl_CommonSSRFTargets_ThrowsSecurityException(string url) + { + // Act & Assert + var exception = Assert.Throws(() => + EndpointValidator.ValidatePublicTargetUrl(url)); + Assert.True( + exception.Message.Contains("private or reserved", StringComparison.OrdinalIgnoreCase) || + exception.Message.Contains("reserved", StringComparison.OrdinalIgnoreCase), + $"Expected error message about private or reserved addresses, but got: {exception.Message}"); + } + + #endregion +} diff --git a/tools/Azure.Mcp.Tools.Acr/src/Services/AcrService.cs b/tools/Azure.Mcp.Tools.Acr/src/Services/AcrService.cs index 68c4f7f76f..d2875e3e9a 100644 --- a/tools/Azure.Mcp.Tools.Acr/src/Services/AcrService.cs +++ b/tools/Azure.Mcp.Tools.Acr/src/Services/AcrService.cs @@ -9,6 +9,7 @@ using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.Mcp.Tools.Acr.Models; using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Helpers; namespace Azure.Mcp.Tools.Acr.Services; @@ -82,6 +83,12 @@ private async Task GetRegistry( private async Task> AddRepositoriesForRegistryAsync(AcrRegistryInfo reg, string? tenant, RetryPolicyOptions? retryPolicy, CancellationToken cancellationToken) { + if (!string.IsNullOrEmpty(reg.LoginServer)) + { + var acrEndpointString = $"https://{reg.LoginServer}"; + EndpointValidator.ValidateAzureServiceEndpoint(acrEndpointString, "acr"); + } + // Build data-plane client for this login server var credential = await GetCredential(tenant, cancellationToken); var options = ConfigureRetryPolicy(AddDefaultPolicies(new ContainerRegistryClientOptions()), retryPolicy); diff --git a/tools/Azure.Mcp.Tools.AppConfig/src/Services/AppConfigService.cs b/tools/Azure.Mcp.Tools.AppConfig/src/Services/AppConfigService.cs index 197a5236b3..4ce45be44b 100644 --- a/tools/Azure.Mcp.Tools.AppConfig/src/Services/AppConfigService.cs +++ b/tools/Azure.Mcp.Tools.AppConfig/src/Services/AppConfigService.cs @@ -11,6 +11,7 @@ using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.Mcp.Tools.AppConfig.Models; using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Helpers; namespace Azure.Mcp.Tools.AppConfig.Services; @@ -152,6 +153,9 @@ private async Task GetConfigurationClient(string accountNam { throw new InvalidOperationException($"The App Configuration store '{accountName}' does not have a valid endpoint."); } + + EndpointValidator.ValidateAzureServiceEndpoint(endpoint, "appconfig"); + var credential = await GetCredential(cancellationToken); var options = new ConfigurationClientOptions(); AddDefaultPolicies(options); diff --git a/tools/Azure.Mcp.Tools.Communication/src/Services/CommunicationService.cs b/tools/Azure.Mcp.Tools.Communication/src/Services/CommunicationService.cs index 1a8fee70a9..3866984034 100644 --- a/tools/Azure.Mcp.Tools.Communication/src/Services/CommunicationService.cs +++ b/tools/Azure.Mcp.Tools.Communication/src/Services/CommunicationService.cs @@ -9,6 +9,7 @@ using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.Mcp.Tools.Communication.Models; using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Helpers; namespace Azure.Mcp.Tools.Communication.Services; @@ -34,6 +35,8 @@ public async Task> SendSmsAsync( (nameof(from), from), (nameof(message), message)); + EndpointValidator.ValidateAzureServiceEndpoint(endpoint, "communication"); + // Validate to array separately since it has special requirements if (to == null || to.Length == 0) throw new ArgumentException("At least one 'to' phone number must be provided", nameof(to)); @@ -113,6 +116,8 @@ public async Task> SendSmsAsync( (nameof(subject), subject), (nameof(message), message)); + EndpointValidator.ValidateAzureServiceEndpoint(endpoint, "communication"); + // Validate to array separately since it has special requirements if (to == null || to.Length == 0) throw new ArgumentException("At least one 'to' email address must be provided", nameof(to)); diff --git a/tools/Azure.Mcp.Tools.LoadTesting/src/Services/LoadTestingService.cs b/tools/Azure.Mcp.Tools.LoadTesting/src/Services/LoadTestingService.cs index 7b90b5de2e..a3251dd1e6 100644 --- a/tools/Azure.Mcp.Tools.LoadTesting/src/Services/LoadTestingService.cs +++ b/tools/Azure.Mcp.Tools.LoadTesting/src/Services/LoadTestingService.cs @@ -15,6 +15,7 @@ using Azure.ResourceManager; using Azure.ResourceManager.LoadTesting; using Azure.ResourceManager.Resources; +using Microsoft.Mcp.Core.Helpers; namespace Azure.Mcp.Tools.LoadTesting.Services; @@ -312,6 +313,12 @@ public async Task CreateTestAsync( CancellationToken cancellationToken = default) { ValidateRequiredParameters((nameof(subscription), subscription), (nameof(testResourceName), testResourceName), (nameof(testId), testId)); + + if (!string.IsNullOrEmpty(endpointUrl)) + { + EndpointValidator.ValidatePublicTargetUrl(endpointUrl); + } + var subscriptionId = (await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken)).Data.SubscriptionId; var loadTestResource = await GetLoadTestResourcesAsync(subscriptionId, resourceGroup, testResourceName, tenant, retryPolicy, cancellationToken); diff --git a/tools/Fabric.Mcp.Tools.PublicApi/src/Services/NetworkResourceProviderService.cs b/tools/Fabric.Mcp.Tools.PublicApi/src/Services/NetworkResourceProviderService.cs index af9c4420e5..a6d1f9d78d 100644 --- a/tools/Fabric.Mcp.Tools.PublicApi/src/Services/NetworkResourceProviderService.cs +++ b/tools/Fabric.Mcp.Tools.PublicApi/src/Services/NetworkResourceProviderService.cs @@ -1,4 +1,6 @@ -using Microsoft.Extensions.Logging; +using System.Security; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Helpers; namespace Fabric.Mcp.Tools.PublicApi.Services { @@ -18,7 +20,25 @@ public async Task GetResource(string resourceName, CancellationToken can if (resourceJson.TryGetProperty("download_url", out var downloadUrl) && !string.IsNullOrEmpty(downloadUrl.GetString())) { - using var requestMessage = new HttpRequestMessage(HttpMethod.Get, downloadUrl.GetString()); + var urlString = downloadUrl.GetString()!; + + var allowedGitHubHosts = new[] + { + "raw.githubusercontent.com", + "github.com" + }; + + try + { + EndpointValidator.ValidateExternalUrl(urlString, allowedGitHubHosts); + } + catch (SecurityException ex) + { + _logger.LogError(ex, "Security validation failed for download URL: {Url}", urlString); + throw; + } + + using var requestMessage = new HttpRequestMessage(HttpMethod.Get, urlString); requestMessage.Headers.Add("User-Agent", "request"); var client = _httpClientFactory.CreateClient();