diff --git a/packages/http-client-python/emitter/src/types.ts b/packages/http-client-python/emitter/src/types.ts index 094c4129f26..b1630740bb3 100644 --- a/packages/http-client-python/emitter/src/types.ts +++ b/packages/http-client-python/emitter/src/types.ts @@ -92,7 +92,7 @@ export function getType( case "enumvalue": return emitEnumMember(type, emitEnum(context, type.enumType)); case "credential": - return emitCredential(type); + return emitCredential(context, type); case "bytes": case "boolean": case "plainDate": @@ -143,7 +143,10 @@ function emitMultiPartFile( }); } -function emitCredential(credential: SdkCredentialType): Record { +function emitCredential( + context: PythonSdkContext, + credential: SdkCredentialType, +): Record { let credential_type: Record = {}; const scheme = credential.scheme; if (scheme.type === "oauth2") { @@ -152,6 +155,7 @@ function emitCredential(credential: SdkCredentialType): Record { policy: { type: "BearerTokenCredentialPolicy", credentialScopes: [], + flows: (context.emitContext.options as any).flavor === "azure" ? [] : scheme.flows, }, }; for (const flow of scheme.flows) { diff --git a/packages/http-client-python/generator/pygen/codegen/models/credential_types.py b/packages/http-client-python/generator/pygen/codegen/models/credential_types.py index 3830e3deb5b..8687b83b80a 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/credential_types.py +++ b/packages/http-client-python/generator/pygen/codegen/models/credential_types.py @@ -48,17 +48,20 @@ def __init__( yaml_data: Dict[str, Any], code_model: "CodeModel", credential_scopes: List[str], + flows: Optional[Dict[str, Any]] = None, ) -> None: super().__init__(yaml_data, code_model) self.credential_scopes = credential_scopes + self.flows = flows def call(self, async_mode: bool) -> str: policy_name = f"{'Async' if async_mode else ''}BearerTokenCredentialPolicy" - return f"policies.{policy_name}(self.credential, *self.credential_scopes, **kwargs)" + auth_flows = f"auth_flows={self.flows}, " if self.flows else "" + return f"policies.{policy_name}(self.credential, *self.credential_scopes, {auth_flows}**kwargs)" @classmethod def from_yaml(cls, yaml_data: Dict[str, Any], code_model: "CodeModel") -> "BearerTokenCredentialPolicyType": - return cls(yaml_data, code_model, yaml_data["credentialScopes"]) + return cls(yaml_data, code_model, yaml_data["credentialScopes"], yaml_data.get("flows")) class ARMChallengeAuthenticationPolicyType(BearerTokenCredentialPolicyType): diff --git a/packages/http-client-python/generator/test/generic_mock_api_tests/asynctests/test_authentication_async.py b/packages/http-client-python/generator/test/generic_mock_api_tests/asynctests/test_authentication_async.py index 122df638356..bb92d42bc4a 100644 --- a/packages/http-client-python/generator/test/generic_mock_api_tests/asynctests/test_authentication_async.py +++ b/packages/http-client-python/generator/test/generic_mock_api_tests/asynctests/test_authentication_async.py @@ -33,6 +33,10 @@ class FakeCredential: async def get_token(*scopes): return core_library.credentials.AccessToken(token="".join(scopes), expires_on=1800) + @staticmethod + async def get_token_info(*scopes, **kwargs): + return core_library.credentials.AccessTokenInfo(token="".join(scopes), expires_on=1800) + return FakeCredential() diff --git a/packages/http-client-python/generator/test/generic_mock_api_tests/test_authentication.py b/packages/http-client-python/generator/test/generic_mock_api_tests/test_authentication.py index 13351ac74e6..5c1dc39a913 100644 --- a/packages/http-client-python/generator/test/generic_mock_api_tests/test_authentication.py +++ b/packages/http-client-python/generator/test/generic_mock_api_tests/test_authentication.py @@ -33,6 +33,10 @@ class FakeCredential: def get_token(*scopes): return core_library.credentials.AccessToken(token="".join(scopes), expires_on=1800) + @staticmethod + def get_token_info(*scopes, **kwargs): + return core_library.credentials.AccessTokenInfo(token="".join(scopes), expires_on=1800) + return FakeCredential() diff --git a/packages/http-client-python/generator/test/unbranded/mock_api_tests/asynctests/test_auth_flow_async.py b/packages/http-client-python/generator/test/unbranded/mock_api_tests/asynctests/test_auth_flow_async.py new file mode 100644 index 00000000000..5ad8263c4b6 --- /dev/null +++ b/packages/http-client-python/generator/test/unbranded/mock_api_tests/asynctests/test_auth_flow_async.py @@ -0,0 +1,19 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import pytest +from authentication.oauth2.aio import OAuth2Client + + +@pytest.mark.asyncio +async def test_oauth2_auth_flows(): + oauth2_client = OAuth2Client("fake_credential") + assert oauth2_client._config.authentication_policy._auth_flows == [ + { + "authorizationUrl": "https://login.microsoftonline.com/common/oauth2/authorize", + "scopes": [{"value": "https://security.microsoft.com/.default"}], + "type": "implicit", + } + ] diff --git a/packages/http-client-python/generator/test/unbranded/mock_api_tests/test_auth_flow.py b/packages/http-client-python/generator/test/unbranded/mock_api_tests/test_auth_flow.py new file mode 100644 index 00000000000..4aa857b020c --- /dev/null +++ b/packages/http-client-python/generator/test/unbranded/mock_api_tests/test_auth_flow.py @@ -0,0 +1,17 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from authentication.oauth2 import OAuth2Client + + +def test_oauth2_auth_flows(): + oauth2_client = OAuth2Client("fake_credential") + assert oauth2_client._config.authentication_policy._auth_flows == [ + { + "authorizationUrl": "https://login.microsoftonline.com/common/oauth2/authorize", + "scopes": [{"value": "https://security.microsoft.com/.default"}], + "type": "implicit", + } + ]