diff --git a/leverage/modules/auth.py b/leverage/modules/auth.py index 9f553bd..b4b5f6f 100644 --- a/leverage/modules/auth.py +++ b/leverage/modules/auth.py @@ -1,6 +1,8 @@ +import json import time from pathlib import Path from configparser import NoSectionError, NoOptionError +from typing import Optional import boto3 from botocore.exceptions import ClientError @@ -57,6 +59,143 @@ def update_config_section(updater: ConfigUpdater, layer_profile: str, data: dict updater.update_file() +def _check_credentials_expiration(config_updater: ConfigUpdater, layer_profile: str, force_refresh: bool = False): + """ + Check if credentials need to be renewed based on expiration time. + + Args: + config_updater: ConfigUpdater instance with the config file loaded. + layer_profile: The layer profile name to check. + force_refresh: If True, skip expiration check and return False (needs refresh). + + Returns: + bool: True if credentials are still valid and should be skipped, False if they need refresh. + """ + if force_refresh: + return False + + try: + expiration = int(config_updater.get(f"profile {layer_profile}", "expiration").value) / 1000 + except (NoSectionError, NoOptionError): + # first time using this profile, skip into the credential's retrieval step + logger.debug("No cached credentials found.") + return False + + # we reduce the validity 30 minutes, to avoid expiration over long-standing tasks + renewal = time.time() + (30 * 60) + logger.debug(f"Token expiration time: {expiration}") + logger.debug(f"Token renewal time: {renewal}") + if renewal < expiration: + # still valid, nothing to do with these profile! + logger.info("Using already configured temporary credentials.") + return True + + return False + + +def _get_credentials_updater(cli): + """ + Prepare and return a ConfigUpdater for the credentials file. + + Args: + cli: The container instance with access to paths. + + Returns: + ConfigUpdater: Instance with credentials file loaded. + """ + creds_path = Path(cli.paths.host_aws_credentials_file) + creds_path.touch(exist_ok=True) + credentials_updater = ConfigUpdater() + credentials_updater.read(cli.paths.host_aws_credentials_file) + return credentials_updater + + +def _retrieve_and_update_credentials( + cli, + client, + access_token: Optional[str], + account_id: str, + sso_role: str, + layer_profile: str, + account_name: str, + config_updater: ConfigUpdater, + credentials_updater: ConfigUpdater, + raise_on_permission_error: bool = True, +): + """ + Retrieve credentials from SSO and update both config and credentials files. + + Args: + cli: The container instance (for getting access token if needed). + client: boto3 SSO client. + access_token: SSO access token (or None to get from cli). + account_id: AWS account ID. + sso_role: SSO role name. + layer_profile: Layer profile name. + account_name: Account name for logging. + config_updater: ConfigUpdater instance for config file. + credentials_updater: ConfigUpdater instance for credentials file. + raise_on_permission_error: If True, raise ExitError on permission errors. + If False, log warning and return False. + + Returns: + bool: True if successful, False if permission error occurred (only when raise_on_permission_error=False). + + Raises: + ExitError: If raise_on_permission_error=True and permission error occurs. + ClientError: For other boto3 errors. + """ + # Retrieve credentials + logger.debug(f"Retrieving role credentials for {sso_role}...") + try: + token = access_token if access_token is not None else cli.get_sso_access_token() + credentials = client.get_role_credentials( + roleName=sso_role, + accountId=account_id, + accessToken=token, + )["roleCredentials"] + except ClientError as error: + if error.response["Error"]["Code"] in ("AccessDeniedException", "ForbiddenException"): + error_msg = ( + f"User does not have permission to assume role [bold]{sso_role}[/bold]" + " in this account.\nPlease check with your administrator or try" + " running [bold]leverage aws configure sso[/bold]." + ) + if raise_on_permission_error: + raise ExitError(40, error_msg) from error + else: + logger.warning( + f"No permission to assume role [bold]{sso_role}[/bold] in {account_name} account. Skipping." + ) + return False + else: + raise + + # Update expiration on aws//config + logger.info(f"Writing {layer_profile} profile") + update_config_section( + config_updater, + f"profile {layer_profile}", + data={ + "expiration": credentials["expiration"], + }, + ) + + # Write credentials on aws//credentials + update_config_section( + credentials_updater, + layer_profile, + data={ + "aws_access_key_id": credentials["accessKeyId"], + "aws_secret_access_key": credentials["secretAccessKey"], + "aws_session_token": credentials["sessionToken"], + }, + ) + + logger.info(f"Credentials for {account_name} account written successfully.") + return True + + def get_profiles(cli): """ Get the AWS profiles present on the layer by parsing some tf files. @@ -86,6 +225,8 @@ def refresh_layer_credentials(cli): config_updater.read(cli.paths.host_aws_profiles_file) client = boto3.client("sso", region_name=cli.sso_region_from_main_profile) + credentials_updater = None # Will be created lazily when needed + for raw in raw_profiles: try: account_id, account_name, sso_role, layer_profile = get_layer_profile( @@ -98,60 +239,131 @@ def refresh_layer_credentials(cli): continue # check if credentials need to be renewed + if _check_credentials_expiration(config_updater, layer_profile, force_refresh=False): + continue + + # Create credentials updater lazily (only when we need to update credentials) + if credentials_updater is None: + credentials_updater = _get_credentials_updater(cli) + + # retrieve and update credentials + _retrieve_and_update_credentials( + cli=cli, + client=client, + access_token=None, # Will be retrieved inside the function + account_id=account_id, + sso_role=sso_role, + layer_profile=layer_profile, + account_name=account_name, + config_updater=config_updater, + credentials_updater=credentials_updater, + raise_on_permission_error=True, + ) + + +def refresh_all_accounts_credentials(cli, force_refresh=False): + """ + Refresh credentials for all configured SSO accounts. + + This function iterates through all SSO profiles configured in the AWS config file + and refreshes their credentials using the current SSO access token. + + Args: + cli: The SSOContainer instance with access to paths and SSO token. + force_refresh: If True, refresh all credentials regardless of expiration. + If False, only refresh expired credentials. + """ + logger.info("Refreshing credentials for all accounts...") + + config_updater = ConfigUpdater() + config_updater.read(cli.paths.host_aws_profiles_file) + + # Get all SSO profiles (they follow the pattern: profile -sso-) + sso_profiles = [] + for section in config_updater.sections(): + if section.startswith(f"profile {cli.project}-sso-") and section != f"profile {cli.project}-sso": + sso_profiles.append(section) + + if not sso_profiles: + logger.warning("No SSO account profiles found. Run 'leverage aws configure sso' first.") + return + + logger.info(f"Found {len(sso_profiles)} account(s) to refresh.") + + client = boto3.client("sso", region_name=cli.sso_region_from_main_profile) + + # Get the SSO access token + # Note: get_sso_access_token() can raise: + # - FileNotFoundError: token file doesn't exist + # - PermissionError: no read permission (OSError subclass) + # - IsADirectoryError: path is directory (OSError subclass) + # - OSError: other OS errors (disk full, network issues, etc.) + # - ValueError: invalid path (rare) + # - json.JSONDecodeError: invalid JSON content + # - KeyError: "accessToken" key missing in JSON + # - TypeError: JSON root is not a dict (e.g., it's a list) + # - UnicodeDecodeError: encoding issues (rare with text mode) + try: + access_token = cli.get_sso_access_token() + except ( + FileNotFoundError, + PermissionError, + IsADirectoryError, + OSError, + ValueError, + json.JSONDecodeError, + KeyError, + TypeError, + UnicodeDecodeError, + ) as e: + raise ExitError(1, f"Failed to get SSO access token: {e}") from e + + # Prepare credentials file + credentials_updater = _get_credentials_updater(cli) + + success_count = 0 + error_count = 0 + + for profile_section in sso_profiles: try: - expiration = int(config_updater.get(f"profile {layer_profile}", "expiration").value) / 1000 - except (NoSectionError, NoOptionError): - # first time using this profile, skip into the credential's retrieval step - logger.debug("No cached credentials found.") - else: - # we reduce the validity 30 minutes, to avoid expiration over long-standing tasks - renewal = time.time() + (30 * 60) - logger.debug(f"Token expiration time: {expiration}") - logger.debug(f"Token renewal time: {renewal}") - if renewal < expiration: - # still valid, nothing to do with these profile! - logger.info("Using already configured temporary credentials.") + # Extract account name from profile (e.g., "profile project-sso-security" -> "security") + account_name = profile_section.replace(f"profile {cli.project}-sso-", "") + + # Get account_id and role_name from the SSO profile + account_id = config_updater.get(profile_section, "account_id").value + sso_role = config_updater.get(profile_section, "role_name").value + + # Build the layer profile name + layer_profile = f"{cli.project}-{account_name}-{sso_role.lower()}" + + # Check if credentials need to be renewed (unless force refresh is enabled) + if _check_credentials_expiration(config_updater, layer_profile, force_refresh=force_refresh): + logger.info(f"Credentials for {account_name} account are still valid, skipping.") + success_count += 1 continue - # retrieve credentials - logger.debug(f"Retrieving role credentials for {sso_role}...") - try: - credentials = client.get_role_credentials( - roleName=sso_role, - accountId=account_id, - accessToken=cli.get_sso_access_token(), - )["roleCredentials"] - except ClientError as error: - if error.response["Error"]["Code"] in ("AccessDeniedException", "ForbiddenException"): - raise ExitError( - 40, - f"User does not have permission to assume role [bold]{sso_role}[/bold]" - " in this account.\nPlease check with your administrator or try" - " running [bold]leverage aws configure sso[/bold].", - ) + # Retrieve and update credentials + logger.info(f"Retrieving credentials for {account_name} account...") + if _retrieve_and_update_credentials( + cli=cli, + client=client, + access_token=access_token, + account_id=account_id, + sso_role=sso_role, + layer_profile=layer_profile, + account_name=account_name, + config_updater=config_updater, + credentials_updater=credentials_updater, + raise_on_permission_error=False, + ): + logger.info(f"Credentials for {account_name} account refreshed successfully.") + success_count += 1 + else: + # Permission error occurred (already logged as warning) + error_count += 1 - # update expiration on aws//config - logger.info(f"Writing {layer_profile} profile") - update_config_section( - config_updater, - f"profile {layer_profile}", - data={ - "expiration": credentials["expiration"], - }, - ) - # write credentials on aws//credentials (create the file if it doesn't exist first) - creds_path = Path(cli.paths.host_aws_credentials_file) - creds_path.touch(exist_ok=True) - credentials_updater = ConfigUpdater() - credentials_updater.read(cli.paths.host_aws_credentials_file) - - update_config_section( - credentials_updater, - layer_profile, - data={ - "aws_access_key_id": credentials["accessKeyId"], - "aws_secret_access_key": credentials["secretAccessKey"], - "aws_session_token": credentials["sessionToken"], - }, - ) - logger.info(f"Credentials for {account_name} account written successfully.") + except Exception as e: + logger.exception(f"Failed to refresh credentials for {account_name}: {e}") + error_count += 1 + + logger.info(f"Credential refresh complete: {success_count} successful, {error_count} failed.") diff --git a/leverage/modules/aws.py b/leverage/modules/aws.py index 8604bee..46786ff 100644 --- a/leverage/modules/aws.py +++ b/leverage/modules/aws.py @@ -6,9 +6,10 @@ from leverage import logger from leverage._internals import pass_state from leverage._internals import pass_container -from leverage._utils import get_or_create_section +from leverage._utils import get_or_create_section, ExitError from leverage.container import get_docker_client, SSOContainer from leverage.container import AWSCLIContainer +from leverage.modules.auth import refresh_all_accounts_credentials from leverage.modules.utils import _handle_subcommand CONTEXT_SETTINGS = {"ignore_unknown_options": True} @@ -199,8 +200,14 @@ def sso(context, cli, args): @sso.command() +@click.option( + "--refresh-all", + is_flag=True, + default=False, + help="Automatically refresh credentials for all configured accounts after login.", +) @pass_container -def login(cli): +def login(cli, refresh_all): """Login""" exit_code, region = cli.exec(f"configure get sso_region --profile {cli.project}-sso") if exit_code: @@ -210,6 +217,13 @@ def login(cli): if exit_code := cli.sso_login(): raise Exit(exit_code) + # Refresh all account credentials if requested + if refresh_all: + try: + refresh_all_accounts_credentials(cli, force_refresh=False) + except ExitError as e: + raise Exit(e.exit_code) from e + @sso.command() @pass_container @@ -223,3 +237,19 @@ def logout(cli): f"Don't forget to log out of your [bold]AWS SSO[/bold] start page {cli.paths.common_conf.get('sso_start_url')}" " and your external identity provider portal." ) + + +@sso.command() +@click.option( + "--force", + is_flag=True, + default=False, + help="Force refresh all credentials, even if they haven't expired.", +) +@pass_container +def refresh(cli, force): + """Refresh credentials for all configured accounts""" + try: + refresh_all_accounts_credentials(cli, force_refresh=force) + except ExitError as e: + raise Exit(e.exit_code) from e diff --git a/leverage/modules/utils.py b/leverage/modules/utils.py index 10ce084..8613831 100644 --- a/leverage/modules/utils.py +++ b/leverage/modules/utils.py @@ -1,3 +1,5 @@ +import inspect + import click from click.exceptions import Exit @@ -14,7 +16,11 @@ def _handle_subcommand(context, cli_container, args, caller_name=None): Raises: Exit: Whenever container execution returns a non-zero exit code """ - caller_pos = args.index(caller_name) if caller_name is not None else 0 + try: + caller_pos = args.index(caller_name) if caller_name is not None else 0 + except ValueError: + # caller_name not in args (e.g., when using --help), start from beginning + caller_pos = 0 # Find if one of the wrapped subcommand was invoked wrapped_subcommands = context.command.commands.keys() @@ -28,11 +34,32 @@ def _handle_subcommand(context, cli_container, args, caller_name=None): else: # Invoke wrapped command - subcommand = context.command.commands.get(subcommand) - if not subcommand.params: - context.invoke(subcommand) + subcommand_obj = context.command.commands.get(subcommand) + + # Check if help is requested in the args after the subcommand name + subcommand_start = args.index(subcommand) if subcommand in args else 0 + remaining_args = args[subcommand_start + 1 :] + + # Only show help if this is a leaf command (not a group) or if --help is the only remaining arg + is_help_request = "--help" in remaining_args or "-h" in remaining_args or "help" in remaining_args + is_group = isinstance(subcommand_obj, click.Group) + + # For groups, only show help if --help is the ONLY remaining arg (not followed by more commands) + if is_help_request and (not is_group or len(remaining_args) == 1): + # Show help for the subcommand + click.echo(subcommand_obj.get_help(context)) else: - context.forward(subcommand) + # Invoke the subcommand normally + # For groups and commands with 'args' parameter, pass remaining args + # Otherwise, create a new context to properly parse Click options + sig = inspect.signature(subcommand_obj.callback) + if "args" in sig.parameters: + # Pass the remaining args to the subcommand (for groups) + context.invoke(subcommand_obj, args=remaining_args) + else: + # Create a new context with remaining args so Click can parse options + with subcommand_obj.make_context(subcommand, list(remaining_args), parent=context) as sub_ctx: + subcommand_obj.invoke(sub_ctx) mount_option = click.option("--mount", multiple=True, type=click.Tuple([str, str])) diff --git a/tests/test_modules/test_auth.py b/tests/test_modules/test_auth.py index f10e549..cbd3b2d 100644 --- a/tests/test_modules/test_auth.py +++ b/tests/test_modules/test_auth.py @@ -9,7 +9,12 @@ from leverage._utils import ExitError from leverage.container import SSOContainer -from leverage.modules.auth import refresh_layer_credentials, get_layer_profile, SkipProfile +from leverage.modules.auth import ( + refresh_layer_credentials, + get_layer_profile, + SkipProfile, + refresh_all_accounts_credentials, +) from leverage.modules.aws import get_account_roles, add_sso_profile, configure_sso_profiles from tests.test_containers import container_fixture_factory @@ -297,3 +302,359 @@ def test_refresh_layer_credentials_no_access(mock_open, mock_update_conf, sso_co with pytest.raises(ExitError): refresh_layer_credentials(sso_container) + + +# Tests for refresh_all_accounts_credentials + + +FILE_AWS_CONFIG_MULTI_ACCOUNTS = """ +[profile test-sso] +sso_region = us-test-1 + +[profile test-sso-security] +account_id = 123456 +role_name = DevOps + +[profile test-sso-shared] +account_id = 234567 +role_name = DevOps + +[profile test-sso-network] +account_id = 345678 +role_name = DevOps + +[profile test-security-devops] +expiration=1705859470 + +[profile test-shared-devops] +expiration=170600900000 + +[profile test-network-devops] +""" + +data_dict_multi = { + "~/.aws/test/config": FILE_AWS_CONFIG_MULTI_ACCOUNTS, + "~/.aws/test/credentials": "", +} + + +def read_text_side_effect_multi(self: PosixPath, *args, **kwargs): + """ + Side effect for reading multi-account config files. + """ + return data_dict_multi.get(str(self), data_dict_multi.get(self.name, "")) + + +def open_side_effect_multi(name: PosixPath, *args, **kwargs): + """ + Side effect for opening multi-account config files. + """ + file_content = data_dict_multi.get(str(name), data_dict_multi.get(name, "")) + return mock.mock_open(read_data=file_content)() + + +b3_client_multi = Mock() +b3_client_multi.get_role_credentials = Mock( + return_value={ + "roleCredentials": { + "expiration": "1705859500", + "accessKeyId": "new-access-key", + "secretAccessKey": "new-secret-key", + "sessionToken": "new-session-token", + } + } +) + + +@mock.patch("leverage.modules.auth.update_config_section") +@mock.patch("pathlib.Path.touch", new=Mock()) +@mock.patch("time.time", new=Mock(return_value=NOW_EPOCH)) +@mock.patch("boto3.client", return_value=b3_client_multi) +@mock.patch("configupdater.parser.open", side_effect=open_side_effect_multi) +def test_refresh_all_accounts_credentials_success(mock_open, mock_boto, mock_update_conf, sso_container, caplog): + """ + Test successful credential refresh for multiple accounts with smart expiration checking. + + This test verifies that: + 1. The function correctly identifies SSO profiles from the config file + 2. It skips accounts with valid (non-expired) credentials when force_refresh=False + 3. It refreshes credentials for accounts that need renewal + 4. It provides clear progress logging for each account + 5. It returns a summary of successful vs failed operations + + The test uses a multi-account scenario with mixed credential states: + - security: needs refresh (no expiration set) + - shared: has valid credentials (expiration > renewal time) + - network: needs refresh (no expiration set) + """ + refresh_all_accounts_credentials(sso_container, force_refresh=False) + + # Should log that it's refreshing credentials + assert "Refreshing credentials for all accounts..." in caplog.text + assert "Found 3 account(s) to refresh." in caplog.text + + # Should retrieve credentials for each account + assert "Retrieving credentials for security account..." in caplog.text + assert "Credentials for security account refreshed successfully" in caplog.text + assert "Retrieving credentials for network account..." in caplog.text + assert "Credentials for network account refreshed successfully" in caplog.text + + # Should skip valid credentials + assert "Credentials for shared account are still valid, skipping." in caplog.text + + # Should show summary + assert "Credential refresh complete: 3 successful, 0 failed." in caplog.text + + +@mock.patch("pathlib.Path.touch", new=Mock()) +@mock.patch("boto3.client", return_value=b3_client_multi) +@mock.patch("configupdater.parser.open", side_effect=open_side_effect_multi) +def test_refresh_all_accounts_credentials_force_refresh(mock_open, mock_boto, sso_container, caplog): + """ + Test that force_refresh=True bypasses expiration checks and refreshes all credentials. + + This test verifies that: + 1. When force_refresh=True, the function ignores existing credential expiration times + 2. All accounts are processed regardless of their current credential validity + 3. No "still valid, skipping" messages are logged + 4. The function behaves as if all credentials need renewal + + This is useful for scenarios where: + - Credentials may be corrupted or invalid despite appearing valid + - User wants to ensure all credentials are fresh + - Troubleshooting credential issues + """ + with mock.patch("leverage.modules.auth.update_config_section"): + refresh_all_accounts_credentials(sso_container, force_refresh=True) + + # Should refresh all accounts, including the one with valid credentials + assert "Retrieving credentials for security account..." in caplog.text + assert "Retrieving credentials for shared account..." in caplog.text + assert "Retrieving credentials for network account..." in caplog.text + + # Should NOT skip any accounts when force_refresh=True + assert "are still valid, skipping" not in caplog.text + + +FILE_AWS_CONFIG_NO_SSO_PROFILES = """ +[profile test-sso] +sso_region = us-test-1 +""" + +data_dict_no_profiles = { + "~/.aws/test/config": FILE_AWS_CONFIG_NO_SSO_PROFILES, + "~/.aws/test/credentials": "", +} + + +def open_side_effect_no_profiles(name: PosixPath, *args, **kwargs): + """ + Side effect for opening config with no SSO profiles. + """ + file_content = data_dict_no_profiles.get(str(name), data_dict_no_profiles.get(name, "")) + return mock.mock_open(read_data=file_content)() + + +@mock.patch("pathlib.Path.touch", new=Mock()) +@mock.patch("boto3.client", return_value=b3_client_multi) +@mock.patch("configupdater.parser.open", side_effect=open_side_effect_no_profiles) +def test_refresh_all_accounts_credentials_no_profiles(mock_open, mock_boto, sso_container, caplog): + """ + Test graceful handling when no SSO account profiles are configured. + + This test verifies that: + 1. The function doesn't crash when no SSO profiles exist + 2. It provides a helpful warning message to the user + 3. It suggests the correct next step (running 'leverage aws configure sso') + 4. It returns early without attempting any credential operations + + This scenario occurs when: + - A user hasn't run 'leverage aws configure sso' yet + - The AWS config file only contains the main SSO profile + - The project has no account-specific SSO profiles configured + """ + refresh_all_accounts_credentials(sso_container, force_refresh=False) + + # Should warn that no profiles were found + assert "No SSO account profiles found" in caplog.text + + +@mock.patch("leverage.modules.auth.update_config_section") +@mock.patch("pathlib.Path.touch", new=Mock()) +@mock.patch("time.time", new=Mock(return_value=NOW_EPOCH)) +@mock.patch("configupdater.parser.open", side_effect=open_side_effect_multi) +def test_refresh_all_accounts_credentials_permission_error(mock_open, mock_update_conf, sso_container, caplog): + """ + Test graceful handling of permission errors for specific accounts while continuing with others. + + This test verifies that: + 1. When one account fails with AccessDeniedException, the function continues processing other accounts + 2. Permission errors are logged as warnings (not errors) to avoid alarming users + 3. The function provides clear feedback about which account failed and why + 4. The final summary correctly reports partial success (some accounts succeeded, some failed) + 5. The function doesn't crash or stop processing when encountering permission issues + + This scenario occurs when: + - User has access to some AWS accounts but not others + - Role permissions have changed since last configuration + - Some accounts have been removed from the user's access + - Cross-account role assumptions fail due to policy changes + """ + with mock.patch("boto3.client") as mocked_client: + mocked_client_obj = MagicMock() + # First call succeeds, second call fails with permission error, third call succeeds + mocked_client_obj.get_role_credentials.side_effect = [ + { + "roleCredentials": { + "expiration": "1705859500", + "accessKeyId": "access-key", + "secretAccessKey": "secret-key", + "sessionToken": "session-token", + } + }, + ClientError({"Error": {"Code": "AccessDeniedException", "Message": "No access"}}, "GetRoleCredentials"), + { + "roleCredentials": { + "expiration": "1705859500", + "accessKeyId": "access-key", + "secretAccessKey": "secret-key", + "sessionToken": "session-token", + } + }, + ] + mocked_client.return_value = mocked_client_obj + + refresh_all_accounts_credentials(sso_container, force_refresh=True) + + # Should warn about the permission error but continue + assert "No permission to assume role" in caplog.text + assert "Skipping." in caplog.text + + # Should show that some succeeded and some failed + assert "Credential refresh complete: 2 successful, 1 failed." in caplog.text + + +@mock.patch("pathlib.Path.touch", new=Mock()) +@mock.patch("configupdater.parser.open", side_effect=open_side_effect_multi) +def test_refresh_all_accounts_credentials_token_error(mock_open, sso_container): + """ + Test that the function properly handles SSO access token retrieval failures. + + This test verifies that: + 1. When get_sso_access_token() fails, the function raises an ExitError with a clear message + 2. The error message includes the original exception details for debugging + 3. The function fails fast and doesn't attempt to process any accounts without a valid token + 4. The error is properly wrapped in ExitError for consistent error handling + + This scenario occurs when: + - SSO session has expired and needs to be renewed + - SSO configuration is missing or incorrect + - Network issues prevent token retrieval + - SSO service is temporarily unavailable + """ + # Mock get_sso_access_token to raise an exception (FileNotFoundError simulates missing token file) + sso_container.get_sso_access_token = Mock(side_effect=FileNotFoundError("Token not found")) + + with mock.patch("boto3.client", return_value=b3_client_multi): + with pytest.raises(ExitError, match="Failed to get SSO access token"): + refresh_all_accounts_credentials(sso_container, force_refresh=False) + + +@mock.patch("leverage.modules.auth.update_config_section") +@mock.patch("pathlib.Path.touch", new=Mock()) +@mock.patch("time.time", new=Mock(return_value=NOW_EPOCH)) +@mock.patch("boto3.client", return_value=b3_client_multi) +@mock.patch("configupdater.parser.open", side_effect=open_side_effect_multi) +def test_refresh_all_accounts_credentials_calls_update_correctly(mock_open, mock_boto, mock_update_conf, sso_container): + """ + Test that the function correctly updates both AWS config and credentials files with proper data. + + This test verifies that: + 1. update_config_section is called the correct number of times (2 calls per account) + 2. Config file updates include the correct profile names with 'profile ' prefix + 3. Credentials file updates include the correct profile names without 'profile ' prefix + 4. The function maintains the proper separation between config and credentials data + 5. All expected account profiles are processed and updated + + The test ensures the function follows the AWS CLI file format conventions: + - Config file: [profile name] sections for expiration tracking + - Credentials file: [name] sections for actual credential storage + """ + refresh_all_accounts_credentials(sso_container, force_refresh=True) + + # Verify update_config_section was called with correct arguments + # For 3 accounts with force_refresh, we should have 6 calls (config + credentials for each) + assert mock_update_conf.call_count == 6 + + # Check that profile names are correct + profile_calls = [call.args[1] for call in mock_update_conf.call_args_list] + assert "profile test-security-devops" in profile_calls + assert "test-security-devops" in profile_calls + assert "profile test-shared-devops" in profile_calls + assert "test-shared-devops" in profile_calls + assert "profile test-network-devops" in profile_calls + assert "test-network-devops" in profile_calls + + +# Integration-style test with real file operations +def test_refresh_all_accounts_credentials_integration(tmp_path, sso_container, caplog): + """ + Integration test using real file operations to verify end-to-end functionality. + + This test verifies that: + 1. The function can read real AWS config files with proper formatting + 2. It correctly parses SSO profile sections and extracts account information + 3. It creates and updates real credentials files + 4. The file operations work correctly with actual ConfigUpdater instances + 5. The function handles real file I/O without mocking the core file operations + + This test uses minimal mocking (only boto3 and time) to test the actual file handling logic. + """ + from pathlib import Path + + # Create real temporary files + config_file = tmp_path / "config" + credentials_file = tmp_path / "credentials" + + # Write real AWS config content + config_content = """[profile test-sso] +sso_region = us-test-1 + +[profile test-sso-security] +account_id = 123456 +role_name = DevOps + +[profile test-sso-shared] +account_id = 234567 +role_name = DevOps + +[profile test-security-devops] +expiration=170600900000 +""" + config_file.write_text(config_content) + credentials_file.write_text("") # Empty credentials file + + # Mock the paths to point to our test files + sso_container.paths.host_aws_profiles_file = str(config_file) + sso_container.paths.host_aws_credentials_file = str(credentials_file) + + # Mock only the external dependencies + with mock.patch("time.time", return_value=NOW_EPOCH): + with mock.patch("boto3.client", return_value=b3_client_multi): + with mock.patch("leverage.modules.auth.update_config_section") as mock_update: + refresh_all_accounts_credentials(sso_container, force_refresh=False) + + # Verify the function was called correctly + # Security account should be skipped (valid credentials), shared account should be refreshed + assert mock_update.call_count == 2 # 1 account refreshed (config + credentials) + + # Verify the correct profile names were used (only shared account should be updated) + profile_calls = [call.args[1] for call in mock_update.call_args_list] + assert "profile test-shared-devops" in profile_calls + assert "test-shared-devops" in profile_calls + + # Verify logging shows smart behavior + assert "Found 2 account(s) to refresh." in caplog.text + assert "Credentials for security account are still valid, skipping." in caplog.text + assert "Retrieving credentials for shared account..." in caplog.text + assert "Credentials for shared account refreshed successfully." in caplog.text