diff --git a/src/momento/auth/credential_provider.py b/src/momento/auth/credential_provider.py index eb16000f..ef874f45 100644 --- a/src/momento/auth/credential_provider.py +++ b/src/momento/auth/credential_provider.py @@ -4,6 +4,10 @@ import os from dataclasses import dataclass from typing import Dict, Optional +from warnings import warn + +from momento.errors.exceptions import InvalidArgumentException +from momento.internal.services import Service from . import momento_endpoint_resolver @@ -27,6 +31,8 @@ def from_environment_variable( ) -> CredentialProvider: """Reads and parses a Momento auth token stored as an environment variable. + Deprecated as of v1.28.0. Use from_environment_variables_v2 instead. + Args: env_var_name (str): Name of the environment variable from which the API key will be read control_endpoint (Optional[str], optional): Optionally overrides the default control endpoint. @@ -42,6 +48,11 @@ def from_environment_variable( Returns: CredentialProvider """ + warn( + "from_environment_variable is deprecated, use from_environment_variables_v2 instead", + DeprecationWarning, + stacklevel=2, + ) api_key = os.getenv(env_var_name) if not api_key: raise RuntimeError(f"Missing required environment variable {env_var_name}") @@ -56,6 +67,8 @@ def from_string( ) -> CredentialProvider: """Reads and parses a Momento auth token. + Deprecated as of v1.28.0. Use from_api_key_v2 or from_disposable_token instead. + Args: auth_token (str): the Momento API key (previously: auth token) control_endpoint (Optional[str], optional): Optionally overrides the default control endpoint. @@ -68,6 +81,11 @@ def from_string( Returns: CredentialProvider """ + warn( + "from_string is deprecated, use from_api_key_v2 or from_disposable_token instead", + DeprecationWarning, + stacklevel=2, + ) token_and_endpoints = momento_endpoint_resolver.resolve(auth_token) control_endpoint = control_endpoint or token_and_endpoints.control_endpoint cache_endpoint = cache_endpoint or token_and_endpoints.cache_endpoint @@ -102,3 +120,86 @@ def _obscure(self, value: str) -> str: def get_auth_token(self) -> str: return self.auth_token + + @staticmethod + def from_api_key_v2(api_key: str, endpoint: str) -> CredentialProvider: + """Creates a CredentialProvider from a v2 API key and endpoint. + + Args: + api_key (str): The v2 API key. + endpoint (str): The Momento service endpoint. + + Returns: + CredentialProvider + """ + if len(api_key) == 0: + raise InvalidArgumentException("API key cannot be empty.", Service.AUTH) + if len(endpoint) == 0: + raise InvalidArgumentException("Endpoint cannot be empty.", Service.AUTH) + + if not momento_endpoint_resolver._is_v2_api_key(api_key): + raise InvalidArgumentException( + "Received an invalid v2 API key. Are you using the correct key and the correct CredentialProvider method?", + Service.AUTH, + ) + return CredentialProvider( + auth_token=api_key, + control_endpoint=momento_endpoint_resolver._MOMENTO_CONTROL_ENDPOINT_PREFIX + endpoint, + cache_endpoint=momento_endpoint_resolver._MOMENTO_CACHE_ENDPOINT_PREFIX + endpoint, + token_endpoint=momento_endpoint_resolver._MOMENTO_TOKEN_ENDPOINT_PREFIX + endpoint, + port=443, + ) + + @staticmethod + def from_environment_variables_v2( + api_key_env_var: str = "MOMENTO_API_KEY", endpoint_env_var: str = "MOMENTO_ENDPOINT" + ) -> CredentialProvider: + """Creates a CredentialProvider from an endpoint and v2 API key stored in the environment variables MOMENTO_API_KEY and MOMENTO_ENDPOINT. + + Args: + api_key_env_var (str): Optionally provide an alternate environment variable name from which the v2 API key will be read. + endpoint_env_var (str): Optionally provide an alternate environment variable name from which the Momento service endpoint will be read. + + Returns: + CredentialProvider + """ + if len(api_key_env_var) == 0: + raise InvalidArgumentException("API key environment variable name cannot be empty.", Service.AUTH) + if len(endpoint_env_var) == 0: + raise InvalidArgumentException("Endpoint environment variable name cannot be empty.", Service.AUTH) + + api_key = os.getenv(api_key_env_var) + if not api_key: + raise RuntimeError(f"Missing required environment variable {api_key_env_var}") + endpoint = os.getenv(endpoint_env_var) + if not endpoint: + raise RuntimeError(f"Missing required environment variable {endpoint_env_var}") + + if not momento_endpoint_resolver._is_v2_api_key(api_key): + raise InvalidArgumentException( + "Received an invalid v2 API key. Are you using the correct key? Or did you mean to use `from_environment_variable()` with a legacy key instead?", + Service.AUTH, + ) + return CredentialProvider.from_api_key_v2(api_key, endpoint) + + @staticmethod + def from_disposable_token(auth_token: str) -> CredentialProvider: + """Reads and parses a Momento disposable auth token. + + Args: + auth_token (str): the Momento disposable auth token + + Returns: + CredentialProvider + """ + if len(auth_token) == 0: + raise InvalidArgumentException("Disposable token cannot be empty.", Service.AUTH) + token_and_endpoints = momento_endpoint_resolver.resolve(auth_token) + auth_token = token_and_endpoints.auth_token + return CredentialProvider( + auth_token, + token_and_endpoints.control_endpoint, + token_and_endpoints.cache_endpoint, + token_and_endpoints.token_endpoint, + 443, + ) diff --git a/src/momento/auth/momento_endpoint_resolver.py b/src/momento/auth/momento_endpoint_resolver.py index 8285c55c..f6cc07ce 100644 --- a/src/momento/auth/momento_endpoint_resolver.py +++ b/src/momento/auth/momento_endpoint_resolver.py @@ -14,6 +14,8 @@ _MOMENTO_TOKEN_ENDPOINT_PREFIX = "token." _CONTROL_ENDPOINT_CLAIM_ID = "cp" _CACHE_ENDPOINT_CLAIM_ID = "c" +_API_KEY_TYPE_CLAIM_ID = "t" +_GLOBAL_API_KEY_TYPE = "g" @dataclass @@ -31,6 +33,14 @@ class _Base64DecodedV1Token: def resolve(auth_token: str) -> _TokenAndEndpoints: + """Helper function used by from_string and from_disposable_token to parse legacy and v1 auth tokens. + + Args: + auth_token (str): The auth token to be resolved. + + Returns: + _TokenAndEndpoints + """ if not auth_token: raise InvalidArgumentException("malformed auth token", Service.AUTH) @@ -44,6 +54,11 @@ def resolve(auth_token: str) -> _TokenAndEndpoints: auth_token=info["api_key"], # type: ignore[misc] ) else: + if _is_v2_api_key(auth_token): + raise InvalidArgumentException( + "Unexpectedly received a v2 API key. Are you using the correct key and the correct CredentialProvider method?", + Service.AUTH, + ) return _get_endpoint_from_token(auth_token) @@ -67,3 +82,13 @@ def _is_base64(value: Union[bytes, str]) -> bool: return base64.b64encode(base64.b64decode(value)) == value except Exception: return False + + +def _is_v2_api_key(key: str) -> bool: + if _is_base64(key): + return False + try: + claims = jwt.decode(key, options={"verify_signature": False}) # type: ignore[misc] + return _API_KEY_TYPE_CLAIM_ID in claims and claims[_API_KEY_TYPE_CLAIM_ID] == _GLOBAL_API_KEY_TYPE # type: ignore[misc] + except DecodeError: + return False diff --git a/tests/momento/auth/test_credential_provider.py b/tests/momento/auth/test_credential_provider.py index 2258e584..139ed753 100644 --- a/tests/momento/auth/test_credential_provider.py +++ b/tests/momento/auth/test_credential_provider.py @@ -1,11 +1,13 @@ import base64 import json import os +import re import jwt import pytest from momento.auth.credential_provider import CredentialProvider from momento.auth.momento_endpoint_resolver import _Base64DecodedV1Token +from momento.errors.exceptions import InvalidArgumentException from tests.utils import uuid_str @@ -23,6 +25,15 @@ os.environ[test_env_var_name] = test_token os.environ[test_v1_env_var_name] = test_encoded_v1_token.decode("utf-8") +# For v2 API key tests +test_v2_key_message = {"t": "g", "jti": "some-id"} +test_v2_api_key = jwt.encode(test_v2_key_message, "secret", algorithm="HS512") +test_v2_key_env_var_name = "MOMENTO_API_KEY" +test_v2_endpoint = "testEndpoint" +test_v2_endpoint_env_var_name = "MOMENTO_ENDPOINT" +os.environ[test_v2_key_env_var_name] = test_v2_api_key +os.environ[test_v2_endpoint_env_var_name] = test_v2_endpoint + @pytest.mark.parametrize( "provider, auth_token, control_endpoint, cache_endpoint", @@ -97,3 +108,173 @@ def test_endpoints(provider: CredentialProvider, auth_token: str, control_endpoi def test_env_token_raises_if_not_exists() -> None: with pytest.raises(RuntimeError, match=r"Missing required environment variable"): CredentialProvider.from_environment_variable(env_var_name=uuid_str()) + + +@pytest.mark.parametrize( + "provider, expected_api_key, expected_control_endpoint, expected_cache_endpoint, expected_token_endpoint", + [ + ( + CredentialProvider.from_api_key_v2( + api_key=test_v2_api_key, + endpoint=test_v2_endpoint, + ), + test_v2_api_key, + f"control.{test_v2_endpoint}", + f"cache.{test_v2_endpoint}", + f"token.{test_v2_endpoint}", + ), + ( + CredentialProvider.from_environment_variables_v2( + api_key_env_var=test_v2_key_env_var_name, + endpoint_env_var=test_v2_endpoint_env_var_name, + ), + test_v2_api_key, + f"control.{test_v2_endpoint}", + f"cache.{test_v2_endpoint}", + f"token.{test_v2_endpoint}", + ), + ( + CredentialProvider.from_environment_variables_v2(), + test_v2_api_key, + f"control.{test_v2_endpoint}", + f"cache.{test_v2_endpoint}", + f"token.{test_v2_endpoint}", + ), + ], +) +def test_v2_api_key_endpoints( + provider: CredentialProvider, + expected_api_key: str, + expected_control_endpoint: str, + expected_cache_endpoint: str, + expected_token_endpoint: str, +) -> None: + assert provider.auth_token == expected_api_key + assert provider.control_endpoint == expected_control_endpoint + assert provider.cache_endpoint == expected_cache_endpoint + assert provider.token_endpoint == expected_token_endpoint + + +def test_v2_key_from_string_raises_if_api_key_empty() -> None: + with pytest.raises(InvalidArgumentException, match="API key cannot be empty"): + CredentialProvider.from_api_key_v2(api_key="", endpoint=test_v2_endpoint) + + +def test_v2_key_from_string_raises_if_endpoint_empty() -> None: + with pytest.raises(InvalidArgumentException, match="Endpoint cannot be empty"): + CredentialProvider.from_api_key_v2(api_key=test_v2_api_key, endpoint="") + + +def test_v2_key_from_env_raises_if_env_var_name_empty() -> None: + with pytest.raises(InvalidArgumentException, match="API key environment variable name cannot be empty"): + CredentialProvider.from_environment_variables_v2( + api_key_env_var="", endpoint_env_var=test_v2_endpoint_env_var_name + ) + + +def test_v2_key_from_env_raises_if_env_var_missing() -> None: + with pytest.raises(RuntimeError, match="Missing required environment variable"): + CredentialProvider.from_environment_variables_v2( + api_key_env_var=uuid_str(), endpoint_env_var=test_v2_endpoint_env_var_name + ) + + +def test_v2_key_from_env_raises_if_endpoint_empty() -> None: + with pytest.raises(InvalidArgumentException, match="Endpoint environment variable name cannot be empty"): + CredentialProvider.from_environment_variables_v2(api_key_env_var=test_v2_key_env_var_name, endpoint_env_var="") + + +def test_v2_key_from_env_raises_if_api_key_empty_string() -> None: + empty_api_key_env_var = uuid_str() + os.environ[empty_api_key_env_var] = "" + with pytest.raises(RuntimeError, match="Missing required environment variable"): + CredentialProvider.from_environment_variables_v2( + api_key_env_var=empty_api_key_env_var, endpoint_env_var=test_v2_endpoint_env_var_name + ) + + +def test_v2_key_from_string_raises_if_base64_api_key() -> None: + with pytest.raises( + InvalidArgumentException, + match=re.escape( + "Received an invalid v2 API key. Are you using the correct key and the correct CredentialProvider method?" + ), + ): + CredentialProvider.from_api_key_v2( + api_key=test_encoded_v1_token.decode("utf-8"), endpoint=test_v2_endpoint_env_var_name + ) + + +def test_v2_key_from_env_raises_if_base64_api_key() -> None: + with pytest.raises( + InvalidArgumentException, + match=re.escape( + "Received an invalid v2 API key. Are you using the correct key? Or did you mean to use `from_environment_variable()` with a legacy key instead?" + ), + ): + CredentialProvider.from_environment_variables_v2( + api_key_env_var=test_v1_env_var_name, endpoint_env_var=test_v2_endpoint_env_var_name + ) + + +def test_v2_key_from_string_raises_if_pre_v1_token() -> None: + with pytest.raises( + InvalidArgumentException, + match=re.escape( + "Received an invalid v2 API key. Are you using the correct key and the correct CredentialProvider method?" + ), + ): + CredentialProvider.from_api_key_v2(api_key=test_token, endpoint=test_v2_endpoint_env_var_name) + + +def test_v2_key_from_env_raises_if_pre_v1_token() -> None: + with pytest.raises( + InvalidArgumentException, + match=re.escape( + "Received an invalid v2 API key. Are you using the correct key? Or did you mean to use `from_environment_variable()` with a legacy key instead?" + ), + ): + CredentialProvider.from_environment_variables_v2( + api_key_env_var=test_env_var_name, endpoint_env_var=test_v2_endpoint_env_var_name + ) + + +def test_v2_key_provided_to_from_string() -> None: + with pytest.raises( + InvalidArgumentException, + match=re.escape( + "Unexpectedly received a v2 API key. Are you using the correct key and the correct CredentialProvider method?" + ), + ): + CredentialProvider.from_string(auth_token=test_v2_api_key) + + +def test_v2_key_provided_to_from_disposable_token() -> None: + with pytest.raises( + InvalidArgumentException, + match=re.escape( + "Unexpectedly received a v2 API key. Are you using the correct key and the correct CredentialProvider method?" + ), + ): + CredentialProvider.from_disposable_token(auth_token=test_v2_api_key) + + +def test_from_disposable_token_raises_if_token_empty() -> None: + with pytest.raises(InvalidArgumentException, match="Disposable token cannot be empty."): + CredentialProvider.from_disposable_token(auth_token="") + + +def test_from_disposable_token_accepts_v1_api_key() -> None: + provider = CredentialProvider.from_disposable_token(auth_token=test_encoded_v1_token.decode("utf-8")) + assert provider.auth_token == test_v1_api_key + assert provider.control_endpoint == "control.test.momentohq.com" + assert provider.cache_endpoint == "cache.test.momentohq.com" + assert provider.token_endpoint == "token.test.momentohq.com" + + +def test_from_disposable_token_accepts_pre_v1_token() -> None: + provider = CredentialProvider.from_disposable_token(auth_token=test_token) + assert provider.auth_token == test_token + assert provider.control_endpoint == test_control_endpoint + assert provider.cache_endpoint == test_cache_endpoint + assert provider.token_endpoint == f"token.{test_cache_endpoint}"