diff --git a/.chronus/changes/multi-clouds-2025-1-10-14-54-3.md b/.chronus/changes/multi-clouds-2025-1-10-14-54-3.md new file mode 100644 index 00000000000..70be10bff6c --- /dev/null +++ b/.chronus/changes/multi-clouds-2025-1-10-14-54-3.md @@ -0,0 +1,7 @@ +--- +changeKind: feature +packages: + - "@typespec/http-client-python" +--- + +Improve user experience in multi clouds scenario \ No newline at end of file diff --git a/packages/http-client-python/generator/pygen/codegen/models/client.py b/packages/http-client-python/generator/pygen/codegen/models/client.py index 848816643f5..20df8466858 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/client.py +++ b/packages/http-client-python/generator/pygen/codegen/models/client.py @@ -15,7 +15,7 @@ OverloadedRequestBuilder, get_request_builder, ) -from .parameter import Parameter, ParameterMethodLocation +from .parameter import Parameter, ParameterMethodLocation, ParameterLocation from .lro_operation import LROOperation from .lro_paging_operation import LROPagingOperation from ...utils import extract_original_name, NAME_LENGTH_LIMIT @@ -54,7 +54,7 @@ def name(self) -> str: return self.yaml_data["name"] -class Client(_ClientConfigBase[ClientGlobalParameterList]): +class Client(_ClientConfigBase[ClientGlobalParameterList]): # pylint: disable=too-many-public-methods """Model representing our service client""" def __init__( @@ -79,6 +79,27 @@ def __init__( self.request_id_header_name = self.yaml_data.get("requestIdHeaderName", None) self.has_etag: bool = yaml_data.get("hasEtag", False) + # update the host parameter value. In later logic, SDK will overwrite it + # with value from cloud_setting if users don't provide it. + if self.need_cloud_setting: + for p in self.parameters.parameters: + if p.location == ParameterLocation.ENDPOINT_PATH: + p.client_default_value = None + p.optional = True + break + + @property + def need_cloud_setting(self) -> bool: + return bool( + self.code_model.options.get("azure_arm", False) + and self.credential_scopes is not None + and self.endpoint_parameter is not None + ) + + @property + def endpoint_parameter(self) -> Optional[Parameter]: + return next((p for p in self.parameters.parameters if p.location == ParameterLocation.ENDPOINT_PATH), None) + def _build_request_builders( self, ) -> List[Union[RequestBuilder, OverloadedRequestBuilder]]: @@ -233,6 +254,10 @@ def _imports_shared(self, async_mode: bool, **kwargs) -> FileImport: "Self", ImportType.STDLIB, ) + if self.need_cloud_setting: + file_import.add_submodule_import("typing", "cast", ImportType.STDLIB) + file_import.add_submodule_import("azure.core.settings", "settings", ImportType.SDKCORE) + file_import.add_submodule_import("azure.mgmt.core.tools", "get_arm_endpoints", ImportType.SDKCORE) return file_import @property @@ -332,6 +357,18 @@ def imports_for_multiapi(self, async_mode: bool, **kwargs) -> FileImport: ) return file_import + @property + def credential_scopes(self) -> Optional[List[str]]: + """Credential scopes for this client""" + + if self.credential: + if hasattr(getattr(self.credential.type, "policy", None), "credential_scopes"): + return self.credential.type.policy.credential_scopes # type: ignore + for t in getattr(self.credential.type, "types", []): + if hasattr(getattr(t, "policy", None), "credential_scopes"): + return t.policy.credential_scopes + return None + @classmethod def from_yaml( cls, diff --git a/packages/http-client-python/generator/pygen/codegen/serializers/client_serializer.py b/packages/http-client-python/generator/pygen/codegen/serializers/client_serializer.py index 6ac6de68bb5..381d758d2e9 100644 --- a/packages/http-client-python/generator/pygen/codegen/serializers/client_serializer.py +++ b/packages/http-client-python/generator/pygen/codegen/serializers/client_serializer.py @@ -3,10 +3,10 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from typing import List +from typing import List, cast from . import utils -from ..models import Client, ParameterMethodLocation +from ..models import Client, ParameterMethodLocation, Parameter, ParameterLocation from .parameter_serializer import ParameterSerializer, PopKwargType from ...utils import build_policies @@ -77,17 +77,40 @@ def property_descriptions(self, async_mode: bool) -> List[str]: retval.append('"""') return retval - def initialize_config(self) -> str: + def initialize_config(self) -> List[str]: + retval = [] + additional_signatures = [] + if self.client.need_cloud_setting: + additional_signatures.append("credential_scopes=credential_scopes") + endpoint_parameter = cast(Parameter, self.client.endpoint_parameter) + retval.extend( + [ + '_cloud = kwargs.pop("cloud_setting", None) or settings.current.azure_cloud # type: ignore', + "_endpoints = get_arm_endpoints(_cloud)", + f"if not {endpoint_parameter.client_name}:", + f' {endpoint_parameter.client_name} = _endpoints["resource_manager"]', + 'credential_scopes = kwargs.pop("credential_scopes", _endpoints["credential_scopes"])', + ] + ) config_name = f"{self.client.name}Configuration" config_call = ", ".join( [ - f"{p.client_name}={p.client_name}" + ( + f"{p.client_name}=" + + ( + f"cast(str, {p.client_name})" + if self.client.need_cloud_setting and p.location == ParameterLocation.ENDPOINT_PATH + else p.client_name + ) + ) for p in self.client.config.parameters.method if p.method_location != ParameterMethodLocation.KWARG ] + + additional_signatures + ["**kwargs"] ) - return f"self._config = {config_name}({config_call})" + retval.append(f"self._config = {config_name}({config_call})") + return retval @property def host_variable_name(self) -> str: @@ -104,8 +127,11 @@ def initialize_pipeline_client(self, async_mode: bool) -> List[str]: result = [] pipeline_client_name = self.client.pipeline_class(async_mode) endpoint_name = "base_url" if self.client.code_model.is_azure_flavor else "endpoint" + host_variable_name = ( + f"cast(str, {self.host_variable_name})" if self.client.need_cloud_setting else self.host_variable_name + ) params = { - endpoint_name: self.host_variable_name, + endpoint_name: host_variable_name, "policies": "_policies", } if not self.client.code_model.is_legacy and self.client.request_id_header_name: diff --git a/packages/http-client-python/generator/pygen/codegen/serializers/general_serializer.py b/packages/http-client-python/generator/pygen/codegen/serializers/general_serializer.py index 39671556c40..fa5f81a39e7 100644 --- a/packages/http-client-python/generator/pygen/codegen/serializers/general_serializer.py +++ b/packages/http-client-python/generator/pygen/codegen/serializers/general_serializer.py @@ -19,7 +19,7 @@ VERSION_MAP = { "msrest": "0.7.1", "isodate": "0.6.1", - "azure-mgmt-core": "1.3.2", + "azure-mgmt-core": "1.5.0", "azure-core": "1.30.0", "typing-extensions": "4.6.0", "corehttp": "1.0.0b6", diff --git a/packages/http-client-python/generator/pygen/codegen/serializers/sample_serializer.py b/packages/http-client-python/generator/pygen/codegen/serializers/sample_serializer.py index cc587e51791..acd10373166 100644 --- a/packages/http-client-python/generator/pygen/codegen/serializers/sample_serializer.py +++ b/packages/http-client-python/generator/pygen/codegen/serializers/sample_serializer.py @@ -62,7 +62,7 @@ def _imports(self) -> FileImportSerializer: ImportType.SDKCORE, ) for param in self.operation.parameters.positional + self.operation.parameters.keyword_only: - if not param.client_default_value and not param.optional and param.wire_name in self.sample_params: + if param.client_default_value is None and not param.optional and param.wire_name in self.sample_params: imports.merge(param.type.imports_for_sample()) return FileImportSerializer(imports, True) @@ -80,7 +80,7 @@ def _client_params(self) -> Dict[str, Any]: for p in ( self.code_model.clients[0].parameters.positional + self.code_model.clients[0].parameters.keyword_only ) - if not (p.optional or p.client_default_value) + if not p.optional and p.client_default_value is None ] client_params = { p.client_name: special_param.get( diff --git a/packages/http-client-python/generator/pygen/codegen/templates/client.py.jinja2 b/packages/http-client-python/generator/pygen/codegen/templates/client.py.jinja2 index ea374d59b70..f640be27605 100644 --- a/packages/http-client-python/generator/pygen/codegen/templates/client.py.jinja2 +++ b/packages/http-client-python/generator/pygen/codegen/templates/client.py.jinja2 @@ -9,7 +9,7 @@ {% if client.has_parameterized_host %} {{ serializer.host_variable_name }} = {{ keywords.escape_str(client.url) }} {% endif %} - {{ serializer.initialize_config() }} + {{ op_tools.serialize(serializer.initialize_config()) | indent(8) }} {{ op_tools.serialize(serializer.initialize_pipeline_client(async_mode)) | indent(8) }} {{ op_tools.serialize(serializer.serializers_and_operation_groups_properties()) | indent(8) }} diff --git a/packages/http-client-python/generator/pygen/codegen/templates/config.py.jinja2 b/packages/http-client-python/generator/pygen/codegen/templates/config.py.jinja2 index 57e8daa3146..def1c26bc73 100644 --- a/packages/http-client-python/generator/pygen/codegen/templates/config.py.jinja2 +++ b/packages/http-client-python/generator/pygen/codegen/templates/config.py.jinja2 @@ -21,14 +21,8 @@ class {{ client.name }}Configuration: {{ client.config.pylint_disable() }} {% if serializer.set_constants() %} {{ op_tools.serialize(serializer.set_constants()) | indent(8) -}} {% endif %} -{% if client.credential %} - {% set cred_scopes = client.credential.type if client.credential.type.policy is defined and client.credential.type.policy.credential_scopes is defined %} - {% if not cred_scopes %} - {% set cred_scopes = client.credential.type.types | selectattr("policy.credential_scopes") | first if client.credential.type.types is defined %} - {% endif %} - {% if cred_scopes %} - self.credential_scopes = kwargs.pop('credential_scopes', {{ cred_scopes.policy.credential_scopes }}) - {% endif %} +{% if client.credential_scopes is not none %} + self.credential_scopes = kwargs.pop('credential_scopes', {{ client.credential_scopes }}) {% endif %} kwargs.setdefault('sdk_moniker', '{{ client.config.sdk_moniker }}/{}'.format(VERSION)) self.polling_interval = kwargs.get("polling_interval", 30) diff --git a/packages/http-client-python/generator/test/azure/requirements.txt b/packages/http-client-python/generator/test/azure/requirements.txt index 38c0d829c97..8d9bb86c7bc 100644 --- a/packages/http-client-python/generator/test/azure/requirements.txt +++ b/packages/http-client-python/generator/test/azure/requirements.txt @@ -1,7 +1,6 @@ -r ../dev_requirements.txt -e ../../ -azure-core==1.30.0 -azure-mgmt-core==1.3.2 +azure-mgmt-core==1.5.0 # only for azure -e ./generated/azure-client-generator-core-access