diff --git a/core/Microsoft.Mcp.Core/src/Services/Azure/Authentication/AzureCloudConfiguration.cs b/core/Microsoft.Mcp.Core/src/Services/Azure/Authentication/AzureCloudConfiguration.cs index 712326e7b7..517b114414 100644 --- a/core/Microsoft.Mcp.Core/src/Services/Azure/Authentication/AzureCloudConfiguration.cs +++ b/core/Microsoft.Mcp.Core/src/Services/Azure/Authentication/AzureCloudConfiguration.cs @@ -14,6 +14,14 @@ namespace Azure.Mcp.Core.Services.Azure.Authentication; /// public class AzureCloudConfiguration : IAzureCloudConfiguration { + + public enum AzureCloud + { + AzurePublicCloud, + AzureChinaCloud, + AzureUSGovernmentCloud, + } + private const string DefaultAuthorityHost = "https://login.microsoftonline.com"; /// @@ -37,7 +45,7 @@ public AzureCloudConfiguration( ?? configuration["Cloud"] ?? Environment.GetEnvironmentVariable("AZURE_CLOUD"); - (AuthorityHost, ArmEnvironment) = ParseCloudValue(cloudValue); + (AuthorityHost, ArmEnvironment, CloudType) = ParseCloudValue(cloudValue); logger?.LogDebug( "Azure cloud configuration initialized. Cloud value: '{CloudValue}', AuthorityHost: '{AuthorityHost}', ArmEnvironment: '{ArmEnvironment}'", @@ -52,11 +60,13 @@ public AzureCloudConfiguration( /// public ArmEnvironment ArmEnvironment { get; } - private static (Uri authorityHost, ArmEnvironment armEnvironment) ParseCloudValue(string? cloudValue) + public AzureCloud CloudType { get; } + + private static (Uri authorityHost, ArmEnvironment armEnvironment, AzureCloud cloudType) ParseCloudValue(string? cloudValue) { if (string.IsNullOrWhiteSpace(cloudValue)) { - return (new Uri(DefaultAuthorityHost), ArmEnvironment.AzurePublicCloud); + return (new Uri(DefaultAuthorityHost), ArmEnvironment.AzurePublicCloud, AzureCloud.AzurePublicCloud); } // Check if it's already a URL - in this case we only have authority host @@ -64,19 +74,19 @@ private static (Uri authorityHost, ArmEnvironment armEnvironment) ParseCloudValu // additional configuration not currently supported) if (cloudValue.StartsWith("https://", StringComparison.OrdinalIgnoreCase)) { - return (new Uri(cloudValue), ArmEnvironment.AzurePublicCloud); + return (new Uri(cloudValue), ArmEnvironment.AzurePublicCloud, AzureCloud.AzurePublicCloud); } // Map common sovereign cloud names to authority hosts and ARM environments return cloudValue.ToLowerInvariant() switch { "azurecloud" or "azurepubliccloud" or "public" => - (new Uri("https://login.microsoftonline.com"), ArmEnvironment.AzurePublicCloud), + (new Uri("https://login.microsoftonline.com"), ArmEnvironment.AzurePublicCloud, AzureCloud.AzurePublicCloud), "azurechinacloud" or "china" => - (new Uri("https://login.chinacloudapi.cn"), ArmEnvironment.AzureChina), + (new Uri("https://login.chinacloudapi.cn"), ArmEnvironment.AzureChina, AzureCloud.AzureChinaCloud), "azureusgovernment" or "azureusgovernmentcloud" or "usgov" or "usgovernment" => - (new Uri("https://login.microsoftonline.us"), ArmEnvironment.AzureGovernment), - _ => (new Uri(DefaultAuthorityHost), ArmEnvironment.AzurePublicCloud) // Default to public cloud if unknown + (new Uri("https://login.microsoftonline.us"), ArmEnvironment.AzureGovernment, AzureCloud.AzureUSGovernmentCloud), + _ => (new Uri(DefaultAuthorityHost), ArmEnvironment.AzurePublicCloud, AzureCloud.AzurePublicCloud) // Default to public cloud if unknown }; } } diff --git a/core/Microsoft.Mcp.Core/src/Services/Azure/Authentication/IAzureCloudConfiguration.cs b/core/Microsoft.Mcp.Core/src/Services/Azure/Authentication/IAzureCloudConfiguration.cs index 3391445083..8cd4aca818 100644 --- a/core/Microsoft.Mcp.Core/src/Services/Azure/Authentication/IAzureCloudConfiguration.cs +++ b/core/Microsoft.Mcp.Core/src/Services/Azure/Authentication/IAzureCloudConfiguration.cs @@ -20,4 +20,9 @@ public interface IAzureCloudConfiguration /// This determines the management endpoint used for Azure Resource Manager operations. /// ArmEnvironment ArmEnvironment { get; } + + /// + /// Gets the type of Azure cloud environment. + /// + AzureCloudConfiguration.AzureCloud CloudType { get; } } diff --git a/servers/Azure.Mcp.Server/changelog-entries/1771617105575.yaml b/servers/Azure.Mcp.Server/changelog-entries/1771617105575.yaml new file mode 100644 index 0000000000..fa412b0e29 --- /dev/null +++ b/servers/Azure.Mcp.Server/changelog-entries/1771617105575.yaml @@ -0,0 +1,3 @@ +changes: + - section: "Features Added" + description: "Added sovereign cloud endpoint support for Storage, Search, Postgres, ServiceFabric, Pricing, and Extension services" \ No newline at end of file diff --git a/tools/Azure.Mcp.Tools.AppLens/src/Services/AppLensService.cs b/tools/Azure.Mcp.Tools.AppLens/src/Services/AppLensService.cs index f0c266edc0..b26563684e 100644 --- a/tools/Azure.Mcp.Tools.AppLens/src/Services/AppLensService.cs +++ b/tools/Azure.Mcp.Tools.AppLens/src/Services/AppLensService.cs @@ -6,6 +6,7 @@ using System.Threading.Channels; using Azure.Core; using Azure.Mcp.Core.Services.Azure; +using Azure.Mcp.Core.Services.Azure.Authentication; using Azure.Mcp.Core.Services.Azure.Subscription; using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.Mcp.Tools.AppLens.Models; @@ -20,9 +21,9 @@ namespace Azure.Mcp.Tools.AppLens.Services; public class AppLensService(IHttpClientFactory httpClientFactory, ISubscriptionService subscriptionService, ITenantService tenantService) : BaseAzureService(tenantService), IAppLensService { private readonly ISubscriptionService _subscriptionService = subscriptionService ?? throw new ArgumentNullException(nameof(subscriptionService)); + private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService)); private readonly IHttpClientFactory _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory)); private readonly AppLensOptions _options = new AppLensOptions(); - private const string ConversationalDiagnosticsSignalREndpoint = "https://diagnosticschat.azure.com/chatHub"; /// public async Task DiagnoseResourceAsync( @@ -93,12 +94,12 @@ private async Task GetAppLensSessionAsync(string resour // Get ARM token var token = await credential.GetTokenAsync( - new TokenRequestContext(["https://management.azure.com/user_impersonation"]), + new TokenRequestContext([GetManagementImpersonationEndpoint().ToString()]), cancellationToken); // Call the AppLens token endpoint using var request = new HttpRequestMessage(HttpMethod.Get, - $"https://management.azure.com/{resourceId}/detectors/GetToken-db48586f-7d94-45fc-88ad-b30ccd3b571c?api-version=2015-08-01"); + GetAppLensTokenEndpoint(resourceId)); request.Headers.Authorization = new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", token.Token); @@ -156,10 +157,10 @@ public async IAsyncEnumerable AskAppLensAsync( // https://learn.microsoft.com/aspnet/core/signalr/configuration?view=aspnetcore-9.0&tabs=dotnet#jsonmessagepack-serialization-options options.PayloadSerializerOptions.TypeInfoResolverChain.Insert(0, AppLensJsonContext.Default); }) - .WithUrl(ConversationalDiagnosticsSignalREndpoint, options => + .WithUrl(GetConversationalDiagnosticsSignalREndpoint(), options => { options.AccessTokenProvider = () => Task.FromResult(session.Token)!; - options.Headers.Add("origin", "https://appservice-diagnostics.trafficmanager.net"); + options.Headers.Add("origin", GetDiagnosticsPortalEndpoint().ToString()); }) .WithAutomaticReconnect() .Build(); @@ -345,4 +346,48 @@ private static AppLensSession ParseGetTokenResponse(string rawResponse) return session; } + + private Uri GetConversationalDiagnosticsSignalREndpoint() + { + return _tenantService.CloudConfiguration.CloudType switch + { + AzureCloudConfiguration.AzureCloud.AzurePublicCloud => new Uri("https://diagnosticschat.azure.com/chatHub"), + AzureCloudConfiguration.AzureCloud.AzureChinaCloud => new Uri("https://diagnosticschat.azure.cn/chatHub"), + AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud => new Uri("https://diagnosticschat.azure.us/chatHub"), + _ => new Uri("https://diagnosticschat.azure.com/chatHub"), + }; + } + + private Uri GetManagementImpersonationEndpoint() + { + return _tenantService.CloudConfiguration.CloudType switch + { + AzureCloudConfiguration.AzureCloud.AzurePublicCloud => new Uri("https://management.azure.com/user_impersonation"), + AzureCloudConfiguration.AzureCloud.AzureChinaCloud => new Uri("https://management.chinacloudapi.cn/user_impersonation"), + AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud => new Uri("https://management.usgovcloudapi.net/user_impersonation"), + _ => new Uri("https://management.azure.com/user_impersonation"), + }; + } + + private Uri GetAppLensTokenEndpoint(string resourceId) + { + return _tenantService.CloudConfiguration.CloudType switch + { + AzureCloudConfiguration.AzureCloud.AzurePublicCloud => new Uri($"https://management.azure.com/{resourceId}/detectors/GetToken-db48586f-7d94-45fc-88ad-b30ccd3b571c?api-version=2015-08-01"), + AzureCloudConfiguration.AzureCloud.AzureChinaCloud => new Uri($"https://management.chinacloudapi.cn/{resourceId}/detectors/GetToken-db48586f-7d94-45fc-88ad-b30ccd3b571c?api-version=2015-08-01"), + AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud => new Uri($"https://management.usgovcloudapi.net/{resourceId}/detectors/GetToken-db48586f-7d94-45fc-88ad-b30ccd3b571c?api-version=2015-08-01"), + _ => new Uri($"https://management.azure.com/{resourceId}/detectors/GetToken-db48586f-7d94-45fc-88ad-b30ccd3b571c?api-version=2015-08-01"), + }; + } + + private Uri GetDiagnosticsPortalEndpoint() + { + return _tenantService.CloudConfiguration.CloudType switch + { + AzureCloudConfiguration.AzureCloud.AzurePublicCloud => new Uri("https://appservice-diagnostics.trafficmanager.net"), + AzureCloudConfiguration.AzureCloud.AzureChinaCloud => new Uri("https://appservice-diagnostics.azure.cn"), + AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud => new Uri("https://appservice-diagnostics.azure.us"), + _ => new Uri("https://appservice-diagnostics.trafficmanager.net"), + }; + } } diff --git a/tools/Azure.Mcp.Tools.AppService/src/Services/AppServiceService.cs b/tools/Azure.Mcp.Tools.AppService/src/Services/AppServiceService.cs index 19bc342c96..7dc01ff141 100644 --- a/tools/Azure.Mcp.Tools.AppService/src/Services/AppServiceService.cs +++ b/tools/Azure.Mcp.Tools.AppService/src/Services/AppServiceService.cs @@ -3,6 +3,7 @@ using Azure.Mcp.Core.Options; using Azure.Mcp.Core.Services.Azure; +using Azure.Mcp.Core.Services.Azure.Authentication; using Azure.Mcp.Core.Services.Azure.Subscription; using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.Mcp.Tools.AppService.Models; @@ -17,6 +18,7 @@ public class AppServiceService( ITenantService tenantService, ILogger logger) : BaseAzureService(tenantService), IAppServiceService { + private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService)); private readonly ISubscriptionService _subscriptionService = subscriptionService ?? throw new ArgumentNullException(nameof(subscriptionService)); private readonly ILogger _logger = logger; @@ -95,7 +97,7 @@ private async Task GetWebAppResourceAsync(string subscription, return webAppResource.Value; } - private static string PrepareConnectionString(string? connectionString, string databaseType, + private string PrepareConnectionString(string? connectionString, string databaseType, string databaseServer, string databaseName) { return string.IsNullOrWhiteSpace(connectionString) @@ -177,15 +179,30 @@ private static ConnectionStringType GetConnectionStringType(string databaseType) }; } - private static string BuildConnectionString(string databaseType, string databaseServer, string databaseName) + private string BuildConnectionString(string databaseType, string databaseServer, string databaseName) { return databaseType.ToLowerInvariant() switch { "sqlserver" => $"Server={databaseServer};Database={databaseName};User Id={{username}};Password={{password}};TrustServerCertificate=True;", "mysql" => $"Server={databaseServer};Database={databaseName};Uid={{username}};Pwd={{password}};", "postgresql" => $"Host={databaseServer};Database={databaseName};Username={{username}};Password={{password}};", - "cosmosdb" => $"AccountEndpoint=https://{databaseServer}.documents.azure.com:443/;AccountKey={{key}};Database={databaseName};", + "cosmosdb" => BuildCosmosConnectionString(databaseServer, databaseName), _ => throw new ArgumentException($"Unsupported database type: {databaseType}") }; } + + private string BuildCosmosConnectionString(string databaseServer, string databaseName) + { + switch (_tenantService.CloudConfiguration.CloudType) + { + case AzureCloudConfiguration.AzureCloud.AzurePublicCloud: + return $"AccountEndpoint=https://{databaseServer}.documents.azure.com:443/;AccountKey={{key}};Database={databaseName};"; + case AzureCloudConfiguration.AzureCloud.AzureChinaCloud: + return $"AccountEndpoint=https://{databaseServer}.documents.azure.cn:443/;AccountKey={{key}};Database={databaseName};"; + case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud: + return $"AccountEndpoint=https://{databaseServer}.documents.azure.us:443/;AccountKey={{key}};Database={databaseName};"; + default: + throw new ArgumentException($"Unsupported Azure cloud type: {_tenantService.CloudConfiguration.CloudType}"); + } + } } diff --git a/tools/Azure.Mcp.Tools.ApplicationInsights/src/Services/ProfilerDataService.cs b/tools/Azure.Mcp.Tools.ApplicationInsights/src/Services/ProfilerDataService.cs index afef1c89f0..84f6f4933f 100644 --- a/tools/Azure.Mcp.Tools.ApplicationInsights/src/Services/ProfilerDataService.cs +++ b/tools/Azure.Mcp.Tools.ApplicationInsights/src/Services/ProfilerDataService.cs @@ -8,6 +8,7 @@ using Azure.Core; using Azure.Mcp.Core.Options; using Azure.Mcp.Core.Services.Azure; +using Azure.Mcp.Core.Services.Azure.Authentication; using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.Mcp.Tools.ApplicationInsights.Commands; using Azure.Mcp.Tools.ApplicationInsights.Models; @@ -27,9 +28,7 @@ public class ProfilerDataService( ITenantService tenantService) : BaseAzureService(tenantService), IProfilerDataService { - private const string Endpoint = "https://dataplane.diagnosticservices.azure.com/"; - private const string DefaultScope = "api://dataplane.diagnosticservices.azure.com/.default"; - + private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService)); private readonly IHttpClientFactory _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory)); private readonly ILogger _logger = logger ?? throw new ArgumentNullException(nameof(logger)); @@ -98,7 +97,7 @@ await response.Content.ReadAsStreamAsync(cancellationToken), private async Task CreateRequestAsync(HttpMethod method, string path, IDictionary? queries, string apiVersion, string? clientRequestId, HttpContent? httpContent, IDictionary>? additionalHeaders, CancellationToken cancellationToken) { - UriBuilder uriBuilder = new(Endpoint) + UriBuilder uriBuilder = new(GetDiagnosticServiceEndpoint()) { Path = path }; @@ -119,7 +118,7 @@ private async Task CreateRequestAsync(HttpMethod method, str var scopes = new string[] { - DefaultScope + GetDiagnosticServicesScope() }; string clientRequestIdLocal = clientRequestId ?? Guid.NewGuid().ToString(); TokenRequestContext tokenRequestContext = new(scopes, clientRequestIdLocal); @@ -199,4 +198,34 @@ private async Task ResolveAppIdAsync(ResourceIdentifier resourceId, Cancel _logger.LogInformation("Resolving appId: {resourceId} => {appId}", resourceId, appId); return Guid.Parse(appId); } + + private Uri GetDiagnosticServiceEndpoint() + { + switch (_tenantService.CloudConfiguration.CloudType) + { + case AzureCloudConfiguration.AzureCloud.AzurePublicCloud: + return new Uri("https://dataplane.diagnosticservices.azure.com"); + case AzureCloudConfiguration.AzureCloud.AzureChinaCloud: + return new Uri("https://dataplane.diagnosticservices.azure.cn"); + case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud: + return new Uri("https://dataplane.diagnosticservices.azure.us"); + default: + return new Uri("https://dataplane.diagnosticservices.azure.com"); + } + } + + private string GetDiagnosticServicesScope() + { + switch (_tenantService.CloudConfiguration.CloudType) + { + case AzureCloudConfiguration.AzureCloud.AzurePublicCloud: + return "api://dataplane.diagnosticservices.azure.com/.default"; + case AzureCloudConfiguration.AzureCloud.AzureChinaCloud: + return "api://dataplane.diagnosticservices.azure.cn/.default"; + case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud: + return "api://dataplane.diagnosticservices.azure.us/.default"; + default: + return "api://dataplane.diagnosticservices.azure.com/.default"; + } + } } diff --git a/tools/Azure.Mcp.Tools.ConfidentialLedger/src/Services/ConfidentialLedgerService.cs b/tools/Azure.Mcp.Tools.ConfidentialLedger/src/Services/ConfidentialLedgerService.cs index 1981a1ffa0..1fcb6f8ed0 100644 --- a/tools/Azure.Mcp.Tools.ConfidentialLedger/src/Services/ConfidentialLedgerService.cs +++ b/tools/Azure.Mcp.Tools.ConfidentialLedger/src/Services/ConfidentialLedgerService.cs @@ -5,6 +5,7 @@ using System.Text.Json; using Azure.Core; using Azure.Mcp.Core.Services.Azure; +using Azure.Mcp.Core.Services.Azure.Authentication; using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.Mcp.Tools.ConfidentialLedger.Models; using Azure.Security.ConfidentialLedger; @@ -15,7 +16,7 @@ public class ConfidentialLedgerService(ITenantService tenantService) : BaseAzureService(tenantService), IConfidentialLedgerService { // NOTE: We construct the data-plane endpoint from the ledger name. - private static Uri BuildLedgerUri(string ledgerName) => new($"https://{ledgerName}.confidential-ledger.azure.com"); + private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService)); private static RequestContent CreateAppendEntryContent(string entryData) { @@ -43,7 +44,7 @@ public async Task AppendEntryAsync(string ledgerName, string var credential = await GetCredential(cancellationToken); // Configure client (retry etc. could be extended later) - ConfidentialLedgerClient client = new(BuildLedgerUri(ledgerName), credential); + ConfidentialLedgerClient client = new(GetLedgerUri(ledgerName), credential); // Build RequestContent manually to avoid trimming issues from reflection-based serialization. using var content = CreateAppendEntryContent(entryData); @@ -74,7 +75,7 @@ public async Task GetLedgerEntryAsync(string ledgerName, s } var credential = await GetCredential(cancellationToken); - ConfidentialLedgerClient client = new(BuildLedgerUri(ledgerName), credential); + ConfidentialLedgerClient client = new(GetLedgerUri(ledgerName), credential); Response? getByCollectionResponse = null; bool loaded = false; @@ -115,4 +116,19 @@ public async Task GetLedgerEntryAsync(string ledgerName, s Contents = contents ?? string.Empty, }; } + + private Uri GetLedgerUri(string ledgerName) + { + switch (_tenantService.CloudConfiguration.CloudType) + { + case AzureCloudConfiguration.AzureCloud.AzurePublicCloud: + return new Uri($"https://{ledgerName}.confidential-ledger.azure.com"); + case AzureCloudConfiguration.AzureCloud.AzureChinaCloud: + return new Uri($"https://{ledgerName}.confidential-ledger.azure.cn"); + case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud: + return new Uri($"https://{ledgerName}.confidential-ledger.azure.us"); + default: + return new Uri($"https://{ledgerName}.confidential-ledger.azure.com"); + } + } } diff --git a/tools/Azure.Mcp.Tools.Cosmos/src/Services/CosmosService.cs b/tools/Azure.Mcp.Tools.Cosmos/src/Services/CosmosService.cs index cbe5ed80c1..187c0e5904 100644 --- a/tools/Azure.Mcp.Tools.Cosmos/src/Services/CosmosService.cs +++ b/tools/Azure.Mcp.Tools.Cosmos/src/Services/CosmosService.cs @@ -4,6 +4,7 @@ using System.Net; using Azure.Mcp.Core.Options; using Azure.Mcp.Core.Services.Azure; +using Azure.Mcp.Core.Services.Azure.Authentication; using Azure.Mcp.Core.Services.Azure.Subscription; using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.Mcp.Core.Services.Caching; @@ -17,10 +18,10 @@ public sealed class CosmosService(ISubscriptionService subscriptionService, ITen : BaseAzureService(tenantService), ICosmosService, IAsyncDisposable { private readonly ISubscriptionService _subscriptionService = subscriptionService ?? throw new ArgumentNullException(nameof(subscriptionService)); + private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService)); private readonly IHttpClientFactory _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory)); private readonly ICacheService _cacheService = cacheService ?? throw new ArgumentNullException(nameof(cacheService)); private readonly ILogger _logger = logger ?? throw new ArgumentNullException(nameof(logger)); - private const string CosmosBaseUri = "https://{0}.documents.azure.com:443/"; private const string CacheGroup = "cosmos"; private const string CosmosClientsCacheKeyPrefix = "clients_"; private const string CosmosDatabasesCacheKeyPrefix = "databases_"; @@ -78,7 +79,7 @@ private async Task CreateCosmosClientWithAuth( var cosmosAccount = await GetCosmosAccountAsync(subscription, accountName, tenant, cancellationToken: cancellationToken); var keys = await cosmosAccount.GetKeysAsync(cancellationToken); cosmosClient = new CosmosClient( - string.Format(CosmosBaseUri, accountName), + string.Format(GetCosmosBaseUriFormat(), accountName), keys.Value.PrimaryMasterKey, clientOptions); break; @@ -86,7 +87,7 @@ private async Task CreateCosmosClientWithAuth( case AuthMethod.Credential: default: cosmosClient = new CosmosClient( - string.Format(CosmosBaseUri, accountName), + string.Format(GetCosmosBaseUriFormat(), accountName), await GetCredential(cancellationToken), clientOptions); break; @@ -98,6 +99,21 @@ await GetCredential(cancellationToken), return cosmosClient; } + private string GetCosmosBaseUriFormat() + { + switch (_tenantService.CloudConfiguration.CloudType) + { + case AzureCloudConfiguration.AzureCloud.AzurePublicCloud: + return "https://{0}.documents.azure.com:443/"; + case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud: + return "https://{0}.documents.azure.us:443/"; + case AzureCloudConfiguration.AzureCloud.AzureChinaCloud: + return "https://{0}.documents.azure.cn:443/"; + default: + return "https://{0}.documents.azure.com:443/"; + } + } + private async Task ValidateCosmosClientAsync(CosmosClient client, CancellationToken cancellationToken = default) { try diff --git a/tools/Azure.Mcp.Tools.EventHubs/src/Services/EventHubsService.cs b/tools/Azure.Mcp.Tools.EventHubs/src/Services/EventHubsService.cs index 40653cc933..4ea39cac30 100644 --- a/tools/Azure.Mcp.Tools.EventHubs/src/Services/EventHubsService.cs +++ b/tools/Azure.Mcp.Tools.EventHubs/src/Services/EventHubsService.cs @@ -30,7 +30,7 @@ public async Task> GetNamespacesAsync( try { - var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken); + var subscriptionResource = await ResolveSubscriptionResourceAsync(subscription, tenant, retryPolicy, cancellationToken); var namespaces = new List(); if (!string.IsNullOrEmpty(resourceGroup)) @@ -108,7 +108,7 @@ public async Task GetNamespaceAsync( try { - var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken); + var subscriptionResource = await ResolveSubscriptionResourceAsync(subscription, tenant, retryPolicy, cancellationToken); var resourceGroupResource = await subscriptionResource.GetResourceGroupAsync(resourceGroup, cancellationToken); if (resourceGroupResource?.Value == null) @@ -155,7 +155,7 @@ public async Task CreateOrUpdateNamespaceAsync( try { - var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken); + var subscriptionResource = await ResolveSubscriptionResourceAsync(subscription, tenant, retryPolicy, cancellationToken); var resourceGroupResource = await subscriptionResource.GetResourceGroupAsync(resourceGroup, cancellationToken); if (resourceGroupResource?.Value == null) @@ -250,8 +250,9 @@ public async Task DeleteNamespaceAsync( try { - var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken); - var subscriptionId = subscriptionResource.Data.SubscriptionId; + var subscriptionId = _subscriptionService.IsSubscriptionId(subscription) + ? subscription + : await _subscriptionService.GetSubscriptionIdByName(subscription, tenant, retryPolicy, cancellationToken); var armClient = await CreateArmClientAsync(tenant, retryPolicy, cancellationToken: cancellationToken); var namespaceId = EventHubsNamespaceResource.CreateResourceIdentifier(subscriptionId, resourceGroup, namespaceName); @@ -296,7 +297,7 @@ public async Task> GetEventHubsAsync( try { - var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken); + var subscriptionResource = await ResolveSubscriptionResourceAsync(subscription, tenant, retryPolicy, cancellationToken); var resourceGroupResource = await subscriptionResource.GetResourceGroupAsync(resourceGroup, cancellationToken); if (resourceGroupResource?.Value == null) @@ -342,7 +343,7 @@ public async Task> GetEventHubsAsync( try { - var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken); + var subscriptionResource = await ResolveSubscriptionResourceAsync(subscription, tenant, retryPolicy, cancellationToken); var resourceGroupResource = await subscriptionResource.GetResourceGroupAsync(resourceGroup, cancellationToken); if (resourceGroupResource?.Value == null) @@ -407,7 +408,7 @@ public async Task CreateOrUpdateEventHubAsync( try { - var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken); + var subscriptionResource = await ResolveSubscriptionResourceAsync(subscription, tenant, retryPolicy, cancellationToken); var resourceGroupResource = await subscriptionResource.GetResourceGroupAsync(resourceGroup, cancellationToken); if (resourceGroupResource?.Value == null) @@ -470,7 +471,7 @@ public async Task DeleteEventHubAsync( try { - var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken); + var subscriptionResource = await ResolveSubscriptionResourceAsync(subscription, tenant, retryPolicy, cancellationToken); try { @@ -536,8 +537,7 @@ public async Task CreateOrUpdateConsumerGroupAsync( try { - var armClient = await CreateArmClientAsync(tenant, retryPolicy, cancellationToken: cancellationToken); - var subscriptionResource = armClient.GetSubscriptionResource(ResourceManager.Resources.SubscriptionResource.CreateResourceIdentifier(subscription)); + var subscriptionResource = await ResolveSubscriptionResourceAsync(subscription, tenant, retryPolicy, cancellationToken); var resourceGroupResource = await subscriptionResource.GetResourceGroupAsync(resourceGroup, cancellationToken); var namespaceResource = await resourceGroupResource.Value.GetEventHubsNamespaces().GetAsync(namespaceName, cancellationToken); var eventHubResource = await namespaceResource.Value.GetEventHubs().GetAsync(eventHubName, cancellationToken); @@ -596,8 +596,7 @@ public async Task DeleteConsumerGroupAsync( try { - var armClient = await CreateArmClientAsync(tenant, retryPolicy, cancellationToken: cancellationToken); - var subscriptionResource = armClient.GetSubscriptionResource(ResourceManager.Resources.SubscriptionResource.CreateResourceIdentifier(subscription)); + var subscriptionResource = await ResolveSubscriptionResourceAsync(subscription, tenant, retryPolicy, cancellationToken); try { @@ -668,7 +667,7 @@ public async Task> GetConsumerGroupsAsync( try { - var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken); + var subscriptionResource = await ResolveSubscriptionResourceAsync(subscription, tenant, retryPolicy, cancellationToken); var resourceGroupResource = await subscriptionResource.GetResourceGroupAsync(resourceGroup, cancellationToken); if (resourceGroupResource?.Value == null) @@ -722,7 +721,7 @@ public async Task> GetConsumerGroupsAsync( try { - var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken); + var subscriptionResource = await ResolveSubscriptionResourceAsync(subscription, tenant, retryPolicy, cancellationToken); var resourceGroupResource = await subscriptionResource.GetResourceGroupAsync(resourceGroup, cancellationToken); if (resourceGroupResource?.Value == null) @@ -776,4 +775,23 @@ private static ConsumerGroup ConvertToConsumerGroup(EventHubsConsumerGroupData c UpdatedTime: consumerGroupData.UpdatedOn); } + /// + /// Returns a SubscriptionResource handle for ARM navigation without making an HTTP call. + /// This avoids the cache-dependent GET /subscriptions/{id} that GetSubscription() makes, + /// which caused non-deterministic test proxy recordings. + /// + private async Task ResolveSubscriptionResourceAsync( + string subscription, + string? tenant, + RetryPolicyOptions? retryPolicy, + CancellationToken cancellationToken) + { + var subscriptionId = _subscriptionService.IsSubscriptionId(subscription) + ? subscription + : await _subscriptionService.GetSubscriptionIdByName(subscription, tenant, retryPolicy, cancellationToken); + + var armClient = await CreateArmClientAsync(tenant, retryPolicy, cancellationToken: cancellationToken); + return armClient.GetSubscriptionResource( + ResourceManager.Resources.SubscriptionResource.CreateResourceIdentifier(subscriptionId)); + } } diff --git a/tools/Azure.Mcp.Tools.EventHubs/tests/Azure.Mcp.Tools.EventHubs.LiveTests/assets.json b/tools/Azure.Mcp.Tools.EventHubs/tests/Azure.Mcp.Tools.EventHubs.LiveTests/assets.json index a97c10dbcc..4f4f3174b6 100644 --- a/tools/Azure.Mcp.Tools.EventHubs/tests/Azure.Mcp.Tools.EventHubs.LiveTests/assets.json +++ b/tools/Azure.Mcp.Tools.EventHubs/tests/Azure.Mcp.Tools.EventHubs.LiveTests/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "", "TagPrefix": "Azure.Mcp.Tools.EventHubs.LiveTests", - "Tag": "Azure.Mcp.Tools.EventHubs.LiveTests_b5550e708b" + "Tag": "Azure.Mcp.Tools.EventHubs.LiveTests_07ba3d2d71" } diff --git a/tools/Azure.Mcp.Tools.Extension/src/Services/CliGenerateService.cs b/tools/Azure.Mcp.Tools.Extension/src/Services/CliGenerateService.cs index 9dff64214e..0bbce41be6 100644 --- a/tools/Azure.Mcp.Tools.Extension/src/Services/CliGenerateService.cs +++ b/tools/Azure.Mcp.Tools.Extension/src/Services/CliGenerateService.cs @@ -8,7 +8,7 @@ namespace Azure.Mcp.Tools.Extension.Services; -internal class CliGenerateService(IHttpClientFactory httpClientFactory, IAzureTokenCredentialProvider tokenCredentialProvider) : ICliGenerateService +internal class CliGenerateService(IHttpClientFactory httpClientFactory, IAzureTokenCredentialProvider tokenCredentialProvider, IAzureCloudConfiguration cloudConfiguration) : ICliGenerateService { private readonly IHttpClientFactory _httpClientFactory = httpClientFactory; private readonly IAzureTokenCredentialProvider _tokenCredentialProvider = tokenCredentialProvider; @@ -22,7 +22,7 @@ public async Task GenerateAzureCLICommandAsync(string inten var accessToken = await credential.GetTokenAsync(new TokenRequestContext([apiScope]), cancellationToken); // AzCli copilot API endpoint - const string url = "https://azclis-copilot-apim-prod-eus.azure-api.net/azcli/copilot"; + var url = GetCliCopilotEndpoint(); var requestBody = new AzureCliGenerateRequest() { @@ -45,4 +45,19 @@ public async Task GenerateAzureCLICommandAsync(string inten HttpResponseMessage responseMessage = await _httpClientFactory.CreateClient().SendAsync(requestMessage, cancellationToken); return responseMessage; } + + private string GetCliCopilotEndpoint() + { + switch (cloudConfiguration.CloudType) + { + case AzureCloudConfiguration.AzureCloud.AzurePublicCloud: + return "https://azclis-copilot-apim-prod-eus.azure-api.net/azcli/copilot"; + case AzureCloudConfiguration.AzureCloud.AzureChinaCloud: + return "https://azclis-copilot-apim-prod-eus.azure-api.cn/azcli/copilot"; + case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud: + return "https://azclis-copilot-apim-prod-eus.azure-api.us/azcli/copilot"; + default: + return "https://azclis-copilot-apim-prod-eus.azure-api.net/azcli/copilot"; + } + } } diff --git a/tools/Azure.Mcp.Tools.KeyVault/src/Services/KeyVaultService.cs b/tools/Azure.Mcp.Tools.KeyVault/src/Services/KeyVaultService.cs index a1a84a438c..efadf3636c 100644 --- a/tools/Azure.Mcp.Tools.KeyVault/src/Services/KeyVaultService.cs +++ b/tools/Azure.Mcp.Tools.KeyVault/src/Services/KeyVaultService.cs @@ -3,6 +3,7 @@ using Azure.Mcp.Core.Options; using Azure.Mcp.Core.Services.Azure; +using Azure.Mcp.Core.Services.Azure.Authentication; using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.Security.KeyVault.Administration; using Azure.Security.KeyVault.Certificates; @@ -13,6 +14,7 @@ namespace Azure.Mcp.Tools.KeyVault.Services; public sealed class KeyVaultService(ITenantService tenantService, IHttpClientFactory httpClientFactory) : BaseAzureService(tenantService), IKeyVaultService { + private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService)); private readonly IHttpClientFactory _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory)); public async Task> ListKeys( @@ -301,7 +303,36 @@ public async Task ImportCertificate( } } - private static Uri BuildVaultUri(string vaultName) => new($"https://{vaultName}.vault.azure.net"); + private Uri BuildVaultUri(string vaultName) + { + switch (_tenantService.CloudConfiguration.CloudType) + { + case AzureCloudConfiguration.AzureCloud.AzurePublicCloud: + return new Uri($"https://{vaultName}.vault.azure.net"); + case AzureCloudConfiguration.AzureCloud.AzureChinaCloud: + return new Uri($"https://{vaultName}.vault.azure.cn"); + case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud: + return new Uri($"https://{vaultName}.vault.usgovcloudapi.net"); + default: + return new Uri($"https://{vaultName}.vault.azure.net"); + } + } + + + private Uri GetHsmUri(string vaultName) + { + switch (_tenantService.CloudConfiguration.CloudType) + { + case AzureCloudConfiguration.AzureCloud.AzurePublicCloud: + return new Uri($"https://{vaultName}.managedhsm.azure.net"); + case AzureCloudConfiguration.AzureCloud.AzureChinaCloud: + return new Uri($"https://{vaultName}.managedhsm.azure.cn"); + case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud: + return new Uri($"https://{vaultName}.managedhsm.usgovcloudapi.net"); + default: + return new Uri($"https://{vaultName}.managedhsm.azure.net"); + } + } // Create clients with injected HttpClient, this will enable record/playback during testing. private KeyClient CreateKeyClient(string vaultName, Azure.Core.TokenCredential credential, RetryPolicyOptions? retry) @@ -346,7 +377,7 @@ public async Task GetVaultSettings( { ValidateRequiredParameters((nameof(vaultName), vaultName), (nameof(subscription), subscription)); var credential = await GetCredential(tenantId, cancellationToken); - var hsmUri = new Uri($"https://{vaultName}.managedhsm.azure.net"); + var hsmUri = GetHsmUri(vaultName); try { var hsmClient = new KeyVaultSettingsClient(hsmUri, credential); diff --git a/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/KeyVaultCommandTests.cs b/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/KeyVaultCommandTests.cs index 0a7b7bee9b..22d5ef9969 100644 --- a/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/KeyVaultCommandTests.cs +++ b/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/KeyVaultCommandTests.cs @@ -18,6 +18,12 @@ public class KeyVaultCommandTests(ITestOutputHelper output, TestProxyFixture fix { private readonly KeyVaultTestCertificateAssets _importCertificateAssets = KeyVaultTestCertificates.Load(); + public override CustomDefaultMatcher? TestMatcher => new() + { + ExcludedHeaders = "Authorization,Content-Type", + CompareBodies = false + }; + public override List BodyRegexSanitizers => [ // Sanitizes all hostnames in URLs to remove actual vault names (not limited to `kid` fields) new BodyRegexSanitizer(new BodyRegexSanitizerBody() { @@ -215,7 +221,6 @@ public async Task Should_create_certificate() [Fact] - [CustomMatcher(compareBody: false)] public async Task Should_import_certificate() { var fakePassword = _importCertificateAssets.Password; diff --git a/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/assets.json b/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/assets.json index 0000352247..f0346c4894 100644 --- a/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/assets.json +++ b/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "", "TagPrefix": "Azure.Mcp.Tools.KeyVault.LiveTests", - "Tag": "Azure.Mcp.Tools.KeyVault.LiveTests_24520a9311" + "Tag": "Azure.Mcp.Tools.KeyVault.LiveTests_bcfab177a3" } diff --git a/tools/Azure.Mcp.Tools.Marketplace/src/Services/MarketplaceService.cs b/tools/Azure.Mcp.Tools.Marketplace/src/Services/MarketplaceService.cs index 7f7d2fb87a..8847fc2e71 100644 --- a/tools/Azure.Mcp.Tools.Marketplace/src/Services/MarketplaceService.cs +++ b/tools/Azure.Mcp.Tools.Marketplace/src/Services/MarketplaceService.cs @@ -15,7 +15,8 @@ namespace Azure.Mcp.Tools.Marketplace.Services; public class MarketplaceService(ITenantService tenantService) : BaseAzureService(tenantService), IMarketplaceService { - private const string ManagementApiBaseUrl = "https://management.azure.com"; + private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService)); + private const string ApiVersion = "2023-01-01-preview"; /// @@ -55,7 +56,8 @@ public async Task GetProduct( (nameof(productId), productId), (nameof(subscription), subscription)); - string productUrl = BuildProductUrl(subscription, productId, includeStopSoldPlans, language, market, + var managementEndpoint = _tenantService.CloudConfiguration.ArmEnvironment.Endpoint.ToString().TrimEnd('/'); + string productUrl = BuildProductUrl(managementEndpoint, subscription, productId, includeStopSoldPlans, language, market, lookupOfferInTenantLevel, planId, skuId, includeServiceInstructionTemplates); return await GetMarketplaceSingleProductResponseAsync(productUrl, tenantId, retryPolicy, cancellationToken); @@ -92,12 +94,14 @@ public async Task ListProducts( { ValidateRequiredParameters((nameof(subscription), subscription)); - string productsUrl = BuildProductsListUrl(subscription, language, search, filter, orderBy, select, nextCursor, expand); + var managementEndpoint = _tenantService.CloudConfiguration.ArmEnvironment.Endpoint.ToString().TrimEnd('/'); + string productsUrl = BuildProductsListUrl(managementEndpoint, subscription, language, search, filter, orderBy, select, nextCursor, expand); return await GetMarketplaceListProductsResponseAsync(productsUrl, tenantId, retryPolicy, cancellationToken); } private static string BuildProductsListUrl( + string managementEndpoint, string subscription, string? language, string? search, @@ -136,7 +140,7 @@ private static string BuildProductsListUrl( queryParams.Add("storefront=any"); // include all storefronts string queryString = string.Join("&", queryParams); - return $"{ManagementApiBaseUrl}/subscriptions/{subscription}/providers/Microsoft.Marketplace/products?{queryString}"; + return $"{managementEndpoint}/subscriptions/{subscription}/providers/Microsoft.Marketplace/products?{queryString}"; } private async Task GetMarketplaceListProductsResponseAsync(string url, string? tenant, RetryPolicyOptions? retryPolicy, CancellationToken cancellationToken) @@ -154,6 +158,7 @@ private async Task GetMarketplaceListProducts private static string BuildProductUrl( + string managementEndpoint, string subscription, string productId, bool? includeStopSoldPlans, @@ -191,7 +196,7 @@ private static string BuildProductUrl( queryParams.Add($"includeServiceInstructionTemplates={includeServiceInstructionTemplates.Value.ToString().ToLower()}"); string queryString = string.Join("&", queryParams); - return $"{ManagementApiBaseUrl}/subscriptions/{subscription}/providers/Microsoft.Marketplace/products/{productId}?{queryString}"; + return $"{managementEndpoint}/subscriptions/{subscription}/providers/Microsoft.Marketplace/products/{productId}?{queryString}"; } private async Task GetMarketplaceSingleProductResponseAsync(string url, string? tenant, RetryPolicyOptions? retryPolicy, CancellationToken cancellationToken) @@ -208,7 +213,8 @@ private async Task GetMarketplaceSingleProductResponseAsync(stri private async Task GetArmAccessTokenAsync(string? tenantId, CancellationToken cancellationToken) { - var tokenRequestContext = new TokenRequestContext([$"{ManagementApiBaseUrl}/.default"]); + var defaultScope = _tenantService.CloudConfiguration.ArmEnvironment.DefaultScope; + var tokenRequestContext = new TokenRequestContext([defaultScope]); var tokenCredential = await GetCredential(tenantId, cancellationToken); return await tokenCredential .GetTokenAsync(tokenRequestContext, cancellationToken); diff --git a/tools/Azure.Mcp.Tools.Marketplace/tests/Azure.Mcp.Tools.Marketplace.LiveTests/assets.json b/tools/Azure.Mcp.Tools.Marketplace/tests/Azure.Mcp.Tools.Marketplace.LiveTests/assets.json index f2731ac0cb..4765d79046 100644 --- a/tools/Azure.Mcp.Tools.Marketplace/tests/Azure.Mcp.Tools.Marketplace.LiveTests/assets.json +++ b/tools/Azure.Mcp.Tools.Marketplace/tests/Azure.Mcp.Tools.Marketplace.LiveTests/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "", "TagPrefix": "Azure.Mcp.Tools.Marketplace.LiveTests", - "Tag": "Azure.Mcp.Tools.Marketplace.LiveTests_0f59abc9d6" + "Tag": "Azure.Mcp.Tools.Marketplace.LiveTests_7c61cc9c48" } diff --git a/tools/Azure.Mcp.Tools.Monitor/src/Services/MonitorHealthModelService.cs b/tools/Azure.Mcp.Tools.Monitor/src/Services/MonitorHealthModelService.cs index 0cdc03a147..82d8846c3e 100644 --- a/tools/Azure.Mcp.Tools.Monitor/src/Services/MonitorHealthModelService.cs +++ b/tools/Azure.Mcp.Tools.Monitor/src/Services/MonitorHealthModelService.cs @@ -6,6 +6,7 @@ using Azure.Core; using Azure.Mcp.Core.Options; using Azure.Mcp.Core.Services.Azure; +using Azure.Mcp.Core.Services.Azure.Authentication; using Azure.Mcp.Core.Services.Azure.Tenant; namespace Azure.Mcp.Tools.Monitor.Services; @@ -13,10 +14,9 @@ namespace Azure.Mcp.Tools.Monitor.Services; public class MonitorHealthModelService(ITenantService tenantService, IHttpClientFactory httpClientFactory) : BaseAzureService(tenantService), IMonitorHealthModelService { - private const string ManagementApiBaseUrl = "https://management.azure.com"; - private const string HealthModelsDataApiScope = "https://data.healthmodels.azure.com/.default"; private const string ApiVersion = "2023-10-01-preview"; private readonly IHttpClientFactory _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory)); + private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService)); /// /// Retrieves the health information for a specific entity in a health model. @@ -68,7 +68,7 @@ private async Task GetDataplaneResponseAsync(string url, CancellationTok private async Task GetDataplaneEndpointAsync(string subscriptionId, string resourceGroupName, string healthModelName, CancellationToken cancellationToken) { string token = await GetControlPlaneTokenAsync(cancellationToken); - string healthModelUrl = $"{ManagementApiBaseUrl}/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.CloudHealth/healthmodels/{healthModelName}?api-version={ApiVersion}"; + string healthModelUrl = $"{GetManagementEndpoint()}/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.CloudHealth/healthmodels/{healthModelName}?api-version={ApiVersion}"; using var request = new HttpRequestMessage(HttpMethod.Get, healthModelUrl); request.Headers.Authorization = new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", token); @@ -106,7 +106,7 @@ private async Task GetControlPlaneTokenAsync(CancellationToken cancellat { TokenCredential credential = await GetCredential(cancellationToken); AccessToken accessToken = await credential.GetTokenAsync( - new TokenRequestContext([$"{ManagementApiBaseUrl}/.default"]), + new TokenRequestContext([_tenantService.CloudConfiguration.ArmEnvironment.DefaultScope]), cancellationToken); return accessToken.Token; @@ -116,9 +116,29 @@ private async Task GetDataplaneTokenAsync(CancellationToken cancellation { TokenCredential credential = await GetCredential(cancellationToken); AccessToken accessToken = await credential.GetTokenAsync( - new TokenRequestContext([HealthModelsDataApiScope]), + new TokenRequestContext([GetHealthModelsDataApiScope()]), cancellationToken); return accessToken.Token; } + + private string GetManagementEndpoint() + { + return _tenantService.CloudConfiguration.ArmEnvironment.Endpoint.ToString().TrimEnd('/'); + } + + private string GetHealthModelsDataApiScope() + { + switch (_tenantService.CloudConfiguration.CloudType) + { + case AzureCloudConfiguration.AzureCloud.AzurePublicCloud: + return "https://data.healthmodels.azure.com/.default"; + case AzureCloudConfiguration.AzureCloud.AzureChinaCloud: + return "https://data.healthmodels.azure.cn/.default"; + case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud: + return "https://data.healthmodels.azure.us/.default"; + default: + return "https://data.healthmodels.azure.com/.default"; + } + } } diff --git a/tools/Azure.Mcp.Tools.Monitor/src/Services/MonitorService.cs b/tools/Azure.Mcp.Tools.Monitor/src/Services/MonitorService.cs index 0dfc6450eb..5e21b0a199 100644 --- a/tools/Azure.Mcp.Tools.Monitor/src/Services/MonitorService.cs +++ b/tools/Azure.Mcp.Tools.Monitor/src/Services/MonitorService.cs @@ -7,6 +7,7 @@ using Azure.Core.Pipeline; using Azure.Mcp.Core.Options; using Azure.Mcp.Core.Services.Azure; +using Azure.Mcp.Core.Services.Azure.Authentication; using Azure.Mcp.Core.Services.Azure.ResourceGroup; using Azure.Mcp.Core.Services.Azure.Subscription; using Azure.Mcp.Core.Services.Azure.Tenant; @@ -27,8 +28,7 @@ public class MonitorService( IHttpClientFactory httpClientFactory) : BaseAzureService(tenantService), IMonitorService { private const string ActivityLogApiVersion = "2017-03-01-preview"; - private const string ActivityLogEndpointFormat - = "https://management.azure.com/subscriptions/{0}/providers/Microsoft.Insights/eventtypes/management/values"; + private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService)); private readonly IHttpClientFactory _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory)); public async Task> QueryResourceLogs( @@ -424,7 +424,7 @@ private async Task> CallActivityLogApiAsync( { var returnValue = new List(); - string endpoint = string.Format(ActivityLogEndpointFormat, subscriptionId); + string endpoint = string.Format(GetLogActivityEndpointString(), subscriptionId); var uriBuilder = new UriBuilder(endpoint); // Build the query parameters @@ -447,7 +447,7 @@ private async Task> CallActivityLogApiAsync( TokenCredential credential = await GetCredential(tenant, cancellationToken); AccessToken accessToken = await credential.GetTokenAsync( - new TokenRequestContext(["https://management.azure.com/.default"]), + new TokenRequestContext([_tenantService.CloudConfiguration.ArmEnvironment.DefaultScope]), cancellationToken); // Make paginated requests @@ -528,4 +528,19 @@ private static bool IsWorkspaceId(string workspace) return (matchingWorkspace.CustomerId, matchingWorkspace.Name); } + + private string GetLogActivityEndpointString() + { + switch (_tenantService.CloudConfiguration.CloudType) + { + case AzureCloudConfiguration.AzureCloud.AzurePublicCloud: + return "https://management.azure.com/subscriptions/{0}/providers/Microsoft.Insights/eventtypes/management/values"; + case AzureCloudConfiguration.AzureCloud.AzureChinaCloud: + return "https://management.chinacloudapi.cn/subscriptions/{0}/providers/Microsoft.Insights/eventtypes/management/values"; + case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud: + return "https://management.usgovcloudapi.net/subscriptions/{0}/providers/Microsoft.Insights/eventtypes/management/values"; + default: + return "https://management.azure.com/subscriptions/{0}/providers/Microsoft.Insights/eventtypes/management/values"; + } + } } diff --git a/tools/Azure.Mcp.Tools.MySql/src/Services/MySqlService.cs b/tools/Azure.Mcp.Tools.MySql/src/Services/MySqlService.cs index c52c06314d..6446d81b65 100644 --- a/tools/Azure.Mcp.Tools.MySql/src/Services/MySqlService.cs +++ b/tools/Azure.Mcp.Tools.MySql/src/Services/MySqlService.cs @@ -4,6 +4,7 @@ using System.Text.RegularExpressions; using Azure.Core; using Azure.Mcp.Core.Services.Azure; +using Azure.Mcp.Core.Services.Azure.Authentication; using Azure.Mcp.Core.Services.Azure.ResourceGroup; using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.Mcp.Tools.MySql.Commands; @@ -15,6 +16,7 @@ namespace Azure.Mcp.Tools.MySql.Services; public class MySqlService(IResourceGroupService resourceGroupService, ITenantService tenantService, ILogger logger) : BaseAzureService(tenantService), IMySqlService { + private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService)); private readonly IResourceGroupService _resourceGroupService = resourceGroupService ?? throw new ArgumentNullException(nameof(resourceGroupService)); private readonly ILogger _logger = logger; @@ -66,18 +68,44 @@ public class MySqlService(IResourceGroupService resourceGroupService, ITenantSer private async Task GetEntraIdAccessTokenAsync(CancellationToken cancellationToken) { - var tokenRequestContext = new TokenRequestContext(["https://ossrdbms-aad.database.windows.net/.default"]); + + var tokenRequestContext = new TokenRequestContext([GetOpenSourceRDBMSEndpoint().ToString()]); TokenCredential tokenCredential = await GetCredential(cancellationToken); AccessToken accessToken = await tokenCredential .GetTokenAsync(tokenRequestContext, cancellationToken); return accessToken.Token; } - private static string NormalizeServerName(string server) + private Uri GetOpenSourceRDBMSEndpoint() + { + switch (_tenantService.CloudConfiguration.CloudType) + { + case AzureCloudConfiguration.AzureCloud.AzurePublicCloud: + return new Uri("https://ossrdbms-aad.database.windows.net/.default"); + case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud: + return new Uri("https://ossrdbms-aad.database.usgovcloudapi.net/.default"); + case AzureCloudConfiguration.AzureCloud.AzureChinaCloud: + return new Uri("https://ossrdbms-aad.database.chinacloudapi.cn/.default"); + default: + return new Uri("https://ossrdbms-aad.database.windows.net/.default"); + } + } + + private string NormalizeServerName(string server) { if (!server.Contains('.')) { - return server + ".mysql.database.azure.com"; + switch (_tenantService.CloudConfiguration.CloudType) + { + case AzureCloudConfiguration.AzureCloud.AzurePublicCloud: + return server + ".mysql.database.azure.com"; + case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud: + return server + ".mysql.database.usgovcloudapi.net"; + case AzureCloudConfiguration.AzureCloud.AzureChinaCloud: + return server + ".mysql.database.chinacloudapi.cn"; + default: + return server + ".mysql.database.azure.com"; + } } return server; } diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs b/tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs index f240fcb0e7..518cd4176f 100644 --- a/tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs +++ b/tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs @@ -6,6 +6,7 @@ using System.Net; using Azure.Core; using Azure.Mcp.Core.Services.Azure; +using Azure.Mcp.Core.Services.Azure.Authentication; using Azure.Mcp.Core.Services.Azure.ResourceGroup; using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.Mcp.Tools.Postgres.Auth; @@ -21,6 +22,7 @@ namespace Azure.Mcp.Tools.Postgres.Services; public class PostgresService : BaseAzureService, IPostgresService { private readonly IResourceGroupService _resourceGroupService; + private readonly ITenantService _tenantService; private readonly IEntraTokenProvider _entraTokenAuth; private readonly IDbProvider _dbProvider; @@ -32,6 +34,7 @@ public PostgresService( : base(tenantService) { _resourceGroupService = resourceGroupService ?? throw new ArgumentNullException(nameof(resourceGroupService)); + _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService)); _entraTokenAuth = entraTokenAuth; _dbProvider = dbProvider; } @@ -44,11 +47,21 @@ private async Task GetEntraIdAccessTokenAsync(CancellationToken cancella return accessToken.Token; } - private static string NormalizeServerName(string server) + private string NormalizeServerName(string server) { if (!server.Contains('.')) { - return server + ".postgres.database.azure.com"; + switch (_tenantService.CloudConfiguration.CloudType) + { + case AzureCloudConfiguration.AzureCloud.AzurePublicCloud: + return server + ".postgres.database.azure.com"; + case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud: + return server + ".postgres.database.usgovcloudapi.net"; + case AzureCloudConfiguration.AzureCloud.AzureChinaCloud: + return server + ".postgres.database.chinacloudapi.cn"; + default: + return server + ".postgres.database.azure.com"; + } } return server; } diff --git a/tools/Azure.Mcp.Tools.Pricing/src/Services/PricingService.cs b/tools/Azure.Mcp.Tools.Pricing/src/Services/PricingService.cs index 7a08bed696..4c743a1e3b 100644 --- a/tools/Azure.Mcp.Tools.Pricing/src/Services/PricingService.cs +++ b/tools/Azure.Mcp.Tools.Pricing/src/Services/PricingService.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using Azure.Mcp.Core.Services.Azure.Authentication; using Azure.Mcp.Tools.Pricing.Models; using AzureRetailPrices; @@ -9,7 +10,7 @@ namespace Azure.Mcp.Tools.Pricing.Services; /// /// Service implementation for Azure Retail Pricing operations. /// -public class PricingService : IPricingService +public class PricingService(IAzureCloudConfiguration cloudConfiguration) : IPricingService { private const int MaxResults = 5000; @@ -45,7 +46,7 @@ public async Task> GetPricesAsync( var clientOptions = new AzureRetailPricesClientOptions(serviceVersion); var client = new AzureRetailPricesClient( - new Uri("https://prices.azure.com"), + GetPricingEndpoint(), clientOptions); var retailPrices = client.GetRetailPricesClient(); @@ -122,6 +123,21 @@ private static string EscapeODataValue(string value) return value.Replace("'", "''"); } + private Uri GetPricingEndpoint() + { + switch (cloudConfiguration.CloudType) + { + case AzureCloudConfiguration.AzureCloud.AzurePublicCloud: + return new Uri("https://prices.azure.com"); + case AzureCloudConfiguration.AzureCloud.AzureChinaCloud: + return new Uri("https://prices.azure.cn"); + case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud: + return new Uri("https://prices.azure.us"); + default: + return new Uri("https://prices.azure.com"); + } + } + private static PriceItem MapToPriceItem(RetailPriceItem item) { var priceItem = new PriceItem diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/QuotaService.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/QuotaService.cs index 937c2ce5d1..d45a3e24b0 100644 --- a/tools/Azure.Mcp.Tools.Quota/src/Services/QuotaService.cs +++ b/tools/Azure.Mcp.Tools.Quota/src/Services/QuotaService.cs @@ -32,6 +32,7 @@ public async Task>> GetAzureQuotaAsync( resourceTypes, subscriptionId, location, + TenantService, loggerFactory, _httpClientFactory, cancellationToken); diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/AzureUsageChecker.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/AzureUsageChecker.cs index 852b62c07d..f6487eb30b 100644 --- a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/AzureUsageChecker.cs +++ b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/AzureUsageChecker.cs @@ -3,6 +3,7 @@ using System.Net.Http; using Azure.Core; +using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.Mcp.Tools.Quota.Services.Util.Usage; using Azure.ResourceManager; using Microsoft.Extensions.Logging; @@ -50,16 +51,28 @@ public abstract class AzureUsageChecker : IUsageChecker protected readonly ArmClient ResourceClient; protected readonly TokenCredential Credential; protected readonly ILogger Logger; - protected const string managementEndpoint = "https://management.azure.com"; + protected readonly ITenantService TenantService; - protected AzureUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger) + protected AzureUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, ITenantService tenantService) { SubscriptionId = subscriptionId; Credential = credential ?? throw new ArgumentNullException(nameof(credential)); - ResourceClient = new ArmClient(credential, subscriptionId); + TenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService)); Logger = logger; + var clientOptions = new ArmClientOptions { Environment = tenantService.CloudConfiguration.ArmEnvironment }; + + ResourceClient = new ArmClient( + credential, + subscriptionId, + clientOptions); } + protected string GetManagementEndpoint() + { + return TenantService.CloudConfiguration.ArmEnvironment.Endpoint.ToString().TrimEnd('/'); + } + + public abstract Task> GetUsageForLocationAsync(string location, CancellationToken cancellationToken); } @@ -81,7 +94,7 @@ public static class UsageCheckerFactory { "Microsoft.ContainerInstance", ResourceProvider.ContainerInstance } }; - public static IUsageChecker CreateUsageChecker(TokenCredential credential, string provider, string subscriptionId, ILoggerFactory loggerFactory, IHttpClientFactory httpClientFactory) + public static IUsageChecker CreateUsageChecker(TokenCredential credential, string provider, string subscriptionId, ILoggerFactory loggerFactory, IHttpClientFactory httpClientFactory, ITenantService tenantService) { if (!ProviderMapping.TryGetValue(provider, out var resourceProvider)) { @@ -90,16 +103,16 @@ public static IUsageChecker CreateUsageChecker(TokenCredential credential, strin return resourceProvider switch { - ResourceProvider.Compute => new ComputeUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger()), - ResourceProvider.CognitiveServices => new CognitiveServicesUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger()), - ResourceProvider.Storage => new StorageUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger()), - ResourceProvider.ContainerApp => new ContainerAppUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger()), - ResourceProvider.Network => new NetworkUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger()), - ResourceProvider.MachineLearning => new MachineLearningUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger()), - ResourceProvider.PostgreSQL => new PostgreSQLUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger(), httpClientFactory), - ResourceProvider.HDInsight => new HDInsightUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger()), - ResourceProvider.Search => new SearchUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger()), - ResourceProvider.ContainerInstance => new ContainerInstanceUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger()), + ResourceProvider.Compute => new ComputeUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger(), tenantService), + ResourceProvider.CognitiveServices => new CognitiveServicesUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger(), tenantService), + ResourceProvider.Storage => new StorageUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger(), tenantService), + ResourceProvider.ContainerApp => new ContainerAppUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger(), tenantService), + ResourceProvider.Network => new NetworkUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger(), tenantService), + ResourceProvider.MachineLearning => new MachineLearningUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger(), tenantService), + ResourceProvider.PostgreSQL => new PostgreSQLUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger(), httpClientFactory, tenantService), + ResourceProvider.HDInsight => new HDInsightUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger(), tenantService), + ResourceProvider.Search => new SearchUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger(), tenantService), + ResourceProvider.ContainerInstance => new ContainerInstanceUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger(), tenantService), _ => throw new ArgumentException($"No implementation for provider: {provider}") }; } @@ -113,6 +126,7 @@ public static async Task>> GetAzureQuotaAsync List resourceTypes, string subscriptionId, string location, + ITenantService tenantService, ILoggerFactory loggerFactory, IHttpClientFactory httpClientFactory, CancellationToken cancellationToken) @@ -130,7 +144,7 @@ public static async Task>> GetAzureQuotaAsync var (provider, resourceTypesForProvider) = (kvp.Key, kvp.Value); try { - var usageChecker = UsageCheckerFactory.CreateUsageChecker(credential, provider, subscriptionId, loggerFactory, httpClientFactory); + var usageChecker = UsageCheckerFactory.CreateUsageChecker(credential, provider, subscriptionId, loggerFactory, httpClientFactory, tenantService); var quotaInfo = await usageChecker.GetUsageForLocationAsync(location, cancellationToken); logger.LogDebug("Retrieved quota info for provider {Provider}: {ItemCount} items", provider, quotaInfo.Count); diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/CognitiveServicesUsageChecker.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/CognitiveServicesUsageChecker.cs index 429cf67c41..5de5f9f00a 100644 --- a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/CognitiveServicesUsageChecker.cs +++ b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/CognitiveServicesUsageChecker.cs @@ -2,13 +2,14 @@ // Licensed under the MIT License. using Azure.Core; +using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.ResourceManager.CognitiveServices; using Azure.ResourceManager.CognitiveServices.Models; using Microsoft.Extensions.Logging; namespace Azure.Mcp.Tools.Quota.Services.Util.Usage; -public class CognitiveServicesUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger) : AzureUsageChecker(credential, subscriptionId, logger) +public class CognitiveServicesUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, ITenantService tenantService) : AzureUsageChecker(credential, subscriptionId, logger, tenantService) { public override async Task> GetUsageForLocationAsync(string location, CancellationToken cancellationToken) { diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ComputeUsageChecker.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ComputeUsageChecker.cs index 4117e7ff75..e0fea2905f 100644 --- a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ComputeUsageChecker.cs +++ b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ComputeUsageChecker.cs @@ -2,13 +2,14 @@ // Licensed under the MIT License. using Azure.Core; +using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.ResourceManager.Compute; using Azure.ResourceManager.Compute.Models; using Microsoft.Extensions.Logging; namespace Azure.Mcp.Tools.Quota.Services.Util.Usage; -public class ComputeUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger) : AzureUsageChecker(credential, subscriptionId, logger) +public class ComputeUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, ITenantService tenantService) : AzureUsageChecker(credential, subscriptionId, logger, tenantService) { public override async Task> GetUsageForLocationAsync(string location, CancellationToken cancellationToken) { diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ContainerAppUsageChecker.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ContainerAppUsageChecker.cs index 012e2ce5c6..ea96a490f1 100644 --- a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ContainerAppUsageChecker.cs +++ b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ContainerAppUsageChecker.cs @@ -2,12 +2,13 @@ // Licensed under the MIT License. using Azure.Core; +using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.ResourceManager.AppContainers; using Microsoft.Extensions.Logging; namespace Azure.Mcp.Tools.Quota.Services.Util.Usage; -public class ContainerAppUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger) : AzureUsageChecker(credential, subscriptionId, logger) +public class ContainerAppUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, ITenantService tenantService) : AzureUsageChecker(credential, subscriptionId, logger, tenantService) { public override async Task> GetUsageForLocationAsync(string location, CancellationToken cancellationToken) { diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ContainerInstanceUsageChecker.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ContainerInstanceUsageChecker.cs index cb31d483bb..4e4329acf8 100644 --- a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ContainerInstanceUsageChecker.cs +++ b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ContainerInstanceUsageChecker.cs @@ -2,13 +2,14 @@ // Licensed under the MIT License. using Azure.Core; +using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.ResourceManager.ContainerInstance; using Azure.ResourceManager.ContainerInstance.Models; using Microsoft.Extensions.Logging; namespace Azure.Mcp.Tools.Quota.Services.Util.Usage; -public class ContainerInstanceUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger) : AzureUsageChecker(credential, subscriptionId, logger) +public class ContainerInstanceUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, ITenantService tenantService) : AzureUsageChecker(credential, subscriptionId, logger, tenantService) { public override async Task> GetUsageForLocationAsync(string location, CancellationToken cancellationToken) { diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/HDInsightUsageChecker.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/HDInsightUsageChecker.cs index c98d4398df..c3861ca301 100644 --- a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/HDInsightUsageChecker.cs +++ b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/HDInsightUsageChecker.cs @@ -2,13 +2,14 @@ // Licensed under the MIT License. using Azure.Core; +using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.ResourceManager.HDInsight; using Azure.ResourceManager.HDInsight.Models; using Microsoft.Extensions.Logging; namespace Azure.Mcp.Tools.Quota.Services.Util.Usage; -public class HDInsightUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger) : AzureUsageChecker(credential, subscriptionId, logger) +public class HDInsightUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, ITenantService tenantService) : AzureUsageChecker(credential, subscriptionId, logger, tenantService) { public override async Task> GetUsageForLocationAsync(string location, CancellationToken cancellationToken) { diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/MachineLearningUsageChecker.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/MachineLearningUsageChecker.cs index e03d7a438a..05a0517998 100644 --- a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/MachineLearningUsageChecker.cs +++ b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/MachineLearningUsageChecker.cs @@ -2,12 +2,13 @@ // Licensed under the MIT License. using Azure.Core; +using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.ResourceManager.MachineLearning; using Microsoft.Extensions.Logging; namespace Azure.Mcp.Tools.Quota.Services.Util.Usage; -public class MachineLearningUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger) : AzureUsageChecker(credential, subscriptionId, logger) +public class MachineLearningUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, ITenantService tenantService) : AzureUsageChecker(credential, subscriptionId, logger, tenantService) { public override async Task> GetUsageForLocationAsync(string location, CancellationToken cancellationToken) { diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/NetworkUsageChecker.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/NetworkUsageChecker.cs index ed2894090f..5ed2d66e36 100644 --- a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/NetworkUsageChecker.cs +++ b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/NetworkUsageChecker.cs @@ -2,12 +2,13 @@ // Licensed under the MIT License. using Azure.Core; +using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.ResourceManager.Network; using Microsoft.Extensions.Logging; namespace Azure.Mcp.Tools.Quota.Services.Util.Usage; -public class NetworkUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger) : AzureUsageChecker(credential, subscriptionId, logger) +public class NetworkUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, ITenantService tenantService) : AzureUsageChecker(credential, subscriptionId, logger, tenantService) { public override async Task> GetUsageForLocationAsync(string location, CancellationToken cancellationToken) { diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/PostgreSQLUsageChecker.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/PostgreSQLUsageChecker.cs index 15bda15836..fa02567efd 100644 --- a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/PostgreSQLUsageChecker.cs +++ b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/PostgreSQLUsageChecker.cs @@ -4,19 +4,21 @@ using System.Net.Http; using System.Net.Http.Headers; using Azure.Core; +using Azure.Mcp.Core.Services.Azure.Tenant; using Microsoft.Extensions.Logging; namespace Azure.Mcp.Tools.Quota.Services.Util.Usage; -public class PostgreSQLUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, IHttpClientFactory httpClientFactory) : AzureUsageChecker(credential, subscriptionId, logger) +public class PostgreSQLUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, IHttpClientFactory httpClientFactory, ITenantService tenantService) : AzureUsageChecker(credential, subscriptionId, logger, tenantService) { private readonly IHttpClientFactory _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory)); + private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService)); public override async Task> GetUsageForLocationAsync(string location, CancellationToken cancellationToken) { try { - var requestUrl = $"{managementEndpoint}/subscriptions/{SubscriptionId}/providers/Microsoft.DBforPostgreSQL/locations/{location}/resourceType/flexibleServers/usages?api-version=2023-06-01-preview"; + var requestUrl = $"{GetManagementEndpoint()}/subscriptions/{SubscriptionId}/providers/Microsoft.DBforPostgreSQL/locations/{location}/resourceType/flexibleServers/usages?api-version=2023-06-01-preview"; using var rawResponse = await GetQuotaByUrlAsync(requestUrl, cancellationToken); if (rawResponse?.RootElement.TryGetProperty("value", out var valueElement) != true) @@ -67,7 +69,7 @@ public override async Task> GetUsageForLocationAsync(string loca { try { - var token = await Credential.GetTokenAsync(new TokenRequestContext([$"{managementEndpoint}/.default"]), cancellationToken); + var token = await Credential.GetTokenAsync(new TokenRequestContext([_tenantService.CloudConfiguration.ArmEnvironment.DefaultScope]), cancellationToken); using var request = new HttpRequestMessage(HttpMethod.Get, requestUrl); request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", token.Token); diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/SearchUsageChecker.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/SearchUsageChecker.cs index 9f0aef545b..b4bb69f25f 100644 --- a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/SearchUsageChecker.cs +++ b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/SearchUsageChecker.cs @@ -2,13 +2,14 @@ // Licensed under the MIT License. using Azure.Core; +using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.ResourceManager.Search; using Azure.ResourceManager.Search.Models; using Microsoft.Extensions.Logging; namespace Azure.Mcp.Tools.Quota.Services.Util.Usage; -public class SearchUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger) : AzureUsageChecker(credential, subscriptionId, logger) +public class SearchUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, ITenantService tenantService) : AzureUsageChecker(credential, subscriptionId, logger, tenantService) { public override async Task> GetUsageForLocationAsync(string location, CancellationToken cancellationToken) { diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/StorageUsageChecker.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/StorageUsageChecker.cs index 86f44bad01..0e356988a5 100644 --- a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/StorageUsageChecker.cs +++ b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/StorageUsageChecker.cs @@ -2,12 +2,13 @@ // Licensed under the MIT License. using Azure.Core; +using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.ResourceManager.Storage; using Microsoft.Extensions.Logging; namespace Azure.Mcp.Tools.Quota.Services.Util.Usage; -public class StorageUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger) : AzureUsageChecker(credential, subscriptionId, logger) +public class StorageUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, ITenantService tenantService) : AzureUsageChecker(credential, subscriptionId, logger, tenantService) { public override async Task> GetUsageForLocationAsync(string location, CancellationToken cancellationToken) { diff --git a/tools/Azure.Mcp.Tools.Quota/tests/Azure.Mcp.Tools.Quota.LiveTests/assets.json b/tools/Azure.Mcp.Tools.Quota/tests/Azure.Mcp.Tools.Quota.LiveTests/assets.json index 1bb30ae96f..a2554d5cfb 100644 --- a/tools/Azure.Mcp.Tools.Quota/tests/Azure.Mcp.Tools.Quota.LiveTests/assets.json +++ b/tools/Azure.Mcp.Tools.Quota/tests/Azure.Mcp.Tools.Quota.LiveTests/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "", "TagPrefix": "Azure.Mcp.Tools.Quota.LiveTests", - "Tag": "Azure.Mcp.Tools.Quota.LiveTests_1a91ae82bf" + "Tag": "Azure.Mcp.Tools.Quota.LiveTests_663924bd7c" } diff --git a/tools/Azure.Mcp.Tools.ResourceHealth/src/Services/ResourceHealthService.cs b/tools/Azure.Mcp.Tools.ResourceHealth/src/Services/ResourceHealthService.cs index 9c87ff1f75..16f6cb7aa8 100644 --- a/tools/Azure.Mcp.Tools.ResourceHealth/src/Services/ResourceHealthService.cs +++ b/tools/Azure.Mcp.Tools.ResourceHealth/src/Services/ResourceHealthService.cs @@ -16,9 +16,9 @@ public class ResourceHealthService(ISubscriptionService subscriptionService, ITe : BaseAzureService(tenantService), IResourceHealthService { private readonly ISubscriptionService _subscriptionService = subscriptionService ?? throw new ArgumentNullException(nameof(subscriptionService)); + private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService)); private readonly IHttpClientFactory _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory)); - private const string AzureManagementBaseUrl = "https://management.azure.com"; private const string ResourceHealthApiVersion = "2025-05-01"; public async Task GetAvailabilityStatusAsync( @@ -33,18 +33,19 @@ public async Task GetAvailabilityStatusAsync( try { + var managementEndpoint = _tenantService.CloudConfiguration.ArmEnvironment.Endpoint ?? throw new InvalidOperationException("Management endpoint is not configured."); + var credential = await GetCredential(cancellationToken); var token = await credential.GetTokenAsync( - new TokenRequestContext([$"{AzureManagementBaseUrl}/.default"]), + new TokenRequestContext([_tenantService.CloudConfiguration.ArmEnvironment.DefaultScope]), cancellationToken); var client = _httpClientFactory.CreateClient(); client.DefaultRequestHeaders.Authorization = new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", token.Token); // Construct URL safely using Uri to ensure path is relative to base - var baseUri = new Uri(AzureManagementBaseUrl); var relativePath = $"{parsedResourceId}/providers/Microsoft.ResourceHealth/availabilityStatuses/current?api-version={ResourceHealthApiVersion}"; - var requestUri = new Uri(baseUri, relativePath); + var requestUri = new Uri(managementEndpoint, relativePath); using var response = await client.GetAsync(requestUri, cancellationToken); response.EnsureSuccessStatusCode(); @@ -83,20 +84,20 @@ public async Task> ListAvailabilityStatusesAsync( var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken); var subscriptionId = subscriptionResource.Id.SubscriptionId; + var managementEndpoint = _tenantService.CloudConfiguration.ArmEnvironment.Endpoint; var credential = await GetCredential(cancellationToken); var token = await credential.GetTokenAsync( - new TokenRequestContext([$"{AzureManagementBaseUrl}/.default"]), + new TokenRequestContext([_tenantService.CloudConfiguration.ArmEnvironment.DefaultScope]), cancellationToken); var client = _httpClientFactory.CreateClient(); client.DefaultRequestHeaders.Authorization = new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", token.Token); // Construct URL safely using Uri to ensure path is relative to base - var baseUri = new Uri(AzureManagementBaseUrl); var relativePath = resourceGroup != null ? $"/subscriptions/{subscriptionId}/resourceGroups/{resourceGroup}/providers/Microsoft.ResourceHealth/availabilityStatuses?api-version={ResourceHealthApiVersion}" : $"/subscriptions/{subscriptionId}/providers/Microsoft.ResourceHealth/availabilityStatuses?api-version={ResourceHealthApiVersion}"; - var requestUri = new Uri(baseUri, relativePath); + var requestUri = new Uri(managementEndpoint, relativePath); using var response = await client.GetAsync(requestUri, cancellationToken); response.EnsureSuccessStatusCode(); @@ -140,9 +141,11 @@ public async Task> ListServiceHealthEventsAsync( var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken); var subscriptionId = subscriptionResource.Id.SubscriptionId; + var managementEndpoint = _tenantService.CloudConfiguration.ArmEnvironment.Endpoint; + var credential = await GetCredential(cancellationToken); var token = await credential.GetTokenAsync( - new TokenRequestContext([$"{AzureManagementBaseUrl}/.default"]), + new TokenRequestContext([_tenantService.CloudConfiguration.ArmEnvironment.DefaultScope]), cancellationToken); var client = _httpClientFactory.CreateClient(); @@ -196,8 +199,7 @@ public async Task> ListServiceHealthEventsAsync( } // Construct URL safely using Uri to ensure path is relative to base - var baseUri = new Uri(AzureManagementBaseUrl); - var requestUri = new Uri(baseUri, relativePath); + var requestUri = new Uri(managementEndpoint, relativePath); using var response = await client.GetAsync(requestUri, cancellationToken); response.EnsureSuccessStatusCode(); diff --git a/tools/Azure.Mcp.Tools.ResourceHealth/tests/Azure.Mcp.Tools.ResourceHealth.UnitTests/Services/ResourceHealthServiceSsrfValidationTests.cs b/tools/Azure.Mcp.Tools.ResourceHealth/tests/Azure.Mcp.Tools.ResourceHealth.UnitTests/Services/ResourceHealthServiceSsrfValidationTests.cs index 56bcaac47d..cde41d0bce 100644 --- a/tools/Azure.Mcp.Tools.ResourceHealth/tests/Azure.Mcp.Tools.ResourceHealth.UnitTests/Services/ResourceHealthServiceSsrfValidationTests.cs +++ b/tools/Azure.Mcp.Tools.ResourceHealth/tests/Azure.Mcp.Tools.ResourceHealth.UnitTests/Services/ResourceHealthServiceSsrfValidationTests.cs @@ -3,9 +3,11 @@ using System.Net; using Azure.Core; +using Azure.Mcp.Core.Services.Azure.Authentication; using Azure.Mcp.Core.Services.Azure.Subscription; using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.Mcp.Tools.ResourceHealth.Services; +using Azure.ResourceManager; using NSubstitute; using Xunit; @@ -33,6 +35,12 @@ public ResourceHealthServiceSsrfValidationTests() private void SetupMocksForValidRequest(HttpResponseMessage response) { + // Mock CloudConfiguration to return a valid ArmEnvironment + var cloudConfig = Substitute.For(); + cloudConfig.ArmEnvironment.Returns(ArmEnvironment.AzurePublicCloud); + cloudConfig.AuthorityHost.Returns(new Uri("https://login.microsoftonline.com")); + _tenantService.CloudConfiguration.Returns(cloudConfig); + // Mock TokenCredential var mockCredential = Substitute.For(); mockCredential.GetTokenAsync(Arg.Any(), Arg.Any()) diff --git a/tools/Azure.Mcp.Tools.Search/src/Services/SearchService.cs b/tools/Azure.Mcp.Tools.Search/src/Services/SearchService.cs index 01ac165f6e..b0526edbd9 100644 --- a/tools/Azure.Mcp.Tools.Search/src/Services/SearchService.cs +++ b/tools/Azure.Mcp.Tools.Search/src/Services/SearchService.cs @@ -5,6 +5,7 @@ using Azure.Core.Pipeline; using Azure.Mcp.Core.Options; using Azure.Mcp.Core.Services.Azure; +using Azure.Mcp.Core.Services.Azure.Authentication; using Azure.Mcp.Core.Services.Azure.Subscription; using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.Mcp.Core.Services.Caching; @@ -26,6 +27,7 @@ public sealed class SearchService( ITenantService tenantService) : BaseAzureService(tenantService), ISearchService { + private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService)); private readonly ISubscriptionService _subscriptionService = subscriptionService ?? throw new ArgumentNullException(nameof(subscriptionService)); private readonly ICacheService _cacheService = cacheService ?? throw new ArgumentNullException(nameof(cacheService)); private const string CacheGroup = "search"; @@ -369,7 +371,7 @@ private async Task GetSearchIndexClient(string serviceName, R clientOptions.Transport = new HttpClientTransport(TenantService.GetClient()); ConfigureRetryPolicy(clientOptions, retryPolicy); - var endpoint = new Uri($"https://{serviceName}.search.windows.net"); + var endpoint = GetSearchEndpoint(serviceName); searchClient = new SearchIndexClient(endpoint, credential, clientOptions); await _cacheService.SetAsync(CacheGroup, key, searchClient, s_cacheDurationClients, cancellationToken); } @@ -421,4 +423,19 @@ private static IndexInfo MapToIndexInfo(SearchIndex index) private static FieldInfo MapToFieldInfo(SearchField field) => new(field.Name, field.Type.ToString(), field.IsKey, field.IsSearchable, field.IsFilterable, field.IsSortable, field.IsFacetable, field.IsHidden != true); + + private Uri GetSearchEndpoint(string serviceName) + { + switch (_tenantService.CloudConfiguration.CloudType) + { + case AzureCloudConfiguration.AzureCloud.AzurePublicCloud: + return new Uri($"https://{serviceName}.search.windows.net"); + case AzureCloudConfiguration.AzureCloud.AzureChinaCloud: + return new Uri($"https://{serviceName}.search.azure.cn"); + case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud: + return new Uri($"https://{serviceName}.search.azure.us"); + default: + return new Uri($"https://{serviceName}.search.windows.net"); + } + } } diff --git a/tools/Azure.Mcp.Tools.ServiceFabric/src/Services/ServiceFabricService.cs b/tools/Azure.Mcp.Tools.ServiceFabric/src/Services/ServiceFabricService.cs index 69448b463b..8bacfddb41 100644 --- a/tools/Azure.Mcp.Tools.ServiceFabric/src/Services/ServiceFabricService.cs +++ b/tools/Azure.Mcp.Tools.ServiceFabric/src/Services/ServiceFabricService.cs @@ -20,11 +20,14 @@ public sealed class ServiceFabricService( IHttpClientFactory httpClientFactory) : BaseAzureService(tenantService), IServiceFabricService { private readonly ISubscriptionService _subscriptionService = subscriptionService ?? throw new ArgumentNullException(nameof(subscriptionService)); + private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService)); private readonly IHttpClientFactory _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory)); - private const string AzureManagementBaseUrl = "https://management.azure.com"; private const string ApiVersion = "2024-04-01"; + private string GetManagementBaseUrl() => + _tenantService.CloudConfiguration.ArmEnvironment.Endpoint.ToString().TrimEnd('/'); + public async Task> ListManagedClusterNodes( string subscription, string resourceGroup, @@ -43,13 +46,13 @@ public async Task> ListManagedClusterNodes( var credential = await GetCredential(tenant, cancellationToken); var token = await credential.GetTokenAsync( - new TokenRequestContext([$"{AzureManagementBaseUrl}/.default"]), + new TokenRequestContext([_tenantService.CloudConfiguration.ArmEnvironment.DefaultScope]), cancellationToken); var client = _httpClientFactory.CreateClient(); client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", token.Token); - var requestUrl = $"{AzureManagementBaseUrl}/subscriptions/{subscriptionId}/resourceGroups/{Uri.EscapeDataString(resourceGroup)}/providers/Microsoft.ServiceFabric/managedClusters/{Uri.EscapeDataString(clusterName)}/nodes?api-version={ApiVersion}"; + var requestUrl = $"{GetManagementBaseUrl()}/subscriptions/{subscriptionId}/resourceGroups/{Uri.EscapeDataString(resourceGroup)}/providers/Microsoft.ServiceFabric/managedClusters/{Uri.EscapeDataString(clusterName)}/nodes?api-version={ApiVersion}"; var allNodes = new List(); @@ -93,13 +96,13 @@ public async Task GetManagedClusterNode( var credential = await GetCredential(tenant, cancellationToken); var token = await credential.GetTokenAsync( - new TokenRequestContext([$"{AzureManagementBaseUrl}/.default"]), + new TokenRequestContext([_tenantService.CloudConfiguration.ArmEnvironment.DefaultScope]), cancellationToken); var client = _httpClientFactory.CreateClient(); client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", token.Token); - var requestUrl = $"{AzureManagementBaseUrl}/subscriptions/{subscriptionId}/resourceGroups/{Uri.EscapeDataString(resourceGroup)}/providers/Microsoft.ServiceFabric/managedClusters/{Uri.EscapeDataString(clusterName)}/nodes/{Uri.EscapeDataString(nodeName)}?api-version={ApiVersion}"; + var requestUrl = $"{GetManagementBaseUrl()}/subscriptions/{subscriptionId}/resourceGroups/{Uri.EscapeDataString(resourceGroup)}/providers/Microsoft.ServiceFabric/managedClusters/{Uri.EscapeDataString(clusterName)}/nodes/{Uri.EscapeDataString(nodeName)}?api-version={ApiVersion}"; using var response = await client.GetAsync(requestUrl, cancellationToken); response.EnsureSuccessStatusCode(); @@ -137,13 +140,13 @@ public async Task RestartManagedClusterNodes( var credential = await GetCredential(tenant, cancellationToken); var token = await credential.GetTokenAsync( - new TokenRequestContext([$"{AzureManagementBaseUrl}/.default"]), + new TokenRequestContext([_tenantService.CloudConfiguration.ArmEnvironment.DefaultScope]), cancellationToken); var client = _httpClientFactory.CreateClient(); client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", token.Token); - var requestUrl = $"{AzureManagementBaseUrl}/subscriptions/{subscriptionId}/resourceGroups/{Uri.EscapeDataString(resourceGroup)}/providers/Microsoft.ServiceFabric/managedClusters/{Uri.EscapeDataString(clusterName)}/nodeTypes/{Uri.EscapeDataString(nodeType)}/restart?api-version={ApiVersion}"; + var requestUrl = $"{GetManagementBaseUrl()}/subscriptions/{subscriptionId}/resourceGroups/{Uri.EscapeDataString(resourceGroup)}/providers/Microsoft.ServiceFabric/managedClusters/{Uri.EscapeDataString(clusterName)}/nodeTypes/{Uri.EscapeDataString(nodeType)}/restart?api-version={ApiVersion}"; var requestBody = new RestartNodeRequest { diff --git a/tools/Azure.Mcp.Tools.Speech/src/Services/Recognizers/FastTranscriptionRecognizer.cs b/tools/Azure.Mcp.Tools.Speech/src/Services/Recognizers/FastTranscriptionRecognizer.cs index 6db63e7b1f..26f4920afc 100644 --- a/tools/Azure.Mcp.Tools.Speech/src/Services/Recognizers/FastTranscriptionRecognizer.cs +++ b/tools/Azure.Mcp.Tools.Speech/src/Services/Recognizers/FastTranscriptionRecognizer.cs @@ -7,6 +7,7 @@ using Azure.Core; using Azure.Mcp.Core.Options; using Azure.Mcp.Core.Services.Azure; +using Azure.Mcp.Core.Services.Azure.Authentication; using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.Mcp.Tools.Speech.Models.FastTranscription; using Microsoft.Extensions.Logging; @@ -23,6 +24,7 @@ public class FastTranscriptionRecognizer( : BaseAzureService(tenantService), IFastTranscriptionRecognizer { private readonly ILogger _logger = logger; + private readonly ITenantService _tenantService = tenantService; private readonly IHttpClientFactory _httpClientFactory = httpClientFactory; /// @@ -65,7 +67,7 @@ public async Task RecognizeAsync( var credential = await GetCredential(cancellationToken); // Get access token for Cognitive Services with proper scope - var tokenRequestContext = new TokenRequestContext(["https://cognitiveservices.azure.com/.default"]); + var tokenRequestContext = new TokenRequestContext([GetCognitiveServicesScope()]); var accessToken = await credential.GetTokenAsync(tokenRequestContext, cancellationToken); // Build the Fast Transcription API URL @@ -245,4 +247,19 @@ private static string GetMimeType(string filePath) _ => "application/octet-stream" }; } + + private string GetCognitiveServicesScope() + { + switch (_tenantService.CloudConfiguration.CloudType) + { + case AzureCloudConfiguration.AzureCloud.AzurePublicCloud: + return "https://cognitiveservices.azure.com/.default"; + case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud: + return "https://cognitiveservices.azure.us/.default"; + case AzureCloudConfiguration.AzureCloud.AzureChinaCloud: + return "https://cognitiveservices.azure.cn/.default"; + default: + return "https://cognitiveservices.azure.com/.default"; + } + } } diff --git a/tools/Azure.Mcp.Tools.Speech/src/Services/Recognizers/RealtimeTranscriptionRecognizer.cs b/tools/Azure.Mcp.Tools.Speech/src/Services/Recognizers/RealtimeTranscriptionRecognizer.cs index 8cc6927ff7..0f609fa988 100644 --- a/tools/Azure.Mcp.Tools.Speech/src/Services/Recognizers/RealtimeTranscriptionRecognizer.cs +++ b/tools/Azure.Mcp.Tools.Speech/src/Services/Recognizers/RealtimeTranscriptionRecognizer.cs @@ -4,6 +4,7 @@ using Azure.Core; using Azure.Mcp.Core.Options; using Azure.Mcp.Core.Services.Azure; +using Azure.Mcp.Core.Services.Azure.Authentication; using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.Mcp.Tools.Speech.Models.Realtime; using Microsoft.CognitiveServices.Speech; @@ -19,6 +20,7 @@ namespace Azure.Mcp.Tools.Speech.Services.Recognizers; public class RealtimeTranscriptionRecognizer(ITenantService tenantService, ILogger logger) : BaseAzureService(tenantService), IRealtimeTranscriptionRecognizer { + private readonly ITenantService _tenantService = tenantService; private readonly ILogger _logger = logger; /// @@ -61,7 +63,7 @@ public async Task RecognizeAsync( var credential = await GetCredential(cancellationToken); // Get access token for Cognitive Services with proper scope - var tokenRequestContext = new TokenRequestContext(["https://cognitiveservices.azure.com/.default"]); + var tokenRequestContext = new TokenRequestContext([GetCognitiveServicesScope()]); var accessToken = await credential.GetTokenAsync(tokenRequestContext, cancellationToken); // Configure Speech SDK with endpoint @@ -496,4 +498,19 @@ private static List ExtractNBestResults(SdkSpeec return nbestResults; } + + private string GetCognitiveServicesScope() + { + switch (_tenantService.CloudConfiguration.CloudType) + { + case AzureCloudConfiguration.AzureCloud.AzurePublicCloud: + return "https://cognitiveservices.azure.com/.default"; + case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud: + return "https://cognitiveservices.azure.us/.default"; + case AzureCloudConfiguration.AzureCloud.AzureChinaCloud: + return "https://cognitiveservices.azure.cn/.default"; + default: + return "https://cognitiveservices.azure.com/.default"; + } + } } diff --git a/tools/Azure.Mcp.Tools.Speech/src/Services/Synthesizers/RealtimeTtsSynthesizer.cs b/tools/Azure.Mcp.Tools.Speech/src/Services/Synthesizers/RealtimeTtsSynthesizer.cs index 07def0f58f..eacbf1dc56 100644 --- a/tools/Azure.Mcp.Tools.Speech/src/Services/Synthesizers/RealtimeTtsSynthesizer.cs +++ b/tools/Azure.Mcp.Tools.Speech/src/Services/Synthesizers/RealtimeTtsSynthesizer.cs @@ -4,6 +4,7 @@ using Azure.Core; using Azure.Mcp.Core.Options; using Azure.Mcp.Core.Services.Azure; +using Azure.Mcp.Core.Services.Azure.Authentication; using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.Mcp.Tools.Speech.Models; using Microsoft.CognitiveServices.Speech; @@ -19,6 +20,7 @@ namespace Azure.Mcp.Tools.Speech.Services.Synthesizers; public class RealtimeTtsSynthesizer(ITenantService tenantService, ILogger logger) : BaseAzureService(tenantService), IRealtimeTtsSynthesizer { + private readonly ITenantService _tenantService = tenantService; private readonly ILogger _logger = logger; /// @@ -102,7 +104,7 @@ public async Task SynthesizeToFileAsync( var credential = await GetCredential(cancellationToken); // Get access token for Cognitive Services with proper scope - var tokenRequestContext = new TokenRequestContext(["https://cognitiveservices.azure.com/.default"]); + var tokenRequestContext = new TokenRequestContext([GetCognitiveServicesScope()]); var accessToken = await credential.GetTokenAsync(tokenRequestContext, cancellationToken); // Convert https endpoint to wss for WebSocket-based TTS @@ -275,4 +277,19 @@ private static SpeechSynthesisOutputFormat ParseOutputFormat(string? format) // If parsing fails, default to Riff24Khz16BitMonoPcm return SpeechSynthesisOutputFormat.Riff24Khz16BitMonoPcm; } + + private string GetCognitiveServicesScope() + { + switch (_tenantService.CloudConfiguration.CloudType) + { + case AzureCloudConfiguration.AzureCloud.AzurePublicCloud: + return "https://cognitiveservices.azure.com/.default"; + case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud: + return "https://cognitiveservices.azure.us/.default"; + case AzureCloudConfiguration.AzureCloud.AzureChinaCloud: + return "https://cognitiveservices.azure.cn/.default"; + default: + return "https://cognitiveservices.azure.com/.default"; + } + } } diff --git a/tools/Azure.Mcp.Tools.Storage/src/Services/StorageService.cs b/tools/Azure.Mcp.Tools.Storage/src/Services/StorageService.cs index 55e00b4045..4cb48956e7 100644 --- a/tools/Azure.Mcp.Tools.Storage/src/Services/StorageService.cs +++ b/tools/Azure.Mcp.Tools.Storage/src/Services/StorageService.cs @@ -7,6 +7,7 @@ using Azure.Data.Tables; using Azure.Mcp.Core.Options; using Azure.Mcp.Core.Services.Azure; +using Azure.Mcp.Core.Services.Azure.Authentication; using Azure.Mcp.Core.Services.Azure.Models; using Azure.Mcp.Core.Services.Azure.Subscription; using Azure.Mcp.Core.Services.Azure.Tenant; @@ -26,6 +27,7 @@ public class StorageService( ILogger logger) : BaseAzureResourceService(subscriptionService, tenantService), IStorageService { + private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService)); private readonly ILogger _logger = logger ?? throw new ArgumentNullException(nameof(logger)); public async Task> GetAccountDetails( @@ -380,7 +382,7 @@ private async Task CreateBlobServiceClient( RetryPolicyOptions? retryPolicy = null, CancellationToken cancellationToken = default) { - var uri = $"https://{account}.blob.core.windows.net"; + var uri = GetBlobEndpoint(account); var options = ConfigureRetryPolicy(AddDefaultPolicies(new BlobClientOptions()), retryPolicy); options.Transport = new HttpClientTransport(TenantService.GetClient()); return new BlobServiceClient(new Uri(uri), await GetCredential(tenant, cancellationToken), options); @@ -488,7 +490,7 @@ protected async Task CreateTableServiceClient( { var options = ConfigureRetryPolicy(AddDefaultPolicies(new TableClientOptions()), retryPolicy); options.Transport = new HttpClientTransport(TenantService.GetClient()); - var defaultUri = $"https://{account}.table.core.windows.net"; + var defaultUri = GetTableEndpoint(account); return new TableServiceClient(new Uri(defaultUri), await GetCredential(tenant, cancellationToken), options); } @@ -524,4 +526,34 @@ public async Task> ListTables( throw new Exception($"Error listing tables: {ex.Message}", ex); } } + + private string GetBlobEndpoint(string account) + { + switch (_tenantService.CloudConfiguration.CloudType) + { + case AzureCloudConfiguration.AzureCloud.AzurePublicCloud: + return $"https://{account}.blob.core.windows.net"; + case AzureCloudConfiguration.AzureCloud.AzureChinaCloud: + return $"https://{account}.blob.core.chinacloudapi.cn"; + case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud: + return $"https://{account}.blob.core.usgovcloudapi.net"; + default: + return $"https://{account}.blob.core.windows.net"; + } + } + + private string GetTableEndpoint(string? account) + { + switch (_tenantService.CloudConfiguration.CloudType) + { + case AzureCloudConfiguration.AzureCloud.AzurePublicCloud: + return $"https://{account}.table.core.windows.net"; + case AzureCloudConfiguration.AzureCloud.AzureChinaCloud: + return $"https://{account}.table.core.chinacloudapi.cn"; + case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud: + return $"https://{account}.table.core.usgovcloudapi.net"; + default: + return $"https://{account}.table.core.windows.net"; + } + } }