diff --git a/build.gradle b/build.gradle index c427d70..54ac4b8 100644 --- a/build.gradle +++ b/build.gradle @@ -32,9 +32,11 @@ dependencies { } def withoutslf4jSupport = { exclude group: 'org.slf4j', module: 'slf4j-api' } + implementation 'com.amazonaws:aws-java-sdk-core:1.12.731', withoutjmespathSupport implementation 'com.amazonaws:aws-java-sdk-redshift:1.12.731', withoutjmespathSupport implementation 'com.amazonaws:aws-java-sdk-sts:1.12.731', withoutjmespathSupport implementation 'com.amazonaws:aws-java-sdk-redshiftserverless:1.12.731', withoutjmespathSupport + implementation 'com.amazonaws:aws-java-sdk-sso:1.12.731' implementation 'org.apache.httpcomponents:httpclient:4.5.14' implementation 'com.fasterxml.jackson.core:jackson-databind:2.16.0' implementation 'com.fasterxml.jackson.core:jackson-core:2.16.0' @@ -73,13 +75,16 @@ processResources { } jar { + from { + configurations.runtimeClasspath.collect { it.isDirectory() ? it : zipTree(it) } + } + duplicatesStrategy = DuplicatesStrategy.EXCLUDE manifest { attributes("Automatic-Module-Name": 'com.amazon.redshift.jdbc') attributes("Main-Class": "com.amazon.redshift.util.RedshiftJDBCMain") attributes("Specification-Title": "JDBC") attributes("Specification-Version": "4.2") attributes("Specification-Vendor": "Oracle Corporation") - attributes("Class-Path": configurations.runtimeClasspath.collect { it.getName() }.join(' ')) } } diff --git a/src/main/java/com/amazon/redshift/Driver.java b/src/main/java/com/amazon/redshift/Driver.java index 9099786..9b8bdd5 100644 --- a/src/main/java/com/amazon/redshift/Driver.java +++ b/src/main/java/com/amazon/redshift/Driver.java @@ -480,7 +480,11 @@ public Connection getResult(long timeout) throws SQLException { * @throws SQLException if the connection could not be made */ private static Connection makeConnection(String url, RedshiftProperties props, RedshiftLogger logger) throws SQLException { - return new RedshiftConnectionImpl(hostSpecs(props), user(props), database(props), props, url, logger); + + String iamauth = props.getProperty("iamauth"); + System.out.println(iamauth); + + return new RedshiftConnectionImpl(hostSpecs(props), user(props), database(props), props, url, logger); } /** @@ -645,7 +649,10 @@ public static RedshiftProperties parseURL(String url, RedshiftProperties default urlArgs = queryString; } // IAM else { - urlProps.setProperty(RedshiftProperty.IAM_AUTH.getName(), String.valueOf(iamAuth)); + // Only set iamAuth to false if it's not already explicitly set by the user + if (urlProps.getProperty(RedshiftProperty.IAM_AUTH.getName()) == null) { + urlProps.setProperty(RedshiftProperty.IAM_AUTH.getName(), String.valueOf(iamAuth)); + } if (urlServer.startsWith("//")) { urlServer = urlServer.substring(2); diff --git a/src/main/java/com/amazon/redshift/TestOktaDriver.java b/src/main/java/com/amazon/redshift/TestOktaDriver.java new file mode 100644 index 0000000..f1a4da6 --- /dev/null +++ b/src/main/java/com/amazon/redshift/TestOktaDriver.java @@ -0,0 +1,47 @@ +package com.amazon.redshift; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; +import java.util.Properties; + +public class TestOktaDriver { + public static void main(String[] args) throws Exception { + // Register the driver + Class.forName("com.amazon.redshift.Driver"); + + // Connection URL + String url = "jdbc:redshift://hubble.chg6aanrjt24.eu-west-1.redshift.amazonaws.com:5439/dev"; + + // Connection properties + Properties props = new Properties(); + props.setProperty("plugin_name", "com.amazon.redshift.plugin.OktaRedshiftPlugin"); + props.setProperty("iamauth", "true"); + + // Plugin parameters + props.setProperty("ssoRoleName", "OktaDataViewer"); + props.setProperty("region", "eu-north-1"); + props.setProperty("ssoStartUrl", "https://d-c3672deb5f.awsapps.com/start"); + props.setProperty("preferred_role", "hubble-rbac/DataViewer"); + props.setProperty("ssoAccountID", "899945594626"); + props.setProperty("clusterid", "hubble"); + + // Test connection + try (Connection conn = DriverManager.getConnection(url, props)) { + System.out.println("Connected successfully!"); + + // Test query + try (Statement stmt = conn.createStatement()) { + ResultSet rs = stmt.executeQuery("SELECT current_user, current_database()"); + if (rs.next()) { + System.out.println("Current user: " + rs.getString(1)); + System.out.println("Current database: " + rs.getString(2)); + } + } + } catch (Exception e) { + System.err.println("Connection failed: " + e.getMessage()); + e.printStackTrace(); + } + } +} diff --git a/src/main/java/com/amazon/redshift/core/IamHelper.java b/src/main/java/com/amazon/redshift/core/IamHelper.java index b05ae36..688e578 100755 --- a/src/main/java/com/amazon/redshift/core/IamHelper.java +++ b/src/main/java/com/amazon/redshift/core/IamHelper.java @@ -36,6 +36,7 @@ import com.amazon.redshift.logger.LogLevel; import com.amazon.redshift.logger.RedshiftLogger; import com.amazon.redshift.plugin.utils.RequestUtils; +import com.amazon.redshift.plugin.OktaRedshiftPlugin; import com.amazon.redshift.util.GT; import com.amazon.redshift.util.RedshiftException; import com.amazon.redshift.util.RedshiftState; @@ -65,6 +66,7 @@ public final class IamHelper extends IdpAuthHelper { public static final int GET_CLUSTER_CREDENTIALS_SAML_V2_API = 3; public static final int GET_CLUSTER_CREDENTIALS_JWT_V2_API = 4; public static final int GET_SERVERLESS_CREDENTIALS_V1_API = 5; + public static final int GET_CLUSTER_CREDENTIALS_PLUGIN_DIRECT = 6; private static final Pattern HOST_PATTERN = Pattern.compile("(.+)\\.(.+)\\.(.+).redshift(-dev)?\\.amazonaws\\.com(.)*"); @@ -598,6 +600,7 @@ else if (RedshiftProperty.DB_GROUPS_FILTER.getName().equalsIgnoreCase(pluginArgK settings.m_idpToken = idpToken; } + } // Group federation API for plugin setClusterCredentials(provider, settings, log, providerType, idpCredentialsRefresh, getClusterCredentialApiType); @@ -682,6 +685,29 @@ private static void setClusterCredentials(AWSCredentialsProvider credProvider, R log.logInfo(now + ": Using GetClusterCredentialsResultV2 with TimeToRefresh " + iamResult.getNextRefreshTime()); } + break; + + case GET_CLUSTER_CREDENTIALS_PLUGIN_DIRECT: + // Plugin directly provides database credentials + if (RedshiftLogger.isEnable()) + log.log(LogLevel.DEBUG, "Using plugin-provided database credentials directly"); + + // Get the credentials from the provider + AWSCredentials pluginCredentials = credProvider.getCredentials(); + if (pluginCredentials instanceof OktaRedshiftPlugin.DatabaseCredentials) { + OktaRedshiftPlugin.DatabaseCredentials dbCreds = (OktaRedshiftPlugin.DatabaseCredentials) pluginCredentials; + settings.m_username = dbCreds.getAWSAccessKeyId(); // username stored in access key field + settings.m_password = dbCreds.getAWSSecretKey(); // password stored in secret key field + + if (RedshiftLogger.isEnable()) { + Date now = new Date(); + log.logInfo(now + ": Using plugin database credentials with expiration " + dbCreds.getExpiration()); + } + } else { + throw new RedshiftException("Expected DatabaseCredentials from plugin but got: " + + pluginCredentials.getClass().getSimpleName(), RedshiftState.UNEXPECTED_ERROR); + } + break; } } @@ -1212,6 +1238,13 @@ static String getCredentialsV2CacheKey(RedshiftJDBCSettings settings, Credential private static int findTypeOfGetClusterCredentialsAPI(RedshiftJDBCSettings settings, CredentialProviderType providerType, AWSCredentialsProvider provider) { + // set the GET_CLUSTER_CREDENTIALS_PLUGIN_DIRECT for OktaRedshiftPlugin that returns DatabaseCredentials directly + if (providerType == CredentialProviderType.PLUGIN && + settings.m_credentialsProvider != null && + settings.m_credentialsProvider.contains("OktaRedshiftPlugin")) { + return GET_CLUSTER_CREDENTIALS_PLUGIN_DIRECT; + } + if (!settings.m_isServerless) { if (!settings.m_groupFederation) diff --git a/src/main/java/com/amazon/redshift/plugin/CommonCredentialsProvider.java b/src/main/java/com/amazon/redshift/plugin/CommonCredentialsProvider.java index f7da6b1..cb27e6d 100644 --- a/src/main/java/com/amazon/redshift/plugin/CommonCredentialsProvider.java +++ b/src/main/java/com/amazon/redshift/plugin/CommonCredentialsProvider.java @@ -22,6 +22,7 @@ import org.apache.commons.logging.LogFactory; import java.io.IOException; +import java.net.URISyntaxException; import java.net.URL; import java.util.Collections; import java.util.Enumeration; @@ -129,7 +130,7 @@ public NativeTokenHolder getCredentials() throws RedshiftException { return credentials; } - protected abstract NativeTokenHolder getAuthToken() throws IOException; + protected abstract NativeTokenHolder getAuthToken() throws IOException, URISyntaxException; @Override public void refresh() throws RedshiftException { diff --git a/src/main/java/com/amazon/redshift/plugin/OktaRedshiftPlugin.java b/src/main/java/com/amazon/redshift/plugin/OktaRedshiftPlugin.java new file mode 100644 index 0000000..114dfbd --- /dev/null +++ b/src/main/java/com/amazon/redshift/plugin/OktaRedshiftPlugin.java @@ -0,0 +1,606 @@ +package com.amazon.redshift.plugin; + +import java.util.Date; +import com.amazon.redshift.IPlugin; +import com.amazon.redshift.NativeTokenHolder; +import com.amazon.redshift.RedshiftProperty; +import com.amazon.redshift.logger.LogLevel; +import com.amazon.redshift.logger.RedshiftLogger; +import com.amazon.redshift.plugin.httpserver.RequestHandler; +import com.amazon.redshift.plugin.httpserver.Server; +import com.amazon.redshift.plugin.utils.RandomStateUtil; +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.auth.AWSStaticCredentialsProvider; +import com.amazonaws.auth.BasicSessionCredentials; +import com.amazonaws.services.securitytoken.AWSSecurityTokenService; +import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder; +import com.amazonaws.services.securitytoken.model.AssumeRoleRequest; +import com.amazonaws.services.securitytoken.model.AssumeRoleResult; +import com.amazonaws.services.securitytoken.model.Credentials; +import com.amazonaws.services.redshift.AmazonRedshift; +import com.amazonaws.services.redshift.AmazonRedshiftClientBuilder; +import com.amazonaws.services.redshift.model.GetClusterCredentialsRequest; +import com.amazonaws.services.redshift.model.GetClusterCredentialsResult; +import com.amazonaws.services.sso.AWSSSO; +import com.amazonaws.services.sso.model.GetRoleCredentialsRequest; +import com.amazonaws.services.sso.model.GetRoleCredentialsResult; +import com.amazonaws.services.sso.model.RoleCredentials; +import com.amazonaws.services.ssooidc.AWSSSOOIDC; +import com.amazonaws.services.sso.AWSSSOClientBuilder; +import com.amazonaws.services.ssooidc.AWSSSOOIDCClientBuilder; +import com.amazonaws.services.ssooidc.model.*; +import com.amazonaws.util.StringUtils; +import org.apache.http.NameValuePair; +import org.apache.http.client.utils.URIBuilder; + +import static com.amazon.redshift.plugin.utils.ResponseUtils.findParameter; + +import java.awt.*; +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.time.Duration; +import java.util.*; +import java.util.List; +import java.util.function.Function; + +/** + * OktaRedshiftPlugin - Handles Okta-based authentication for Amazon Redshift connections + * This plugin implements OAuth 2.0 authorization code flow with PKCE for secure authentication + * through AWS SSO OIDC, followed by role assumption to obtain AWS credentials for Redshift access. + */ +public class OktaRedshiftPlugin extends IdpCredentialsProvider implements IPlugin { + + // Variables for SSO authentication configuration + private String ssoRoleName; // AWS SSO role name (e.g., "OktaAdminLogin") + private String redshiftRoleArn; // ARN of the Redshift role to assume + private String ssoRegion; // AWS region for SSO operations + private String ssoStartUrl; // SSO start URL for authentication + private String ssoAccountId; // AWS account ID for SSO operations + + // Variables for Redshift cluster connection + private String clusterId; // Redshift cluster identifier + private String dbName; // Database name + private String dbUser; // Database user name + + // OAuth 2.0 and OIDC configuration constants + private static final String redirectUriBase = "http://127.0.0.1"; // Base URL for OAuth redirect + private final int listenPort = 7890; // Local port for OAuth callback + private final String idcClientDisplayName = RedshiftProperty.IDC_CLIENT_DISPLAY_NAME.getDefaultValue(); // Client display name + private static final String idcClientType = "public"; // OAuth client type (public for PKCE) + private static final String idcClientScope = "sso:account:access"; // OAuth scope for SSO account access + private static final String authCodeGrantType = "authorization_code"; // OAuth 2.0 grant type + public final int codeVerifierByteLength = 60; // PKCE code verifier length in bytes + public static final String oauthCsrfStateParameterName = "state"; // OAuth state parameter name + private static final String authCodeParameterName = "code"; // Authorization code parameter name + private final int idcResponseTimeout = 120; // Timeout for user authentication (seconds) + public final long milliSecondMultiplier = 1000L; // Millisecond conversion factor + int defaultIdcTimeoutExpiryInSecs = 1200; // Default token expiry time (seconds) + + // Runtime configuration and client instances + protected String redirectUri; // Complete redirect URI for OAuth flow + AWSSSOOIDC ssoOidcClient; // AWS SSO OIDC client for token operations + + // Cache for storing client registration results to avoid repeated registration + // Cache key format: :: + private static final Map registerClientResultCache = new HashMap(); + + + /** + * Main entry point for obtaining AWS credentials for Redshift connection. + * Orchestrates the OAuth flow and role assumption process. + * + * @return AWSCredentials that can be used to call GetClusterCredentials + */ + @Override + public AWSCredentials getCredentials() { + try { + // Step 1: Get IdC access token through OAuth flow + NativeTokenHolder idcToken = getIdcToken(); + + // Step 2: Always require role assumption for this plugin + if (StringUtils.isNullOrEmpty(redshiftRoleArn)) { + throw new IOException("Redshift role ARN is required but not provided"); + } + + // Step 3: Use IdC token to assume Redshift role and get AWS credentials + return getAwsCredentials(idcToken); + } catch (Exception e) { + if (RedshiftLogger.isEnable()) + m_log.log(LogLevel.ERROR, e, "Error getting AWS credentials"); + throw new RuntimeException("Failed to get AWS credentials", e); + } + } + + @Override + public void refresh() { + // Credentials will be refreshed automatically when getCredentials() is called + } + + // IPlugin interface methods + @Override + public void setLogger(RedshiftLogger log) { + m_log = log; + } + + @Override + public void setGroupFederation(boolean groupFederation) { + // Not used by this plugin + } + + @Override + public String getIdpToken() { + // Not used by regular credential providers + return null; + } + + @Override + public String getCacheKey() { + return getPluginSpecificCacheKey(); + } + + @Override + public int getSubType() { + return 0; // Default subtype + } + + @Override + public String getPluginSpecificCacheKey() { + return String.format("OktaRedshift_%s_%s_%s_%s_%s", + ssoStartUrl != null ? ssoStartUrl : "", + ssoRegion != null ? ssoRegion : "", + ssoAccountId != null ? ssoAccountId : "", + ssoRoleName != null ? ssoRoleName : "", + redshiftRoleArn != null ? redshiftRoleArn : ""); + } + + /** + * Executes the complete OAuth 2.0 authorization code flow with PKCE to obtain an IdC access token. + * + * @return NativeTokenHolder containing the IdC access token + * @throws IOException if the OAuth flow fails + * @throws URISyntaxException if URL construction fails + */ + private NativeTokenHolder getIdcToken() throws IOException, URISyntaxException { + // Validate all required parameters before starting OAuth flow + checkRequiredParameters(); + + // Initialize SSO OIDC client for the specified region + ssoOidcClient = AWSSSOOIDCClientBuilder.standard().withRegion(ssoRegion).build(); + redirectUri = redirectUriBase + ":" + listenPort; + + // Step 1: Register OAuth client or retrieve from cache + RegisterClientResult registerClientResult = getRegisterClientResult(); + + // Step 2: Generate PKCE code verifier and challenge for security + String codeVerifier = generateCodeVerifier(); + String codeChallenge = generateCodeChallenge(codeVerifier); + + // Step 3: Open browser and get authorization code from user + String authCode = fetchAuthorizationCode(codeChallenge, registerClientResult); + + // Step 4: Exchange authorization code for access token + CreateTokenResult createTokenResult = fetchTokenResult(registerClientResult, authCode, codeVerifier); + + // Step 5: Process token result and return wrapped token + return processCreateTokenResult(createTokenResult); + } + + + private void checkRequiredParameters() throws InternalPluginException { + if (StringUtils.isNullOrEmpty(ssoStartUrl)) { + if (RedshiftLogger.isEnable()) + m_log.logDebug("IdC authentication failed: issuer_url needs to be provided in connection params"); + throw new InternalPluginException("IdC authentication failed: The issuer URL must be included in the connection parameters."); + } + if (StringUtils.isNullOrEmpty(ssoRegion)) { + if (RedshiftLogger.isEnable()) + m_log.logDebug("IdC authentication failed: idc_region needs to be provided in connection params"); + throw new InternalPluginException("IdC authentication failed: The IdC region must be included in the connection parameters."); + } + if (StringUtils.isNullOrEmpty(redshiftRoleArn)) { + if (RedshiftLogger.isEnable()) + m_log.logDebug("IdC authentication failed: redshift_role_arn needs to be provided in connection params"); + throw new InternalPluginException("redshift_role_arn is required"); + } + if (StringUtils.isNullOrEmpty(ssoAccountId)) { + if (RedshiftLogger.isEnable()) + m_log.logDebug("IdC authentication failed: ssoAccountID needs to be provided in connection params"); + throw new InternalPluginException("IdC authentication failed: The SSO account ID must be included in the connection parameters."); + } + if (StringUtils.isNullOrEmpty(ssoRoleName)) { + if (RedshiftLogger.isEnable()) + m_log.logDebug("IdC authentication failed: ssoRoleName needs to be provided in connection params"); + throw new InternalPluginException("IdC authentication failed: The SSO role name must be included in the connection parameters."); + } + } + + private RegisterClientResult getRegisterClientResult() throws IOException { + String registerClientCacheKey = redirectUri + ":" + ssoRegion + ":" + listenPort; + RegisterClientResult cachedRegisterClientResult = registerClientResultCache.get(registerClientCacheKey); + + if (isCachedRegisteredClientValid(cachedRegisterClientResult)) { + if (RedshiftLogger.isEnable()) { + m_log.logInfo("Using cached client result"); + m_log.logInfo("Cached client result expires in " + cachedRegisterClientResult.getClientSecretExpiresAt()); + } + return cachedRegisterClientResult; + } + + RegisterClientRequest registerClientRequest = new RegisterClientRequest(); + registerClientRequest.withClientName(idcClientDisplayName); + registerClientRequest.withClientType(idcClientType); + registerClientRequest.withScopes(idcClientScope); + registerClientRequest.withIssuerUrl(ssoStartUrl); + registerClientRequest.withRedirectUris(redirectUri); + registerClientRequest.withGrantTypes(authCodeGrantType); + + RegisterClientResult registerClientResult = null; + + try { + registerClientResult = ssoOidcClient.registerClient(registerClientRequest); + if (RedshiftLogger.isEnable()) { + m_log.logInfo("Register client response code {0}", registerClientResult.getSdkHttpMetadata().getHttpStatusCode()); + } + } catch (InternalServerException e) { + if (RedshiftLogger.isEnable()) { + m_log.log(LogLevel.ERROR, e, "Idc authentication failed: Error during the request"); + } + throw new IOException("Idc authentication failed"); + } catch (Exception e) { + if (RedshiftLogger.isEnable()) { + m_log.log(LogLevel.ERROR, e, "Error while registering client"); + } + throw new IOException("IdC registration failed"); + } + + registerClientResultCache.put(registerClientCacheKey, registerClientResult); + if (RedshiftLogger.isEnable()) { + m_log.logInfo("Cached the register client result, expires at {0}", registerClientResult.getClientSecretExpiresAt()); + } + + return registerClientResult; + } + + + private CreateTokenResult fetchTokenResult(RegisterClientResult registerClientResult, String authCode, String codeVerifier) throws IOException { + long pollingEndtime = System.currentTimeMillis() + idcResponseTimeout * milliSecondMultiplier; + + int pollingIntervalInSec = 1; + + while (System.currentTimeMillis() < pollingEndtime) { + try { + CreateTokenRequest createTokenRequest = new CreateTokenRequest(); + createTokenRequest.withClientId(registerClientResult.getClientId()) + .withClientSecret(registerClientResult.getClientSecret()) + .withCode(authCode) + .withGrantType(authCodeGrantType) + .withCodeVerifier(codeVerifier) + .withRedirectUri(redirectUri); + + CreateTokenResult createTokenResult = ssoOidcClient.createToken(createTokenRequest); + + if (RedshiftLogger.isEnable() && registerClientResult.getSdkHttpMetadata() != null) + m_log.logDebug("Token response received"); + + if (createTokenResult != null && createTokenResult.getAccessToken() != null) { + return createTokenResult; + } else { + if (RedshiftLogger.isEnable()) m_log.logError("Failed to get IdC Token"); + throw new IOException("IdC authentication failed: Failed to get IdC Token"); + } + } catch (AuthorizationPendingException ex) { + if (RedshiftLogger.isEnable()) m_log.logDebug("Browser authorization pending from user"); + } catch (SlowDownException ex) { + if (RedshiftLogger.isEnable()) + m_log.log(LogLevel.ERROR, ex, "Error: Too frequent createToken requests made by client;"); + throw new IOException("IdC authentication failed : Requests to the IdC service are too frequent.", ex); + } catch (AccessDeniedException ex) { + if (RedshiftLogger.isEnable()) + m_log.log(LogLevel.ERROR, ex, "Error: Access denied, please ensure app assignment is done for the user;"); + throw new IOException("IdC authentication failed : You don't have sufficient permission to perform the action. Please ensure app assignment is done for the user.", ex); + } catch (InternalServerException ex) { + if (RedshiftLogger.isEnable()) m_log.log(LogLevel.ERROR, ex, "Error: Server error in creating token;"); + throw new IOException("IdC authentication failed : An error occurred during the request.", ex); + } catch (Exception ex) { + if (RedshiftLogger.isEnable()) + m_log.log(LogLevel.ERROR, ex, "Error: Unexpected error in create token;"); + throw new IOException("IdC createToken failed : There was an error during the request.", ex); + } + } + + try { + Thread.sleep(pollingIntervalInSec * milliSecondMultiplier); + } catch (InterruptedException e) { + if (RedshiftLogger.isEnable()) m_log.log(LogLevel.ERROR, e, "Thread interrupted during sleep"); + } + + if (RedshiftLogger.isEnable()) + m_log.logError("Error: Request timed out while waiting for user authentication in the browser"); + throw new IOException("IdC authentication failed : The request timed out. Authentication wasn't completed."); + } + + private NativeTokenHolder processCreateTokenResult(CreateTokenResult createTokenResult) { + String idcToken = createTokenResult.getAccessToken(); + + if (StringUtils.isNullOrEmpty((idcToken))) { + throw new InternalPluginException("Returned token result is null or empty"); + } + + int expiresInSec = defaultIdcTimeoutExpiryInSecs; + + if (createTokenResult.getExpiresIn() != null && createTokenResult.getExpiresIn() > 0) { + expiresInSec = createTokenResult.getExpiresIn(); + } + Date expiresInDate = new Date(System.currentTimeMillis() + expiresInSec * milliSecondMultiplier); + if (RedshiftLogger.isEnable()) m_log.logDebug("Token expires at {0}", expiresInDate); + + return NativeTokenHolder.newInstance(idcToken, expiresInDate); + } + + + private AWSCredentials getAwsCredentials(NativeTokenHolder idcToken) throws IOException { + // Get SSO role credentials + AWSSSO sso = AWSSSOClientBuilder.standard().withRegion(ssoRegion).build(); + GetRoleCredentialsRequest getRoleRequest = new GetRoleCredentialsRequest() + .withAccessToken(idcToken.getAccessToken()) + .withAccountId(ssoAccountId) + .withRoleName(ssoRoleName); + + GetRoleCredentialsResult roleCredentialsResult = sso.getRoleCredentials(getRoleRequest); + RoleCredentials roleCredentials = roleCredentialsResult.getRoleCredentials(); + + // Create session credentials from SSO role + BasicSessionCredentials sessionCredentials = new BasicSessionCredentials( + roleCredentials.getAccessKeyId(), + roleCredentials.getSecretAccessKey(), + roleCredentials.getSessionToken()); + + // Assume the preferred role using SSO credentials + String roleArn = redshiftRoleArn; + if (!redshiftRoleArn.startsWith("arn:aws:iam::")) { + roleArn = "arn:aws:iam::" + ssoAccountId + ":role/" + redshiftRoleArn; + } + + AWSSecurityTokenService awsSTS = AWSSecurityTokenServiceClientBuilder.standard() + .withCredentials(new AWSStaticCredentialsProvider(sessionCredentials)) + .withRegion(ssoRegion) + .build(); + + AssumeRoleRequest assumeRoleRequest = new AssumeRoleRequest() + .withRoleArn(roleArn) + .withRoleSessionName("redshift-okta-" + java.util.UUID.randomUUID()) + .withDurationSeconds(3600); + + AssumeRoleResult assumeRoleResult = awsSTS.assumeRole(assumeRoleRequest); + Credentials stsCredential = assumeRoleResult.getCredentials(); + + // Create credentials for the assumed role + BasicSessionCredentials redshiftRoleCredentials = new BasicSessionCredentials( + stsCredential.getAccessKeyId(), + stsCredential.getSecretAccessKey(), + stsCredential.getSessionToken()); + + // Call GetClusterCredentials using the assumed role credentials + AmazonRedshift redshiftClient = AmazonRedshiftClientBuilder.standard() + .withCredentials(new AWSStaticCredentialsProvider(redshiftRoleCredentials)) + .withRegion("eu-west-1") + .build(); + + // Extract database user from preferred_role + String dbUserName = this.redshiftRoleArn; + if (dbUserName != null && dbUserName.contains("/")) { + dbUserName = dbUserName.substring(dbUserName.lastIndexOf("/") + 1).toLowerCase(); + } + if (dbUserName == null || dbUserName.isEmpty()) { + dbUserName = "redshift_user"; // fallback + } + + GetClusterCredentialsRequest clusterCredentialsRequest = new GetClusterCredentialsRequest() + .withClusterIdentifier(clusterId) + .withDbName(dbName) + .withDbUser(dbUserName) + .withDurationSeconds(3600); + + if (RedshiftLogger.isEnable()) { + m_log.log(LogLevel.DEBUG, "Calling GetClusterCredentials for cluster: {0}, db: {1}, user: {2}", + clusterId, dbName, dbUserName); + } + + GetClusterCredentialsResult clusterCredentialsResult = redshiftClient.getClusterCredentials(clusterCredentialsRequest); + + // Return database credentials + return new DatabaseCredentials( + clusterCredentialsResult.getDbUser(), + clusterCredentialsResult.getDbPassword(), + clusterCredentialsResult.getExpiration() + ); + } + + protected String generateCodeVerifier() { + byte[] randomBytes = new byte[codeVerifierByteLength]; + SecureRandom secureRandom = new SecureRandom(); + secureRandom.nextBytes(randomBytes); + + return Base64.getUrlEncoder().withoutPadding().encodeToString(randomBytes); + } + + private String generateCodeChallenge(String codeVerifier) { + byte[] sha256Hash = sha256(codeVerifier.getBytes(StandardCharsets.US_ASCII)); + + return Base64.getUrlEncoder().withoutPadding().encodeToString(sha256Hash); + } + + private byte[] sha256(byte[] input) { + try { + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + return digest.digest(input); + } catch (NoSuchAlgorithmException e) { + if (RedshiftLogger.isEnable()) m_log.log(LogLevel.ERROR, e, "Thread interrupted during sleep"); + return null; + } + } + + private String fetchAuthorizationCode(String codeChallenge, RegisterClientResult registerClientResult) throws URISyntaxException, IOException { + final String state = RandomStateUtil.generateRandomState(); + RequestHandler requestHandler = new RequestHandler(new Function, Object>() { + public Object apply(List nameValuePairs) { + String incomingState = findParameter(oauthCsrfStateParameterName, nameValuePairs); + + if (!state.equals(incomingState)) { + String stateErrorMessage = "Incoming state" + incomingState + " does not match the outgoing state" + state; + if (RedshiftLogger.isEnable()) m_log.log(LogLevel.DEBUG, stateErrorMessage); + throw new InternalPluginException(stateErrorMessage); + } + String code = findParameter(authCodeParameterName, nameValuePairs); + if (StringUtils.isNullOrEmpty(code)) { + String stateErrorMessage = "No Valid code found"; + if (RedshiftLogger.isEnable()) m_log.log(LogLevel.DEBUG, stateErrorMessage); + throw new InternalPluginException(stateErrorMessage); + } + return code; + } + }); + + Server server = new Server(listenPort, requestHandler, Duration.ofSeconds(idcResponseTimeout), m_log); + try { + server.listen(); + if (RedshiftLogger.isEnable()) m_log.log(LogLevel.DEBUG, "Listening for connection on port " + listenPort); + + openBrowser(state, codeChallenge, registerClientResult); + server.waitForResult(); + } catch (URISyntaxException | IOException ex) { + if (RedshiftLogger.isEnable()) m_log.logError(ex); + + server.stop(); + throw ex; + } + + Object result = requestHandler.getResult(); + + if (result instanceof InternalPluginException) { + if (RedshiftLogger.isEnable()) { + m_log.logDebug("Error while fetching authorization token"); + } + throw (InternalPluginException) result; + } + if (result instanceof String) { + if (RedshiftLogger.isEnable()) { + m_log.logInfo("Fetched authorization token"); + } + return (String) result; + } + throw new InternalPluginException("Error fetching authentication code from browser. Failed to login during timeout."); + } + + private void openBrowser(String state, String codeChallenge, RegisterClientResult registerClientResult) throws URISyntaxException, IOException { + String idcHost = "oidc" + "." + ssoRegion + "." + "amazonaws.com"; + + URIBuilder builder = new URIBuilder().setScheme("https") + .setHost(idcHost) + .setPath("/authorize") + .addParameter("response_type", authCodeParameterName) + .addParameter("client_id", registerClientResult.getClientId()) + .addParameter("redirect_uri", redirectUri) + .addParameter("scopes", idcClientScope) + .addParameter(oauthCsrfStateParameterName, state) + .addParameter("code_challenge", codeChallenge) + .addParameter("code_challenge_method", "S256"); + + // Add account ID to scope the token to the specific account + if (!StringUtils.isNullOrEmpty(ssoAccountId)) { + builder.addParameter("account_id", ssoAccountId); + } + + URI authorizeRequestUrl; + authorizeRequestUrl = builder.build(); + + validateURL(authorizeRequestUrl.toString()); + + if (Desktop.isDesktopSupported() && Desktop.getDesktop().isSupported(Desktop.Action.BROWSE)) { + Desktop.getDesktop().browse(authorizeRequestUrl); + } else { + m_log.logError("Unable to open the browser. Desktop environment is not supported"); + } + + if (RedshiftLogger.isEnable()) + m_log.logDebug("Authorization code request URI: \n%s", authorizeRequestUrl.toString()); + + } + + private boolean isCachedRegisteredClientValid(RegisterClientResult cachedRegisterClientResult) { + if (cachedRegisterClientResult == null || cachedRegisterClientResult.getClientSecretExpiresAt() == null) { + return false; + } + + return System.currentTimeMillis() < cachedRegisterClientResult.getClientSecretExpiresAt() * 1000; + } + + @Override + public void addParameter(String key, String value) { + if ("ssorolename".equalsIgnoreCase(key)) { + this.ssoRoleName = value; + } else if ("preferred_role".equalsIgnoreCase(key)) { + this.redshiftRoleArn = value; + } else if ("region".equalsIgnoreCase(key)) { + this.ssoRegion = value; + } else if ("ssostarturl".equalsIgnoreCase(key)) { + this.ssoStartUrl = value; + } else if ("ssoaccountid".equalsIgnoreCase(key)) { + this.ssoAccountId = value; + } else if ("clusterid".equalsIgnoreCase(key)) { + this.clusterId = value; + } else if ("dbname".equalsIgnoreCase(key)) { + this.dbName = value; + } else if ("dbuser".equalsIgnoreCase(key)) { + this.dbUser = value; + } + } + public static void main(String[] args) throws Exception { + // String profileName = "aws-sso-LunarWay-Development-Data-OktaDataLogin"; + // String profileName = "aws-sso-LunarWay-Development-Data-OktaAdminLogin"; + // why is this not set in .aws/config + // "aws-sso-LunarWay-Production-Data-OktaDataViewer"; + + OktaRedshiftPlugin plugin = new OktaRedshiftPlugin(); + plugin.addParameter("ssoRoleName", "OktaDataViewer"); + plugin.addParameter("region", "eu-north-1"); + plugin.addParameter("ssoStartUrl", "https://d-c3672deb5f.awsapps.com/start"); + plugin.addParameter("preferred_role", "hubble-rbac/DataViewer"); + // arn:aws:iam::899945594626:role/hubble-rbac/DataViewer + plugin.addParameter("ssoAccountID", "899945594626"); + + AWSCredentials creds = plugin.getCredentials(); + + System.out.println(creds.getAWSAccessKeyId()); + } + + /** + * Simple holder for database credentials (username/password) + */ + public static class DatabaseCredentials implements AWSCredentials { + private final String username; + private final String password; + private final Date expiration; + + public DatabaseCredentials(String username, String password, Date expiration) { + this.username = username; + this.password = password; + this.expiration = expiration; + } + + public String getUsername() { return username; } + public String getPassword() { return password; } + public Date getExpiration() { return expiration; } + + // AWSCredentials interface - store username/password in these fields + @Override public String getAWSAccessKeyId() { return username; } + @Override public String getAWSSecretKey() { return password; } + } + +} + diff --git a/src/main/java/com/amazon/redshift/util/RedshiftConstants.java b/src/main/java/com/amazon/redshift/util/RedshiftConstants.java index 1cffe3d..797d17c 100644 --- a/src/main/java/com/amazon/redshift/util/RedshiftConstants.java +++ b/src/main/java/com/amazon/redshift/util/RedshiftConstants.java @@ -31,5 +31,6 @@ private RedshiftConstants() { public static final String NATIVE_IDP_OKTA_NON_BROWSER_PLUGIN = "com.amazon.redshift.plugin.BasicNativeSamlCredentialsProvider"; public static final String IDP_TOKEN_PLUGIN = "com.amazon.redshift.plugin.IdpTokenAuthPlugin"; public static final String IDC_PKCE_BROWSER_PLUGIN = "com.amazon.redshift.plugin.BrowserIdcAuthPlugin"; + public static final String IDC_PKCE_BROWSER_OKTA_PLUGIN = "com.amazon.redshift.plugin.OktaRedshiftPlugin"; } diff --git a/src/main/java/com/amazon/redshift/util/StreamWrapper.java b/src/main/java/com/amazon/redshift/util/StreamWrapper.java index 93a2932..dafe2bf 100644 --- a/src/main/java/com/amazon/redshift/util/StreamWrapper.java +++ b/src/main/java/com/amazon/redshift/util/StreamWrapper.java @@ -112,13 +112,6 @@ public void close() throws IOException { closed = true; } } - - protected void finalize() throws IOException { - // forcibly close it because super.finalize() may keep the FD open, which may prevent - // file deletion - close(); - super.finalize(); - } }; } else { this.rawData = rawData;