diff --git a/core/Microsoft.Mcp.Core/src/Services/Azure/Authentication/AzureCloudConfiguration.cs b/core/Microsoft.Mcp.Core/src/Services/Azure/Authentication/AzureCloudConfiguration.cs
index 712326e7b7..517b114414 100644
--- a/core/Microsoft.Mcp.Core/src/Services/Azure/Authentication/AzureCloudConfiguration.cs
+++ b/core/Microsoft.Mcp.Core/src/Services/Azure/Authentication/AzureCloudConfiguration.cs
@@ -14,6 +14,14 @@ namespace Azure.Mcp.Core.Services.Azure.Authentication;
///
public class AzureCloudConfiguration : IAzureCloudConfiguration
{
+
+ public enum AzureCloud
+ {
+ AzurePublicCloud,
+ AzureChinaCloud,
+ AzureUSGovernmentCloud,
+ }
+
private const string DefaultAuthorityHost = "https://login.microsoftonline.com";
///
@@ -37,7 +45,7 @@ public AzureCloudConfiguration(
?? configuration["Cloud"]
?? Environment.GetEnvironmentVariable("AZURE_CLOUD");
- (AuthorityHost, ArmEnvironment) = ParseCloudValue(cloudValue);
+ (AuthorityHost, ArmEnvironment, CloudType) = ParseCloudValue(cloudValue);
logger?.LogDebug(
"Azure cloud configuration initialized. Cloud value: '{CloudValue}', AuthorityHost: '{AuthorityHost}', ArmEnvironment: '{ArmEnvironment}'",
@@ -52,11 +60,13 @@ public AzureCloudConfiguration(
///
public ArmEnvironment ArmEnvironment { get; }
- private static (Uri authorityHost, ArmEnvironment armEnvironment) ParseCloudValue(string? cloudValue)
+ public AzureCloud CloudType { get; }
+
+ private static (Uri authorityHost, ArmEnvironment armEnvironment, AzureCloud cloudType) ParseCloudValue(string? cloudValue)
{
if (string.IsNullOrWhiteSpace(cloudValue))
{
- return (new Uri(DefaultAuthorityHost), ArmEnvironment.AzurePublicCloud);
+ return (new Uri(DefaultAuthorityHost), ArmEnvironment.AzurePublicCloud, AzureCloud.AzurePublicCloud);
}
// Check if it's already a URL - in this case we only have authority host
@@ -64,19 +74,19 @@ private static (Uri authorityHost, ArmEnvironment armEnvironment) ParseCloudValu
// additional configuration not currently supported)
if (cloudValue.StartsWith("https://", StringComparison.OrdinalIgnoreCase))
{
- return (new Uri(cloudValue), ArmEnvironment.AzurePublicCloud);
+ return (new Uri(cloudValue), ArmEnvironment.AzurePublicCloud, AzureCloud.AzurePublicCloud);
}
// Map common sovereign cloud names to authority hosts and ARM environments
return cloudValue.ToLowerInvariant() switch
{
"azurecloud" or "azurepubliccloud" or "public" =>
- (new Uri("https://login.microsoftonline.com"), ArmEnvironment.AzurePublicCloud),
+ (new Uri("https://login.microsoftonline.com"), ArmEnvironment.AzurePublicCloud, AzureCloud.AzurePublicCloud),
"azurechinacloud" or "china" =>
- (new Uri("https://login.chinacloudapi.cn"), ArmEnvironment.AzureChina),
+ (new Uri("https://login.chinacloudapi.cn"), ArmEnvironment.AzureChina, AzureCloud.AzureChinaCloud),
"azureusgovernment" or "azureusgovernmentcloud" or "usgov" or "usgovernment" =>
- (new Uri("https://login.microsoftonline.us"), ArmEnvironment.AzureGovernment),
- _ => (new Uri(DefaultAuthorityHost), ArmEnvironment.AzurePublicCloud) // Default to public cloud if unknown
+ (new Uri("https://login.microsoftonline.us"), ArmEnvironment.AzureGovernment, AzureCloud.AzureUSGovernmentCloud),
+ _ => (new Uri(DefaultAuthorityHost), ArmEnvironment.AzurePublicCloud, AzureCloud.AzurePublicCloud) // Default to public cloud if unknown
};
}
}
diff --git a/core/Microsoft.Mcp.Core/src/Services/Azure/Authentication/IAzureCloudConfiguration.cs b/core/Microsoft.Mcp.Core/src/Services/Azure/Authentication/IAzureCloudConfiguration.cs
index 3391445083..8cd4aca818 100644
--- a/core/Microsoft.Mcp.Core/src/Services/Azure/Authentication/IAzureCloudConfiguration.cs
+++ b/core/Microsoft.Mcp.Core/src/Services/Azure/Authentication/IAzureCloudConfiguration.cs
@@ -20,4 +20,9 @@ public interface IAzureCloudConfiguration
/// This determines the management endpoint used for Azure Resource Manager operations.
///
ArmEnvironment ArmEnvironment { get; }
+
+ ///
+ /// Gets the type of Azure cloud environment.
+ ///
+ AzureCloudConfiguration.AzureCloud CloudType { get; }
}
diff --git a/servers/Azure.Mcp.Server/changelog-entries/1771617105575.yaml b/servers/Azure.Mcp.Server/changelog-entries/1771617105575.yaml
new file mode 100644
index 0000000000..fa412b0e29
--- /dev/null
+++ b/servers/Azure.Mcp.Server/changelog-entries/1771617105575.yaml
@@ -0,0 +1,3 @@
+changes:
+ - section: "Features Added"
+ description: "Added sovereign cloud endpoint support for Storage, Search, Postgres, ServiceFabric, Pricing, and Extension services"
\ No newline at end of file
diff --git a/tools/Azure.Mcp.Tools.AppLens/src/Services/AppLensService.cs b/tools/Azure.Mcp.Tools.AppLens/src/Services/AppLensService.cs
index f0c266edc0..b26563684e 100644
--- a/tools/Azure.Mcp.Tools.AppLens/src/Services/AppLensService.cs
+++ b/tools/Azure.Mcp.Tools.AppLens/src/Services/AppLensService.cs
@@ -6,6 +6,7 @@
using System.Threading.Channels;
using Azure.Core;
using Azure.Mcp.Core.Services.Azure;
+using Azure.Mcp.Core.Services.Azure.Authentication;
using Azure.Mcp.Core.Services.Azure.Subscription;
using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.Mcp.Tools.AppLens.Models;
@@ -20,9 +21,9 @@ namespace Azure.Mcp.Tools.AppLens.Services;
public class AppLensService(IHttpClientFactory httpClientFactory, ISubscriptionService subscriptionService, ITenantService tenantService) : BaseAzureService(tenantService), IAppLensService
{
private readonly ISubscriptionService _subscriptionService = subscriptionService ?? throw new ArgumentNullException(nameof(subscriptionService));
+ private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService));
private readonly IHttpClientFactory _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory));
private readonly AppLensOptions _options = new AppLensOptions();
- private const string ConversationalDiagnosticsSignalREndpoint = "https://diagnosticschat.azure.com/chatHub";
///
public async Task DiagnoseResourceAsync(
@@ -93,12 +94,12 @@ private async Task GetAppLensSessionAsync(string resour
// Get ARM token
var token = await credential.GetTokenAsync(
- new TokenRequestContext(["https://management.azure.com/user_impersonation"]),
+ new TokenRequestContext([GetManagementImpersonationEndpoint().ToString()]),
cancellationToken);
// Call the AppLens token endpoint
using var request = new HttpRequestMessage(HttpMethod.Get,
- $"https://management.azure.com/{resourceId}/detectors/GetToken-db48586f-7d94-45fc-88ad-b30ccd3b571c?api-version=2015-08-01");
+ GetAppLensTokenEndpoint(resourceId));
request.Headers.Authorization = new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", token.Token);
@@ -156,10 +157,10 @@ public async IAsyncEnumerable AskAppLensAsync(
// https://learn.microsoft.com/aspnet/core/signalr/configuration?view=aspnetcore-9.0&tabs=dotnet#jsonmessagepack-serialization-options
options.PayloadSerializerOptions.TypeInfoResolverChain.Insert(0, AppLensJsonContext.Default);
})
- .WithUrl(ConversationalDiagnosticsSignalREndpoint, options =>
+ .WithUrl(GetConversationalDiagnosticsSignalREndpoint(), options =>
{
options.AccessTokenProvider = () => Task.FromResult(session.Token)!;
- options.Headers.Add("origin", "https://appservice-diagnostics.trafficmanager.net");
+ options.Headers.Add("origin", GetDiagnosticsPortalEndpoint().ToString());
})
.WithAutomaticReconnect()
.Build();
@@ -345,4 +346,48 @@ private static AppLensSession ParseGetTokenResponse(string rawResponse)
return session;
}
+
+ private Uri GetConversationalDiagnosticsSignalREndpoint()
+ {
+ return _tenantService.CloudConfiguration.CloudType switch
+ {
+ AzureCloudConfiguration.AzureCloud.AzurePublicCloud => new Uri("https://diagnosticschat.azure.com/chatHub"),
+ AzureCloudConfiguration.AzureCloud.AzureChinaCloud => new Uri("https://diagnosticschat.azure.cn/chatHub"),
+ AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud => new Uri("https://diagnosticschat.azure.us/chatHub"),
+ _ => new Uri("https://diagnosticschat.azure.com/chatHub"),
+ };
+ }
+
+ private Uri GetManagementImpersonationEndpoint()
+ {
+ return _tenantService.CloudConfiguration.CloudType switch
+ {
+ AzureCloudConfiguration.AzureCloud.AzurePublicCloud => new Uri("https://management.azure.com/user_impersonation"),
+ AzureCloudConfiguration.AzureCloud.AzureChinaCloud => new Uri("https://management.chinacloudapi.cn/user_impersonation"),
+ AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud => new Uri("https://management.usgovcloudapi.net/user_impersonation"),
+ _ => new Uri("https://management.azure.com/user_impersonation"),
+ };
+ }
+
+ private Uri GetAppLensTokenEndpoint(string resourceId)
+ {
+ return _tenantService.CloudConfiguration.CloudType switch
+ {
+ AzureCloudConfiguration.AzureCloud.AzurePublicCloud => new Uri($"https://management.azure.com/{resourceId}/detectors/GetToken-db48586f-7d94-45fc-88ad-b30ccd3b571c?api-version=2015-08-01"),
+ AzureCloudConfiguration.AzureCloud.AzureChinaCloud => new Uri($"https://management.chinacloudapi.cn/{resourceId}/detectors/GetToken-db48586f-7d94-45fc-88ad-b30ccd3b571c?api-version=2015-08-01"),
+ AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud => new Uri($"https://management.usgovcloudapi.net/{resourceId}/detectors/GetToken-db48586f-7d94-45fc-88ad-b30ccd3b571c?api-version=2015-08-01"),
+ _ => new Uri($"https://management.azure.com/{resourceId}/detectors/GetToken-db48586f-7d94-45fc-88ad-b30ccd3b571c?api-version=2015-08-01"),
+ };
+ }
+
+ private Uri GetDiagnosticsPortalEndpoint()
+ {
+ return _tenantService.CloudConfiguration.CloudType switch
+ {
+ AzureCloudConfiguration.AzureCloud.AzurePublicCloud => new Uri("https://appservice-diagnostics.trafficmanager.net"),
+ AzureCloudConfiguration.AzureCloud.AzureChinaCloud => new Uri("https://appservice-diagnostics.azure.cn"),
+ AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud => new Uri("https://appservice-diagnostics.azure.us"),
+ _ => new Uri("https://appservice-diagnostics.trafficmanager.net"),
+ };
+ }
}
diff --git a/tools/Azure.Mcp.Tools.AppService/src/Services/AppServiceService.cs b/tools/Azure.Mcp.Tools.AppService/src/Services/AppServiceService.cs
index 19bc342c96..7dc01ff141 100644
--- a/tools/Azure.Mcp.Tools.AppService/src/Services/AppServiceService.cs
+++ b/tools/Azure.Mcp.Tools.AppService/src/Services/AppServiceService.cs
@@ -3,6 +3,7 @@
using Azure.Mcp.Core.Options;
using Azure.Mcp.Core.Services.Azure;
+using Azure.Mcp.Core.Services.Azure.Authentication;
using Azure.Mcp.Core.Services.Azure.Subscription;
using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.Mcp.Tools.AppService.Models;
@@ -17,6 +18,7 @@ public class AppServiceService(
ITenantService tenantService,
ILogger logger) : BaseAzureService(tenantService), IAppServiceService
{
+ private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService));
private readonly ISubscriptionService _subscriptionService = subscriptionService ?? throw new ArgumentNullException(nameof(subscriptionService));
private readonly ILogger _logger = logger;
@@ -95,7 +97,7 @@ private async Task GetWebAppResourceAsync(string subscription,
return webAppResource.Value;
}
- private static string PrepareConnectionString(string? connectionString, string databaseType,
+ private string PrepareConnectionString(string? connectionString, string databaseType,
string databaseServer, string databaseName)
{
return string.IsNullOrWhiteSpace(connectionString)
@@ -177,15 +179,30 @@ private static ConnectionStringType GetConnectionStringType(string databaseType)
};
}
- private static string BuildConnectionString(string databaseType, string databaseServer, string databaseName)
+ private string BuildConnectionString(string databaseType, string databaseServer, string databaseName)
{
return databaseType.ToLowerInvariant() switch
{
"sqlserver" => $"Server={databaseServer};Database={databaseName};User Id={{username}};Password={{password}};TrustServerCertificate=True;",
"mysql" => $"Server={databaseServer};Database={databaseName};Uid={{username}};Pwd={{password}};",
"postgresql" => $"Host={databaseServer};Database={databaseName};Username={{username}};Password={{password}};",
- "cosmosdb" => $"AccountEndpoint=https://{databaseServer}.documents.azure.com:443/;AccountKey={{key}};Database={databaseName};",
+ "cosmosdb" => BuildCosmosConnectionString(databaseServer, databaseName),
_ => throw new ArgumentException($"Unsupported database type: {databaseType}")
};
}
+
+ private string BuildCosmosConnectionString(string databaseServer, string databaseName)
+ {
+ switch (_tenantService.CloudConfiguration.CloudType)
+ {
+ case AzureCloudConfiguration.AzureCloud.AzurePublicCloud:
+ return $"AccountEndpoint=https://{databaseServer}.documents.azure.com:443/;AccountKey={{key}};Database={databaseName};";
+ case AzureCloudConfiguration.AzureCloud.AzureChinaCloud:
+ return $"AccountEndpoint=https://{databaseServer}.documents.azure.cn:443/;AccountKey={{key}};Database={databaseName};";
+ case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud:
+ return $"AccountEndpoint=https://{databaseServer}.documents.azure.us:443/;AccountKey={{key}};Database={databaseName};";
+ default:
+ throw new ArgumentException($"Unsupported Azure cloud type: {_tenantService.CloudConfiguration.CloudType}");
+ }
+ }
}
diff --git a/tools/Azure.Mcp.Tools.ApplicationInsights/src/Services/ProfilerDataService.cs b/tools/Azure.Mcp.Tools.ApplicationInsights/src/Services/ProfilerDataService.cs
index afef1c89f0..84f6f4933f 100644
--- a/tools/Azure.Mcp.Tools.ApplicationInsights/src/Services/ProfilerDataService.cs
+++ b/tools/Azure.Mcp.Tools.ApplicationInsights/src/Services/ProfilerDataService.cs
@@ -8,6 +8,7 @@
using Azure.Core;
using Azure.Mcp.Core.Options;
using Azure.Mcp.Core.Services.Azure;
+using Azure.Mcp.Core.Services.Azure.Authentication;
using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.Mcp.Tools.ApplicationInsights.Commands;
using Azure.Mcp.Tools.ApplicationInsights.Models;
@@ -27,9 +28,7 @@ public class ProfilerDataService(
ITenantService tenantService)
: BaseAzureService(tenantService), IProfilerDataService
{
- private const string Endpoint = "https://dataplane.diagnosticservices.azure.com/";
- private const string DefaultScope = "api://dataplane.diagnosticservices.azure.com/.default";
-
+ private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService));
private readonly IHttpClientFactory _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory));
private readonly ILogger _logger = logger ?? throw new ArgumentNullException(nameof(logger));
@@ -98,7 +97,7 @@ await response.Content.ReadAsStreamAsync(cancellationToken),
private async Task CreateRequestAsync(HttpMethod method, string path, IDictionary? queries, string apiVersion, string? clientRequestId, HttpContent? httpContent, IDictionary>? additionalHeaders, CancellationToken cancellationToken)
{
- UriBuilder uriBuilder = new(Endpoint)
+ UriBuilder uriBuilder = new(GetDiagnosticServiceEndpoint())
{
Path = path
};
@@ -119,7 +118,7 @@ private async Task CreateRequestAsync(HttpMethod method, str
var scopes = new string[]
{
- DefaultScope
+ GetDiagnosticServicesScope()
};
string clientRequestIdLocal = clientRequestId ?? Guid.NewGuid().ToString();
TokenRequestContext tokenRequestContext = new(scopes, clientRequestIdLocal);
@@ -199,4 +198,34 @@ private async Task ResolveAppIdAsync(ResourceIdentifier resourceId, Cancel
_logger.LogInformation("Resolving appId: {resourceId} => {appId}", resourceId, appId);
return Guid.Parse(appId);
}
+
+ private Uri GetDiagnosticServiceEndpoint()
+ {
+ switch (_tenantService.CloudConfiguration.CloudType)
+ {
+ case AzureCloudConfiguration.AzureCloud.AzurePublicCloud:
+ return new Uri("https://dataplane.diagnosticservices.azure.com");
+ case AzureCloudConfiguration.AzureCloud.AzureChinaCloud:
+ return new Uri("https://dataplane.diagnosticservices.azure.cn");
+ case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud:
+ return new Uri("https://dataplane.diagnosticservices.azure.us");
+ default:
+ return new Uri("https://dataplane.diagnosticservices.azure.com");
+ }
+ }
+
+ private string GetDiagnosticServicesScope()
+ {
+ switch (_tenantService.CloudConfiguration.CloudType)
+ {
+ case AzureCloudConfiguration.AzureCloud.AzurePublicCloud:
+ return "api://dataplane.diagnosticservices.azure.com/.default";
+ case AzureCloudConfiguration.AzureCloud.AzureChinaCloud:
+ return "api://dataplane.diagnosticservices.azure.cn/.default";
+ case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud:
+ return "api://dataplane.diagnosticservices.azure.us/.default";
+ default:
+ return "api://dataplane.diagnosticservices.azure.com/.default";
+ }
+ }
}
diff --git a/tools/Azure.Mcp.Tools.ConfidentialLedger/src/Services/ConfidentialLedgerService.cs b/tools/Azure.Mcp.Tools.ConfidentialLedger/src/Services/ConfidentialLedgerService.cs
index 1981a1ffa0..1fcb6f8ed0 100644
--- a/tools/Azure.Mcp.Tools.ConfidentialLedger/src/Services/ConfidentialLedgerService.cs
+++ b/tools/Azure.Mcp.Tools.ConfidentialLedger/src/Services/ConfidentialLedgerService.cs
@@ -5,6 +5,7 @@
using System.Text.Json;
using Azure.Core;
using Azure.Mcp.Core.Services.Azure;
+using Azure.Mcp.Core.Services.Azure.Authentication;
using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.Mcp.Tools.ConfidentialLedger.Models;
using Azure.Security.ConfidentialLedger;
@@ -15,7 +16,7 @@ public class ConfidentialLedgerService(ITenantService tenantService)
: BaseAzureService(tenantService), IConfidentialLedgerService
{
// NOTE: We construct the data-plane endpoint from the ledger name.
- private static Uri BuildLedgerUri(string ledgerName) => new($"https://{ledgerName}.confidential-ledger.azure.com");
+ private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService));
private static RequestContent CreateAppendEntryContent(string entryData)
{
@@ -43,7 +44,7 @@ public async Task AppendEntryAsync(string ledgerName, string
var credential = await GetCredential(cancellationToken);
// Configure client (retry etc. could be extended later)
- ConfidentialLedgerClient client = new(BuildLedgerUri(ledgerName), credential);
+ ConfidentialLedgerClient client = new(GetLedgerUri(ledgerName), credential);
// Build RequestContent manually to avoid trimming issues from reflection-based serialization.
using var content = CreateAppendEntryContent(entryData);
@@ -74,7 +75,7 @@ public async Task GetLedgerEntryAsync(string ledgerName, s
}
var credential = await GetCredential(cancellationToken);
- ConfidentialLedgerClient client = new(BuildLedgerUri(ledgerName), credential);
+ ConfidentialLedgerClient client = new(GetLedgerUri(ledgerName), credential);
Response? getByCollectionResponse = null;
bool loaded = false;
@@ -115,4 +116,19 @@ public async Task GetLedgerEntryAsync(string ledgerName, s
Contents = contents ?? string.Empty,
};
}
+
+ private Uri GetLedgerUri(string ledgerName)
+ {
+ switch (_tenantService.CloudConfiguration.CloudType)
+ {
+ case AzureCloudConfiguration.AzureCloud.AzurePublicCloud:
+ return new Uri($"https://{ledgerName}.confidential-ledger.azure.com");
+ case AzureCloudConfiguration.AzureCloud.AzureChinaCloud:
+ return new Uri($"https://{ledgerName}.confidential-ledger.azure.cn");
+ case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud:
+ return new Uri($"https://{ledgerName}.confidential-ledger.azure.us");
+ default:
+ return new Uri($"https://{ledgerName}.confidential-ledger.azure.com");
+ }
+ }
}
diff --git a/tools/Azure.Mcp.Tools.Cosmos/src/Services/CosmosService.cs b/tools/Azure.Mcp.Tools.Cosmos/src/Services/CosmosService.cs
index cbe5ed80c1..187c0e5904 100644
--- a/tools/Azure.Mcp.Tools.Cosmos/src/Services/CosmosService.cs
+++ b/tools/Azure.Mcp.Tools.Cosmos/src/Services/CosmosService.cs
@@ -4,6 +4,7 @@
using System.Net;
using Azure.Mcp.Core.Options;
using Azure.Mcp.Core.Services.Azure;
+using Azure.Mcp.Core.Services.Azure.Authentication;
using Azure.Mcp.Core.Services.Azure.Subscription;
using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.Mcp.Core.Services.Caching;
@@ -17,10 +18,10 @@ public sealed class CosmosService(ISubscriptionService subscriptionService, ITen
: BaseAzureService(tenantService), ICosmosService, IAsyncDisposable
{
private readonly ISubscriptionService _subscriptionService = subscriptionService ?? throw new ArgumentNullException(nameof(subscriptionService));
+ private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService));
private readonly IHttpClientFactory _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory));
private readonly ICacheService _cacheService = cacheService ?? throw new ArgumentNullException(nameof(cacheService));
private readonly ILogger _logger = logger ?? throw new ArgumentNullException(nameof(logger));
- private const string CosmosBaseUri = "https://{0}.documents.azure.com:443/";
private const string CacheGroup = "cosmos";
private const string CosmosClientsCacheKeyPrefix = "clients_";
private const string CosmosDatabasesCacheKeyPrefix = "databases_";
@@ -78,7 +79,7 @@ private async Task CreateCosmosClientWithAuth(
var cosmosAccount = await GetCosmosAccountAsync(subscription, accountName, tenant, cancellationToken: cancellationToken);
var keys = await cosmosAccount.GetKeysAsync(cancellationToken);
cosmosClient = new CosmosClient(
- string.Format(CosmosBaseUri, accountName),
+ string.Format(GetCosmosBaseUriFormat(), accountName),
keys.Value.PrimaryMasterKey,
clientOptions);
break;
@@ -86,7 +87,7 @@ private async Task CreateCosmosClientWithAuth(
case AuthMethod.Credential:
default:
cosmosClient = new CosmosClient(
- string.Format(CosmosBaseUri, accountName),
+ string.Format(GetCosmosBaseUriFormat(), accountName),
await GetCredential(cancellationToken),
clientOptions);
break;
@@ -98,6 +99,21 @@ await GetCredential(cancellationToken),
return cosmosClient;
}
+ private string GetCosmosBaseUriFormat()
+ {
+ switch (_tenantService.CloudConfiguration.CloudType)
+ {
+ case AzureCloudConfiguration.AzureCloud.AzurePublicCloud:
+ return "https://{0}.documents.azure.com:443/";
+ case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud:
+ return "https://{0}.documents.azure.us:443/";
+ case AzureCloudConfiguration.AzureCloud.AzureChinaCloud:
+ return "https://{0}.documents.azure.cn:443/";
+ default:
+ return "https://{0}.documents.azure.com:443/";
+ }
+ }
+
private async Task ValidateCosmosClientAsync(CosmosClient client, CancellationToken cancellationToken = default)
{
try
diff --git a/tools/Azure.Mcp.Tools.EventHubs/src/Services/EventHubsService.cs b/tools/Azure.Mcp.Tools.EventHubs/src/Services/EventHubsService.cs
index 40653cc933..4ea39cac30 100644
--- a/tools/Azure.Mcp.Tools.EventHubs/src/Services/EventHubsService.cs
+++ b/tools/Azure.Mcp.Tools.EventHubs/src/Services/EventHubsService.cs
@@ -30,7 +30,7 @@ public async Task> GetNamespacesAsync(
try
{
- var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken);
+ var subscriptionResource = await ResolveSubscriptionResourceAsync(subscription, tenant, retryPolicy, cancellationToken);
var namespaces = new List();
if (!string.IsNullOrEmpty(resourceGroup))
@@ -108,7 +108,7 @@ public async Task GetNamespaceAsync(
try
{
- var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken);
+ var subscriptionResource = await ResolveSubscriptionResourceAsync(subscription, tenant, retryPolicy, cancellationToken);
var resourceGroupResource = await subscriptionResource.GetResourceGroupAsync(resourceGroup, cancellationToken);
if (resourceGroupResource?.Value == null)
@@ -155,7 +155,7 @@ public async Task CreateOrUpdateNamespaceAsync(
try
{
- var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken);
+ var subscriptionResource = await ResolveSubscriptionResourceAsync(subscription, tenant, retryPolicy, cancellationToken);
var resourceGroupResource = await subscriptionResource.GetResourceGroupAsync(resourceGroup, cancellationToken);
if (resourceGroupResource?.Value == null)
@@ -250,8 +250,9 @@ public async Task DeleteNamespaceAsync(
try
{
- var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken);
- var subscriptionId = subscriptionResource.Data.SubscriptionId;
+ var subscriptionId = _subscriptionService.IsSubscriptionId(subscription)
+ ? subscription
+ : await _subscriptionService.GetSubscriptionIdByName(subscription, tenant, retryPolicy, cancellationToken);
var armClient = await CreateArmClientAsync(tenant, retryPolicy, cancellationToken: cancellationToken);
var namespaceId = EventHubsNamespaceResource.CreateResourceIdentifier(subscriptionId, resourceGroup, namespaceName);
@@ -296,7 +297,7 @@ public async Task> GetEventHubsAsync(
try
{
- var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken);
+ var subscriptionResource = await ResolveSubscriptionResourceAsync(subscription, tenant, retryPolicy, cancellationToken);
var resourceGroupResource = await subscriptionResource.GetResourceGroupAsync(resourceGroup, cancellationToken);
if (resourceGroupResource?.Value == null)
@@ -342,7 +343,7 @@ public async Task> GetEventHubsAsync(
try
{
- var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken);
+ var subscriptionResource = await ResolveSubscriptionResourceAsync(subscription, tenant, retryPolicy, cancellationToken);
var resourceGroupResource = await subscriptionResource.GetResourceGroupAsync(resourceGroup, cancellationToken);
if (resourceGroupResource?.Value == null)
@@ -407,7 +408,7 @@ public async Task CreateOrUpdateEventHubAsync(
try
{
- var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken);
+ var subscriptionResource = await ResolveSubscriptionResourceAsync(subscription, tenant, retryPolicy, cancellationToken);
var resourceGroupResource = await subscriptionResource.GetResourceGroupAsync(resourceGroup, cancellationToken);
if (resourceGroupResource?.Value == null)
@@ -470,7 +471,7 @@ public async Task DeleteEventHubAsync(
try
{
- var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken);
+ var subscriptionResource = await ResolveSubscriptionResourceAsync(subscription, tenant, retryPolicy, cancellationToken);
try
{
@@ -536,8 +537,7 @@ public async Task CreateOrUpdateConsumerGroupAsync(
try
{
- var armClient = await CreateArmClientAsync(tenant, retryPolicy, cancellationToken: cancellationToken);
- var subscriptionResource = armClient.GetSubscriptionResource(ResourceManager.Resources.SubscriptionResource.CreateResourceIdentifier(subscription));
+ var subscriptionResource = await ResolveSubscriptionResourceAsync(subscription, tenant, retryPolicy, cancellationToken);
var resourceGroupResource = await subscriptionResource.GetResourceGroupAsync(resourceGroup, cancellationToken);
var namespaceResource = await resourceGroupResource.Value.GetEventHubsNamespaces().GetAsync(namespaceName, cancellationToken);
var eventHubResource = await namespaceResource.Value.GetEventHubs().GetAsync(eventHubName, cancellationToken);
@@ -596,8 +596,7 @@ public async Task DeleteConsumerGroupAsync(
try
{
- var armClient = await CreateArmClientAsync(tenant, retryPolicy, cancellationToken: cancellationToken);
- var subscriptionResource = armClient.GetSubscriptionResource(ResourceManager.Resources.SubscriptionResource.CreateResourceIdentifier(subscription));
+ var subscriptionResource = await ResolveSubscriptionResourceAsync(subscription, tenant, retryPolicy, cancellationToken);
try
{
@@ -668,7 +667,7 @@ public async Task> GetConsumerGroupsAsync(
try
{
- var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken);
+ var subscriptionResource = await ResolveSubscriptionResourceAsync(subscription, tenant, retryPolicy, cancellationToken);
var resourceGroupResource = await subscriptionResource.GetResourceGroupAsync(resourceGroup, cancellationToken);
if (resourceGroupResource?.Value == null)
@@ -722,7 +721,7 @@ public async Task> GetConsumerGroupsAsync(
try
{
- var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken);
+ var subscriptionResource = await ResolveSubscriptionResourceAsync(subscription, tenant, retryPolicy, cancellationToken);
var resourceGroupResource = await subscriptionResource.GetResourceGroupAsync(resourceGroup, cancellationToken);
if (resourceGroupResource?.Value == null)
@@ -776,4 +775,23 @@ private static ConsumerGroup ConvertToConsumerGroup(EventHubsConsumerGroupData c
UpdatedTime: consumerGroupData.UpdatedOn);
}
+ ///
+ /// Returns a SubscriptionResource handle for ARM navigation without making an HTTP call.
+ /// This avoids the cache-dependent GET /subscriptions/{id} that GetSubscription() makes,
+ /// which caused non-deterministic test proxy recordings.
+ ///
+ private async Task ResolveSubscriptionResourceAsync(
+ string subscription,
+ string? tenant,
+ RetryPolicyOptions? retryPolicy,
+ CancellationToken cancellationToken)
+ {
+ var subscriptionId = _subscriptionService.IsSubscriptionId(subscription)
+ ? subscription
+ : await _subscriptionService.GetSubscriptionIdByName(subscription, tenant, retryPolicy, cancellationToken);
+
+ var armClient = await CreateArmClientAsync(tenant, retryPolicy, cancellationToken: cancellationToken);
+ return armClient.GetSubscriptionResource(
+ ResourceManager.Resources.SubscriptionResource.CreateResourceIdentifier(subscriptionId));
+ }
}
diff --git a/tools/Azure.Mcp.Tools.EventHubs/tests/Azure.Mcp.Tools.EventHubs.LiveTests/assets.json b/tools/Azure.Mcp.Tools.EventHubs/tests/Azure.Mcp.Tools.EventHubs.LiveTests/assets.json
index a97c10dbcc..4f4f3174b6 100644
--- a/tools/Azure.Mcp.Tools.EventHubs/tests/Azure.Mcp.Tools.EventHubs.LiveTests/assets.json
+++ b/tools/Azure.Mcp.Tools.EventHubs/tests/Azure.Mcp.Tools.EventHubs.LiveTests/assets.json
@@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "",
"TagPrefix": "Azure.Mcp.Tools.EventHubs.LiveTests",
- "Tag": "Azure.Mcp.Tools.EventHubs.LiveTests_b5550e708b"
+ "Tag": "Azure.Mcp.Tools.EventHubs.LiveTests_07ba3d2d71"
}
diff --git a/tools/Azure.Mcp.Tools.Extension/src/Services/CliGenerateService.cs b/tools/Azure.Mcp.Tools.Extension/src/Services/CliGenerateService.cs
index 9dff64214e..0bbce41be6 100644
--- a/tools/Azure.Mcp.Tools.Extension/src/Services/CliGenerateService.cs
+++ b/tools/Azure.Mcp.Tools.Extension/src/Services/CliGenerateService.cs
@@ -8,7 +8,7 @@
namespace Azure.Mcp.Tools.Extension.Services;
-internal class CliGenerateService(IHttpClientFactory httpClientFactory, IAzureTokenCredentialProvider tokenCredentialProvider) : ICliGenerateService
+internal class CliGenerateService(IHttpClientFactory httpClientFactory, IAzureTokenCredentialProvider tokenCredentialProvider, IAzureCloudConfiguration cloudConfiguration) : ICliGenerateService
{
private readonly IHttpClientFactory _httpClientFactory = httpClientFactory;
private readonly IAzureTokenCredentialProvider _tokenCredentialProvider = tokenCredentialProvider;
@@ -22,7 +22,7 @@ public async Task GenerateAzureCLICommandAsync(string inten
var accessToken = await credential.GetTokenAsync(new TokenRequestContext([apiScope]), cancellationToken);
// AzCli copilot API endpoint
- const string url = "https://azclis-copilot-apim-prod-eus.azure-api.net/azcli/copilot";
+ var url = GetCliCopilotEndpoint();
var requestBody = new AzureCliGenerateRequest()
{
@@ -45,4 +45,19 @@ public async Task GenerateAzureCLICommandAsync(string inten
HttpResponseMessage responseMessage = await _httpClientFactory.CreateClient().SendAsync(requestMessage, cancellationToken);
return responseMessage;
}
+
+ private string GetCliCopilotEndpoint()
+ {
+ switch (cloudConfiguration.CloudType)
+ {
+ case AzureCloudConfiguration.AzureCloud.AzurePublicCloud:
+ return "https://azclis-copilot-apim-prod-eus.azure-api.net/azcli/copilot";
+ case AzureCloudConfiguration.AzureCloud.AzureChinaCloud:
+ return "https://azclis-copilot-apim-prod-eus.azure-api.cn/azcli/copilot";
+ case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud:
+ return "https://azclis-copilot-apim-prod-eus.azure-api.us/azcli/copilot";
+ default:
+ return "https://azclis-copilot-apim-prod-eus.azure-api.net/azcli/copilot";
+ }
+ }
}
diff --git a/tools/Azure.Mcp.Tools.KeyVault/src/Services/KeyVaultService.cs b/tools/Azure.Mcp.Tools.KeyVault/src/Services/KeyVaultService.cs
index a1a84a438c..efadf3636c 100644
--- a/tools/Azure.Mcp.Tools.KeyVault/src/Services/KeyVaultService.cs
+++ b/tools/Azure.Mcp.Tools.KeyVault/src/Services/KeyVaultService.cs
@@ -3,6 +3,7 @@
using Azure.Mcp.Core.Options;
using Azure.Mcp.Core.Services.Azure;
+using Azure.Mcp.Core.Services.Azure.Authentication;
using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.Security.KeyVault.Administration;
using Azure.Security.KeyVault.Certificates;
@@ -13,6 +14,7 @@ namespace Azure.Mcp.Tools.KeyVault.Services;
public sealed class KeyVaultService(ITenantService tenantService, IHttpClientFactory httpClientFactory) : BaseAzureService(tenantService), IKeyVaultService
{
+ private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService));
private readonly IHttpClientFactory _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory));
public async Task> ListKeys(
@@ -301,7 +303,36 @@ public async Task ImportCertificate(
}
}
- private static Uri BuildVaultUri(string vaultName) => new($"https://{vaultName}.vault.azure.net");
+ private Uri BuildVaultUri(string vaultName)
+ {
+ switch (_tenantService.CloudConfiguration.CloudType)
+ {
+ case AzureCloudConfiguration.AzureCloud.AzurePublicCloud:
+ return new Uri($"https://{vaultName}.vault.azure.net");
+ case AzureCloudConfiguration.AzureCloud.AzureChinaCloud:
+ return new Uri($"https://{vaultName}.vault.azure.cn");
+ case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud:
+ return new Uri($"https://{vaultName}.vault.usgovcloudapi.net");
+ default:
+ return new Uri($"https://{vaultName}.vault.azure.net");
+ }
+ }
+
+
+ private Uri GetHsmUri(string vaultName)
+ {
+ switch (_tenantService.CloudConfiguration.CloudType)
+ {
+ case AzureCloudConfiguration.AzureCloud.AzurePublicCloud:
+ return new Uri($"https://{vaultName}.managedhsm.azure.net");
+ case AzureCloudConfiguration.AzureCloud.AzureChinaCloud:
+ return new Uri($"https://{vaultName}.managedhsm.azure.cn");
+ case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud:
+ return new Uri($"https://{vaultName}.managedhsm.usgovcloudapi.net");
+ default:
+ return new Uri($"https://{vaultName}.managedhsm.azure.net");
+ }
+ }
// Create clients with injected HttpClient, this will enable record/playback during testing.
private KeyClient CreateKeyClient(string vaultName, Azure.Core.TokenCredential credential, RetryPolicyOptions? retry)
@@ -346,7 +377,7 @@ public async Task GetVaultSettings(
{
ValidateRequiredParameters((nameof(vaultName), vaultName), (nameof(subscription), subscription));
var credential = await GetCredential(tenantId, cancellationToken);
- var hsmUri = new Uri($"https://{vaultName}.managedhsm.azure.net");
+ var hsmUri = GetHsmUri(vaultName);
try
{
var hsmClient = new KeyVaultSettingsClient(hsmUri, credential);
diff --git a/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/KeyVaultCommandTests.cs b/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/KeyVaultCommandTests.cs
index 0a7b7bee9b..22d5ef9969 100644
--- a/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/KeyVaultCommandTests.cs
+++ b/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/KeyVaultCommandTests.cs
@@ -18,6 +18,12 @@ public class KeyVaultCommandTests(ITestOutputHelper output, TestProxyFixture fix
{
private readonly KeyVaultTestCertificateAssets _importCertificateAssets = KeyVaultTestCertificates.Load();
+ public override CustomDefaultMatcher? TestMatcher => new()
+ {
+ ExcludedHeaders = "Authorization,Content-Type",
+ CompareBodies = false
+ };
+
public override List BodyRegexSanitizers => [
// Sanitizes all hostnames in URLs to remove actual vault names (not limited to `kid` fields)
new BodyRegexSanitizer(new BodyRegexSanitizerBody() {
@@ -215,7 +221,6 @@ public async Task Should_create_certificate()
[Fact]
- [CustomMatcher(compareBody: false)]
public async Task Should_import_certificate()
{
var fakePassword = _importCertificateAssets.Password;
diff --git a/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/assets.json b/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/assets.json
index 0000352247..f0346c4894 100644
--- a/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/assets.json
+++ b/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/assets.json
@@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "",
"TagPrefix": "Azure.Mcp.Tools.KeyVault.LiveTests",
- "Tag": "Azure.Mcp.Tools.KeyVault.LiveTests_24520a9311"
+ "Tag": "Azure.Mcp.Tools.KeyVault.LiveTests_bcfab177a3"
}
diff --git a/tools/Azure.Mcp.Tools.Marketplace/src/Services/MarketplaceService.cs b/tools/Azure.Mcp.Tools.Marketplace/src/Services/MarketplaceService.cs
index 7f7d2fb87a..8847fc2e71 100644
--- a/tools/Azure.Mcp.Tools.Marketplace/src/Services/MarketplaceService.cs
+++ b/tools/Azure.Mcp.Tools.Marketplace/src/Services/MarketplaceService.cs
@@ -15,7 +15,8 @@ namespace Azure.Mcp.Tools.Marketplace.Services;
public class MarketplaceService(ITenantService tenantService)
: BaseAzureService(tenantService), IMarketplaceService
{
- private const string ManagementApiBaseUrl = "https://management.azure.com";
+ private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService));
+
private const string ApiVersion = "2023-01-01-preview";
///
@@ -55,7 +56,8 @@ public async Task GetProduct(
(nameof(productId), productId),
(nameof(subscription), subscription));
- string productUrl = BuildProductUrl(subscription, productId, includeStopSoldPlans, language, market,
+ var managementEndpoint = _tenantService.CloudConfiguration.ArmEnvironment.Endpoint.ToString().TrimEnd('/');
+ string productUrl = BuildProductUrl(managementEndpoint, subscription, productId, includeStopSoldPlans, language, market,
lookupOfferInTenantLevel, planId, skuId, includeServiceInstructionTemplates);
return await GetMarketplaceSingleProductResponseAsync(productUrl, tenantId, retryPolicy, cancellationToken);
@@ -92,12 +94,14 @@ public async Task ListProducts(
{
ValidateRequiredParameters((nameof(subscription), subscription));
- string productsUrl = BuildProductsListUrl(subscription, language, search, filter, orderBy, select, nextCursor, expand);
+ var managementEndpoint = _tenantService.CloudConfiguration.ArmEnvironment.Endpoint.ToString().TrimEnd('/');
+ string productsUrl = BuildProductsListUrl(managementEndpoint, subscription, language, search, filter, orderBy, select, nextCursor, expand);
return await GetMarketplaceListProductsResponseAsync(productsUrl, tenantId, retryPolicy, cancellationToken);
}
private static string BuildProductsListUrl(
+ string managementEndpoint,
string subscription,
string? language,
string? search,
@@ -136,7 +140,7 @@ private static string BuildProductsListUrl(
queryParams.Add("storefront=any"); // include all storefronts
string queryString = string.Join("&", queryParams);
- return $"{ManagementApiBaseUrl}/subscriptions/{subscription}/providers/Microsoft.Marketplace/products?{queryString}";
+ return $"{managementEndpoint}/subscriptions/{subscription}/providers/Microsoft.Marketplace/products?{queryString}";
}
private async Task GetMarketplaceListProductsResponseAsync(string url, string? tenant, RetryPolicyOptions? retryPolicy, CancellationToken cancellationToken)
@@ -154,6 +158,7 @@ private async Task GetMarketplaceListProducts
private static string BuildProductUrl(
+ string managementEndpoint,
string subscription,
string productId,
bool? includeStopSoldPlans,
@@ -191,7 +196,7 @@ private static string BuildProductUrl(
queryParams.Add($"includeServiceInstructionTemplates={includeServiceInstructionTemplates.Value.ToString().ToLower()}");
string queryString = string.Join("&", queryParams);
- return $"{ManagementApiBaseUrl}/subscriptions/{subscription}/providers/Microsoft.Marketplace/products/{productId}?{queryString}";
+ return $"{managementEndpoint}/subscriptions/{subscription}/providers/Microsoft.Marketplace/products/{productId}?{queryString}";
}
private async Task GetMarketplaceSingleProductResponseAsync(string url, string? tenant, RetryPolicyOptions? retryPolicy, CancellationToken cancellationToken)
@@ -208,7 +213,8 @@ private async Task GetMarketplaceSingleProductResponseAsync(stri
private async Task GetArmAccessTokenAsync(string? tenantId, CancellationToken cancellationToken)
{
- var tokenRequestContext = new TokenRequestContext([$"{ManagementApiBaseUrl}/.default"]);
+ var defaultScope = _tenantService.CloudConfiguration.ArmEnvironment.DefaultScope;
+ var tokenRequestContext = new TokenRequestContext([defaultScope]);
var tokenCredential = await GetCredential(tenantId, cancellationToken);
return await tokenCredential
.GetTokenAsync(tokenRequestContext, cancellationToken);
diff --git a/tools/Azure.Mcp.Tools.Marketplace/tests/Azure.Mcp.Tools.Marketplace.LiveTests/assets.json b/tools/Azure.Mcp.Tools.Marketplace/tests/Azure.Mcp.Tools.Marketplace.LiveTests/assets.json
index f2731ac0cb..4765d79046 100644
--- a/tools/Azure.Mcp.Tools.Marketplace/tests/Azure.Mcp.Tools.Marketplace.LiveTests/assets.json
+++ b/tools/Azure.Mcp.Tools.Marketplace/tests/Azure.Mcp.Tools.Marketplace.LiveTests/assets.json
@@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "",
"TagPrefix": "Azure.Mcp.Tools.Marketplace.LiveTests",
- "Tag": "Azure.Mcp.Tools.Marketplace.LiveTests_0f59abc9d6"
+ "Tag": "Azure.Mcp.Tools.Marketplace.LiveTests_7c61cc9c48"
}
diff --git a/tools/Azure.Mcp.Tools.Monitor/src/Services/MonitorHealthModelService.cs b/tools/Azure.Mcp.Tools.Monitor/src/Services/MonitorHealthModelService.cs
index 0cdc03a147..82d8846c3e 100644
--- a/tools/Azure.Mcp.Tools.Monitor/src/Services/MonitorHealthModelService.cs
+++ b/tools/Azure.Mcp.Tools.Monitor/src/Services/MonitorHealthModelService.cs
@@ -6,6 +6,7 @@
using Azure.Core;
using Azure.Mcp.Core.Options;
using Azure.Mcp.Core.Services.Azure;
+using Azure.Mcp.Core.Services.Azure.Authentication;
using Azure.Mcp.Core.Services.Azure.Tenant;
namespace Azure.Mcp.Tools.Monitor.Services;
@@ -13,10 +14,9 @@ namespace Azure.Mcp.Tools.Monitor.Services;
public class MonitorHealthModelService(ITenantService tenantService, IHttpClientFactory httpClientFactory)
: BaseAzureService(tenantService), IMonitorHealthModelService
{
- private const string ManagementApiBaseUrl = "https://management.azure.com";
- private const string HealthModelsDataApiScope = "https://data.healthmodels.azure.com/.default";
private const string ApiVersion = "2023-10-01-preview";
private readonly IHttpClientFactory _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory));
+ private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService));
///
/// Retrieves the health information for a specific entity in a health model.
@@ -68,7 +68,7 @@ private async Task GetDataplaneResponseAsync(string url, CancellationTok
private async Task GetDataplaneEndpointAsync(string subscriptionId, string resourceGroupName, string healthModelName, CancellationToken cancellationToken)
{
string token = await GetControlPlaneTokenAsync(cancellationToken);
- string healthModelUrl = $"{ManagementApiBaseUrl}/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.CloudHealth/healthmodels/{healthModelName}?api-version={ApiVersion}";
+ string healthModelUrl = $"{GetManagementEndpoint()}/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.CloudHealth/healthmodels/{healthModelName}?api-version={ApiVersion}";
using var request = new HttpRequestMessage(HttpMethod.Get, healthModelUrl);
request.Headers.Authorization = new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", token);
@@ -106,7 +106,7 @@ private async Task GetControlPlaneTokenAsync(CancellationToken cancellat
{
TokenCredential credential = await GetCredential(cancellationToken);
AccessToken accessToken = await credential.GetTokenAsync(
- new TokenRequestContext([$"{ManagementApiBaseUrl}/.default"]),
+ new TokenRequestContext([_tenantService.CloudConfiguration.ArmEnvironment.DefaultScope]),
cancellationToken);
return accessToken.Token;
@@ -116,9 +116,29 @@ private async Task GetDataplaneTokenAsync(CancellationToken cancellation
{
TokenCredential credential = await GetCredential(cancellationToken);
AccessToken accessToken = await credential.GetTokenAsync(
- new TokenRequestContext([HealthModelsDataApiScope]),
+ new TokenRequestContext([GetHealthModelsDataApiScope()]),
cancellationToken);
return accessToken.Token;
}
+
+ private string GetManagementEndpoint()
+ {
+ return _tenantService.CloudConfiguration.ArmEnvironment.Endpoint.ToString().TrimEnd('/');
+ }
+
+ private string GetHealthModelsDataApiScope()
+ {
+ switch (_tenantService.CloudConfiguration.CloudType)
+ {
+ case AzureCloudConfiguration.AzureCloud.AzurePublicCloud:
+ return "https://data.healthmodels.azure.com/.default";
+ case AzureCloudConfiguration.AzureCloud.AzureChinaCloud:
+ return "https://data.healthmodels.azure.cn/.default";
+ case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud:
+ return "https://data.healthmodels.azure.us/.default";
+ default:
+ return "https://data.healthmodels.azure.com/.default";
+ }
+ }
}
diff --git a/tools/Azure.Mcp.Tools.Monitor/src/Services/MonitorService.cs b/tools/Azure.Mcp.Tools.Monitor/src/Services/MonitorService.cs
index 0dfc6450eb..5e21b0a199 100644
--- a/tools/Azure.Mcp.Tools.Monitor/src/Services/MonitorService.cs
+++ b/tools/Azure.Mcp.Tools.Monitor/src/Services/MonitorService.cs
@@ -7,6 +7,7 @@
using Azure.Core.Pipeline;
using Azure.Mcp.Core.Options;
using Azure.Mcp.Core.Services.Azure;
+using Azure.Mcp.Core.Services.Azure.Authentication;
using Azure.Mcp.Core.Services.Azure.ResourceGroup;
using Azure.Mcp.Core.Services.Azure.Subscription;
using Azure.Mcp.Core.Services.Azure.Tenant;
@@ -27,8 +28,7 @@ public class MonitorService(
IHttpClientFactory httpClientFactory) : BaseAzureService(tenantService), IMonitorService
{
private const string ActivityLogApiVersion = "2017-03-01-preview";
- private const string ActivityLogEndpointFormat
- = "https://management.azure.com/subscriptions/{0}/providers/Microsoft.Insights/eventtypes/management/values";
+ private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService));
private readonly IHttpClientFactory _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory));
public async Task> QueryResourceLogs(
@@ -424,7 +424,7 @@ private async Task> CallActivityLogApiAsync(
{
var returnValue = new List();
- string endpoint = string.Format(ActivityLogEndpointFormat, subscriptionId);
+ string endpoint = string.Format(GetLogActivityEndpointString(), subscriptionId);
var uriBuilder = new UriBuilder(endpoint);
// Build the query parameters
@@ -447,7 +447,7 @@ private async Task> CallActivityLogApiAsync(
TokenCredential credential = await GetCredential(tenant, cancellationToken);
AccessToken accessToken = await credential.GetTokenAsync(
- new TokenRequestContext(["https://management.azure.com/.default"]),
+ new TokenRequestContext([_tenantService.CloudConfiguration.ArmEnvironment.DefaultScope]),
cancellationToken);
// Make paginated requests
@@ -528,4 +528,19 @@ private static bool IsWorkspaceId(string workspace)
return (matchingWorkspace.CustomerId, matchingWorkspace.Name);
}
+
+ private string GetLogActivityEndpointString()
+ {
+ switch (_tenantService.CloudConfiguration.CloudType)
+ {
+ case AzureCloudConfiguration.AzureCloud.AzurePublicCloud:
+ return "https://management.azure.com/subscriptions/{0}/providers/Microsoft.Insights/eventtypes/management/values";
+ case AzureCloudConfiguration.AzureCloud.AzureChinaCloud:
+ return "https://management.chinacloudapi.cn/subscriptions/{0}/providers/Microsoft.Insights/eventtypes/management/values";
+ case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud:
+ return "https://management.usgovcloudapi.net/subscriptions/{0}/providers/Microsoft.Insights/eventtypes/management/values";
+ default:
+ return "https://management.azure.com/subscriptions/{0}/providers/Microsoft.Insights/eventtypes/management/values";
+ }
+ }
}
diff --git a/tools/Azure.Mcp.Tools.MySql/src/Services/MySqlService.cs b/tools/Azure.Mcp.Tools.MySql/src/Services/MySqlService.cs
index c52c06314d..6446d81b65 100644
--- a/tools/Azure.Mcp.Tools.MySql/src/Services/MySqlService.cs
+++ b/tools/Azure.Mcp.Tools.MySql/src/Services/MySqlService.cs
@@ -4,6 +4,7 @@
using System.Text.RegularExpressions;
using Azure.Core;
using Azure.Mcp.Core.Services.Azure;
+using Azure.Mcp.Core.Services.Azure.Authentication;
using Azure.Mcp.Core.Services.Azure.ResourceGroup;
using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.Mcp.Tools.MySql.Commands;
@@ -15,6 +16,7 @@ namespace Azure.Mcp.Tools.MySql.Services;
public class MySqlService(IResourceGroupService resourceGroupService, ITenantService tenantService, ILogger logger) : BaseAzureService(tenantService), IMySqlService
{
+ private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService));
private readonly IResourceGroupService _resourceGroupService = resourceGroupService ?? throw new ArgumentNullException(nameof(resourceGroupService));
private readonly ILogger _logger = logger;
@@ -66,18 +68,44 @@ public class MySqlService(IResourceGroupService resourceGroupService, ITenantSer
private async Task GetEntraIdAccessTokenAsync(CancellationToken cancellationToken)
{
- var tokenRequestContext = new TokenRequestContext(["https://ossrdbms-aad.database.windows.net/.default"]);
+
+ var tokenRequestContext = new TokenRequestContext([GetOpenSourceRDBMSEndpoint().ToString()]);
TokenCredential tokenCredential = await GetCredential(cancellationToken);
AccessToken accessToken = await tokenCredential
.GetTokenAsync(tokenRequestContext, cancellationToken);
return accessToken.Token;
}
- private static string NormalizeServerName(string server)
+ private Uri GetOpenSourceRDBMSEndpoint()
+ {
+ switch (_tenantService.CloudConfiguration.CloudType)
+ {
+ case AzureCloudConfiguration.AzureCloud.AzurePublicCloud:
+ return new Uri("https://ossrdbms-aad.database.windows.net/.default");
+ case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud:
+ return new Uri("https://ossrdbms-aad.database.usgovcloudapi.net/.default");
+ case AzureCloudConfiguration.AzureCloud.AzureChinaCloud:
+ return new Uri("https://ossrdbms-aad.database.chinacloudapi.cn/.default");
+ default:
+ return new Uri("https://ossrdbms-aad.database.windows.net/.default");
+ }
+ }
+
+ private string NormalizeServerName(string server)
{
if (!server.Contains('.'))
{
- return server + ".mysql.database.azure.com";
+ switch (_tenantService.CloudConfiguration.CloudType)
+ {
+ case AzureCloudConfiguration.AzureCloud.AzurePublicCloud:
+ return server + ".mysql.database.azure.com";
+ case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud:
+ return server + ".mysql.database.usgovcloudapi.net";
+ case AzureCloudConfiguration.AzureCloud.AzureChinaCloud:
+ return server + ".mysql.database.chinacloudapi.cn";
+ default:
+ return server + ".mysql.database.azure.com";
+ }
}
return server;
}
diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs b/tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs
index f240fcb0e7..518cd4176f 100644
--- a/tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs
+++ b/tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs
@@ -6,6 +6,7 @@
using System.Net;
using Azure.Core;
using Azure.Mcp.Core.Services.Azure;
+using Azure.Mcp.Core.Services.Azure.Authentication;
using Azure.Mcp.Core.Services.Azure.ResourceGroup;
using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.Mcp.Tools.Postgres.Auth;
@@ -21,6 +22,7 @@ namespace Azure.Mcp.Tools.Postgres.Services;
public class PostgresService : BaseAzureService, IPostgresService
{
private readonly IResourceGroupService _resourceGroupService;
+ private readonly ITenantService _tenantService;
private readonly IEntraTokenProvider _entraTokenAuth;
private readonly IDbProvider _dbProvider;
@@ -32,6 +34,7 @@ public PostgresService(
: base(tenantService)
{
_resourceGroupService = resourceGroupService ?? throw new ArgumentNullException(nameof(resourceGroupService));
+ _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService));
_entraTokenAuth = entraTokenAuth;
_dbProvider = dbProvider;
}
@@ -44,11 +47,21 @@ private async Task GetEntraIdAccessTokenAsync(CancellationToken cancella
return accessToken.Token;
}
- private static string NormalizeServerName(string server)
+ private string NormalizeServerName(string server)
{
if (!server.Contains('.'))
{
- return server + ".postgres.database.azure.com";
+ switch (_tenantService.CloudConfiguration.CloudType)
+ {
+ case AzureCloudConfiguration.AzureCloud.AzurePublicCloud:
+ return server + ".postgres.database.azure.com";
+ case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud:
+ return server + ".postgres.database.usgovcloudapi.net";
+ case AzureCloudConfiguration.AzureCloud.AzureChinaCloud:
+ return server + ".postgres.database.chinacloudapi.cn";
+ default:
+ return server + ".postgres.database.azure.com";
+ }
}
return server;
}
diff --git a/tools/Azure.Mcp.Tools.Pricing/src/Services/PricingService.cs b/tools/Azure.Mcp.Tools.Pricing/src/Services/PricingService.cs
index 7a08bed696..4c743a1e3b 100644
--- a/tools/Azure.Mcp.Tools.Pricing/src/Services/PricingService.cs
+++ b/tools/Azure.Mcp.Tools.Pricing/src/Services/PricingService.cs
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
+using Azure.Mcp.Core.Services.Azure.Authentication;
using Azure.Mcp.Tools.Pricing.Models;
using AzureRetailPrices;
@@ -9,7 +10,7 @@ namespace Azure.Mcp.Tools.Pricing.Services;
///
/// Service implementation for Azure Retail Pricing operations.
///
-public class PricingService : IPricingService
+public class PricingService(IAzureCloudConfiguration cloudConfiguration) : IPricingService
{
private const int MaxResults = 5000;
@@ -45,7 +46,7 @@ public async Task> GetPricesAsync(
var clientOptions = new AzureRetailPricesClientOptions(serviceVersion);
var client = new AzureRetailPricesClient(
- new Uri("https://prices.azure.com"),
+ GetPricingEndpoint(),
clientOptions);
var retailPrices = client.GetRetailPricesClient();
@@ -122,6 +123,21 @@ private static string EscapeODataValue(string value)
return value.Replace("'", "''");
}
+ private Uri GetPricingEndpoint()
+ {
+ switch (cloudConfiguration.CloudType)
+ {
+ case AzureCloudConfiguration.AzureCloud.AzurePublicCloud:
+ return new Uri("https://prices.azure.com");
+ case AzureCloudConfiguration.AzureCloud.AzureChinaCloud:
+ return new Uri("https://prices.azure.cn");
+ case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud:
+ return new Uri("https://prices.azure.us");
+ default:
+ return new Uri("https://prices.azure.com");
+ }
+ }
+
private static PriceItem MapToPriceItem(RetailPriceItem item)
{
var priceItem = new PriceItem
diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/QuotaService.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/QuotaService.cs
index 937c2ce5d1..d45a3e24b0 100644
--- a/tools/Azure.Mcp.Tools.Quota/src/Services/QuotaService.cs
+++ b/tools/Azure.Mcp.Tools.Quota/src/Services/QuotaService.cs
@@ -32,6 +32,7 @@ public async Task>> GetAzureQuotaAsync(
resourceTypes,
subscriptionId,
location,
+ TenantService,
loggerFactory,
_httpClientFactory,
cancellationToken);
diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/AzureUsageChecker.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/AzureUsageChecker.cs
index 852b62c07d..f6487eb30b 100644
--- a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/AzureUsageChecker.cs
+++ b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/AzureUsageChecker.cs
@@ -3,6 +3,7 @@
using System.Net.Http;
using Azure.Core;
+using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.Mcp.Tools.Quota.Services.Util.Usage;
using Azure.ResourceManager;
using Microsoft.Extensions.Logging;
@@ -50,16 +51,28 @@ public abstract class AzureUsageChecker : IUsageChecker
protected readonly ArmClient ResourceClient;
protected readonly TokenCredential Credential;
protected readonly ILogger Logger;
- protected const string managementEndpoint = "https://management.azure.com";
+ protected readonly ITenantService TenantService;
- protected AzureUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger)
+ protected AzureUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, ITenantService tenantService)
{
SubscriptionId = subscriptionId;
Credential = credential ?? throw new ArgumentNullException(nameof(credential));
- ResourceClient = new ArmClient(credential, subscriptionId);
+ TenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService));
Logger = logger;
+ var clientOptions = new ArmClientOptions { Environment = tenantService.CloudConfiguration.ArmEnvironment };
+
+ ResourceClient = new ArmClient(
+ credential,
+ subscriptionId,
+ clientOptions);
}
+ protected string GetManagementEndpoint()
+ {
+ return TenantService.CloudConfiguration.ArmEnvironment.Endpoint.ToString().TrimEnd('/');
+ }
+
+
public abstract Task> GetUsageForLocationAsync(string location, CancellationToken cancellationToken);
}
@@ -81,7 +94,7 @@ public static class UsageCheckerFactory
{ "Microsoft.ContainerInstance", ResourceProvider.ContainerInstance }
};
- public static IUsageChecker CreateUsageChecker(TokenCredential credential, string provider, string subscriptionId, ILoggerFactory loggerFactory, IHttpClientFactory httpClientFactory)
+ public static IUsageChecker CreateUsageChecker(TokenCredential credential, string provider, string subscriptionId, ILoggerFactory loggerFactory, IHttpClientFactory httpClientFactory, ITenantService tenantService)
{
if (!ProviderMapping.TryGetValue(provider, out var resourceProvider))
{
@@ -90,16 +103,16 @@ public static IUsageChecker CreateUsageChecker(TokenCredential credential, strin
return resourceProvider switch
{
- ResourceProvider.Compute => new ComputeUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger()),
- ResourceProvider.CognitiveServices => new CognitiveServicesUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger()),
- ResourceProvider.Storage => new StorageUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger()),
- ResourceProvider.ContainerApp => new ContainerAppUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger()),
- ResourceProvider.Network => new NetworkUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger()),
- ResourceProvider.MachineLearning => new MachineLearningUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger()),
- ResourceProvider.PostgreSQL => new PostgreSQLUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger(), httpClientFactory),
- ResourceProvider.HDInsight => new HDInsightUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger()),
- ResourceProvider.Search => new SearchUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger()),
- ResourceProvider.ContainerInstance => new ContainerInstanceUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger()),
+ ResourceProvider.Compute => new ComputeUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger(), tenantService),
+ ResourceProvider.CognitiveServices => new CognitiveServicesUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger(), tenantService),
+ ResourceProvider.Storage => new StorageUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger(), tenantService),
+ ResourceProvider.ContainerApp => new ContainerAppUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger(), tenantService),
+ ResourceProvider.Network => new NetworkUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger(), tenantService),
+ ResourceProvider.MachineLearning => new MachineLearningUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger(), tenantService),
+ ResourceProvider.PostgreSQL => new PostgreSQLUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger(), httpClientFactory, tenantService),
+ ResourceProvider.HDInsight => new HDInsightUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger(), tenantService),
+ ResourceProvider.Search => new SearchUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger(), tenantService),
+ ResourceProvider.ContainerInstance => new ContainerInstanceUsageChecker(credential, subscriptionId, loggerFactory.CreateLogger(), tenantService),
_ => throw new ArgumentException($"No implementation for provider: {provider}")
};
}
@@ -113,6 +126,7 @@ public static async Task>> GetAzureQuotaAsync
List resourceTypes,
string subscriptionId,
string location,
+ ITenantService tenantService,
ILoggerFactory loggerFactory,
IHttpClientFactory httpClientFactory,
CancellationToken cancellationToken)
@@ -130,7 +144,7 @@ public static async Task>> GetAzureQuotaAsync
var (provider, resourceTypesForProvider) = (kvp.Key, kvp.Value);
try
{
- var usageChecker = UsageCheckerFactory.CreateUsageChecker(credential, provider, subscriptionId, loggerFactory, httpClientFactory);
+ var usageChecker = UsageCheckerFactory.CreateUsageChecker(credential, provider, subscriptionId, loggerFactory, httpClientFactory, tenantService);
var quotaInfo = await usageChecker.GetUsageForLocationAsync(location, cancellationToken);
logger.LogDebug("Retrieved quota info for provider {Provider}: {ItemCount} items", provider, quotaInfo.Count);
diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/CognitiveServicesUsageChecker.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/CognitiveServicesUsageChecker.cs
index 429cf67c41..5de5f9f00a 100644
--- a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/CognitiveServicesUsageChecker.cs
+++ b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/CognitiveServicesUsageChecker.cs
@@ -2,13 +2,14 @@
// Licensed under the MIT License.
using Azure.Core;
+using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.ResourceManager.CognitiveServices;
using Azure.ResourceManager.CognitiveServices.Models;
using Microsoft.Extensions.Logging;
namespace Azure.Mcp.Tools.Quota.Services.Util.Usage;
-public class CognitiveServicesUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger) : AzureUsageChecker(credential, subscriptionId, logger)
+public class CognitiveServicesUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, ITenantService tenantService) : AzureUsageChecker(credential, subscriptionId, logger, tenantService)
{
public override async Task> GetUsageForLocationAsync(string location, CancellationToken cancellationToken)
{
diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ComputeUsageChecker.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ComputeUsageChecker.cs
index 4117e7ff75..e0fea2905f 100644
--- a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ComputeUsageChecker.cs
+++ b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ComputeUsageChecker.cs
@@ -2,13 +2,14 @@
// Licensed under the MIT License.
using Azure.Core;
+using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.ResourceManager.Compute;
using Azure.ResourceManager.Compute.Models;
using Microsoft.Extensions.Logging;
namespace Azure.Mcp.Tools.Quota.Services.Util.Usage;
-public class ComputeUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger) : AzureUsageChecker(credential, subscriptionId, logger)
+public class ComputeUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, ITenantService tenantService) : AzureUsageChecker(credential, subscriptionId, logger, tenantService)
{
public override async Task> GetUsageForLocationAsync(string location, CancellationToken cancellationToken)
{
diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ContainerAppUsageChecker.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ContainerAppUsageChecker.cs
index 012e2ce5c6..ea96a490f1 100644
--- a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ContainerAppUsageChecker.cs
+++ b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ContainerAppUsageChecker.cs
@@ -2,12 +2,13 @@
// Licensed under the MIT License.
using Azure.Core;
+using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.ResourceManager.AppContainers;
using Microsoft.Extensions.Logging;
namespace Azure.Mcp.Tools.Quota.Services.Util.Usage;
-public class ContainerAppUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger) : AzureUsageChecker(credential, subscriptionId, logger)
+public class ContainerAppUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, ITenantService tenantService) : AzureUsageChecker(credential, subscriptionId, logger, tenantService)
{
public override async Task> GetUsageForLocationAsync(string location, CancellationToken cancellationToken)
{
diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ContainerInstanceUsageChecker.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ContainerInstanceUsageChecker.cs
index cb31d483bb..4e4329acf8 100644
--- a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ContainerInstanceUsageChecker.cs
+++ b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/ContainerInstanceUsageChecker.cs
@@ -2,13 +2,14 @@
// Licensed under the MIT License.
using Azure.Core;
+using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.ResourceManager.ContainerInstance;
using Azure.ResourceManager.ContainerInstance.Models;
using Microsoft.Extensions.Logging;
namespace Azure.Mcp.Tools.Quota.Services.Util.Usage;
-public class ContainerInstanceUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger) : AzureUsageChecker(credential, subscriptionId, logger)
+public class ContainerInstanceUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, ITenantService tenantService) : AzureUsageChecker(credential, subscriptionId, logger, tenantService)
{
public override async Task> GetUsageForLocationAsync(string location, CancellationToken cancellationToken)
{
diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/HDInsightUsageChecker.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/HDInsightUsageChecker.cs
index c98d4398df..c3861ca301 100644
--- a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/HDInsightUsageChecker.cs
+++ b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/HDInsightUsageChecker.cs
@@ -2,13 +2,14 @@
// Licensed under the MIT License.
using Azure.Core;
+using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.ResourceManager.HDInsight;
using Azure.ResourceManager.HDInsight.Models;
using Microsoft.Extensions.Logging;
namespace Azure.Mcp.Tools.Quota.Services.Util.Usage;
-public class HDInsightUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger) : AzureUsageChecker(credential, subscriptionId, logger)
+public class HDInsightUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, ITenantService tenantService) : AzureUsageChecker(credential, subscriptionId, logger, tenantService)
{
public override async Task> GetUsageForLocationAsync(string location, CancellationToken cancellationToken)
{
diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/MachineLearningUsageChecker.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/MachineLearningUsageChecker.cs
index e03d7a438a..05a0517998 100644
--- a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/MachineLearningUsageChecker.cs
+++ b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/MachineLearningUsageChecker.cs
@@ -2,12 +2,13 @@
// Licensed under the MIT License.
using Azure.Core;
+using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.ResourceManager.MachineLearning;
using Microsoft.Extensions.Logging;
namespace Azure.Mcp.Tools.Quota.Services.Util.Usage;
-public class MachineLearningUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger) : AzureUsageChecker(credential, subscriptionId, logger)
+public class MachineLearningUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, ITenantService tenantService) : AzureUsageChecker(credential, subscriptionId, logger, tenantService)
{
public override async Task> GetUsageForLocationAsync(string location, CancellationToken cancellationToken)
{
diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/NetworkUsageChecker.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/NetworkUsageChecker.cs
index ed2894090f..5ed2d66e36 100644
--- a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/NetworkUsageChecker.cs
+++ b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/NetworkUsageChecker.cs
@@ -2,12 +2,13 @@
// Licensed under the MIT License.
using Azure.Core;
+using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.ResourceManager.Network;
using Microsoft.Extensions.Logging;
namespace Azure.Mcp.Tools.Quota.Services.Util.Usage;
-public class NetworkUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger) : AzureUsageChecker(credential, subscriptionId, logger)
+public class NetworkUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, ITenantService tenantService) : AzureUsageChecker(credential, subscriptionId, logger, tenantService)
{
public override async Task> GetUsageForLocationAsync(string location, CancellationToken cancellationToken)
{
diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/PostgreSQLUsageChecker.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/PostgreSQLUsageChecker.cs
index 15bda15836..fa02567efd 100644
--- a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/PostgreSQLUsageChecker.cs
+++ b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/PostgreSQLUsageChecker.cs
@@ -4,19 +4,21 @@
using System.Net.Http;
using System.Net.Http.Headers;
using Azure.Core;
+using Azure.Mcp.Core.Services.Azure.Tenant;
using Microsoft.Extensions.Logging;
namespace Azure.Mcp.Tools.Quota.Services.Util.Usage;
-public class PostgreSQLUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, IHttpClientFactory httpClientFactory) : AzureUsageChecker(credential, subscriptionId, logger)
+public class PostgreSQLUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, IHttpClientFactory httpClientFactory, ITenantService tenantService) : AzureUsageChecker(credential, subscriptionId, logger, tenantService)
{
private readonly IHttpClientFactory _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory));
+ private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService));
public override async Task> GetUsageForLocationAsync(string location, CancellationToken cancellationToken)
{
try
{
- var requestUrl = $"{managementEndpoint}/subscriptions/{SubscriptionId}/providers/Microsoft.DBforPostgreSQL/locations/{location}/resourceType/flexibleServers/usages?api-version=2023-06-01-preview";
+ var requestUrl = $"{GetManagementEndpoint()}/subscriptions/{SubscriptionId}/providers/Microsoft.DBforPostgreSQL/locations/{location}/resourceType/flexibleServers/usages?api-version=2023-06-01-preview";
using var rawResponse = await GetQuotaByUrlAsync(requestUrl, cancellationToken);
if (rawResponse?.RootElement.TryGetProperty("value", out var valueElement) != true)
@@ -67,7 +69,7 @@ public override async Task> GetUsageForLocationAsync(string loca
{
try
{
- var token = await Credential.GetTokenAsync(new TokenRequestContext([$"{managementEndpoint}/.default"]), cancellationToken);
+ var token = await Credential.GetTokenAsync(new TokenRequestContext([_tenantService.CloudConfiguration.ArmEnvironment.DefaultScope]), cancellationToken);
using var request = new HttpRequestMessage(HttpMethod.Get, requestUrl);
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", token.Token);
diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/SearchUsageChecker.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/SearchUsageChecker.cs
index 9f0aef545b..b4bb69f25f 100644
--- a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/SearchUsageChecker.cs
+++ b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/SearchUsageChecker.cs
@@ -2,13 +2,14 @@
// Licensed under the MIT License.
using Azure.Core;
+using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.ResourceManager.Search;
using Azure.ResourceManager.Search.Models;
using Microsoft.Extensions.Logging;
namespace Azure.Mcp.Tools.Quota.Services.Util.Usage;
-public class SearchUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger) : AzureUsageChecker(credential, subscriptionId, logger)
+public class SearchUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, ITenantService tenantService) : AzureUsageChecker(credential, subscriptionId, logger, tenantService)
{
public override async Task> GetUsageForLocationAsync(string location, CancellationToken cancellationToken)
{
diff --git a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/StorageUsageChecker.cs b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/StorageUsageChecker.cs
index 86f44bad01..0e356988a5 100644
--- a/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/StorageUsageChecker.cs
+++ b/tools/Azure.Mcp.Tools.Quota/src/Services/Util/Usage/StorageUsageChecker.cs
@@ -2,12 +2,13 @@
// Licensed under the MIT License.
using Azure.Core;
+using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.ResourceManager.Storage;
using Microsoft.Extensions.Logging;
namespace Azure.Mcp.Tools.Quota.Services.Util.Usage;
-public class StorageUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger) : AzureUsageChecker(credential, subscriptionId, logger)
+public class StorageUsageChecker(TokenCredential credential, string subscriptionId, ILogger logger, ITenantService tenantService) : AzureUsageChecker(credential, subscriptionId, logger, tenantService)
{
public override async Task> GetUsageForLocationAsync(string location, CancellationToken cancellationToken)
{
diff --git a/tools/Azure.Mcp.Tools.Quota/tests/Azure.Mcp.Tools.Quota.LiveTests/assets.json b/tools/Azure.Mcp.Tools.Quota/tests/Azure.Mcp.Tools.Quota.LiveTests/assets.json
index 1bb30ae96f..a2554d5cfb 100644
--- a/tools/Azure.Mcp.Tools.Quota/tests/Azure.Mcp.Tools.Quota.LiveTests/assets.json
+++ b/tools/Azure.Mcp.Tools.Quota/tests/Azure.Mcp.Tools.Quota.LiveTests/assets.json
@@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "",
"TagPrefix": "Azure.Mcp.Tools.Quota.LiveTests",
- "Tag": "Azure.Mcp.Tools.Quota.LiveTests_1a91ae82bf"
+ "Tag": "Azure.Mcp.Tools.Quota.LiveTests_663924bd7c"
}
diff --git a/tools/Azure.Mcp.Tools.ResourceHealth/src/Services/ResourceHealthService.cs b/tools/Azure.Mcp.Tools.ResourceHealth/src/Services/ResourceHealthService.cs
index 9c87ff1f75..16f6cb7aa8 100644
--- a/tools/Azure.Mcp.Tools.ResourceHealth/src/Services/ResourceHealthService.cs
+++ b/tools/Azure.Mcp.Tools.ResourceHealth/src/Services/ResourceHealthService.cs
@@ -16,9 +16,9 @@ public class ResourceHealthService(ISubscriptionService subscriptionService, ITe
: BaseAzureService(tenantService), IResourceHealthService
{
private readonly ISubscriptionService _subscriptionService = subscriptionService ?? throw new ArgumentNullException(nameof(subscriptionService));
+ private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService));
private readonly IHttpClientFactory _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory));
- private const string AzureManagementBaseUrl = "https://management.azure.com";
private const string ResourceHealthApiVersion = "2025-05-01";
public async Task GetAvailabilityStatusAsync(
@@ -33,18 +33,19 @@ public async Task GetAvailabilityStatusAsync(
try
{
+ var managementEndpoint = _tenantService.CloudConfiguration.ArmEnvironment.Endpoint ?? throw new InvalidOperationException("Management endpoint is not configured.");
+
var credential = await GetCredential(cancellationToken);
var token = await credential.GetTokenAsync(
- new TokenRequestContext([$"{AzureManagementBaseUrl}/.default"]),
+ new TokenRequestContext([_tenantService.CloudConfiguration.ArmEnvironment.DefaultScope]),
cancellationToken);
var client = _httpClientFactory.CreateClient();
client.DefaultRequestHeaders.Authorization = new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", token.Token);
// Construct URL safely using Uri to ensure path is relative to base
- var baseUri = new Uri(AzureManagementBaseUrl);
var relativePath = $"{parsedResourceId}/providers/Microsoft.ResourceHealth/availabilityStatuses/current?api-version={ResourceHealthApiVersion}";
- var requestUri = new Uri(baseUri, relativePath);
+ var requestUri = new Uri(managementEndpoint, relativePath);
using var response = await client.GetAsync(requestUri, cancellationToken);
response.EnsureSuccessStatusCode();
@@ -83,20 +84,20 @@ public async Task> ListAvailabilityStatusesAsync(
var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken);
var subscriptionId = subscriptionResource.Id.SubscriptionId;
+ var managementEndpoint = _tenantService.CloudConfiguration.ArmEnvironment.Endpoint;
var credential = await GetCredential(cancellationToken);
var token = await credential.GetTokenAsync(
- new TokenRequestContext([$"{AzureManagementBaseUrl}/.default"]),
+ new TokenRequestContext([_tenantService.CloudConfiguration.ArmEnvironment.DefaultScope]),
cancellationToken);
var client = _httpClientFactory.CreateClient();
client.DefaultRequestHeaders.Authorization = new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", token.Token);
// Construct URL safely using Uri to ensure path is relative to base
- var baseUri = new Uri(AzureManagementBaseUrl);
var relativePath = resourceGroup != null
? $"/subscriptions/{subscriptionId}/resourceGroups/{resourceGroup}/providers/Microsoft.ResourceHealth/availabilityStatuses?api-version={ResourceHealthApiVersion}"
: $"/subscriptions/{subscriptionId}/providers/Microsoft.ResourceHealth/availabilityStatuses?api-version={ResourceHealthApiVersion}";
- var requestUri = new Uri(baseUri, relativePath);
+ var requestUri = new Uri(managementEndpoint, relativePath);
using var response = await client.GetAsync(requestUri, cancellationToken);
response.EnsureSuccessStatusCode();
@@ -140,9 +141,11 @@ public async Task> ListServiceHealthEventsAsync(
var subscriptionResource = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy, cancellationToken);
var subscriptionId = subscriptionResource.Id.SubscriptionId;
+ var managementEndpoint = _tenantService.CloudConfiguration.ArmEnvironment.Endpoint;
+
var credential = await GetCredential(cancellationToken);
var token = await credential.GetTokenAsync(
- new TokenRequestContext([$"{AzureManagementBaseUrl}/.default"]),
+ new TokenRequestContext([_tenantService.CloudConfiguration.ArmEnvironment.DefaultScope]),
cancellationToken);
var client = _httpClientFactory.CreateClient();
@@ -196,8 +199,7 @@ public async Task> ListServiceHealthEventsAsync(
}
// Construct URL safely using Uri to ensure path is relative to base
- var baseUri = new Uri(AzureManagementBaseUrl);
- var requestUri = new Uri(baseUri, relativePath);
+ var requestUri = new Uri(managementEndpoint, relativePath);
using var response = await client.GetAsync(requestUri, cancellationToken);
response.EnsureSuccessStatusCode();
diff --git a/tools/Azure.Mcp.Tools.ResourceHealth/tests/Azure.Mcp.Tools.ResourceHealth.UnitTests/Services/ResourceHealthServiceSsrfValidationTests.cs b/tools/Azure.Mcp.Tools.ResourceHealth/tests/Azure.Mcp.Tools.ResourceHealth.UnitTests/Services/ResourceHealthServiceSsrfValidationTests.cs
index 56bcaac47d..cde41d0bce 100644
--- a/tools/Azure.Mcp.Tools.ResourceHealth/tests/Azure.Mcp.Tools.ResourceHealth.UnitTests/Services/ResourceHealthServiceSsrfValidationTests.cs
+++ b/tools/Azure.Mcp.Tools.ResourceHealth/tests/Azure.Mcp.Tools.ResourceHealth.UnitTests/Services/ResourceHealthServiceSsrfValidationTests.cs
@@ -3,9 +3,11 @@
using System.Net;
using Azure.Core;
+using Azure.Mcp.Core.Services.Azure.Authentication;
using Azure.Mcp.Core.Services.Azure.Subscription;
using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.Mcp.Tools.ResourceHealth.Services;
+using Azure.ResourceManager;
using NSubstitute;
using Xunit;
@@ -33,6 +35,12 @@ public ResourceHealthServiceSsrfValidationTests()
private void SetupMocksForValidRequest(HttpResponseMessage response)
{
+ // Mock CloudConfiguration to return a valid ArmEnvironment
+ var cloudConfig = Substitute.For();
+ cloudConfig.ArmEnvironment.Returns(ArmEnvironment.AzurePublicCloud);
+ cloudConfig.AuthorityHost.Returns(new Uri("https://login.microsoftonline.com"));
+ _tenantService.CloudConfiguration.Returns(cloudConfig);
+
// Mock TokenCredential
var mockCredential = Substitute.For();
mockCredential.GetTokenAsync(Arg.Any(), Arg.Any())
diff --git a/tools/Azure.Mcp.Tools.Search/src/Services/SearchService.cs b/tools/Azure.Mcp.Tools.Search/src/Services/SearchService.cs
index 01ac165f6e..b0526edbd9 100644
--- a/tools/Azure.Mcp.Tools.Search/src/Services/SearchService.cs
+++ b/tools/Azure.Mcp.Tools.Search/src/Services/SearchService.cs
@@ -5,6 +5,7 @@
using Azure.Core.Pipeline;
using Azure.Mcp.Core.Options;
using Azure.Mcp.Core.Services.Azure;
+using Azure.Mcp.Core.Services.Azure.Authentication;
using Azure.Mcp.Core.Services.Azure.Subscription;
using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.Mcp.Core.Services.Caching;
@@ -26,6 +27,7 @@ public sealed class SearchService(
ITenantService tenantService)
: BaseAzureService(tenantService), ISearchService
{
+ private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService));
private readonly ISubscriptionService _subscriptionService = subscriptionService ?? throw new ArgumentNullException(nameof(subscriptionService));
private readonly ICacheService _cacheService = cacheService ?? throw new ArgumentNullException(nameof(cacheService));
private const string CacheGroup = "search";
@@ -369,7 +371,7 @@ private async Task GetSearchIndexClient(string serviceName, R
clientOptions.Transport = new HttpClientTransport(TenantService.GetClient());
ConfigureRetryPolicy(clientOptions, retryPolicy);
- var endpoint = new Uri($"https://{serviceName}.search.windows.net");
+ var endpoint = GetSearchEndpoint(serviceName);
searchClient = new SearchIndexClient(endpoint, credential, clientOptions);
await _cacheService.SetAsync(CacheGroup, key, searchClient, s_cacheDurationClients, cancellationToken);
}
@@ -421,4 +423,19 @@ private static IndexInfo MapToIndexInfo(SearchIndex index)
private static FieldInfo MapToFieldInfo(SearchField field)
=> new(field.Name, field.Type.ToString(), field.IsKey, field.IsSearchable, field.IsFilterable, field.IsSortable,
field.IsFacetable, field.IsHidden != true);
+
+ private Uri GetSearchEndpoint(string serviceName)
+ {
+ switch (_tenantService.CloudConfiguration.CloudType)
+ {
+ case AzureCloudConfiguration.AzureCloud.AzurePublicCloud:
+ return new Uri($"https://{serviceName}.search.windows.net");
+ case AzureCloudConfiguration.AzureCloud.AzureChinaCloud:
+ return new Uri($"https://{serviceName}.search.azure.cn");
+ case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud:
+ return new Uri($"https://{serviceName}.search.azure.us");
+ default:
+ return new Uri($"https://{serviceName}.search.windows.net");
+ }
+ }
}
diff --git a/tools/Azure.Mcp.Tools.ServiceFabric/src/Services/ServiceFabricService.cs b/tools/Azure.Mcp.Tools.ServiceFabric/src/Services/ServiceFabricService.cs
index 69448b463b..8bacfddb41 100644
--- a/tools/Azure.Mcp.Tools.ServiceFabric/src/Services/ServiceFabricService.cs
+++ b/tools/Azure.Mcp.Tools.ServiceFabric/src/Services/ServiceFabricService.cs
@@ -20,11 +20,14 @@ public sealed class ServiceFabricService(
IHttpClientFactory httpClientFactory) : BaseAzureService(tenantService), IServiceFabricService
{
private readonly ISubscriptionService _subscriptionService = subscriptionService ?? throw new ArgumentNullException(nameof(subscriptionService));
+ private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService));
private readonly IHttpClientFactory _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory));
- private const string AzureManagementBaseUrl = "https://management.azure.com";
private const string ApiVersion = "2024-04-01";
+ private string GetManagementBaseUrl() =>
+ _tenantService.CloudConfiguration.ArmEnvironment.Endpoint.ToString().TrimEnd('/');
+
public async Task> ListManagedClusterNodes(
string subscription,
string resourceGroup,
@@ -43,13 +46,13 @@ public async Task> ListManagedClusterNodes(
var credential = await GetCredential(tenant, cancellationToken);
var token = await credential.GetTokenAsync(
- new TokenRequestContext([$"{AzureManagementBaseUrl}/.default"]),
+ new TokenRequestContext([_tenantService.CloudConfiguration.ArmEnvironment.DefaultScope]),
cancellationToken);
var client = _httpClientFactory.CreateClient();
client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", token.Token);
- var requestUrl = $"{AzureManagementBaseUrl}/subscriptions/{subscriptionId}/resourceGroups/{Uri.EscapeDataString(resourceGroup)}/providers/Microsoft.ServiceFabric/managedClusters/{Uri.EscapeDataString(clusterName)}/nodes?api-version={ApiVersion}";
+ var requestUrl = $"{GetManagementBaseUrl()}/subscriptions/{subscriptionId}/resourceGroups/{Uri.EscapeDataString(resourceGroup)}/providers/Microsoft.ServiceFabric/managedClusters/{Uri.EscapeDataString(clusterName)}/nodes?api-version={ApiVersion}";
var allNodes = new List();
@@ -93,13 +96,13 @@ public async Task GetManagedClusterNode(
var credential = await GetCredential(tenant, cancellationToken);
var token = await credential.GetTokenAsync(
- new TokenRequestContext([$"{AzureManagementBaseUrl}/.default"]),
+ new TokenRequestContext([_tenantService.CloudConfiguration.ArmEnvironment.DefaultScope]),
cancellationToken);
var client = _httpClientFactory.CreateClient();
client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", token.Token);
- var requestUrl = $"{AzureManagementBaseUrl}/subscriptions/{subscriptionId}/resourceGroups/{Uri.EscapeDataString(resourceGroup)}/providers/Microsoft.ServiceFabric/managedClusters/{Uri.EscapeDataString(clusterName)}/nodes/{Uri.EscapeDataString(nodeName)}?api-version={ApiVersion}";
+ var requestUrl = $"{GetManagementBaseUrl()}/subscriptions/{subscriptionId}/resourceGroups/{Uri.EscapeDataString(resourceGroup)}/providers/Microsoft.ServiceFabric/managedClusters/{Uri.EscapeDataString(clusterName)}/nodes/{Uri.EscapeDataString(nodeName)}?api-version={ApiVersion}";
using var response = await client.GetAsync(requestUrl, cancellationToken);
response.EnsureSuccessStatusCode();
@@ -137,13 +140,13 @@ public async Task RestartManagedClusterNodes(
var credential = await GetCredential(tenant, cancellationToken);
var token = await credential.GetTokenAsync(
- new TokenRequestContext([$"{AzureManagementBaseUrl}/.default"]),
+ new TokenRequestContext([_tenantService.CloudConfiguration.ArmEnvironment.DefaultScope]),
cancellationToken);
var client = _httpClientFactory.CreateClient();
client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", token.Token);
- var requestUrl = $"{AzureManagementBaseUrl}/subscriptions/{subscriptionId}/resourceGroups/{Uri.EscapeDataString(resourceGroup)}/providers/Microsoft.ServiceFabric/managedClusters/{Uri.EscapeDataString(clusterName)}/nodeTypes/{Uri.EscapeDataString(nodeType)}/restart?api-version={ApiVersion}";
+ var requestUrl = $"{GetManagementBaseUrl()}/subscriptions/{subscriptionId}/resourceGroups/{Uri.EscapeDataString(resourceGroup)}/providers/Microsoft.ServiceFabric/managedClusters/{Uri.EscapeDataString(clusterName)}/nodeTypes/{Uri.EscapeDataString(nodeType)}/restart?api-version={ApiVersion}";
var requestBody = new RestartNodeRequest
{
diff --git a/tools/Azure.Mcp.Tools.Speech/src/Services/Recognizers/FastTranscriptionRecognizer.cs b/tools/Azure.Mcp.Tools.Speech/src/Services/Recognizers/FastTranscriptionRecognizer.cs
index 6db63e7b1f..26f4920afc 100644
--- a/tools/Azure.Mcp.Tools.Speech/src/Services/Recognizers/FastTranscriptionRecognizer.cs
+++ b/tools/Azure.Mcp.Tools.Speech/src/Services/Recognizers/FastTranscriptionRecognizer.cs
@@ -7,6 +7,7 @@
using Azure.Core;
using Azure.Mcp.Core.Options;
using Azure.Mcp.Core.Services.Azure;
+using Azure.Mcp.Core.Services.Azure.Authentication;
using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.Mcp.Tools.Speech.Models.FastTranscription;
using Microsoft.Extensions.Logging;
@@ -23,6 +24,7 @@ public class FastTranscriptionRecognizer(
: BaseAzureService(tenantService), IFastTranscriptionRecognizer
{
private readonly ILogger _logger = logger;
+ private readonly ITenantService _tenantService = tenantService;
private readonly IHttpClientFactory _httpClientFactory = httpClientFactory;
///
@@ -65,7 +67,7 @@ public async Task RecognizeAsync(
var credential = await GetCredential(cancellationToken);
// Get access token for Cognitive Services with proper scope
- var tokenRequestContext = new TokenRequestContext(["https://cognitiveservices.azure.com/.default"]);
+ var tokenRequestContext = new TokenRequestContext([GetCognitiveServicesScope()]);
var accessToken = await credential.GetTokenAsync(tokenRequestContext, cancellationToken);
// Build the Fast Transcription API URL
@@ -245,4 +247,19 @@ private static string GetMimeType(string filePath)
_ => "application/octet-stream"
};
}
+
+ private string GetCognitiveServicesScope()
+ {
+ switch (_tenantService.CloudConfiguration.CloudType)
+ {
+ case AzureCloudConfiguration.AzureCloud.AzurePublicCloud:
+ return "https://cognitiveservices.azure.com/.default";
+ case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud:
+ return "https://cognitiveservices.azure.us/.default";
+ case AzureCloudConfiguration.AzureCloud.AzureChinaCloud:
+ return "https://cognitiveservices.azure.cn/.default";
+ default:
+ return "https://cognitiveservices.azure.com/.default";
+ }
+ }
}
diff --git a/tools/Azure.Mcp.Tools.Speech/src/Services/Recognizers/RealtimeTranscriptionRecognizer.cs b/tools/Azure.Mcp.Tools.Speech/src/Services/Recognizers/RealtimeTranscriptionRecognizer.cs
index 8cc6927ff7..0f609fa988 100644
--- a/tools/Azure.Mcp.Tools.Speech/src/Services/Recognizers/RealtimeTranscriptionRecognizer.cs
+++ b/tools/Azure.Mcp.Tools.Speech/src/Services/Recognizers/RealtimeTranscriptionRecognizer.cs
@@ -4,6 +4,7 @@
using Azure.Core;
using Azure.Mcp.Core.Options;
using Azure.Mcp.Core.Services.Azure;
+using Azure.Mcp.Core.Services.Azure.Authentication;
using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.Mcp.Tools.Speech.Models.Realtime;
using Microsoft.CognitiveServices.Speech;
@@ -19,6 +20,7 @@ namespace Azure.Mcp.Tools.Speech.Services.Recognizers;
public class RealtimeTranscriptionRecognizer(ITenantService tenantService, ILogger logger)
: BaseAzureService(tenantService), IRealtimeTranscriptionRecognizer
{
+ private readonly ITenantService _tenantService = tenantService;
private readonly ILogger _logger = logger;
///
@@ -61,7 +63,7 @@ public async Task RecognizeAsync(
var credential = await GetCredential(cancellationToken);
// Get access token for Cognitive Services with proper scope
- var tokenRequestContext = new TokenRequestContext(["https://cognitiveservices.azure.com/.default"]);
+ var tokenRequestContext = new TokenRequestContext([GetCognitiveServicesScope()]);
var accessToken = await credential.GetTokenAsync(tokenRequestContext, cancellationToken);
// Configure Speech SDK with endpoint
@@ -496,4 +498,19 @@ private static List ExtractNBestResults(SdkSpeec
return nbestResults;
}
+
+ private string GetCognitiveServicesScope()
+ {
+ switch (_tenantService.CloudConfiguration.CloudType)
+ {
+ case AzureCloudConfiguration.AzureCloud.AzurePublicCloud:
+ return "https://cognitiveservices.azure.com/.default";
+ case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud:
+ return "https://cognitiveservices.azure.us/.default";
+ case AzureCloudConfiguration.AzureCloud.AzureChinaCloud:
+ return "https://cognitiveservices.azure.cn/.default";
+ default:
+ return "https://cognitiveservices.azure.com/.default";
+ }
+ }
}
diff --git a/tools/Azure.Mcp.Tools.Speech/src/Services/Synthesizers/RealtimeTtsSynthesizer.cs b/tools/Azure.Mcp.Tools.Speech/src/Services/Synthesizers/RealtimeTtsSynthesizer.cs
index 07def0f58f..eacbf1dc56 100644
--- a/tools/Azure.Mcp.Tools.Speech/src/Services/Synthesizers/RealtimeTtsSynthesizer.cs
+++ b/tools/Azure.Mcp.Tools.Speech/src/Services/Synthesizers/RealtimeTtsSynthesizer.cs
@@ -4,6 +4,7 @@
using Azure.Core;
using Azure.Mcp.Core.Options;
using Azure.Mcp.Core.Services.Azure;
+using Azure.Mcp.Core.Services.Azure.Authentication;
using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.Mcp.Tools.Speech.Models;
using Microsoft.CognitiveServices.Speech;
@@ -19,6 +20,7 @@ namespace Azure.Mcp.Tools.Speech.Services.Synthesizers;
public class RealtimeTtsSynthesizer(ITenantService tenantService, ILogger logger)
: BaseAzureService(tenantService), IRealtimeTtsSynthesizer
{
+ private readonly ITenantService _tenantService = tenantService;
private readonly ILogger _logger = logger;
///
@@ -102,7 +104,7 @@ public async Task SynthesizeToFileAsync(
var credential = await GetCredential(cancellationToken);
// Get access token for Cognitive Services with proper scope
- var tokenRequestContext = new TokenRequestContext(["https://cognitiveservices.azure.com/.default"]);
+ var tokenRequestContext = new TokenRequestContext([GetCognitiveServicesScope()]);
var accessToken = await credential.GetTokenAsync(tokenRequestContext, cancellationToken);
// Convert https endpoint to wss for WebSocket-based TTS
@@ -275,4 +277,19 @@ private static SpeechSynthesisOutputFormat ParseOutputFormat(string? format)
// If parsing fails, default to Riff24Khz16BitMonoPcm
return SpeechSynthesisOutputFormat.Riff24Khz16BitMonoPcm;
}
+
+ private string GetCognitiveServicesScope()
+ {
+ switch (_tenantService.CloudConfiguration.CloudType)
+ {
+ case AzureCloudConfiguration.AzureCloud.AzurePublicCloud:
+ return "https://cognitiveservices.azure.com/.default";
+ case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud:
+ return "https://cognitiveservices.azure.us/.default";
+ case AzureCloudConfiguration.AzureCloud.AzureChinaCloud:
+ return "https://cognitiveservices.azure.cn/.default";
+ default:
+ return "https://cognitiveservices.azure.com/.default";
+ }
+ }
}
diff --git a/tools/Azure.Mcp.Tools.Storage/src/Services/StorageService.cs b/tools/Azure.Mcp.Tools.Storage/src/Services/StorageService.cs
index 55e00b4045..4cb48956e7 100644
--- a/tools/Azure.Mcp.Tools.Storage/src/Services/StorageService.cs
+++ b/tools/Azure.Mcp.Tools.Storage/src/Services/StorageService.cs
@@ -7,6 +7,7 @@
using Azure.Data.Tables;
using Azure.Mcp.Core.Options;
using Azure.Mcp.Core.Services.Azure;
+using Azure.Mcp.Core.Services.Azure.Authentication;
using Azure.Mcp.Core.Services.Azure.Models;
using Azure.Mcp.Core.Services.Azure.Subscription;
using Azure.Mcp.Core.Services.Azure.Tenant;
@@ -26,6 +27,7 @@ public class StorageService(
ILogger logger)
: BaseAzureResourceService(subscriptionService, tenantService), IStorageService
{
+ private readonly ITenantService _tenantService = tenantService ?? throw new ArgumentNullException(nameof(tenantService));
private readonly ILogger _logger = logger ?? throw new ArgumentNullException(nameof(logger));
public async Task> GetAccountDetails(
@@ -380,7 +382,7 @@ private async Task CreateBlobServiceClient(
RetryPolicyOptions? retryPolicy = null,
CancellationToken cancellationToken = default)
{
- var uri = $"https://{account}.blob.core.windows.net";
+ var uri = GetBlobEndpoint(account);
var options = ConfigureRetryPolicy(AddDefaultPolicies(new BlobClientOptions()), retryPolicy);
options.Transport = new HttpClientTransport(TenantService.GetClient());
return new BlobServiceClient(new Uri(uri), await GetCredential(tenant, cancellationToken), options);
@@ -488,7 +490,7 @@ protected async Task CreateTableServiceClient(
{
var options = ConfigureRetryPolicy(AddDefaultPolicies(new TableClientOptions()), retryPolicy);
options.Transport = new HttpClientTransport(TenantService.GetClient());
- var defaultUri = $"https://{account}.table.core.windows.net";
+ var defaultUri = GetTableEndpoint(account);
return new TableServiceClient(new Uri(defaultUri), await GetCredential(tenant, cancellationToken), options);
}
@@ -524,4 +526,34 @@ public async Task> ListTables(
throw new Exception($"Error listing tables: {ex.Message}", ex);
}
}
+
+ private string GetBlobEndpoint(string account)
+ {
+ switch (_tenantService.CloudConfiguration.CloudType)
+ {
+ case AzureCloudConfiguration.AzureCloud.AzurePublicCloud:
+ return $"https://{account}.blob.core.windows.net";
+ case AzureCloudConfiguration.AzureCloud.AzureChinaCloud:
+ return $"https://{account}.blob.core.chinacloudapi.cn";
+ case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud:
+ return $"https://{account}.blob.core.usgovcloudapi.net";
+ default:
+ return $"https://{account}.blob.core.windows.net";
+ }
+ }
+
+ private string GetTableEndpoint(string? account)
+ {
+ switch (_tenantService.CloudConfiguration.CloudType)
+ {
+ case AzureCloudConfiguration.AzureCloud.AzurePublicCloud:
+ return $"https://{account}.table.core.windows.net";
+ case AzureCloudConfiguration.AzureCloud.AzureChinaCloud:
+ return $"https://{account}.table.core.chinacloudapi.cn";
+ case AzureCloudConfiguration.AzureCloud.AzureUSGovernmentCloud:
+ return $"https://{account}.table.core.usgovcloudapi.net";
+ default:
+ return $"https://{account}.table.core.windows.net";
+ }
+ }
}