Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
318 changes: 265 additions & 53 deletions leverage/modules/auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import time
from pathlib import Path
from configparser import NoSectionError, NoOptionError
from typing import Optional

import boto3
from botocore.exceptions import ClientError
Expand Down Expand Up @@ -57,6 +59,143 @@ def update_config_section(updater: ConfigUpdater, layer_profile: str, data: dict
updater.update_file()


def _check_credentials_expiration(config_updater: ConfigUpdater, layer_profile: str, force_refresh: bool = False):
"""
Check if credentials need to be renewed based on expiration time.

Args:
config_updater: ConfigUpdater instance with the config file loaded.
layer_profile: The layer profile name to check.
force_refresh: If True, skip expiration check and return False (needs refresh).

Returns:
bool: True if credentials are still valid and should be skipped, False if they need refresh.
"""
if force_refresh:
return False

try:
expiration = int(config_updater.get(f"profile {layer_profile}", "expiration").value) / 1000
except (NoSectionError, NoOptionError):
# first time using this profile, skip into the credential's retrieval step
logger.debug("No cached credentials found.")
return False

# we reduce the validity 30 minutes, to avoid expiration over long-standing tasks
renewal = time.time() + (30 * 60)
logger.debug(f"Token expiration time: {expiration}")
logger.debug(f"Token renewal time: {renewal}")
if renewal < expiration:
# still valid, nothing to do with these profile!
logger.info("Using already configured temporary credentials.")
return True

return False


def _get_credentials_updater(cli):
"""
Prepare and return a ConfigUpdater for the credentials file.

Args:
cli: The container instance with access to paths.

Returns:
ConfigUpdater: Instance with credentials file loaded.
"""
creds_path = Path(cli.paths.host_aws_credentials_file)
creds_path.touch(exist_ok=True)
credentials_updater = ConfigUpdater()
credentials_updater.read(cli.paths.host_aws_credentials_file)
return credentials_updater


def _retrieve_and_update_credentials(
cli,
client,
access_token: Optional[str],
account_id: str,
sso_role: str,
layer_profile: str,
account_name: str,
config_updater: ConfigUpdater,
credentials_updater: ConfigUpdater,
raise_on_permission_error: bool = True,
):
"""
Retrieve credentials from SSO and update both config and credentials files.

Args:
cli: The container instance (for getting access token if needed).
client: boto3 SSO client.
access_token: SSO access token (or None to get from cli).
account_id: AWS account ID.
sso_role: SSO role name.
layer_profile: Layer profile name.
account_name: Account name for logging.
config_updater: ConfigUpdater instance for config file.
credentials_updater: ConfigUpdater instance for credentials file.
raise_on_permission_error: If True, raise ExitError on permission errors.
If False, log warning and return False.

Returns:
bool: True if successful, False if permission error occurred (only when raise_on_permission_error=False).

Raises:
ExitError: If raise_on_permission_error=True and permission error occurs.
ClientError: For other boto3 errors.
"""
# Retrieve credentials
logger.debug(f"Retrieving role credentials for {sso_role}...")
try:
token = access_token if access_token is not None else cli.get_sso_access_token()
credentials = client.get_role_credentials(
roleName=sso_role,
accountId=account_id,
accessToken=token,
)["roleCredentials"]
except ClientError as error:
if error.response["Error"]["Code"] in ("AccessDeniedException", "ForbiddenException"):
error_msg = (
f"User does not have permission to assume role [bold]{sso_role}[/bold]"
" in this account.\nPlease check with your administrator or try"
" running [bold]leverage aws configure sso[/bold]."
)
if raise_on_permission_error:
raise ExitError(40, error_msg) from error
else:
logger.warning(
f"No permission to assume role [bold]{sso_role}[/bold] in {account_name} account. Skipping."
)
return False
else:
raise

# Update expiration on aws/<project>/config
logger.info(f"Writing {layer_profile} profile")
update_config_section(
config_updater,
f"profile {layer_profile}",
data={
"expiration": credentials["expiration"],
},
)

# Write credentials on aws/<project>/credentials
update_config_section(
credentials_updater,
layer_profile,
data={
"aws_access_key_id": credentials["accessKeyId"],
"aws_secret_access_key": credentials["secretAccessKey"],
"aws_session_token": credentials["sessionToken"],
},
)

logger.info(f"Credentials for {account_name} account written successfully.")
return True


def get_profiles(cli):
"""
Get the AWS profiles present on the layer by parsing some tf files.
Expand Down Expand Up @@ -86,6 +225,8 @@ def refresh_layer_credentials(cli):
config_updater.read(cli.paths.host_aws_profiles_file)

client = boto3.client("sso", region_name=cli.sso_region_from_main_profile)
credentials_updater = None # Will be created lazily when needed

for raw in raw_profiles:
try:
account_id, account_name, sso_role, layer_profile = get_layer_profile(
Expand All @@ -98,60 +239,131 @@ def refresh_layer_credentials(cli):
continue

# check if credentials need to be renewed
if _check_credentials_expiration(config_updater, layer_profile, force_refresh=False):
continue

# Create credentials updater lazily (only when we need to update credentials)
if credentials_updater is None:
credentials_updater = _get_credentials_updater(cli)

# retrieve and update credentials
_retrieve_and_update_credentials(
cli=cli,
client=client,
access_token=None, # Will be retrieved inside the function
account_id=account_id,
sso_role=sso_role,
layer_profile=layer_profile,
account_name=account_name,
config_updater=config_updater,
credentials_updater=credentials_updater,
raise_on_permission_error=True,
)


def refresh_all_accounts_credentials(cli, force_refresh=False):
"""
Refresh credentials for all configured SSO accounts.

This function iterates through all SSO profiles configured in the AWS config file
and refreshes their credentials using the current SSO access token.

Args:
cli: The SSOContainer instance with access to paths and SSO token.
force_refresh: If True, refresh all credentials regardless of expiration.
If False, only refresh expired credentials.
"""
logger.info("Refreshing credentials for all accounts...")

config_updater = ConfigUpdater()
config_updater.read(cli.paths.host_aws_profiles_file)

# Get all SSO profiles (they follow the pattern: profile <project>-sso-<account>)
sso_profiles = []
for section in config_updater.sections():
if section.startswith(f"profile {cli.project}-sso-") and section != f"profile {cli.project}-sso":
sso_profiles.append(section)

if not sso_profiles:
logger.warning("No SSO account profiles found. Run 'leverage aws configure sso' first.")
return

logger.info(f"Found {len(sso_profiles)} account(s) to refresh.")

client = boto3.client("sso", region_name=cli.sso_region_from_main_profile)

# Get the SSO access token
# Note: get_sso_access_token() can raise:
# - FileNotFoundError: token file doesn't exist
# - PermissionError: no read permission (OSError subclass)
# - IsADirectoryError: path is directory (OSError subclass)
# - OSError: other OS errors (disk full, network issues, etc.)
# - ValueError: invalid path (rare)
# - json.JSONDecodeError: invalid JSON content
# - KeyError: "accessToken" key missing in JSON
# - TypeError: JSON root is not a dict (e.g., it's a list)
# - UnicodeDecodeError: encoding issues (rare with text mode)
try:
access_token = cli.get_sso_access_token()
except (
FileNotFoundError,
PermissionError,
IsADirectoryError,
OSError,
ValueError,
json.JSONDecodeError,
KeyError,
TypeError,
UnicodeDecodeError,
) as e:
raise ExitError(1, f"Failed to get SSO access token: {e}") from e

# Prepare credentials file
credentials_updater = _get_credentials_updater(cli)

success_count = 0
error_count = 0

for profile_section in sso_profiles:
try:
expiration = int(config_updater.get(f"profile {layer_profile}", "expiration").value) / 1000
except (NoSectionError, NoOptionError):
# first time using this profile, skip into the credential's retrieval step
logger.debug("No cached credentials found.")
else:
# we reduce the validity 30 minutes, to avoid expiration over long-standing tasks
renewal = time.time() + (30 * 60)
logger.debug(f"Token expiration time: {expiration}")
logger.debug(f"Token renewal time: {renewal}")
if renewal < expiration:
# still valid, nothing to do with these profile!
logger.info("Using already configured temporary credentials.")
# Extract account name from profile (e.g., "profile project-sso-security" -> "security")
account_name = profile_section.replace(f"profile {cli.project}-sso-", "")

# Get account_id and role_name from the SSO profile
account_id = config_updater.get(profile_section, "account_id").value
sso_role = config_updater.get(profile_section, "role_name").value

# Build the layer profile name
layer_profile = f"{cli.project}-{account_name}-{sso_role.lower()}"

# Check if credentials need to be renewed (unless force refresh is enabled)
if _check_credentials_expiration(config_updater, layer_profile, force_refresh=force_refresh):
logger.info(f"Credentials for {account_name} account are still valid, skipping.")
success_count += 1
continue

# retrieve credentials
logger.debug(f"Retrieving role credentials for {sso_role}...")
try:
credentials = client.get_role_credentials(
roleName=sso_role,
accountId=account_id,
accessToken=cli.get_sso_access_token(),
)["roleCredentials"]
except ClientError as error:
if error.response["Error"]["Code"] in ("AccessDeniedException", "ForbiddenException"):
raise ExitError(
40,
f"User does not have permission to assume role [bold]{sso_role}[/bold]"
" in this account.\nPlease check with your administrator or try"
" running [bold]leverage aws configure sso[/bold].",
)
# Retrieve and update credentials
logger.info(f"Retrieving credentials for {account_name} account...")
if _retrieve_and_update_credentials(
cli=cli,
client=client,
access_token=access_token,
account_id=account_id,
sso_role=sso_role,
layer_profile=layer_profile,
account_name=account_name,
config_updater=config_updater,
credentials_updater=credentials_updater,
raise_on_permission_error=False,
):
logger.info(f"Credentials for {account_name} account refreshed successfully.")
success_count += 1
else:
# Permission error occurred (already logged as warning)
error_count += 1

# update expiration on aws/<project>/config
logger.info(f"Writing {layer_profile} profile")
update_config_section(
config_updater,
f"profile {layer_profile}",
data={
"expiration": credentials["expiration"],
},
)
# write credentials on aws/<project>/credentials (create the file if it doesn't exist first)
creds_path = Path(cli.paths.host_aws_credentials_file)
creds_path.touch(exist_ok=True)
credentials_updater = ConfigUpdater()
credentials_updater.read(cli.paths.host_aws_credentials_file)

update_config_section(
credentials_updater,
layer_profile,
data={
"aws_access_key_id": credentials["accessKeyId"],
"aws_secret_access_key": credentials["secretAccessKey"],
"aws_session_token": credentials["sessionToken"],
},
)
logger.info(f"Credentials for {account_name} account written successfully.")
except Exception as e:
logger.exception(f"Failed to refresh credentials for {account_name}: {e}")
error_count += 1

logger.info(f"Credential refresh complete: {success_count} successful, {error_count} failed.")
Loading