From 1d9da3c56b8b3876cb249eaf3a533f3a8541cb0d Mon Sep 17 00:00:00 2001 From: "Xolo, Tlatoani" Date: Tue, 4 Nov 2025 11:17:43 -0500 Subject: [PATCH 1/3] updating based on PR comments version update adds support for mongo-oidc --- build.gradle | 11 +- .../java/com/dbschema/MongoJdbcDriver.java | 17 +- .../mongo/DriverPropertyInfoHelper.java | 2 +- .../dbschema/mongo/MongoClientWrapper.java | 168 ++++++---- .../com/dbschema/mongo/MongoConnection.java | 8 +- .../java/com/dbschema/mongo/MongoService.java | 5 +- .../com/dbschema/mongo/oidc/OidcAuthFlow.java | 305 ++++++++++++++++++ .../com/dbschema/mongo/oidc/OidcCallback.java | 42 +++ .../com/dbschema/mongo/oidc/OidcResponse.java | 58 ++++ .../mongo/oidc/OidcTimeoutException.java | 7 + .../java/com/dbschema/mongo/oidc/Server.java | 173 ++++++++++ 11 files changed, 720 insertions(+), 76 deletions(-) create mode 100644 src/main/java/com/dbschema/mongo/oidc/OidcAuthFlow.java create mode 100755 src/main/java/com/dbschema/mongo/oidc/OidcCallback.java create mode 100644 src/main/java/com/dbschema/mongo/oidc/OidcResponse.java create mode 100644 src/main/java/com/dbschema/mongo/oidc/OidcTimeoutException.java create mode 100644 src/main/java/com/dbschema/mongo/oidc/Server.java diff --git a/build.gradle b/build.gradle index 4844405..5bb8fd2 100644 --- a/build.gradle +++ b/build.gradle @@ -4,7 +4,7 @@ plugins { id "com.github.johnrengelman.shadow" version "7.0.0" } -version '1.20' +version '1.30' repositories { mavenCentral() @@ -12,11 +12,12 @@ repositories { dependencies { implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk8" - implementation "org.mongodb:mongodb-driver-sync:4.11.1" + implementation "org.mongodb:mongodb-driver-sync:5.6.1" implementation group: 'org.jetbrains', name: 'annotations', version: '15.0' implementation group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' implementation group: 'org.graalvm.js', name: 'js', version: '22.3.1' implementation files('libs/JMongosh-0.9.jar') + implementation group: 'com.nimbusds', name: 'oauth2-oidc-sdk', version: '11.+' testImplementation group: 'junit', name: 'junit', version: '4.13.1' testImplementation group: 'commons-io', name: 'commons-io', version: '2.7' } @@ -29,4 +30,10 @@ test { shadowJar { archiveFileName = "mongo-jdbc-standalone-${version}.jar" mergeServiceFiles() + + relocate('org', 'shadow.org') { + exclude 'org.ow2.asm:.*' + exclude 'net.minidev:.*' + exclude 'org.javassist:.*' + } } diff --git a/src/main/java/com/dbschema/MongoJdbcDriver.java b/src/main/java/com/dbschema/MongoJdbcDriver.java index 819aeac..3ad6223 100644 --- a/src/main/java/com/dbschema/MongoJdbcDriver.java +++ b/src/main/java/com/dbschema/MongoJdbcDriver.java @@ -1,15 +1,20 @@ package com.dbschema; import com.dbschema.mongo.DriverPropertyInfoHelper; +import com.dbschema.mongo.MongoClientWrapper; import com.dbschema.mongo.MongoConnection; +import com.dbschema.mongo.MongoService; import com.dbschema.mongo.mongosh.LazyShellHolder; import com.dbschema.mongo.mongosh.PrecalculatingShellHolder; import com.dbschema.mongo.mongosh.ShellHolder; +import com.dbschema.mongo.oidc.OidcCallback; +import com.mongodb.MongoCredential; import org.graalvm.polyglot.Engine; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; import java.sql.*; +import java.util.Optional; import java.util.Properties; import java.util.concurrent.ExecutorService; import java.util.logging.Logger; @@ -32,6 +37,7 @@ public class MongoJdbcDriver implements Driver { private @Nullable ExecutorService executorService; private @Nullable Engine sharedEngine; private @NotNull ShellHolder shellHolder; + private MongoConnection mongoConnection; static { try { @@ -91,7 +97,16 @@ public Connection connect(String url, Properties info) throws SQLException { synchronized (this) { ShellHolder shellHolder = this.shellHolder; this.shellHolder = createShellHolder(); - return new MongoConnection(url, info, username, password, fetchDocumentsForMeta, shellHolder); + MongoCredential.OidcCallbackContext existingResult = Optional.ofNullable(this.mongoConnection) + .map(MongoConnection::getService) + .map(MongoService::getClient) + .map(MongoClientWrapper::getOidcCallback) + .map(OidcCallback::getCallbackContext) + .orElse(null); + + this.mongoConnection = new MongoConnection(url, info, username, password, fetchDocumentsForMeta, shellHolder, existingResult); + + return this.mongoConnection; } } diff --git a/src/main/java/com/dbschema/mongo/DriverPropertyInfoHelper.java b/src/main/java/com/dbschema/mongo/DriverPropertyInfoHelper.java index fa17d22..907af6f 100644 --- a/src/main/java/com/dbschema/mongo/DriverPropertyInfoHelper.java +++ b/src/main/java/com/dbschema/mongo/DriverPropertyInfoHelper.java @@ -5,7 +5,7 @@ public class DriverPropertyInfoHelper { public static final String AUTH_MECHANISM = "authMechanism"; - public static final String[] AUTH_MECHANISM_CHOICES = new String[]{"GSSAPI", "MONGODB-AWS", "MONGODB-X509", "PLAIN", "SCRAM-SHA-1", "SCRAM-SHA-256"}; + public static final String[] AUTH_MECHANISM_CHOICES = new String[]{"GSSAPI", "MONGODB-AWS", "MONGODB-X509", "PLAIN", "SCRAM-SHA-1", "SCRAM-SHA-256", "MONGODB-OIDC"}; public static final String AUTH_SOURCE = "authSource"; public static final String AWS_SESSION_TOKEN = "AWS_SESSION_TOKEN"; public static final String SERVICE_NAME = "SERVICE_NAME"; diff --git a/src/main/java/com/dbschema/mongo/MongoClientWrapper.java b/src/main/java/com/dbschema/mongo/MongoClientWrapper.java index 3206c59..cc4f092 100644 --- a/src/main/java/com/dbschema/mongo/MongoClientWrapper.java +++ b/src/main/java/com/dbschema/mongo/MongoClientWrapper.java @@ -1,7 +1,11 @@ package com.dbschema.mongo; +import com.dbschema.mongo.oidc.OidcCallback; import com.mongodb.ConnectionString; import com.mongodb.MongoClientSettings; +import com.mongodb.MongoCredential; +import com.mongodb.ServerApi; +import com.mongodb.ServerApiVersion; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoClients; import com.mongodb.client.MongoDatabase; @@ -26,79 +30,103 @@ public class MongoClientWrapper implements AutoCloseable { private boolean isClosed = false; private final MongoClient mongoClient; public final String databaseNameFromUrl; + public final OidcCallback oidcCallback; - public MongoClientWrapper(@NotNull String uri, @NotNull Properties prop, @Nullable String username, @Nullable String password) throws SQLException { + public MongoClientWrapper(@NotNull String uri, @NotNull Properties prop, @Nullable String username, @Nullable String password, @Nullable MongoCredential.OidcCallbackContext callbackContext) throws SQLException { + this.oidcCallback = new OidcCallback(callbackContext); try { - boolean automaticEncoding = ENCODE_CREDENTIALS_DEFAULT; - if (prop.getProperty(ENCODE_CREDENTIALS) != null) { - automaticEncoding = Boolean.parseBoolean(prop.getProperty(ENCODE_CREDENTIALS)); - } + boolean automaticEncoding = ENCODE_CREDENTIALS_DEFAULT; + if (prop.getProperty(ENCODE_CREDENTIALS) != null) { + automaticEncoding = Boolean.parseBoolean(prop.getProperty(ENCODE_CREDENTIALS)); + } - uri = insertCredentials(uri, username, password, automaticEncoding); - uri = insertAuthMechanism(uri, prop.getProperty(AUTH_MECHANISM)); - uri = insertAuthSource(uri, prop.getProperty(AUTH_SOURCE)); - uri = insertAuthProperty(uri, AWS_SESSION_TOKEN, prop.getProperty(AWS_SESSION_TOKEN)); - uri = insertAuthProperty(uri, SERVICE_NAME, prop.getProperty(SERVICE_NAME)); - uri = insertAuthProperty(uri, SERVICE_REALM, prop.getProperty(SERVICE_REALM)); - String canonicalizeHostName = prop.getProperty(CANONICALIZE_HOST_NAME); - if (Boolean.TRUE.toString().equalsIgnoreCase(canonicalizeHostName) || Boolean.FALSE.toString().equalsIgnoreCase(canonicalizeHostName)) { - uri = insertAuthProperty(uri, CANONICALIZE_HOST_NAME, canonicalizeHostName); - } - else if (canonicalizeHostName != null) { - System.err.println("Unknown " + CANONICALIZE_HOST_NAME + " value. Must be true or false."); - } - uri = insertRetryWrites(uri, prop.getProperty(RETRY_WRITES)); + uri = insertCredentials(uri, username, password, automaticEncoding); + uri = insertAuthMechanism(uri, prop.getProperty(AUTH_MECHANISM)); + uri = insertAuthSource(uri, prop.getProperty(AUTH_SOURCE)); + uri = insertAuthProperty(uri, AWS_SESSION_TOKEN, prop.getProperty(AWS_SESSION_TOKEN)); + uri = insertAuthProperty(uri, SERVICE_NAME, prop.getProperty(SERVICE_NAME)); + uri = insertAuthProperty(uri, SERVICE_REALM, prop.getProperty(SERVICE_REALM)); + String canonicalizeHostName = prop.getProperty(CANONICALIZE_HOST_NAME); + if (Boolean.TRUE.toString().equalsIgnoreCase(canonicalizeHostName) || Boolean.FALSE.toString().equalsIgnoreCase(canonicalizeHostName)) { + uri = insertAuthProperty(uri, CANONICALIZE_HOST_NAME, canonicalizeHostName); + } + else if (canonicalizeHostName != null) { + System.err.println("Unknown " + CANONICALIZE_HOST_NAME + " value. Must be true or false."); + } + uri = insertRetryWrites(uri, prop.getProperty(RETRY_WRITES)); + + + // Construct a ServerApi instance using the ServerApi.builder() method + ServerApi serverApi = ServerApi.builder() + .version(ServerApiVersion.V1) + .build(); ConnectionString connectionString = new ConnectionString(uri); - databaseNameFromUrl = connectionString.getDatabase(); - int maxPoolSize = getMaxPoolSize(prop); - MongoClientSettings.Builder builder = MongoClientSettings.builder() - .applyConnectionString(connectionString) - .applyToConnectionPoolSettings(b -> b.maxSize(maxPoolSize)); - String application = prop.getProperty(APPLICATION_NAME); - if (!isNullOrEmpty(application)) { - builder.applicationName(application); - } - if ("true".equals(prop.getProperty("ssl"))) { - boolean allowInvalidCertificates = uri.contains("tlsAllowInvalidCertificates=true") || uri.contains("sslAllowInvalidCertificates=true") - || isTrue(prop.getProperty(ALLOW_INVALID_CERTIFICATES, Boolean.toString(ALLOW_INVALID_CERTIFICATES_DEFAULT))); - builder.applyToSslSettings(s -> { - s.enabled(true); - boolean allowInvalidHostnames = isTrue(prop.getProperty(ALLOW_INVALID_HOSTNAMES, Boolean.toString(ALLOW_INVALID_HOSTNAMES_DEFAULT))); - if (allowInvalidHostnames) s.invalidHostNameAllowed(true); - if (allowInvalidCertificates) { - String keyStoreType = System.getProperty("javax.net.ssl.keyStoreType", KeyStore.getDefaultType()); - String keyStorePassword = System.getProperty("javax.net.ssl.keyStorePassword", ""); - String keyStoreUrl = System.getProperty("javax.net.ssl.keyStore", ""); - // check keyStoreUrl - if (!isNullOrEmpty(keyStoreUrl)) { - try { - new URL(keyStoreUrl); - } catch (MalformedURLException e) { - keyStoreUrl = "file:" + keyStoreUrl; - } - } - try { - s.context(getTrustEverybodySSLContext(keyStoreUrl, keyStoreType, keyStorePassword)); - } - catch (SSLUtil.SSLParamsException e) { - throw new RuntimeException(e); - } - } - }); - } - if (connectionString.getUuidRepresentation() == null) { - String uuidRepresentation = prop.getProperty(UUID_REPRESENTATION, UUID_REPRESENTATION_DEFAULT); - builder.uuidRepresentation(createUuidRepresentation(uuidRepresentation)); - } - if (connectionString.getServerSelectionTimeout() == null) { - int timeout = Integer.parseInt(prop.getProperty(SERVER_SELECTION_TIMEOUT, SERVER_SELECTION_TIMEOUT_DEFAULT)); - builder.applyToClusterSettings(b -> b.serverSelectionTimeout(timeout, TimeUnit.MILLISECONDS)); - } - if (connectionString.getConnectTimeout() == null) { - int timeout = Integer.parseInt(prop.getProperty(CONNECT_TIMEOUT, CONNECT_TIMEOUT_DEFAULT)); - builder.applyToSocketSettings(b -> b.connectTimeout(timeout, TimeUnit.MILLISECONDS)); - } + + MongoCredential credential; + + credential = + MongoCredential.createOidcCredential( + connectionString.getUsername()) + .withMechanismProperty( + MongoCredential.OIDC_HUMAN_CALLBACK_KEY, oidcCallback); + + + databaseNameFromUrl = connectionString.getDatabase(); + MongoClientSettings.Builder builder = MongoClientSettings.builder() + .applyConnectionString(connectionString) + .serverApi(serverApi) + .credential(credential) + .uuidRepresentation(createUuidRepresentation(prop.getProperty(UUID_REPRESENTATION, UUID_REPRESENTATION_DEFAULT))) + .applyToConnectionPoolSettings(b -> b.maxSize(getMaxPoolSize(prop))) + + .applyToSocketSettings(b -> b.connectTimeout(Integer.parseInt(prop.getProperty(CONNECT_TIMEOUT, CONNECT_TIMEOUT_DEFAULT)), TimeUnit.MILLISECONDS)); + + String application = prop.getProperty(APPLICATION_NAME); + if (!isNullOrEmpty(application)) { + builder.applicationName(application); + } + if ("true".equals(prop.getProperty("ssl"))) { + boolean allowInvalidCertificates = uri.contains("tlsAllowInvalidCertificates=true") || uri.contains("sslAllowInvalidCertificates=true") + || isTrue(prop.getProperty(ALLOW_INVALID_CERTIFICATES, Boolean.toString(ALLOW_INVALID_CERTIFICATES_DEFAULT))); + builder.applyToSslSettings(s -> { + s.enabled(true); + boolean allowInvalidHostnames = isTrue(prop.getProperty(ALLOW_INVALID_HOSTNAMES, Boolean.toString(ALLOW_INVALID_HOSTNAMES_DEFAULT))); + if (allowInvalidHostnames) s.invalidHostNameAllowed(true); + if (allowInvalidCertificates) { + String keyStoreType = System.getProperty("javax.net.ssl.keyStoreType", KeyStore.getDefaultType()); + String keyStorePassword = System.getProperty("javax.net.ssl.keyStorePassword", ""); + String keyStoreUrl = System.getProperty("javax.net.ssl.keyStore", ""); + // check keyStoreUrl + if (!isNullOrEmpty(keyStoreUrl)) { + try { + new URL(keyStoreUrl); + } catch (MalformedURLException e) { + keyStoreUrl = "file:" + keyStoreUrl; + } + } + try { + s.context(getTrustEverybodySSLContext(keyStoreUrl, keyStoreType, keyStorePassword)); + } + catch (SSLUtil.SSLParamsException e) { + throw new RuntimeException(e); + } + } + }); + } + if (connectionString.getUuidRepresentation() == null) { + String uuidRepresentation = prop.getProperty(UUID_REPRESENTATION, UUID_REPRESENTATION_DEFAULT); + builder.uuidRepresentation(createUuidRepresentation(uuidRepresentation)); + } + if (connectionString.getServerSelectionTimeout() == null) { + int timeout = Integer.parseInt(prop.getProperty(SERVER_SELECTION_TIMEOUT, SERVER_SELECTION_TIMEOUT_DEFAULT)); + builder.applyToClusterSettings(b -> b.serverSelectionTimeout(timeout, TimeUnit.MILLISECONDS)); + } + if (connectionString.getConnectTimeout() == null) { + int timeout = Integer.parseInt(prop.getProperty(CONNECT_TIMEOUT, CONNECT_TIMEOUT_DEFAULT)); + builder.applyToSocketSettings(b -> b.connectTimeout(timeout, TimeUnit.MILLISECONDS)); + } + this.mongoClient = MongoClients.create(builder.build()); } catch (Exception e) { @@ -160,6 +188,10 @@ public MongoDatabase getDatabase(String databaseName) throws SQLAlreadyClosedExc return mongoClient.getDatabase(databaseName); } + public OidcCallback getOidcCallback() { + return this.oidcCallback; + } + @NotNull public MongoClient getMongoClient() { return mongoClient; diff --git a/src/main/java/com/dbschema/mongo/MongoConnection.java b/src/main/java/com/dbschema/mongo/MongoConnection.java index 9cec94a..6cc3447 100644 --- a/src/main/java/com/dbschema/mongo/MongoConnection.java +++ b/src/main/java/com/dbschema/mongo/MongoConnection.java @@ -3,6 +3,7 @@ import com.dbschema.mongo.mongosh.MongoshScriptEngine; import com.dbschema.mongo.mongosh.PrecalculatingShellHolder; import com.dbschema.mongo.mongosh.ShellHolder; +import com.mongodb.MongoCredential; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; @@ -17,14 +18,17 @@ public class MongoConnection implements Connection { private String schema; private boolean isClosed = false; private boolean isReadOnly = false; + private MongoCredential.OidcCallbackContext oidcCallbackContext; public MongoConnection(@NotNull String url, @NotNull Properties info, @Nullable String username, @Nullable String password, int fetchDocumentsForMeta, - @NotNull ShellHolder shellHolder) throws SQLException { - this.service = new MongoService(url, info, username, password, fetchDocumentsForMeta); + @NotNull ShellHolder shellHolder, + @Nullable MongoCredential.OidcCallbackContext callbackContext) throws SQLException { + this.oidcCallbackContext = callbackContext; + this.service = new MongoService(url, info, username, password, fetchDocumentsForMeta, this.oidcCallbackContext); this.scriptEngine = new MongoshScriptEngine(this, shellHolder); try { setSchema(service.getDatabaseNameFromUrl()); diff --git a/src/main/java/com/dbschema/mongo/MongoService.java b/src/main/java/com/dbschema/mongo/MongoService.java index 4513080..c301695 100644 --- a/src/main/java/com/dbschema/mongo/MongoService.java +++ b/src/main/java/com/dbschema/mongo/MongoService.java @@ -1,6 +1,7 @@ package com.dbschema.mongo; import com.dbschema.mongo.schema.MetaCollection; +import com.mongodb.MongoCredential; import com.mongodb.MongoSecurityException; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCollection; @@ -26,10 +27,10 @@ public class MongoService implements AutoCloseable { public MongoService(@NotNull String uri, @NotNull Properties prop, @Nullable String username, - @Nullable String password, int fetchDocumentsForMeta) throws SQLException { + @Nullable String password, int fetchDocumentsForMeta, @Nullable MongoCredential.OidcCallbackContext callbackContext) throws SQLException { this.uri = uri; this.fetchDocumentsForMeta = fetchDocumentsForMeta; - client = new MongoClientWrapper(uri, prop, username, password); + client = new MongoClientWrapper(uri, prop, username, password, callbackContext); } public MongoClientWrapper getClient() { diff --git a/src/main/java/com/dbschema/mongo/oidc/OidcAuthFlow.java b/src/main/java/com/dbschema/mongo/oidc/OidcAuthFlow.java new file mode 100644 index 0000000..b5197da --- /dev/null +++ b/src/main/java/com/dbschema/mongo/oidc/OidcAuthFlow.java @@ -0,0 +1,305 @@ +package com.dbschema.mongo.oidc; + +import com.mongodb.MongoCredential.IdpInfo; +import com.mongodb.MongoCredential.OidcCallbackContext; +import com.mongodb.MongoCredential.OidcCallbackResult; +import com.nimbusds.oauth2.sdk.AuthorizationCode; +import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant; +import com.nimbusds.oauth2.sdk.AuthorizationRequest; +import com.nimbusds.oauth2.sdk.ParseException; +import com.nimbusds.oauth2.sdk.RefreshTokenGrant; +import com.nimbusds.oauth2.sdk.ResponseType; +import com.nimbusds.oauth2.sdk.Scope; +import com.nimbusds.oauth2.sdk.TokenErrorResponse; +import com.nimbusds.oauth2.sdk.TokenRequest; +import com.nimbusds.oauth2.sdk.TokenResponse; +import com.nimbusds.oauth2.sdk.http.HTTPResponse; +import com.nimbusds.oauth2.sdk.id.ClientID; +import com.nimbusds.oauth2.sdk.id.Issuer; +import com.nimbusds.oauth2.sdk.id.State; +import com.nimbusds.oauth2.sdk.pkce.CodeChallengeMethod; +import com.nimbusds.oauth2.sdk.pkce.CodeVerifier; +import com.nimbusds.oauth2.sdk.token.RefreshToken; +import com.nimbusds.oauth2.sdk.token.Tokens; +import com.nimbusds.openid.connect.sdk.OIDCTokenResponse; +import com.nimbusds.openid.connect.sdk.OIDCTokenResponseParser; +import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata; +import java.io.IOException; +import java.net.URI; +import java.time.Duration; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.security.auth.RefreshFailedException; + +public class OidcAuthFlow { + + private static final Logger logger = Logger.getLogger(OidcAuthFlow.class.getName()); + private static final String OFFLINE_ACCESS = "offline_access"; + private static final String OPENID = "openid"; + private OidcCallbackResult oidcCallbackResult; + + public OidcAuthFlow() {} + + public Scope buildScopes( + String clientID, IdpInfo idpServerInfo, OIDCProviderMetadata providerMetadata) { + Set scopes = new HashSet<>(); + Scope supportedScopes = providerMetadata.getScopes(); + + // Add openid and offline_access scopes by default + scopes.add(OPENID); + scopes.add(OFFLINE_ACCESS); + + // Add custom scopes from request that are supported by the IdP + List requestedScopes = idpServerInfo.getRequestScopes(); + if (requestedScopes != null) { + // azure + String clientIDDefault = clientID + "/.default"; + if (requestedScopes.contains(clientIDDefault)) { + scopes.add(clientIDDefault); + } + if (supportedScopes != null) { + for (String scope : requestedScopes) { + if (supportedScopes.contains(scope)) { + scopes.add(scope); + } else { + logger.warning(String.format("Scope '%s' is not supported", scope)); + } + } + } + } + + Scope finalScopes = new Scope(); + for (String scope : scopes) { + finalScopes.add(new Scope.Value(scope)); + } + return finalScopes; + } + + public OidcCallbackResult doAuthCodeFlow(OidcCallbackContext callbackContext) + throws OidcTimeoutException { + IdpInfo idpServerInfo = callbackContext.getIdpInfo(); + String clientID = idpServerInfo.getClientId(); + String issuerURI = idpServerInfo.getIssuer(); + + if (!isValid(idpServerInfo, clientID, issuerURI)) { + return null; + } + + Server server = new Server(); + try { + OIDCProviderMetadata providerMetadata = + OIDCProviderMetadata.resolve(new Issuer(issuerURI)); + URI authorizationEndpoint = providerMetadata.getAuthorizationEndpointURI(); + URI tokenEndpoint = providerMetadata.getTokenEndpointURI(); + Scope requestedScopes = buildScopes(clientID, idpServerInfo, providerMetadata); + + server.start(); + + URI redirectURI = + new URI( + "http://localhost:" + + Server.DEFAULT_REDIRECT_PORT + + "/redirect"); + State state = new State(); + CodeVerifier codeVerifier = new CodeVerifier(); + + AuthorizationRequest request = + new AuthorizationRequest.Builder( + new ResponseType(ResponseType.Value.CODE), + new ClientID(clientID)) + .scope(requestedScopes) + .redirectionURI(redirectURI) + .state(state) + .codeChallenge(codeVerifier, CodeChallengeMethod.S256) + .endpointURI(authorizationEndpoint) + .build(); + + try { + openURL(request.toURI().toString()); + } catch (Exception e) { + log(Level.SEVERE, "Failed to open the browser: " + e.getMessage()); + return null; + } + + OidcResponse response = server.getOidcResponse(callbackContext.getTimeout()); + if (response == null || !state.getValue().equals(response.getState())) { + log(Level.SEVERE, "OIDC response is null or returned an invalid state"); + return null; + } + + AuthorizationCode code = new AuthorizationCode(response.getCode()); + AuthorizationCodeGrant codeGrant = + new AuthorizationCodeGrant(code, redirectURI, codeVerifier); + TokenRequest tokenRequest = + new TokenRequest(tokenEndpoint, new ClientID(clientID), codeGrant); + + HTTPResponse httpResponse = tokenRequest.toHTTPRequest().send(); + TokenResponse tokenResponse = OIDCTokenResponseParser.parse(httpResponse); + if (!tokenResponse.indicatesSuccess()) { + log(Level.SEVERE, String.format("Request failed: %s", httpResponse.getBody())); + return null; + } + + return getOidcCallbackResultFromTokenResponse((OIDCTokenResponse) tokenResponse); + } catch (OidcTimeoutException e) { + throw e; + } + catch (Exception e) { + log(Level.SEVERE, "Error during OIDC authentication " + e.getMessage()); + + return null; + } finally { + try { + Thread.sleep((1000 * 2)); + } catch (InterruptedException e) { + log(Level.WARNING, "Thread interrupted " + e.getMessage()); + } + server.stop(); + switchContext(); + } + } + + private void switchContext() { + String osName = System.getProperty("os.name").toLowerCase(); + logger.log(Level.INFO, String.format("osName: %s", osName)); + Runtime runtime = Runtime.getRuntime(); + try { + runtime.exec(new String[]{"osascript", "-e" , "tell application \"Datagrip\" to activate"}); + } catch (IOException e) { + log(Level.SEVERE,e.getMessage()); + } + + } + + /** + * Opens the specified URI in the default web browser, supporting macOS, Windows, and + * Linux/Unix. This method uses platform-specific commands to invoke the browser. + * + * @param url the URL to be opened as a string + * @throws Exception if no supported browser is found or an error occurs while attempting to + * open the URL + */ + private void openURL(String url) throws Exception { + String osName = System.getProperty("os.name").toLowerCase(); + logger.log(Level.INFO, String.format("osName: %s", osName)); + Runtime runtime = Runtime.getRuntime(); + + if (osName.contains("windows")) { + runtime.exec(new String[] {"rundll32", "url.dll,FileProtocolHandler", url}); + } else if (osName.contains("mac os")) { + runtime.exec(new String[] {"open", "-gj" ,url}); + } else { + String[] browsers = {"xdg-open", "firefox", "google-chrome"}; + IOException lastError = null; + for (String browser : browsers) { + try { + // Check if browser exists + Process process = runtime.exec(new String[] {"which", browser}); + if (process.waitFor() == 0) { + runtime.exec(new String[] {browser, url}); + } + } catch (IOException e) { + lastError = e; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + break; + } + } + + throw lastError != null + ? lastError + : new IOException("No web browser found to open the URL"); + } + } + + private void log(Level level, String message) { + logger.log(level, message); + } + + public OidcCallbackResult doRefresh(OidcCallbackContext callbackContext) + throws RefreshFailedException { + IdpInfo idpServerInfo = callbackContext.getIdpInfo(); + String clientID = idpServerInfo.getClientId(); + String issuerURI = idpServerInfo.getIssuer(); + + // Check that the IdP information is valid + if (!isValid(idpServerInfo, clientID, issuerURI)) { + return null; + } + try { + // Use OpenID Connect Discovery to fetch the provider metadata + OIDCProviderMetadata providerMetadata = + OIDCProviderMetadata.resolve(new Issuer(issuerURI)); + URI tokenEndpoint = providerMetadata.getTokenEndpointURI(); + + // This function will never be called without a refresh token (to be checked in the driver function), + // but we throw an exception to be explicit about the fact that we expect a refresh token. + String refreshToken = callbackContext.getRefreshToken(); + if (refreshToken == null) { + throw new IllegalArgumentException("Refresh token is required"); + } + + RefreshTokenGrant refreshTokenGrant = + new RefreshTokenGrant(new RefreshToken(refreshToken)); + TokenRequest tokenRequest = + new TokenRequest(tokenEndpoint, new ClientID(clientID), refreshTokenGrant); + HTTPResponse httpResponse = tokenRequest.toHTTPRequest().send(); + + try { + TokenResponse tokenResponse = OIDCTokenResponseParser.parse(httpResponse); + if (!tokenResponse.indicatesSuccess()) { + TokenErrorResponse errorResponse = tokenResponse.toErrorResponse(); + String errorCode = + errorResponse.getErrorObject() != null + ? errorResponse.getErrorObject().getCode() + : null; + String errorDescription = + errorResponse.getErrorObject() != null + ? errorResponse.getErrorObject().getDescription() + : null; + throw new RefreshFailedException( + "Token refresh failed with error: " + + "code=" + + errorCode + + ", description=" + + errorDescription); + } + return getOidcCallbackResultFromTokenResponse((OIDCTokenResponse) tokenResponse); + } catch (ParseException e) { + throw new RefreshFailedException( + "Failed to parse server response: " + + e.getMessage() + + " [response=" + + httpResponse.getBody() + + "]"); + } + + } catch (Exception e) { + log(Level.SEVERE, "OpenID Connect: Error during token refresh. " + e.getMessage()); + if (e instanceof RefreshFailedException) { + throw (RefreshFailedException) e; + } + return null; + } + } + + private boolean isValid(IdpInfo idpInfo, String clientID, String issuerURI) { + return idpInfo != null && clientID != null && !clientID.isEmpty() && issuerURI != null; + } + + private OidcCallbackResult getOidcCallbackResultFromTokenResponse( + OIDCTokenResponse tokenResponse) { + Tokens tokens = tokenResponse.getOIDCTokens(); + String accessToken = tokens.getAccessToken().getValue(); + String refreshToken = + tokens.getRefreshToken() != null ? tokens.getRefreshToken().getValue() : null; + Duration expiresIn = Duration.ofSeconds(tokens.getAccessToken().getLifetime()); + + this.oidcCallbackResult = new OidcCallbackResult(accessToken, expiresIn, refreshToken); + + return this.oidcCallbackResult; + } +} diff --git a/src/main/java/com/dbschema/mongo/oidc/OidcCallback.java b/src/main/java/com/dbschema/mongo/oidc/OidcCallback.java new file mode 100755 index 0000000..bf9331a --- /dev/null +++ b/src/main/java/com/dbschema/mongo/oidc/OidcCallback.java @@ -0,0 +1,42 @@ +package com.dbschema.mongo.oidc; + +import com.mongodb.MongoCredential; +import com.mongodb.MongoCredential.OidcCallbackContext; +import com.mongodb.MongoCredential.OidcCallbackResult; +import javax.security.auth.RefreshFailedException; + +public class OidcCallback implements MongoCredential.OidcCallback { + private final OidcAuthFlow oidcAuthFlow; + private OidcCallbackContext callbackContext; + + public OidcCallback(OidcCallbackContext callbackContext) { + this.oidcAuthFlow = new OidcAuthFlow(); + this.callbackContext = callbackContext; + } + + public OidcAuthFlow getOidcAuthFlow() { + return this.oidcAuthFlow; + } + + public OidcCallbackResult onRequest(OidcCallbackContext callbackContext) { + + if (this.callbackContext != null && this.callbackContext.getRefreshToken() != null && !this.callbackContext.getRefreshToken().isEmpty()) { + try { + return oidcAuthFlow.doRefresh(callbackContext); + } catch (RefreshFailedException e) { + throw new RuntimeException(e); + } + } else { + this.callbackContext = callbackContext; + try { + return oidcAuthFlow.doAuthCodeFlow(callbackContext); + } catch (OidcTimeoutException e) { + throw new RuntimeException(e); + } + } + } + + public OidcCallbackContext getCallbackContext() { + return this.callbackContext; + } +} diff --git a/src/main/java/com/dbschema/mongo/oidc/OidcResponse.java b/src/main/java/com/dbschema/mongo/oidc/OidcResponse.java new file mode 100644 index 0000000..7685b20 --- /dev/null +++ b/src/main/java/com/dbschema/mongo/oidc/OidcResponse.java @@ -0,0 +1,58 @@ +package com.dbschema.mongo.oidc; + +public class OidcResponse { + private String code; + private String state; + private String error; + private String errorDescription; + + public String getCode() { + return code; + } + + public String getState() { + return state; + } + + public String getError() { + return error; + } + + public String getErrorDescription() { + return errorDescription; + } + + public void setCode(String code) { + this.code = code; + } + + public void setState(String state) { + this.state = state; + } + + public void setError(String error) { + this.error = error; + } + + public void setErrorDescription(String errorDescription) { + this.errorDescription = errorDescription; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + if (code != null) { + sb.append("Code: ").append(code).append("\n"); + } + if (state != null) { + sb.append("State: ").append(state).append("\n"); + } + if (error != null) { + sb.append("Error: ").append(error).append("\n"); + } + if (errorDescription != null) { + sb.append("Error Description: ").append(errorDescription).append("\n"); + } + return sb.toString(); + } +} diff --git a/src/main/java/com/dbschema/mongo/oidc/OidcTimeoutException.java b/src/main/java/com/dbschema/mongo/oidc/OidcTimeoutException.java new file mode 100644 index 0000000..a1a7b0d --- /dev/null +++ b/src/main/java/com/dbschema/mongo/oidc/OidcTimeoutException.java @@ -0,0 +1,7 @@ +package com.dbschema.mongo.oidc; + +public class OidcTimeoutException extends Exception { + public OidcTimeoutException(String message) { + super(message); + } +} diff --git a/src/main/java/com/dbschema/mongo/oidc/Server.java b/src/main/java/com/dbschema/mongo/oidc/Server.java new file mode 100644 index 0000000..6a7d2ec --- /dev/null +++ b/src/main/java/com/dbschema/mongo/oidc/Server.java @@ -0,0 +1,173 @@ +package com.dbschema.mongo.oidc; + +import com.sun.net.httpserver.HttpExchange; +import com.sun.net.httpserver.HttpHandler; +import com.sun.net.httpserver.HttpServer; + +import java.io.IOException; +import java.net.HttpURLConnection; +import java.net.InetSocketAddress; +import java.net.URLDecoder; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; + +public class Server { + + private static final Logger logger = Logger.getLogger(Server.class.getName()); + + public static final int DEFAULT_REDIRECT_PORT = 27098; + private static final int RESPONSE_TIMEOUT_SECONDS = 300; + private static final String ACCEPTED_ENDPOINT = "/accepted"; + private static final String CALLBACK_ENDPOINT = "/callback"; + private static final String REDIRECT_ENDPOINT = "/redirect"; + + private HttpServer server; + private final BlockingQueue oidcResponseQueue; + + public Server() { + oidcResponseQueue = new LinkedBlockingQueue<>(); + } + + /** + * Starts the HTTP server and sets up the necessary contexts and handlers. + * + * @throws IOException if an I/O error occurs while creating or starting the server + */ + public void start() throws IOException { + server = HttpServer.create(new InetSocketAddress(DEFAULT_REDIRECT_PORT), 0); + + server.createContext(CALLBACK_ENDPOINT, new CallbackHandler()); + server.createContext(REDIRECT_ENDPOINT, new CallbackHandler()); + server.createContext(ACCEPTED_ENDPOINT, new AcceptedHandler()); + server.setExecutor(Executors.newFixedThreadPool(5)); + + // Start the server + server.start(); + logger.info("Server started on port " + DEFAULT_REDIRECT_PORT); + } + + public OidcResponse getOidcResponse() throws InterruptedException, OidcTimeoutException { + return getOidcResponse(Duration.ofSeconds(RESPONSE_TIMEOUT_SECONDS)); + } + + public OidcResponse getOidcResponse(Duration timeout) + throws OidcTimeoutException, InterruptedException { + if (timeout == null) { + return getOidcResponse(); + } + OidcResponse response = oidcResponseQueue.poll(timeout.getSeconds(), TimeUnit.SECONDS); + if (response == null) { + throw new OidcTimeoutException("Timeout waiting for OIDC response"); + } + return response; + } + + public void stop() { + if (server != null) { + server.stop(0); + } + } + + private class CallbackHandler implements HttpHandler { + + private Map parseQueryParams(HttpExchange exchange) { + Map queryParams = new HashMap<>(); + String rawQuery = exchange.getRequestURI().getRawQuery(); + + if (rawQuery != null) { + String[] params = rawQuery.split("&"); + for (String param : params) { + int equalsIndex = param.indexOf('='); + if (equalsIndex > 0) { + String key = param.substring(0, equalsIndex); + String encodedValue = param.substring(equalsIndex + 1); + String value = URLDecoder.decode(encodedValue, StandardCharsets.UTF_8); + queryParams.put(key, value); + } else { + queryParams.put(param, ""); + } + } + } + return queryParams; + } + + private boolean putOidcResponse(HttpExchange exchange, OidcResponse oidcResponse) + throws IOException { + try { + oidcResponseQueue.put(oidcResponse); + return true; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + reply(exchange, 500); + return false; + } + } + + + @Override + public void handle(HttpExchange exchange) throws IOException { + Map queryParams = parseQueryParams(exchange); + OidcResponse oidcResponse = new OidcResponse(); + + if (queryParams.containsKey("code")) { + oidcResponse.setCode(queryParams.get("code")); + oidcResponse.setState(queryParams.getOrDefault("state", "")); + if (!putOidcResponse(exchange, oidcResponse)) { + return; + } + + exchange.getResponseHeaders().set("Location", ACCEPTED_ENDPOINT); + reply(exchange, HttpURLConnection.HTTP_MOVED_TEMP); + } else if (queryParams.containsKey("error")) { + oidcResponse.setError(queryParams.get("error")); + oidcResponse.setErrorDescription( + queryParams.getOrDefault("error_description", "Unknown error")); + if (!putOidcResponse(exchange, oidcResponse)) { + return; + } + reply(exchange, HttpURLConnection.HTTP_BAD_REQUEST); + + } else { + oidcResponse.setError("Not found"); + String allParams = + queryParams + .entrySet() + .stream() + .map(entry -> entry.getKey() + "=" + entry.getValue()) + .reduce((param1, param2) -> param1 + ", " + param2) + .orElse("No parameters"); + oidcResponse.setErrorDescription("Not found. Parameters: " + allParams); + if (!putOidcResponse(exchange, oidcResponse)) { + return; + } + reply(exchange, HttpURLConnection.HTTP_NOT_FOUND); + } + } + } + + private class AcceptedHandler implements HttpHandler { + @Override + public void handle(HttpExchange exchange) throws IOException { + reply(exchange, HttpURLConnection.HTTP_OK); + } + } + + private void reply(HttpExchange exchange, int statusCode) + throws IOException { + exchange.getResponseHeaders().set("Content-Type", "text/html; charset=utf-8"); + try (exchange) { + exchange.sendResponseHeaders(statusCode, -1); + } catch (Exception e) { + logger.log(Level.SEVERE, "Error sending response", e); + throw e; + } + } +} From 196225d5ce329e39b2fc069c6d1a4d44e84a597b Mon Sep 17 00:00:00 2001 From: "Xolo, Tlatoani" Date: Tue, 4 Nov 2025 11:17:43 -0500 Subject: [PATCH 2/3] updating based on PR comments version update adds support for mongo-oidc --- build.gradle | 11 +- .../java/com/dbschema/MongoJdbcDriver.java | 17 +- .../mongo/DriverPropertyInfoHelper.java | 2 +- .../dbschema/mongo/MongoClientWrapper.java | 168 ++++++---- .../com/dbschema/mongo/MongoConnection.java | 8 +- .../java/com/dbschema/mongo/MongoService.java | 5 +- .../com/dbschema/mongo/oidc/OidcAuthFlow.java | 305 ++++++++++++++++++ .../com/dbschema/mongo/oidc/OidcCallback.java | 42 +++ .../com/dbschema/mongo/oidc/OidcResponse.java | 58 ++++ .../mongo/oidc/OidcTimeoutException.java | 7 + .../java/com/dbschema/mongo/oidc/Server.java | 173 ++++++++++ 11 files changed, 721 insertions(+), 75 deletions(-) create mode 100644 src/main/java/com/dbschema/mongo/oidc/OidcAuthFlow.java create mode 100755 src/main/java/com/dbschema/mongo/oidc/OidcCallback.java create mode 100644 src/main/java/com/dbschema/mongo/oidc/OidcResponse.java create mode 100644 src/main/java/com/dbschema/mongo/oidc/OidcTimeoutException.java create mode 100644 src/main/java/com/dbschema/mongo/oidc/Server.java diff --git a/build.gradle b/build.gradle index cf24672..0a17ca3 100644 --- a/build.gradle +++ b/build.gradle @@ -12,9 +12,12 @@ repositories { dependencies { implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk8" - implementation "org.mongodb:mongodb-driver-sync:4.11.1" + implementation "org.mongodb:mongodb-driver-sync:5.6.1" implementation group: 'org.jetbrains', name: 'annotations', version: '15.0' implementation group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' + implementation group: 'org.graalvm.js', name: 'js', version: '22.3.1' + implementation files('libs/JMongosh-0.9.jar') + implementation group: 'com.nimbusds', name: 'oauth2-oidc-sdk', version: '11.+' implementation group: 'org.graalvm.polyglot', name: 'polyglot', version: '25.0.1' implementation group: 'org.graalvm.js', name: 'js', version: '25.0.1' implementation files('libs/JMongosh-0.9.1.jar') @@ -30,6 +33,12 @@ test { shadowJar { archiveFileName = "mongo-jdbc-standalone-${version}.jar" mergeServiceFiles() + + relocate('org', 'shadow.org') { + exclude 'org.ow2.asm:.*' + exclude 'net.minidev:.*' + exclude 'org.javassist:.*' + } manifest { attributes('Multi-Release' : 'true') } diff --git a/src/main/java/com/dbschema/MongoJdbcDriver.java b/src/main/java/com/dbschema/MongoJdbcDriver.java index 8dbc7a8..75f9e29 100644 --- a/src/main/java/com/dbschema/MongoJdbcDriver.java +++ b/src/main/java/com/dbschema/MongoJdbcDriver.java @@ -1,16 +1,21 @@ package com.dbschema; import com.dbschema.mongo.DriverPropertyInfoHelper; +import com.dbschema.mongo.MongoClientWrapper; import com.dbschema.mongo.MongoConnection; +import com.dbschema.mongo.MongoService; import com.dbschema.mongo.mongosh.LazyShellHolder; import com.dbschema.mongo.mongosh.PrecalculatingShellHolder; import com.dbschema.mongo.mongosh.ShellHolder; +import com.dbschema.mongo.oidc.OidcCallback; +import com.mongodb.MongoCredential; import com.mongodb.mongosh.MongoShell; import org.graalvm.polyglot.Engine; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; import java.sql.*; +import java.util.Optional; import java.util.Properties; import java.util.concurrent.ExecutorService; import java.util.logging.Logger; @@ -33,6 +38,7 @@ public class MongoJdbcDriver implements Driver { private @Nullable ExecutorService executorService; private @Nullable Engine sharedEngine; private @NotNull ShellHolder shellHolder; + private MongoConnection mongoConnection; static { try { @@ -100,7 +106,16 @@ public Connection connect(String url, Properties info) throws SQLException { synchronized (this) { ShellHolder shellHolder = this.shellHolder; this.shellHolder = createShellHolder(); - return new MongoConnection(url, info, username, password, fetchDocumentsForMeta, shellHolder); + MongoCredential.OidcCallbackContext existingResult = Optional.ofNullable(this.mongoConnection) + .map(MongoConnection::getService) + .map(MongoService::getClient) + .map(MongoClientWrapper::getOidcCallback) + .map(OidcCallback::getCallbackContext) + .orElse(null); + + this.mongoConnection = new MongoConnection(url, info, username, password, fetchDocumentsForMeta, shellHolder, existingResult); + + return this.mongoConnection; } } diff --git a/src/main/java/com/dbschema/mongo/DriverPropertyInfoHelper.java b/src/main/java/com/dbschema/mongo/DriverPropertyInfoHelper.java index fa17d22..907af6f 100644 --- a/src/main/java/com/dbschema/mongo/DriverPropertyInfoHelper.java +++ b/src/main/java/com/dbschema/mongo/DriverPropertyInfoHelper.java @@ -5,7 +5,7 @@ public class DriverPropertyInfoHelper { public static final String AUTH_MECHANISM = "authMechanism"; - public static final String[] AUTH_MECHANISM_CHOICES = new String[]{"GSSAPI", "MONGODB-AWS", "MONGODB-X509", "PLAIN", "SCRAM-SHA-1", "SCRAM-SHA-256"}; + public static final String[] AUTH_MECHANISM_CHOICES = new String[]{"GSSAPI", "MONGODB-AWS", "MONGODB-X509", "PLAIN", "SCRAM-SHA-1", "SCRAM-SHA-256", "MONGODB-OIDC"}; public static final String AUTH_SOURCE = "authSource"; public static final String AWS_SESSION_TOKEN = "AWS_SESSION_TOKEN"; public static final String SERVICE_NAME = "SERVICE_NAME"; diff --git a/src/main/java/com/dbschema/mongo/MongoClientWrapper.java b/src/main/java/com/dbschema/mongo/MongoClientWrapper.java index 3206c59..cc4f092 100644 --- a/src/main/java/com/dbschema/mongo/MongoClientWrapper.java +++ b/src/main/java/com/dbschema/mongo/MongoClientWrapper.java @@ -1,7 +1,11 @@ package com.dbschema.mongo; +import com.dbschema.mongo.oidc.OidcCallback; import com.mongodb.ConnectionString; import com.mongodb.MongoClientSettings; +import com.mongodb.MongoCredential; +import com.mongodb.ServerApi; +import com.mongodb.ServerApiVersion; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoClients; import com.mongodb.client.MongoDatabase; @@ -26,79 +30,103 @@ public class MongoClientWrapper implements AutoCloseable { private boolean isClosed = false; private final MongoClient mongoClient; public final String databaseNameFromUrl; + public final OidcCallback oidcCallback; - public MongoClientWrapper(@NotNull String uri, @NotNull Properties prop, @Nullable String username, @Nullable String password) throws SQLException { + public MongoClientWrapper(@NotNull String uri, @NotNull Properties prop, @Nullable String username, @Nullable String password, @Nullable MongoCredential.OidcCallbackContext callbackContext) throws SQLException { + this.oidcCallback = new OidcCallback(callbackContext); try { - boolean automaticEncoding = ENCODE_CREDENTIALS_DEFAULT; - if (prop.getProperty(ENCODE_CREDENTIALS) != null) { - automaticEncoding = Boolean.parseBoolean(prop.getProperty(ENCODE_CREDENTIALS)); - } + boolean automaticEncoding = ENCODE_CREDENTIALS_DEFAULT; + if (prop.getProperty(ENCODE_CREDENTIALS) != null) { + automaticEncoding = Boolean.parseBoolean(prop.getProperty(ENCODE_CREDENTIALS)); + } - uri = insertCredentials(uri, username, password, automaticEncoding); - uri = insertAuthMechanism(uri, prop.getProperty(AUTH_MECHANISM)); - uri = insertAuthSource(uri, prop.getProperty(AUTH_SOURCE)); - uri = insertAuthProperty(uri, AWS_SESSION_TOKEN, prop.getProperty(AWS_SESSION_TOKEN)); - uri = insertAuthProperty(uri, SERVICE_NAME, prop.getProperty(SERVICE_NAME)); - uri = insertAuthProperty(uri, SERVICE_REALM, prop.getProperty(SERVICE_REALM)); - String canonicalizeHostName = prop.getProperty(CANONICALIZE_HOST_NAME); - if (Boolean.TRUE.toString().equalsIgnoreCase(canonicalizeHostName) || Boolean.FALSE.toString().equalsIgnoreCase(canonicalizeHostName)) { - uri = insertAuthProperty(uri, CANONICALIZE_HOST_NAME, canonicalizeHostName); - } - else if (canonicalizeHostName != null) { - System.err.println("Unknown " + CANONICALIZE_HOST_NAME + " value. Must be true or false."); - } - uri = insertRetryWrites(uri, prop.getProperty(RETRY_WRITES)); + uri = insertCredentials(uri, username, password, automaticEncoding); + uri = insertAuthMechanism(uri, prop.getProperty(AUTH_MECHANISM)); + uri = insertAuthSource(uri, prop.getProperty(AUTH_SOURCE)); + uri = insertAuthProperty(uri, AWS_SESSION_TOKEN, prop.getProperty(AWS_SESSION_TOKEN)); + uri = insertAuthProperty(uri, SERVICE_NAME, prop.getProperty(SERVICE_NAME)); + uri = insertAuthProperty(uri, SERVICE_REALM, prop.getProperty(SERVICE_REALM)); + String canonicalizeHostName = prop.getProperty(CANONICALIZE_HOST_NAME); + if (Boolean.TRUE.toString().equalsIgnoreCase(canonicalizeHostName) || Boolean.FALSE.toString().equalsIgnoreCase(canonicalizeHostName)) { + uri = insertAuthProperty(uri, CANONICALIZE_HOST_NAME, canonicalizeHostName); + } + else if (canonicalizeHostName != null) { + System.err.println("Unknown " + CANONICALIZE_HOST_NAME + " value. Must be true or false."); + } + uri = insertRetryWrites(uri, prop.getProperty(RETRY_WRITES)); + + + // Construct a ServerApi instance using the ServerApi.builder() method + ServerApi serverApi = ServerApi.builder() + .version(ServerApiVersion.V1) + .build(); ConnectionString connectionString = new ConnectionString(uri); - databaseNameFromUrl = connectionString.getDatabase(); - int maxPoolSize = getMaxPoolSize(prop); - MongoClientSettings.Builder builder = MongoClientSettings.builder() - .applyConnectionString(connectionString) - .applyToConnectionPoolSettings(b -> b.maxSize(maxPoolSize)); - String application = prop.getProperty(APPLICATION_NAME); - if (!isNullOrEmpty(application)) { - builder.applicationName(application); - } - if ("true".equals(prop.getProperty("ssl"))) { - boolean allowInvalidCertificates = uri.contains("tlsAllowInvalidCertificates=true") || uri.contains("sslAllowInvalidCertificates=true") - || isTrue(prop.getProperty(ALLOW_INVALID_CERTIFICATES, Boolean.toString(ALLOW_INVALID_CERTIFICATES_DEFAULT))); - builder.applyToSslSettings(s -> { - s.enabled(true); - boolean allowInvalidHostnames = isTrue(prop.getProperty(ALLOW_INVALID_HOSTNAMES, Boolean.toString(ALLOW_INVALID_HOSTNAMES_DEFAULT))); - if (allowInvalidHostnames) s.invalidHostNameAllowed(true); - if (allowInvalidCertificates) { - String keyStoreType = System.getProperty("javax.net.ssl.keyStoreType", KeyStore.getDefaultType()); - String keyStorePassword = System.getProperty("javax.net.ssl.keyStorePassword", ""); - String keyStoreUrl = System.getProperty("javax.net.ssl.keyStore", ""); - // check keyStoreUrl - if (!isNullOrEmpty(keyStoreUrl)) { - try { - new URL(keyStoreUrl); - } catch (MalformedURLException e) { - keyStoreUrl = "file:" + keyStoreUrl; - } - } - try { - s.context(getTrustEverybodySSLContext(keyStoreUrl, keyStoreType, keyStorePassword)); - } - catch (SSLUtil.SSLParamsException e) { - throw new RuntimeException(e); - } - } - }); - } - if (connectionString.getUuidRepresentation() == null) { - String uuidRepresentation = prop.getProperty(UUID_REPRESENTATION, UUID_REPRESENTATION_DEFAULT); - builder.uuidRepresentation(createUuidRepresentation(uuidRepresentation)); - } - if (connectionString.getServerSelectionTimeout() == null) { - int timeout = Integer.parseInt(prop.getProperty(SERVER_SELECTION_TIMEOUT, SERVER_SELECTION_TIMEOUT_DEFAULT)); - builder.applyToClusterSettings(b -> b.serverSelectionTimeout(timeout, TimeUnit.MILLISECONDS)); - } - if (connectionString.getConnectTimeout() == null) { - int timeout = Integer.parseInt(prop.getProperty(CONNECT_TIMEOUT, CONNECT_TIMEOUT_DEFAULT)); - builder.applyToSocketSettings(b -> b.connectTimeout(timeout, TimeUnit.MILLISECONDS)); - } + + MongoCredential credential; + + credential = + MongoCredential.createOidcCredential( + connectionString.getUsername()) + .withMechanismProperty( + MongoCredential.OIDC_HUMAN_CALLBACK_KEY, oidcCallback); + + + databaseNameFromUrl = connectionString.getDatabase(); + MongoClientSettings.Builder builder = MongoClientSettings.builder() + .applyConnectionString(connectionString) + .serverApi(serverApi) + .credential(credential) + .uuidRepresentation(createUuidRepresentation(prop.getProperty(UUID_REPRESENTATION, UUID_REPRESENTATION_DEFAULT))) + .applyToConnectionPoolSettings(b -> b.maxSize(getMaxPoolSize(prop))) + + .applyToSocketSettings(b -> b.connectTimeout(Integer.parseInt(prop.getProperty(CONNECT_TIMEOUT, CONNECT_TIMEOUT_DEFAULT)), TimeUnit.MILLISECONDS)); + + String application = prop.getProperty(APPLICATION_NAME); + if (!isNullOrEmpty(application)) { + builder.applicationName(application); + } + if ("true".equals(prop.getProperty("ssl"))) { + boolean allowInvalidCertificates = uri.contains("tlsAllowInvalidCertificates=true") || uri.contains("sslAllowInvalidCertificates=true") + || isTrue(prop.getProperty(ALLOW_INVALID_CERTIFICATES, Boolean.toString(ALLOW_INVALID_CERTIFICATES_DEFAULT))); + builder.applyToSslSettings(s -> { + s.enabled(true); + boolean allowInvalidHostnames = isTrue(prop.getProperty(ALLOW_INVALID_HOSTNAMES, Boolean.toString(ALLOW_INVALID_HOSTNAMES_DEFAULT))); + if (allowInvalidHostnames) s.invalidHostNameAllowed(true); + if (allowInvalidCertificates) { + String keyStoreType = System.getProperty("javax.net.ssl.keyStoreType", KeyStore.getDefaultType()); + String keyStorePassword = System.getProperty("javax.net.ssl.keyStorePassword", ""); + String keyStoreUrl = System.getProperty("javax.net.ssl.keyStore", ""); + // check keyStoreUrl + if (!isNullOrEmpty(keyStoreUrl)) { + try { + new URL(keyStoreUrl); + } catch (MalformedURLException e) { + keyStoreUrl = "file:" + keyStoreUrl; + } + } + try { + s.context(getTrustEverybodySSLContext(keyStoreUrl, keyStoreType, keyStorePassword)); + } + catch (SSLUtil.SSLParamsException e) { + throw new RuntimeException(e); + } + } + }); + } + if (connectionString.getUuidRepresentation() == null) { + String uuidRepresentation = prop.getProperty(UUID_REPRESENTATION, UUID_REPRESENTATION_DEFAULT); + builder.uuidRepresentation(createUuidRepresentation(uuidRepresentation)); + } + if (connectionString.getServerSelectionTimeout() == null) { + int timeout = Integer.parseInt(prop.getProperty(SERVER_SELECTION_TIMEOUT, SERVER_SELECTION_TIMEOUT_DEFAULT)); + builder.applyToClusterSettings(b -> b.serverSelectionTimeout(timeout, TimeUnit.MILLISECONDS)); + } + if (connectionString.getConnectTimeout() == null) { + int timeout = Integer.parseInt(prop.getProperty(CONNECT_TIMEOUT, CONNECT_TIMEOUT_DEFAULT)); + builder.applyToSocketSettings(b -> b.connectTimeout(timeout, TimeUnit.MILLISECONDS)); + } + this.mongoClient = MongoClients.create(builder.build()); } catch (Exception e) { @@ -160,6 +188,10 @@ public MongoDatabase getDatabase(String databaseName) throws SQLAlreadyClosedExc return mongoClient.getDatabase(databaseName); } + public OidcCallback getOidcCallback() { + return this.oidcCallback; + } + @NotNull public MongoClient getMongoClient() { return mongoClient; diff --git a/src/main/java/com/dbschema/mongo/MongoConnection.java b/src/main/java/com/dbschema/mongo/MongoConnection.java index 9cec94a..6cc3447 100644 --- a/src/main/java/com/dbschema/mongo/MongoConnection.java +++ b/src/main/java/com/dbschema/mongo/MongoConnection.java @@ -3,6 +3,7 @@ import com.dbschema.mongo.mongosh.MongoshScriptEngine; import com.dbschema.mongo.mongosh.PrecalculatingShellHolder; import com.dbschema.mongo.mongosh.ShellHolder; +import com.mongodb.MongoCredential; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; @@ -17,14 +18,17 @@ public class MongoConnection implements Connection { private String schema; private boolean isClosed = false; private boolean isReadOnly = false; + private MongoCredential.OidcCallbackContext oidcCallbackContext; public MongoConnection(@NotNull String url, @NotNull Properties info, @Nullable String username, @Nullable String password, int fetchDocumentsForMeta, - @NotNull ShellHolder shellHolder) throws SQLException { - this.service = new MongoService(url, info, username, password, fetchDocumentsForMeta); + @NotNull ShellHolder shellHolder, + @Nullable MongoCredential.OidcCallbackContext callbackContext) throws SQLException { + this.oidcCallbackContext = callbackContext; + this.service = new MongoService(url, info, username, password, fetchDocumentsForMeta, this.oidcCallbackContext); this.scriptEngine = new MongoshScriptEngine(this, shellHolder); try { setSchema(service.getDatabaseNameFromUrl()); diff --git a/src/main/java/com/dbschema/mongo/MongoService.java b/src/main/java/com/dbschema/mongo/MongoService.java index 4513080..c301695 100644 --- a/src/main/java/com/dbschema/mongo/MongoService.java +++ b/src/main/java/com/dbschema/mongo/MongoService.java @@ -1,6 +1,7 @@ package com.dbschema.mongo; import com.dbschema.mongo.schema.MetaCollection; +import com.mongodb.MongoCredential; import com.mongodb.MongoSecurityException; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCollection; @@ -26,10 +27,10 @@ public class MongoService implements AutoCloseable { public MongoService(@NotNull String uri, @NotNull Properties prop, @Nullable String username, - @Nullable String password, int fetchDocumentsForMeta) throws SQLException { + @Nullable String password, int fetchDocumentsForMeta, @Nullable MongoCredential.OidcCallbackContext callbackContext) throws SQLException { this.uri = uri; this.fetchDocumentsForMeta = fetchDocumentsForMeta; - client = new MongoClientWrapper(uri, prop, username, password); + client = new MongoClientWrapper(uri, prop, username, password, callbackContext); } public MongoClientWrapper getClient() { diff --git a/src/main/java/com/dbschema/mongo/oidc/OidcAuthFlow.java b/src/main/java/com/dbschema/mongo/oidc/OidcAuthFlow.java new file mode 100644 index 0000000..b5197da --- /dev/null +++ b/src/main/java/com/dbschema/mongo/oidc/OidcAuthFlow.java @@ -0,0 +1,305 @@ +package com.dbschema.mongo.oidc; + +import com.mongodb.MongoCredential.IdpInfo; +import com.mongodb.MongoCredential.OidcCallbackContext; +import com.mongodb.MongoCredential.OidcCallbackResult; +import com.nimbusds.oauth2.sdk.AuthorizationCode; +import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant; +import com.nimbusds.oauth2.sdk.AuthorizationRequest; +import com.nimbusds.oauth2.sdk.ParseException; +import com.nimbusds.oauth2.sdk.RefreshTokenGrant; +import com.nimbusds.oauth2.sdk.ResponseType; +import com.nimbusds.oauth2.sdk.Scope; +import com.nimbusds.oauth2.sdk.TokenErrorResponse; +import com.nimbusds.oauth2.sdk.TokenRequest; +import com.nimbusds.oauth2.sdk.TokenResponse; +import com.nimbusds.oauth2.sdk.http.HTTPResponse; +import com.nimbusds.oauth2.sdk.id.ClientID; +import com.nimbusds.oauth2.sdk.id.Issuer; +import com.nimbusds.oauth2.sdk.id.State; +import com.nimbusds.oauth2.sdk.pkce.CodeChallengeMethod; +import com.nimbusds.oauth2.sdk.pkce.CodeVerifier; +import com.nimbusds.oauth2.sdk.token.RefreshToken; +import com.nimbusds.oauth2.sdk.token.Tokens; +import com.nimbusds.openid.connect.sdk.OIDCTokenResponse; +import com.nimbusds.openid.connect.sdk.OIDCTokenResponseParser; +import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata; +import java.io.IOException; +import java.net.URI; +import java.time.Duration; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.security.auth.RefreshFailedException; + +public class OidcAuthFlow { + + private static final Logger logger = Logger.getLogger(OidcAuthFlow.class.getName()); + private static final String OFFLINE_ACCESS = "offline_access"; + private static final String OPENID = "openid"; + private OidcCallbackResult oidcCallbackResult; + + public OidcAuthFlow() {} + + public Scope buildScopes( + String clientID, IdpInfo idpServerInfo, OIDCProviderMetadata providerMetadata) { + Set scopes = new HashSet<>(); + Scope supportedScopes = providerMetadata.getScopes(); + + // Add openid and offline_access scopes by default + scopes.add(OPENID); + scopes.add(OFFLINE_ACCESS); + + // Add custom scopes from request that are supported by the IdP + List requestedScopes = idpServerInfo.getRequestScopes(); + if (requestedScopes != null) { + // azure + String clientIDDefault = clientID + "/.default"; + if (requestedScopes.contains(clientIDDefault)) { + scopes.add(clientIDDefault); + } + if (supportedScopes != null) { + for (String scope : requestedScopes) { + if (supportedScopes.contains(scope)) { + scopes.add(scope); + } else { + logger.warning(String.format("Scope '%s' is not supported", scope)); + } + } + } + } + + Scope finalScopes = new Scope(); + for (String scope : scopes) { + finalScopes.add(new Scope.Value(scope)); + } + return finalScopes; + } + + public OidcCallbackResult doAuthCodeFlow(OidcCallbackContext callbackContext) + throws OidcTimeoutException { + IdpInfo idpServerInfo = callbackContext.getIdpInfo(); + String clientID = idpServerInfo.getClientId(); + String issuerURI = idpServerInfo.getIssuer(); + + if (!isValid(idpServerInfo, clientID, issuerURI)) { + return null; + } + + Server server = new Server(); + try { + OIDCProviderMetadata providerMetadata = + OIDCProviderMetadata.resolve(new Issuer(issuerURI)); + URI authorizationEndpoint = providerMetadata.getAuthorizationEndpointURI(); + URI tokenEndpoint = providerMetadata.getTokenEndpointURI(); + Scope requestedScopes = buildScopes(clientID, idpServerInfo, providerMetadata); + + server.start(); + + URI redirectURI = + new URI( + "http://localhost:" + + Server.DEFAULT_REDIRECT_PORT + + "/redirect"); + State state = new State(); + CodeVerifier codeVerifier = new CodeVerifier(); + + AuthorizationRequest request = + new AuthorizationRequest.Builder( + new ResponseType(ResponseType.Value.CODE), + new ClientID(clientID)) + .scope(requestedScopes) + .redirectionURI(redirectURI) + .state(state) + .codeChallenge(codeVerifier, CodeChallengeMethod.S256) + .endpointURI(authorizationEndpoint) + .build(); + + try { + openURL(request.toURI().toString()); + } catch (Exception e) { + log(Level.SEVERE, "Failed to open the browser: " + e.getMessage()); + return null; + } + + OidcResponse response = server.getOidcResponse(callbackContext.getTimeout()); + if (response == null || !state.getValue().equals(response.getState())) { + log(Level.SEVERE, "OIDC response is null or returned an invalid state"); + return null; + } + + AuthorizationCode code = new AuthorizationCode(response.getCode()); + AuthorizationCodeGrant codeGrant = + new AuthorizationCodeGrant(code, redirectURI, codeVerifier); + TokenRequest tokenRequest = + new TokenRequest(tokenEndpoint, new ClientID(clientID), codeGrant); + + HTTPResponse httpResponse = tokenRequest.toHTTPRequest().send(); + TokenResponse tokenResponse = OIDCTokenResponseParser.parse(httpResponse); + if (!tokenResponse.indicatesSuccess()) { + log(Level.SEVERE, String.format("Request failed: %s", httpResponse.getBody())); + return null; + } + + return getOidcCallbackResultFromTokenResponse((OIDCTokenResponse) tokenResponse); + } catch (OidcTimeoutException e) { + throw e; + } + catch (Exception e) { + log(Level.SEVERE, "Error during OIDC authentication " + e.getMessage()); + + return null; + } finally { + try { + Thread.sleep((1000 * 2)); + } catch (InterruptedException e) { + log(Level.WARNING, "Thread interrupted " + e.getMessage()); + } + server.stop(); + switchContext(); + } + } + + private void switchContext() { + String osName = System.getProperty("os.name").toLowerCase(); + logger.log(Level.INFO, String.format("osName: %s", osName)); + Runtime runtime = Runtime.getRuntime(); + try { + runtime.exec(new String[]{"osascript", "-e" , "tell application \"Datagrip\" to activate"}); + } catch (IOException e) { + log(Level.SEVERE,e.getMessage()); + } + + } + + /** + * Opens the specified URI in the default web browser, supporting macOS, Windows, and + * Linux/Unix. This method uses platform-specific commands to invoke the browser. + * + * @param url the URL to be opened as a string + * @throws Exception if no supported browser is found or an error occurs while attempting to + * open the URL + */ + private void openURL(String url) throws Exception { + String osName = System.getProperty("os.name").toLowerCase(); + logger.log(Level.INFO, String.format("osName: %s", osName)); + Runtime runtime = Runtime.getRuntime(); + + if (osName.contains("windows")) { + runtime.exec(new String[] {"rundll32", "url.dll,FileProtocolHandler", url}); + } else if (osName.contains("mac os")) { + runtime.exec(new String[] {"open", "-gj" ,url}); + } else { + String[] browsers = {"xdg-open", "firefox", "google-chrome"}; + IOException lastError = null; + for (String browser : browsers) { + try { + // Check if browser exists + Process process = runtime.exec(new String[] {"which", browser}); + if (process.waitFor() == 0) { + runtime.exec(new String[] {browser, url}); + } + } catch (IOException e) { + lastError = e; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + break; + } + } + + throw lastError != null + ? lastError + : new IOException("No web browser found to open the URL"); + } + } + + private void log(Level level, String message) { + logger.log(level, message); + } + + public OidcCallbackResult doRefresh(OidcCallbackContext callbackContext) + throws RefreshFailedException { + IdpInfo idpServerInfo = callbackContext.getIdpInfo(); + String clientID = idpServerInfo.getClientId(); + String issuerURI = idpServerInfo.getIssuer(); + + // Check that the IdP information is valid + if (!isValid(idpServerInfo, clientID, issuerURI)) { + return null; + } + try { + // Use OpenID Connect Discovery to fetch the provider metadata + OIDCProviderMetadata providerMetadata = + OIDCProviderMetadata.resolve(new Issuer(issuerURI)); + URI tokenEndpoint = providerMetadata.getTokenEndpointURI(); + + // This function will never be called without a refresh token (to be checked in the driver function), + // but we throw an exception to be explicit about the fact that we expect a refresh token. + String refreshToken = callbackContext.getRefreshToken(); + if (refreshToken == null) { + throw new IllegalArgumentException("Refresh token is required"); + } + + RefreshTokenGrant refreshTokenGrant = + new RefreshTokenGrant(new RefreshToken(refreshToken)); + TokenRequest tokenRequest = + new TokenRequest(tokenEndpoint, new ClientID(clientID), refreshTokenGrant); + HTTPResponse httpResponse = tokenRequest.toHTTPRequest().send(); + + try { + TokenResponse tokenResponse = OIDCTokenResponseParser.parse(httpResponse); + if (!tokenResponse.indicatesSuccess()) { + TokenErrorResponse errorResponse = tokenResponse.toErrorResponse(); + String errorCode = + errorResponse.getErrorObject() != null + ? errorResponse.getErrorObject().getCode() + : null; + String errorDescription = + errorResponse.getErrorObject() != null + ? errorResponse.getErrorObject().getDescription() + : null; + throw new RefreshFailedException( + "Token refresh failed with error: " + + "code=" + + errorCode + + ", description=" + + errorDescription); + } + return getOidcCallbackResultFromTokenResponse((OIDCTokenResponse) tokenResponse); + } catch (ParseException e) { + throw new RefreshFailedException( + "Failed to parse server response: " + + e.getMessage() + + " [response=" + + httpResponse.getBody() + + "]"); + } + + } catch (Exception e) { + log(Level.SEVERE, "OpenID Connect: Error during token refresh. " + e.getMessage()); + if (e instanceof RefreshFailedException) { + throw (RefreshFailedException) e; + } + return null; + } + } + + private boolean isValid(IdpInfo idpInfo, String clientID, String issuerURI) { + return idpInfo != null && clientID != null && !clientID.isEmpty() && issuerURI != null; + } + + private OidcCallbackResult getOidcCallbackResultFromTokenResponse( + OIDCTokenResponse tokenResponse) { + Tokens tokens = tokenResponse.getOIDCTokens(); + String accessToken = tokens.getAccessToken().getValue(); + String refreshToken = + tokens.getRefreshToken() != null ? tokens.getRefreshToken().getValue() : null; + Duration expiresIn = Duration.ofSeconds(tokens.getAccessToken().getLifetime()); + + this.oidcCallbackResult = new OidcCallbackResult(accessToken, expiresIn, refreshToken); + + return this.oidcCallbackResult; + } +} diff --git a/src/main/java/com/dbschema/mongo/oidc/OidcCallback.java b/src/main/java/com/dbschema/mongo/oidc/OidcCallback.java new file mode 100755 index 0000000..bf9331a --- /dev/null +++ b/src/main/java/com/dbschema/mongo/oidc/OidcCallback.java @@ -0,0 +1,42 @@ +package com.dbschema.mongo.oidc; + +import com.mongodb.MongoCredential; +import com.mongodb.MongoCredential.OidcCallbackContext; +import com.mongodb.MongoCredential.OidcCallbackResult; +import javax.security.auth.RefreshFailedException; + +public class OidcCallback implements MongoCredential.OidcCallback { + private final OidcAuthFlow oidcAuthFlow; + private OidcCallbackContext callbackContext; + + public OidcCallback(OidcCallbackContext callbackContext) { + this.oidcAuthFlow = new OidcAuthFlow(); + this.callbackContext = callbackContext; + } + + public OidcAuthFlow getOidcAuthFlow() { + return this.oidcAuthFlow; + } + + public OidcCallbackResult onRequest(OidcCallbackContext callbackContext) { + + if (this.callbackContext != null && this.callbackContext.getRefreshToken() != null && !this.callbackContext.getRefreshToken().isEmpty()) { + try { + return oidcAuthFlow.doRefresh(callbackContext); + } catch (RefreshFailedException e) { + throw new RuntimeException(e); + } + } else { + this.callbackContext = callbackContext; + try { + return oidcAuthFlow.doAuthCodeFlow(callbackContext); + } catch (OidcTimeoutException e) { + throw new RuntimeException(e); + } + } + } + + public OidcCallbackContext getCallbackContext() { + return this.callbackContext; + } +} diff --git a/src/main/java/com/dbschema/mongo/oidc/OidcResponse.java b/src/main/java/com/dbschema/mongo/oidc/OidcResponse.java new file mode 100644 index 0000000..7685b20 --- /dev/null +++ b/src/main/java/com/dbschema/mongo/oidc/OidcResponse.java @@ -0,0 +1,58 @@ +package com.dbschema.mongo.oidc; + +public class OidcResponse { + private String code; + private String state; + private String error; + private String errorDescription; + + public String getCode() { + return code; + } + + public String getState() { + return state; + } + + public String getError() { + return error; + } + + public String getErrorDescription() { + return errorDescription; + } + + public void setCode(String code) { + this.code = code; + } + + public void setState(String state) { + this.state = state; + } + + public void setError(String error) { + this.error = error; + } + + public void setErrorDescription(String errorDescription) { + this.errorDescription = errorDescription; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + if (code != null) { + sb.append("Code: ").append(code).append("\n"); + } + if (state != null) { + sb.append("State: ").append(state).append("\n"); + } + if (error != null) { + sb.append("Error: ").append(error).append("\n"); + } + if (errorDescription != null) { + sb.append("Error Description: ").append(errorDescription).append("\n"); + } + return sb.toString(); + } +} diff --git a/src/main/java/com/dbschema/mongo/oidc/OidcTimeoutException.java b/src/main/java/com/dbschema/mongo/oidc/OidcTimeoutException.java new file mode 100644 index 0000000..a1a7b0d --- /dev/null +++ b/src/main/java/com/dbschema/mongo/oidc/OidcTimeoutException.java @@ -0,0 +1,7 @@ +package com.dbschema.mongo.oidc; + +public class OidcTimeoutException extends Exception { + public OidcTimeoutException(String message) { + super(message); + } +} diff --git a/src/main/java/com/dbschema/mongo/oidc/Server.java b/src/main/java/com/dbschema/mongo/oidc/Server.java new file mode 100644 index 0000000..6a7d2ec --- /dev/null +++ b/src/main/java/com/dbschema/mongo/oidc/Server.java @@ -0,0 +1,173 @@ +package com.dbschema.mongo.oidc; + +import com.sun.net.httpserver.HttpExchange; +import com.sun.net.httpserver.HttpHandler; +import com.sun.net.httpserver.HttpServer; + +import java.io.IOException; +import java.net.HttpURLConnection; +import java.net.InetSocketAddress; +import java.net.URLDecoder; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; + +public class Server { + + private static final Logger logger = Logger.getLogger(Server.class.getName()); + + public static final int DEFAULT_REDIRECT_PORT = 27098; + private static final int RESPONSE_TIMEOUT_SECONDS = 300; + private static final String ACCEPTED_ENDPOINT = "/accepted"; + private static final String CALLBACK_ENDPOINT = "/callback"; + private static final String REDIRECT_ENDPOINT = "/redirect"; + + private HttpServer server; + private final BlockingQueue oidcResponseQueue; + + public Server() { + oidcResponseQueue = new LinkedBlockingQueue<>(); + } + + /** + * Starts the HTTP server and sets up the necessary contexts and handlers. + * + * @throws IOException if an I/O error occurs while creating or starting the server + */ + public void start() throws IOException { + server = HttpServer.create(new InetSocketAddress(DEFAULT_REDIRECT_PORT), 0); + + server.createContext(CALLBACK_ENDPOINT, new CallbackHandler()); + server.createContext(REDIRECT_ENDPOINT, new CallbackHandler()); + server.createContext(ACCEPTED_ENDPOINT, new AcceptedHandler()); + server.setExecutor(Executors.newFixedThreadPool(5)); + + // Start the server + server.start(); + logger.info("Server started on port " + DEFAULT_REDIRECT_PORT); + } + + public OidcResponse getOidcResponse() throws InterruptedException, OidcTimeoutException { + return getOidcResponse(Duration.ofSeconds(RESPONSE_TIMEOUT_SECONDS)); + } + + public OidcResponse getOidcResponse(Duration timeout) + throws OidcTimeoutException, InterruptedException { + if (timeout == null) { + return getOidcResponse(); + } + OidcResponse response = oidcResponseQueue.poll(timeout.getSeconds(), TimeUnit.SECONDS); + if (response == null) { + throw new OidcTimeoutException("Timeout waiting for OIDC response"); + } + return response; + } + + public void stop() { + if (server != null) { + server.stop(0); + } + } + + private class CallbackHandler implements HttpHandler { + + private Map parseQueryParams(HttpExchange exchange) { + Map queryParams = new HashMap<>(); + String rawQuery = exchange.getRequestURI().getRawQuery(); + + if (rawQuery != null) { + String[] params = rawQuery.split("&"); + for (String param : params) { + int equalsIndex = param.indexOf('='); + if (equalsIndex > 0) { + String key = param.substring(0, equalsIndex); + String encodedValue = param.substring(equalsIndex + 1); + String value = URLDecoder.decode(encodedValue, StandardCharsets.UTF_8); + queryParams.put(key, value); + } else { + queryParams.put(param, ""); + } + } + } + return queryParams; + } + + private boolean putOidcResponse(HttpExchange exchange, OidcResponse oidcResponse) + throws IOException { + try { + oidcResponseQueue.put(oidcResponse); + return true; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + reply(exchange, 500); + return false; + } + } + + + @Override + public void handle(HttpExchange exchange) throws IOException { + Map queryParams = parseQueryParams(exchange); + OidcResponse oidcResponse = new OidcResponse(); + + if (queryParams.containsKey("code")) { + oidcResponse.setCode(queryParams.get("code")); + oidcResponse.setState(queryParams.getOrDefault("state", "")); + if (!putOidcResponse(exchange, oidcResponse)) { + return; + } + + exchange.getResponseHeaders().set("Location", ACCEPTED_ENDPOINT); + reply(exchange, HttpURLConnection.HTTP_MOVED_TEMP); + } else if (queryParams.containsKey("error")) { + oidcResponse.setError(queryParams.get("error")); + oidcResponse.setErrorDescription( + queryParams.getOrDefault("error_description", "Unknown error")); + if (!putOidcResponse(exchange, oidcResponse)) { + return; + } + reply(exchange, HttpURLConnection.HTTP_BAD_REQUEST); + + } else { + oidcResponse.setError("Not found"); + String allParams = + queryParams + .entrySet() + .stream() + .map(entry -> entry.getKey() + "=" + entry.getValue()) + .reduce((param1, param2) -> param1 + ", " + param2) + .orElse("No parameters"); + oidcResponse.setErrorDescription("Not found. Parameters: " + allParams); + if (!putOidcResponse(exchange, oidcResponse)) { + return; + } + reply(exchange, HttpURLConnection.HTTP_NOT_FOUND); + } + } + } + + private class AcceptedHandler implements HttpHandler { + @Override + public void handle(HttpExchange exchange) throws IOException { + reply(exchange, HttpURLConnection.HTTP_OK); + } + } + + private void reply(HttpExchange exchange, int statusCode) + throws IOException { + exchange.getResponseHeaders().set("Content-Type", "text/html; charset=utf-8"); + try (exchange) { + exchange.sendResponseHeaders(statusCode, -1); + } catch (Exception e) { + logger.log(Level.SEVERE, "Error sending response", e); + throw e; + } + } +} From 00b6c5031a50c4e1a3d8288fea5405642fee2f93 Mon Sep 17 00:00:00 2001 From: "Xolo, Tlatoani" Date: Sat, 17 Jan 2026 15:33:29 -0500 Subject: [PATCH 3/3] fixing build --- build.gradle | 6 ------ 1 file changed, 6 deletions(-) diff --git a/build.gradle b/build.gradle index 0a17ca3..cb9ad40 100644 --- a/build.gradle +++ b/build.gradle @@ -33,12 +33,6 @@ test { shadowJar { archiveFileName = "mongo-jdbc-standalone-${version}.jar" mergeServiceFiles() - - relocate('org', 'shadow.org') { - exclude 'org.ow2.asm:.*' - exclude 'net.minidev:.*' - exclude 'org.javassist:.*' - } manifest { attributes('Multi-Release' : 'true') }